diff --git a/run_agent.py b/run_agent.py index bf00f86c7c..d2a60c7f05 100644 --- a/run_agent.py +++ b/run_agent.py @@ -8938,6 +8938,56 @@ class AIAgent: and "skill_manage" in self.valid_tool_names): self._iters_since_skill += 1 + # ── Pre-API-call /steer drain ────────────────────────────────── + # If a /steer arrived during the previous API call (while the model + # was thinking), drain it now — before we build api_messages — so + # the model sees the steer text on THIS iteration. Without this, + # steers sent during an API call only land after the NEXT tool batch, + # which may never come if the model returns a final response. + # + # We scan backwards for the last tool-role message in the messages + # list. If found, the steer is appended there. If not (first + # iteration, no tools yet), the steer stays pending for the next + # tool batch — injecting into a user message would break role + # alternation, and there's no tool output to piggyback on. + _pre_api_steer = self._drain_pending_steer() + if _pre_api_steer: + _injected = False + for _si in range(len(messages) - 1, -1, -1): + _sm = messages[_si] + if isinstance(_sm, dict) and _sm.get("role") == "tool": + marker = f"\n\n[USER STEER (injected mid-run, not tool output): {_pre_api_steer}]" + existing = _sm.get("content", "") + if isinstance(existing, str): + _sm["content"] = existing + marker + else: + # Multimodal content blocks — append text block + try: + blocks = list(existing) if existing else [] + blocks.append({"type": "text", "text": marker}) + _sm["content"] = blocks + except Exception: + pass + _injected = True + logger.debug( + "Pre-API-call steer drain: injected into tool msg at index %d", + _si, + ) + break + if not _injected: + # No tool message to inject into — put it back so + # the post-tool-execution drain picks it up later. + _lock = getattr(self, "_pending_steer_lock", None) + if _lock is not None: + with _lock: + if self._pending_steer: + self._pending_steer = self._pending_steer + "\n" + _pre_api_steer + else: + self._pending_steer = _pre_api_steer + else: + existing = getattr(self, "_pending_steer", None) + self._pending_steer = (existing + "\n" + _pre_api_steer) if existing else _pre_api_steer + # Prepare messages for API call # If we have an ephemeral system prompt, prepend it to the messages # Note: Reasoning is embedded in content via tags for trajectory storage. diff --git a/tests/run_agent/test_steer.py b/tests/run_agent/test_steer.py index a298ede8c0..9a9e4b51cc 100644 --- a/tests/run_agent/test_steer.py +++ b/tests/run_agent/test_steer.py @@ -199,6 +199,82 @@ class TestSteerClearedOnInterrupt: assert agent._pending_steer is None +class TestPreApiCallSteerDrain: + """Test that steers arriving during an API call are drained before the + next API call — not deferred until the next tool batch. This is the + fix for the scenario where /steer sent during model thinking only lands + after the agent is completely done.""" + + def test_pre_api_drain_injects_into_last_tool_result(self): + """If a steer is pending when the main loop starts building + api_messages, it should be injected into the last tool result + in the messages list.""" + agent = _bare_agent() + # Simulate messages after a tool batch completed + messages = [ + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": "ok", "tool_calls": [ + {"id": "tc1", "function": {"name": "terminal", "arguments": "{}"}} + ]}, + {"role": "tool", "content": "output here", "tool_call_id": "tc1"}, + ] + # Steer arrives during API call (set after tool execution) + agent.steer("focus on error handling") + # Simulate what the pre-API-call drain does: + _pre_api_steer = agent._drain_pending_steer() + assert _pre_api_steer == "focus on error handling" + # Inject into last tool msg (mirrors the new code in run_conversation) + for _si in range(len(messages) - 1, -1, -1): + if messages[_si].get("role") == "tool": + messages[_si]["content"] += f"\n\n[USER STEER (injected mid-run, not tool output): {_pre_api_steer}]" + break + assert "[USER STEER" in messages[-1]["content"] + assert "focus on error handling" in messages[-1]["content"] + assert agent._pending_steer is None + + def test_pre_api_drain_restashes_when_no_tool_message(self): + """If there are no tool results yet (first iteration), the steer + should be put back into _pending_steer for the post-tool drain.""" + agent = _bare_agent() + messages = [ + {"role": "user", "content": "hello"}, + ] + agent.steer("early steer") + _pre_api_steer = agent._drain_pending_steer() + assert _pre_api_steer == "early steer" + # No tool message found — put it back + found = False + for _si in range(len(messages) - 1, -1, -1): + if messages[_si].get("role") == "tool": + found = True + break + assert not found + # Restash + agent._pending_steer = _pre_api_steer + assert agent._pending_steer == "early steer" + + def test_pre_api_drain_finds_tool_msg_past_assistant(self): + """The pre-API drain should scan backwards past a non-tool message + (e.g., if an assistant message was somehow appended after tools) + and still find the tool result.""" + agent = _bare_agent() + messages = [ + {"role": "user", "content": "do something"}, + {"role": "assistant", "content": "let me check", "tool_calls": [ + {"id": "tc1", "function": {"name": "web_search", "arguments": "{}"}} + ]}, + {"role": "tool", "content": "search results", "tool_call_id": "tc1"}, + ] + agent.steer("change approach") + _pre_api_steer = agent._drain_pending_steer() + assert _pre_api_steer is not None + for _si in range(len(messages) - 1, -1, -1): + if messages[_si].get("role") == "tool": + messages[_si]["content"] += f"\n\n[USER STEER (injected mid-run, not tool output): {_pre_api_steer}]" + break + assert "change approach" in messages[2]["content"] + + class TestSteerCommandRegistry: def test_steer_in_command_registry(self): """The /steer slash command must be registered so it reaches all