diff --git a/cli.py b/cli.py index ea76991ac..8aa8bb03f 100644 --- a/cli.py +++ b/cli.py @@ -5720,6 +5720,30 @@ class HermesCLI: _cprint(f" Queued for the next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}") else: _cprint(f" Queued: {payload[:80]}{'...' if len(payload) > 80 else ''}") + elif canonical == "steer": + # Inject a message after the next tool call without interrupting. + # If the agent is actively running, push the text into the agent's + # pending_steer slot — the drain hook in _execute_tool_calls_* + # will append it to the next tool result's content. If no agent + # is running, fall back to queue semantics (same as /queue). + parts = cmd_original.split(None, 1) + payload = parts[1].strip() if len(parts) > 1 else "" + if not payload: + _cprint(" Usage: /steer ") + elif self._agent_running and self.agent is not None and hasattr(self.agent, "steer"): + try: + accepted = self.agent.steer(payload) + except Exception as exc: + _cprint(f" Steer failed: {exc}") + else: + if accepted: + _cprint(f" ⏩ Steer queued — arrives after the next tool call: {payload[:80]}{'...' if len(payload) > 80 else ''}") + else: + _cprint(" Steer rejected (empty payload).") + else: + # No active run — treat as a normal next-turn message. + self._pending_input.put(payload) + _cprint(f" No agent running; queued as next turn: {payload[:80]}{'...' if len(payload) > 80 else ''}") elif canonical == "skin": self._handle_skin_command(cmd_original) elif canonical == "voice": @@ -8244,7 +8268,15 @@ class HermesCLI: else: print(f"\n⚡ Sending after interrupt: '{preview}'") self._pending_input.put(combined) - + + # If a /steer was left over (agent finished before another tool + # batch could absorb it), deliver it as the next user turn. + _leftover_steer = result.get("pending_steer") if result else None + if _leftover_steer and hasattr(self, '_pending_input'): + preview = _leftover_steer[:60] + ("..." if len(_leftover_steer) > 60 else "") + print(f"\n⏩ Delivering leftover /steer as next turn: '{preview}'") + self._pending_input.put(_leftover_steer) + return response except Exception as e: diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 5cad956a3..31973b962 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -1994,6 +1994,11 @@ class DiscordAdapter(BasePlatformAdapter): async def slash_stop(interaction: discord.Interaction): await self._run_simple_slash(interaction, "/stop", "Stop requested~") + @tree.command(name="steer", description="Inject a message after the next tool call (no interrupt)") + @discord.app_commands.describe(prompt="Text to inject into the agent's next tool result") + async def slash_steer(interaction: discord.Interaction, prompt: str): + await self._run_simple_slash(interaction, f"/steer {prompt}".strip()) + @tree.command(name="compress", description="Compress conversation context") async def slash_compress(interaction: discord.Interaction): await self._run_simple_slash(interaction, "/compress") diff --git a/gateway/run.py b/gateway/run.py index 62b813f0d..1525ad147 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3019,6 +3019,54 @@ class GatewayRunner: adapter._pending_messages[_quick_key] = queued_event return "Queued for the next turn." + # /steer — inject mid-run after the next tool call. + # Unlike /queue (turn boundary), /steer lands BETWEEN tool-call + # iterations inside the same agent run, by appending to the + # last tool result's content. No interrupt, no new user turn, + # no role-alternation violation. + if _cmd_def_inner and _cmd_def_inner.name == "steer": + steer_text = event.get_command_args().strip() + if not steer_text: + return "Usage: /steer " + running_agent = self._running_agents.get(_quick_key) + if running_agent is _AGENT_PENDING_SENTINEL: + # Agent hasn't started yet — queue as turn-boundary fallback. + adapter = self.adapters.get(source.platform) + if adapter: + from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT + queued_event = _ME( + text=steer_text, + message_type=_MT.TEXT, + source=event.source, + message_id=event.message_id, + channel_prompt=event.channel_prompt, + ) + adapter._pending_messages[_quick_key] = queued_event + return "Agent still starting — /steer queued for the next turn." + if running_agent and hasattr(running_agent, "steer"): + try: + accepted = running_agent.steer(steer_text) + except Exception as exc: + logger.warning("Steer failed for session %s: %s", _quick_key[:20], exc) + return f"⚠️ Steer failed: {exc}" + if accepted: + preview = steer_text[:60] + ("..." if len(steer_text) > 60 else "") + return f"⏩ Steer queued — arrives after the next tool call: '{preview}'" + return "Steer rejected (empty payload)." + # Running agent is missing or lacks steer() — fall back to queue. + adapter = self.adapters.get(source.platform) + if adapter: + from gateway.platforms.base import MessageEvent as _ME, MessageType as _MT + queued_event = _ME( + text=steer_text, + message_type=_MT.TEXT, + source=event.source, + message_id=event.message_id, + channel_prompt=event.channel_prompt, + ) + adapter._pending_messages[_quick_key] = queued_event + return "No active agent — /steer queued for the next turn." + # /model must not be used while the agent is running. if _cmd_def_inner and _cmd_def_inner.name == "model": return "Agent is running — wait or /stop first, then switch models." @@ -3260,6 +3308,21 @@ class GatewayRunner: if canonical == "btw": return await self._handle_btw_command(event) + if canonical == "steer": + # No active agent — /steer has no tool call to inject into. + # Strip the prefix so downstream treats it as a normal user + # message. If the payload is empty, surface the usage hint. + steer_payload = event.get_command_args().strip() + if not steer_payload: + return "Usage: /steer (no agent is running; sending as a normal message)" + try: + event.text = steer_payload + except Exception: + pass + # Do NOT return — fall through to _handle_message_with_agent + # at the end of this function so the rewritten text is sent + # to the agent as a regular user turn. + if canonical == "voice": return await self._handle_voice_command(event) diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index ce257b0d7..681e6f9b2 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -91,6 +91,8 @@ COMMAND_REGISTRY: list[CommandDef] = [ aliases=("tasks",)), CommandDef("queue", "Queue a prompt for the next turn (doesn't interrupt)", "Session", aliases=("q",), args_hint=""), + CommandDef("steer", "Inject a message after the next tool call without interrupting", "Session", + args_hint=""), CommandDef("status", "Show session info", "Session"), CommandDef("profile", "Show active profile name and home directory", "Info"), CommandDef("sethome", "Set this chat as the home channel", "Session", @@ -275,6 +277,7 @@ ACTIVE_SESSION_BYPASS_COMMANDS: frozenset[str] = frozenset( "queue", "restart", "status", + "steer", "stop", "update", } diff --git a/run_agent.py b/run_agent.py index d5ff125e3..a47455e53 100644 --- a/run_agent.py +++ b/run_agent.py @@ -832,6 +832,16 @@ class AIAgent: self._interrupt_thread_signal_pending = False self._client_lock = threading.RLock() + # /steer mechanism — inject a user note into the next tool result + # without interrupting the agent. Unlike interrupt(), steer() does + # NOT set _interrupt_requested; it waits for the current tool batch + # to finish naturally, then the drain hook appends the text to the + # last tool result's content so the model sees it on its next + # iteration. Message-role alternation is preserved (we modify an + # existing tool message rather than inserting a new user turn). + self._pending_steer: Optional[str] = None + self._pending_steer_lock = threading.Lock() + # Concurrent-tool worker thread tracking. `_execute_tool_calls_concurrent` # runs each tool on its own ThreadPoolExecutor worker — those worker # threads have tids distinct from `_execution_thread_id`, so @@ -3265,6 +3275,129 @@ class AIAgent: _set_interrupt(False, _wtid) except Exception: pass + # A hard interrupt supersedes any pending /steer — the steer was + # meant for the agent's next tool-call iteration, which will no + # longer happen. Drop it instead of surprising the user with a + # late injection on the post-interrupt turn. + _steer_lock = getattr(self, "_pending_steer_lock", None) + if _steer_lock is not None: + with _steer_lock: + self._pending_steer = None + + def steer(self, text: str) -> bool: + """ + Inject a user message into the next tool result without interrupting. + + Unlike interrupt(), this does NOT stop the current tool call. The + text is stashed and the agent loop appends it to the LAST tool + result's content once the current tool batch finishes. The model + sees the steer as part of the tool output on its next iteration. + + Thread-safe: callable from gateway/CLI/TUI threads. Multiple calls + before the drain point concatenate with newlines. + + Args: + text: The user text to inject. Empty strings are ignored. + + Returns: + True if the steer was accepted, False if the text was empty. + """ + if not text or not text.strip(): + return False + cleaned = text.strip() + _lock = getattr(self, "_pending_steer_lock", None) + if _lock is None: + # Test stubs that built AIAgent via object.__new__ skip __init__. + # Fall back to direct attribute set; no concurrent callers expected + # in those stubs. + existing = getattr(self, "_pending_steer", None) + self._pending_steer = (existing + "\n" + cleaned) if existing else cleaned + return True + with _lock: + if self._pending_steer: + self._pending_steer = self._pending_steer + "\n" + cleaned + else: + self._pending_steer = cleaned + return True + + def _drain_pending_steer(self) -> Optional[str]: + """Return the pending steer text (if any) and clear the slot. + + Safe to call from the agent execution thread after appending tool + results. Returns None when no steer is pending. + """ + _lock = getattr(self, "_pending_steer_lock", None) + if _lock is None: + text = getattr(self, "_pending_steer", None) + self._pending_steer = None + return text + with _lock: + text = self._pending_steer + self._pending_steer = None + return text + + def _apply_pending_steer_to_tool_results(self, messages: list, num_tool_msgs: int) -> None: + """Append any pending /steer text to the last tool result in this turn. + + Called at the end of a tool-call batch, before the next API call. + The steer is appended to the last ``role:"tool"`` message's content + with a clear marker so the model understands it came from the user + and NOT from the tool itself. Role alternation is preserved — + nothing new is inserted, we only modify existing content. + + Args: + messages: The running messages list. + num_tool_msgs: Number of tool results appended in this batch; + used to locate the tail slice safely. + """ + if num_tool_msgs <= 0 or not messages: + return + steer_text = self._drain_pending_steer() + if not steer_text: + return + # Find the last tool-role message in the recent tail. Skipping + # non-tool messages defends against future code appending + # something else at the boundary. + target_idx = None + for j in range(len(messages) - 1, max(len(messages) - num_tool_msgs - 1, -1), -1): + msg = messages[j] + if isinstance(msg, dict) and msg.get("role") == "tool": + target_idx = j + break + if target_idx is None: + # No tool result in this batch (e.g. all skipped by interrupt); + # put the steer back so the caller's fallback path can deliver + # it as a normal next-turn user message. + _lock = getattr(self, "_pending_steer_lock", None) + if _lock is not None: + with _lock: + if self._pending_steer: + self._pending_steer = self._pending_steer + "\n" + steer_text + else: + self._pending_steer = steer_text + else: + existing = getattr(self, "_pending_steer", None) + self._pending_steer = (existing + "\n" + steer_text) if existing else steer_text + return + marker = f"\n\n[USER STEER (injected mid-run, not tool output): {steer_text}]" + existing_content = messages[target_idx].get("content", "") + if not isinstance(existing_content, str): + # Anthropic multimodal content blocks — preserve them and append + # a text block at the end. + try: + blocks = list(existing_content) if existing_content else [] + blocks.append({"type": "text", "text": marker.lstrip()}) + messages[target_idx]["content"] = blocks + except Exception: + # Fall back to string replacement if content shape is unexpected. + messages[target_idx]["content"] = f"{existing_content}{marker}" + else: + messages[target_idx]["content"] = existing_content + marker + logger.info( + "Delivered /steer to agent after tool batch (%d chars): %s", + len(steer_text), + steer_text[:120] + ("..." if len(steer_text) > 120 else ""), + ) def _touch_activity(self, desc: str) -> None: """Update the last-activity timestamp and description (thread-safe).""" @@ -7951,6 +8084,13 @@ class AIAgent: turn_tool_msgs = messages[-num_tools:] enforce_turn_budget(turn_tool_msgs, env=get_active_env(effective_task_id)) + # ── /steer injection ────────────────────────────────────────────── + # Append any pending user steer text to the last tool result so the + # agent sees it on its next iteration. Runs AFTER budget enforcement + # so the steer marker is never truncated. See steer() for details. + if num_tools > 0: + self._apply_pending_steer_to_tool_results(messages, num_tools) + def _execute_tool_calls_sequential(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: """Execute tool calls sequentially (original behavior). Used for single calls or interactive tools.""" for i, tool_call in enumerate(assistant_message.tool_calls, 1): @@ -8330,6 +8470,12 @@ class AIAgent: if num_tools_seq > 0: enforce_turn_budget(messages[-num_tools_seq:], env=get_active_env(effective_task_id)) + # ── /steer injection ────────────────────────────────────────────── + # See _execute_tool_calls_parallel for the rationale. Same hook, + # applied to sequential execution as well. + if num_tools_seq > 0: + self._apply_pending_steer_to_tool_results(messages, num_tools_seq) + def _handle_max_iterations(self, messages: list, api_call_count: int) -> str: @@ -11610,6 +11756,12 @@ class AIAgent: "cost_status": self.session_cost_status, "cost_source": self.session_cost_source, } + # If a /steer landed after the final assistant turn (no more tool + # batches to drain into), hand it back to the caller so it can be + # delivered as the next user turn instead of being silently lost. + _leftover_steer = self._drain_pending_steer() + if _leftover_steer: + result["pending_steer"] = _leftover_steer self._response_was_previewed = False # Include interrupt message if one triggered the interrupt diff --git a/tests/gateway/test_command_bypass_active_session.py b/tests/gateway/test_command_bypass_active_session.py index 10ff06212..c45624394 100644 --- a/tests/gateway/test_command_bypass_active_session.py +++ b/tests/gateway/test_command_bypass_active_session.py @@ -200,6 +200,25 @@ class TestCommandBypassActiveSession: "/background response was not sent back to the user" ) + @pytest.mark.asyncio + async def test_steer_bypasses_guard(self): + """/steer must bypass the Level-1 active-session guard so it reaches + the gateway runner's /steer handler and injects into the running + agent instead of being queued as user text for the next turn. + """ + adapter = _make_adapter() + sk = _session_key() + adapter._active_sessions[sk] = asyncio.Event() + + await adapter.handle_message(_make_event("/steer also check auth.log")) + + assert sk not in adapter._pending_messages, ( + "/steer was queued as a pending message instead of being dispatched" + ) + assert any("handled:steer" in r for r in adapter.sent_responses), ( + "/steer response was not sent back to the user" + ) + @pytest.mark.asyncio async def test_help_bypasses_guard(self): """/help must bypass so it is not silently dropped as pending slash text.""" diff --git a/tests/gateway/test_steer_command.py b/tests/gateway/test_steer_command.py new file mode 100644 index 000000000..b756ff096 --- /dev/null +++ b/tests/gateway/test_steer_command.py @@ -0,0 +1,191 @@ +"""Tests for the gateway /steer command handler. + +/steer injects a user message into the agent's next tool result without +interrupting. The gateway runner must: + + 1. When an agent IS running → call ``agent.steer(text)``, do NOT set + ``_interrupt_requested``, do NOT touch ``_pending_messages``. + 2. When the agent is the PENDING sentinel → fall back to /queue + semantics (store in ``adapter._pending_messages``). + 3. When no agent is active → strip the slash prefix and let the normal + prompt pipeline handle it as a regular user message. +""" +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent( + text=text, + source=_make_source(), + message_id="m1", + ) + + +def _make_runner(session_entry: SessionEntry): + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + adapter._pending_messages = {} + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.load_transcript.return_value = [] + runner.session_store.has_any_sessions.return_value = True + runner._running_agents = {} + runner._running_agents_ts = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = MagicMock() + runner._session_db.get_session_title.return_value = None + runner._reasoning_config = None + runner._provider_routing = {} + runner._fallback_model = None + runner._show_reasoning = False + runner._is_user_authorized = lambda _source: True + runner._set_session_env = lambda _context: None + runner._should_send_voice_reply = lambda *_args, **_kwargs: False + runner._send_voice_reply = AsyncMock() + runner._capture_gateway_honcho_if_configured = lambda *args, **kwargs: None + runner._emit_gateway_run_progress = AsyncMock() + return runner, adapter + + +def _session_entry() -> SessionEntry: + return SessionEntry( + session_key=build_session_key(_make_source()), + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + total_tokens=0, + ) + + +@pytest.mark.asyncio +async def test_steer_calls_agent_steer_and_does_not_interrupt(): + """When an agent is running, /steer must call agent.steer(text) and + leave interrupt state untouched.""" + runner, adapter = _make_runner(_session_entry()) + sk = build_session_key(_make_source()) + + running_agent = MagicMock() + running_agent.steer.return_value = True + runner._running_agents[sk] = running_agent + + result = await runner._handle_message(_make_event("/steer also check auth.log")) + + # The handler replied with a confirmation + assert result is not None + assert "steer" in result.lower() or "queued" in result.lower() + # The agent's steer() was called with the payload (prefix stripped) + running_agent.steer.assert_called_once_with("also check auth.log") + # Critically: interrupt was NOT called + running_agent.interrupt.assert_not_called() + # And no user-text queueing happened — the steer doesn't go into + # _pending_messages (that would be turn-boundary /queue semantics). + assert runner._pending_messages == {} + assert adapter._pending_messages == {} + + +@pytest.mark.asyncio +async def test_steer_without_payload_returns_usage(): + runner, _adapter = _make_runner(_session_entry()) + sk = build_session_key(_make_source()) + running_agent = MagicMock() + runner._running_agents[sk] = running_agent + + result = await runner._handle_message(_make_event("/steer")) + + assert result is not None + assert "Usage" in result or "usage" in result + running_agent.steer.assert_not_called() + running_agent.interrupt.assert_not_called() + + +@pytest.mark.asyncio +async def test_steer_with_pending_sentinel_falls_back_to_queue(): + """When the agent hasn't finished booting (sentinel), /steer should + queue as a turn-boundary follow-up instead of crashing.""" + from gateway.run import _AGENT_PENDING_SENTINEL + + runner, adapter = _make_runner(_session_entry()) + sk = build_session_key(_make_source()) + runner._running_agents[sk] = _AGENT_PENDING_SENTINEL + + result = await runner._handle_message(_make_event("/steer wait up")) + + assert result is not None + assert "queued" in result.lower() or "starting" in result.lower() + # The fallback put the text into the adapter's pending queue. + assert sk in adapter._pending_messages + assert adapter._pending_messages[sk].text == "wait up" + + +@pytest.mark.asyncio +async def test_steer_agent_without_steer_method_falls_back(): + """If the running agent somehow lacks the steer() method (older build, + test stub), the handler must not explode — fall back to /queue.""" + runner, adapter = _make_runner(_session_entry()) + sk = build_session_key(_make_source()) + + # A bare object that does NOT have steer() — use a spec'd Mock so + # hasattr(agent, "steer") returns False. + running_agent = MagicMock(spec=[]) + runner._running_agents[sk] = running_agent + + result = await runner._handle_message(_make_event("/steer fallback")) + + assert result is not None + # Must mention queueing since steer wasn't available + assert "queued" in result.lower() + assert sk in adapter._pending_messages + assert adapter._pending_messages[sk].text == "fallback" + + +@pytest.mark.asyncio +async def test_steer_rejected_payload_returns_rejection_message(): + """If agent.steer() returns False (e.g. empty after strip — though + the gateway already guards this), surface a rejection message.""" + runner, _adapter = _make_runner(_session_entry()) + sk = build_session_key(_make_source()) + + running_agent = MagicMock() + running_agent.steer.return_value = False + runner._running_agents[sk] = running_agent + + result = await runner._handle_message(_make_event("/steer hello")) + + assert result is not None + assert "rejected" in result.lower() or "empty" in result.lower() + + +if __name__ == "__main__": # pragma: no cover + pytest.main([__file__, "-v"]) diff --git a/tests/run_agent/test_steer.py b/tests/run_agent/test_steer.py new file mode 100644 index 000000000..a298ede8c --- /dev/null +++ b/tests/run_agent/test_steer.py @@ -0,0 +1,228 @@ +"""Tests for AIAgent.steer() — mid-run user message injection. + +/steer lets the user add a note to the agent's next tool result without +interrupting the current tool call. The agent sees the note inline with +tool output on its next iteration, preserving message-role alternation +and prompt-cache integrity. +""" +from __future__ import annotations + +import threading + +import pytest + +from run_agent import AIAgent + + +def _bare_agent() -> AIAgent: + """Build an AIAgent without running __init__, then install the steer + state manually — matches the existing object.__new__ stub pattern + used elsewhere in the test suite. + """ + agent = object.__new__(AIAgent) + agent._pending_steer = None + agent._pending_steer_lock = threading.Lock() + return agent + + +class TestSteerAcceptance: + def test_accepts_non_empty_text(self): + agent = _bare_agent() + assert agent.steer("go ahead and check the logs") is True + assert agent._pending_steer == "go ahead and check the logs" + + def test_rejects_empty_string(self): + agent = _bare_agent() + assert agent.steer("") is False + assert agent._pending_steer is None + + def test_rejects_whitespace_only(self): + agent = _bare_agent() + assert agent.steer(" \n\t ") is False + assert agent._pending_steer is None + + def test_rejects_none(self): + agent = _bare_agent() + assert agent.steer(None) is False # type: ignore[arg-type] + assert agent._pending_steer is None + + def test_strips_surrounding_whitespace(self): + agent = _bare_agent() + assert agent.steer(" hello world \n") is True + assert agent._pending_steer == "hello world" + + def test_concatenates_multiple_steers_with_newlines(self): + agent = _bare_agent() + agent.steer("first note") + agent.steer("second note") + agent.steer("third note") + assert agent._pending_steer == "first note\nsecond note\nthird note" + + +class TestSteerDrain: + def test_drain_returns_and_clears(self): + agent = _bare_agent() + agent.steer("hello") + assert agent._drain_pending_steer() == "hello" + assert agent._pending_steer is None + + def test_drain_on_empty_returns_none(self): + agent = _bare_agent() + assert agent._drain_pending_steer() is None + + +class TestSteerInjection: + def test_appends_to_last_tool_result(self): + agent = _bare_agent() + agent.steer("please also check auth.log") + messages = [ + {"role": "user", "content": "what's in /var/log?"}, + {"role": "assistant", "tool_calls": [{"id": "a"}, {"id": "b"}]}, + {"role": "tool", "content": "ls output A", "tool_call_id": "a"}, + {"role": "tool", "content": "ls output B", "tool_call_id": "b"}, + ] + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2) + # The LAST tool result is modified; earlier ones are untouched. + assert messages[2]["content"] == "ls output A" + assert "ls output B" in messages[3]["content"] + assert "[USER STEER" in messages[3]["content"] + assert "please also check auth.log" in messages[3]["content"] + # And pending_steer is consumed. + assert agent._pending_steer is None + + def test_no_op_when_no_steer_pending(self): + agent = _bare_agent() + messages = [ + {"role": "assistant", "tool_calls": [{"id": "a"}]}, + {"role": "tool", "content": "output", "tool_call_id": "a"}, + ] + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) + assert messages[-1]["content"] == "output" # unchanged + + def test_no_op_when_num_tool_msgs_zero(self): + agent = _bare_agent() + agent.steer("steer") + messages = [{"role": "user", "content": "hi"}] + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=0) + # Steer should remain pending (nothing to drain into) + assert agent._pending_steer == "steer" + + def test_marker_is_unambiguous_about_origin(self): + """The injection marker must make clear the text is from the user + and not tool output — this is the cache-safe way to signal + provenance without violating message-role alternation. + """ + agent = _bare_agent() + agent.steer("stop after next step") + messages = [{"role": "tool", "content": "x", "tool_call_id": "1"}] + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) + content = messages[-1]["content"] + assert "USER STEER" in content + assert "not tool output" in content.lower() or "injected mid-run" in content.lower() + + def test_multimodal_content_list_preserved(self): + """Anthropic-style list content should be preserved, with the steer + appended as a text block.""" + agent = _bare_agent() + agent.steer("extra note") + original_blocks = [{"type": "text", "text": "existing output"}] + messages = [ + {"role": "tool", "content": list(original_blocks), "tool_call_id": "1"} + ] + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=1) + new_content = messages[-1]["content"] + assert isinstance(new_content, list) + assert len(new_content) == 2 + assert new_content[0] == {"type": "text", "text": "existing output"} + assert new_content[1]["type"] == "text" + assert "extra note" in new_content[1]["text"] + + def test_restashed_when_no_tool_result_in_batch(self): + """If the 'batch' contains no tool-role messages (e.g. all skipped + after an interrupt), the steer should be put back into the pending + slot so the caller's fallback path can deliver it.""" + agent = _bare_agent() + agent.steer("ping") + messages = [ + {"role": "user", "content": "x"}, + {"role": "assistant", "content": "y"}, + ] + # Claim there were N tool msgs, but the tail has none — simulates + # the interrupt-cancelled case. + agent._apply_pending_steer_to_tool_results(messages, num_tool_msgs=2) + # Messages untouched + assert messages[-1]["content"] == "y" + # And the steer is back in pending so the fallback can grab it + assert agent._pending_steer == "ping" + + +class TestSteerThreadSafety: + def test_concurrent_steer_calls_preserve_all_text(self): + agent = _bare_agent() + N = 200 + + def worker(idx: int) -> None: + agent.steer(f"note-{idx}") + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(N)] + for t in threads: + t.start() + for t in threads: + t.join() + + text = agent._drain_pending_steer() + assert text is not None + # Every single note must be preserved — none dropped by the lock. + lines = text.split("\n") + assert len(lines) == N + assert set(lines) == {f"note-{i}" for i in range(N)} + + +class TestSteerClearedOnInterrupt: + def test_clear_interrupt_drops_pending_steer(self): + """A hard interrupt supersedes any pending steer — the agent's + next tool iteration won't happen, so delivering the steer later + would be surprising.""" + agent = _bare_agent() + # Minimal surface needed by clear_interrupt() + agent._interrupt_requested = True + agent._interrupt_message = None + agent._interrupt_thread_signal_pending = False + agent._execution_thread_id = None + agent._tool_worker_threads = None + agent._tool_worker_threads_lock = None + + agent.steer("will be dropped") + assert agent._pending_steer == "will be dropped" + + agent.clear_interrupt() + assert agent._pending_steer is None + + +class TestSteerCommandRegistry: + def test_steer_in_command_registry(self): + """The /steer slash command must be registered so it reaches all + platforms (CLI, gateway, TUI autocomplete, Telegram/Slack menus). + """ + from hermes_cli.commands import resolve_command, ACTIVE_SESSION_BYPASS_COMMANDS + + cmd = resolve_command("steer") + assert cmd is not None + assert cmd.name == "steer" + assert cmd.category == "Session" + assert cmd.args_hint == "" + + def test_steer_in_bypass_set(self): + """When the agent is running, /steer MUST bypass the Level-1 + base-adapter queue so it reaches the gateway runner's /steer + handler. Otherwise it would be queued as user text and only + delivered at turn end — defeating the whole point. + """ + from hermes_cli.commands import ACTIVE_SESSION_BYPASS_COMMANDS, should_bypass_active_session + + assert "steer" in ACTIVE_SESSION_BYPASS_COMMANDS + assert should_bypass_active_session("steer") is True + + +if __name__ == "__main__": # pragma: no cover + pytest.main([__file__, "-v"]) diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index e7681b784..ea231e626 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -438,3 +438,74 @@ def test_rollback_restore_resolves_number_and_file_path(): assert resp["result"]["success"] is True assert calls["args"][1] == "bbb222" assert calls["args"][2] == "src/app.tsx" + + +# ── session.steer ──────────────────────────────────────────────────── + + +def test_session_steer_calls_agent_steer_when_agent_supports_it(): + """The TUI RPC method must call agent.steer(text) and return a + queued status without touching interrupt state. + """ + calls = {} + + class _Agent: + def steer(self, text): + calls["steer_text"] = text + return True + + def interrupt(self, *args, **kwargs): + calls["interrupt_called"] = True + + server._sessions["sid"] = _session(agent=_Agent()) + try: + resp = server.handle_request( + { + "id": "1", + "method": "session.steer", + "params": {"session_id": "sid", "text": "also check auth.log"}, + } + ) + finally: + server._sessions.pop("sid", None) + + assert "result" in resp, resp + assert resp["result"]["status"] == "queued" + assert resp["result"]["text"] == "also check auth.log" + assert calls["steer_text"] == "also check auth.log" + assert "interrupt_called" not in calls # must NOT interrupt + + +def test_session_steer_rejects_empty_text(): + server._sessions["sid"] = _session(agent=types.SimpleNamespace(steer=lambda t: True)) + try: + resp = server.handle_request( + { + "id": "1", + "method": "session.steer", + "params": {"session_id": "sid", "text": " "}, + } + ) + finally: + server._sessions.pop("sid", None) + + assert "error" in resp, resp + assert resp["error"]["code"] == 4002 + + +def test_session_steer_errors_when_agent_has_no_steer_method(): + server._sessions["sid"] = _session(agent=types.SimpleNamespace()) # no steer() + try: + resp = server.handle_request( + { + "id": "1", + "method": "session.steer", + "params": {"session_id": "sid", "text": "hi"}, + } + ) + finally: + server._sessions.pop("sid", None) + + assert "error" in resp, resp + assert resp["error"]["code"] == 4010 + diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 3ef76a0f0..a7dae9e5c 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1340,6 +1340,31 @@ def _(rid, params: dict) -> dict: return _ok(rid, {"status": "interrupted"}) +@method("session.steer") +def _(rid, params: dict) -> dict: + """Inject a user message into the next tool result without interrupting. + + Mirrors AIAgent.steer(). Safe to call while a turn is running — the text + lands on the last tool result of the next tool batch and the model sees + it on its next iteration. No interrupt, no new user turn, no role + alternation violation. + """ + text = (params.get("text") or "").strip() + if not text: + return _err(rid, 4002, "text is required") + session, err = _sess_nowait(params, rid) + if err: + return err + agent = session.get("agent") + if agent is None or not hasattr(agent, "steer"): + return _err(rid, 4010, "agent does not support steer") + try: + accepted = agent.steer(text) + except Exception as exc: + return _err(rid, 5000, f"steer failed: {exc}") + return _ok(rid, {"status": "queued" if accepted else "rejected", "text": text}) + + @method("terminal.resize") def _(rid, params: dict) -> dict: session, err = _sess_nowait(params, rid) diff --git a/ui-tui/src/app/slash/commands/core.ts b/ui-tui/src/app/slash/commands/core.ts index e0832c7a6..a151b2cdc 100644 --- a/ui-tui/src/app/slash/commands/core.ts +++ b/ui-tui/src/app/slash/commands/core.ts @@ -1,7 +1,7 @@ import { dailyFortune, randomFortune } from '../../../content/fortunes.js' import { HOTKEYS } from '../../../content/hotkeys.js' import { nextDetailsMode, parseDetailsMode } from '../../../domain/details.js' -import type { ConfigGetValueResponse, ConfigSetResponse, SessionUndoResponse } from '../../../gatewayTypes.js' +import type { ConfigGetValueResponse, ConfigSetResponse, SessionSteerResponse, SessionUndoResponse } from '../../../gatewayTypes.js' import { writeOsc52Clipboard } from '../../../lib/osc52.js' import type { DetailsMode, Msg, PanelSection } from '../../../types.js' import { patchOverlayState } from '../../overlayStore.js' @@ -245,6 +245,36 @@ export const coreCommands: SlashCommand[] = [ } }, + { + help: 'inject a message after the next tool call (no interrupt)', + name: 'steer', + run: (arg, ctx) => { + const payload = arg?.trim() ?? '' + + if (!payload) { + return ctx.transcript.sys('usage: /steer ') + } + + // If the agent isn't running, fall back to the queue so the user's + // message isn't lost — identical semantics to the gateway handler. + if (!ctx.ui.busy || !ctx.sid) { + ctx.composer.enqueue(payload) + ctx.transcript.sys(`no active turn — queued for next: "${payload.slice(0, 50)}${payload.length > 50 ? '…' : ''}"`) + return + } + + ctx.gateway.rpc('session.steer', { session_id: ctx.sid, text: payload }).then( + ctx.guarded(r => { + if (r?.status === 'queued') { + ctx.transcript.sys(`⏩ steer queued — arrives after next tool call: "${payload.slice(0, 50)}${payload.length > 50 ? '…' : ''}"`) + } else { + ctx.transcript.sys('steer rejected') + } + }) + ).catch(ctx.guardedErr) + } + }, + { help: 'undo last exchange', name: 'undo', diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index 9e21b9bc5..c8d1c6855 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -152,6 +152,11 @@ export interface SessionInterruptResponse { ok?: boolean } +export interface SessionSteerResponse { + status?: 'queued' | 'rejected' + text?: string +} + // ── Prompt / submission ────────────────────────────────────────────── export interface PromptSubmitResponse {