diff --git a/run_agent.py b/run_agent.py index 37572db5e..eb89e83cf 100644 --- a/run_agent.py +++ b/run_agent.py @@ -5391,13 +5391,22 @@ class AIAgent: # a new API call, creating a duplicate message. Return a # partial "stop" response instead so the outer loop treats this # turn as complete (no retry, no fallback). + # Recover whatever content was already streamed to the user. + # _current_streamed_assistant_text accumulates text fired + # through _fire_stream_delta, so it has exactly what the + # user saw before the connection died. + _partial_text = ( + getattr(self, "_current_streamed_assistant_text", "") or "" + ).strip() or None logger.warning( "Partial stream delivered before error; returning stub " - "response to prevent duplicate messages: %s", + "response with %s chars of recovered content to prevent " + "duplicate messages: %s", + len(_partial_text or ""), result["error"], ) _stub_msg = SimpleNamespace( - role="assistant", content=None, tool_calls=None, + role="assistant", content=_partial_text, tool_calls=None, reasoning_content=None, ) return SimpleNamespace( @@ -9889,6 +9898,30 @@ class AIAgent: # Check if response only has think block with no actual content after it if not self._has_content_after_think_block(final_response): + # ── Partial stream recovery ───────────────────── + # If content was already streamed to the user before + # the connection died, use it as the final response + # instead of falling through to prior-turn fallback + # or wasting API calls on retries. + _partial_streamed = ( + getattr(self, "_current_streamed_assistant_text", "") or "" + ) + if self._has_content_after_think_block(_partial_streamed): + _turn_exit_reason = "partial_stream_recovery" + _recovered = self._strip_think_blocks(_partial_streamed).strip() + logger.info( + "Partial stream content delivered (%d chars) " + "— using as final response", + len(_recovered), + ) + self._emit_status( + "↻ Stream interrupted — using delivered content " + "as final response" + ) + final_response = _recovered + self._response_was_previewed = True + break + # If the previous turn already delivered real content alongside # tool calls (e.g. "You're welcome!" + memory save), the model # has nothing more to say. Use the earlier content immediately diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 2112ddc3f..49a5a33d1 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1949,6 +1949,88 @@ class TestRunConversation: failure_msgs = [m for m in status_messages if "no content" in m.lower() or "no fallback" in m.lower()] assert len(failure_msgs) >= 1, f"Expected at least 1 failure status, got: {status_messages}" + def test_partial_stream_recovery_uses_streamed_content(self, agent): + """When streaming fails after partial delivery, recovered partial content becomes final response.""" + self._setup_agent(agent) + # Simulate a partial-stream-stub response: content recovered from streaming + partial_resp = _mock_response( + content="Here is the partial answer that was stream", + finish_reason="stop", + ) + agent.client.chat.completions.create.return_value = partial_resp + # Simulate that streaming had already delivered this text + agent._current_streamed_assistant_text = "Here is the partial answer that was stream" + with ( + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("explain something") + # The partial content should be used as-is (not empty, not retried) + assert result["completed"] is True + assert result["final_response"] == "Here is the partial answer that was stream" + assert result["api_calls"] == 1 # No retries + + def test_partial_stream_recovery_on_empty_stub(self, agent): + """When stub response has no content but text was streamed, use streamed text.""" + self._setup_agent(agent) + # Stub response with no content (old behavior before fix) + empty_stub = _mock_response(content=None, finish_reason="stop") + + def _fake_api_call(api_kwargs): + # Simulate what streaming does: accumulate text before returning + # a stub with no content (connection died mid-stream) + agent._current_streamed_assistant_text = "The answer to your question is that" + return empty_stub + + status_messages = [] + + def _capture_status(msg): + status_messages.append(msg) + + with ( + patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call), + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + patch.object(agent, "_emit_status", side_effect=_capture_status), + ): + result = agent.run_conversation("ask me") + # Should recover partial streamed content, not fall through to (empty) + assert result["completed"] is True + assert result["final_response"] == "The answer to your question is that" + assert result["api_calls"] == 1 # No wasted retries + # Should emit the stream-interrupted status, NOT the empty-retry status + recovery_msgs = [m for m in status_messages if "stream interrupted" in m.lower()] + assert len(recovery_msgs) >= 1, f"Expected stream recovery status, got: {status_messages}" + # Should NOT have retry statuses + retry_msgs = [m for m in status_messages if "retrying" in m.lower()] + assert len(retry_msgs) == 0, f"Should not retry when stream content exists: {status_messages}" + + def test_partial_stream_recovery_preempts_prior_turn_fallback(self, agent): + """Partial streamed content takes priority over _last_content_with_tools fallback.""" + self._setup_agent(agent) + # Set up the prior-turn fallback content (from a previous turn with tool calls) + agent._last_content_with_tools = "Old content from prior turn with tools" + # Stub response with no content + empty_stub = _mock_response(content=None, finish_reason="stop") + + def _fake_api_call(api_kwargs): + # Simulate partial streaming before connection death + agent._current_streamed_assistant_text = "Fresh partial content from this turn" + return empty_stub + + with ( + patch.object(agent, "_interruptible_api_call", side_effect=_fake_api_call), + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("question") + # Should use the streamed content, not the old prior-turn fallback + assert result["final_response"] == "Fresh partial content from this turn" + assert result["api_calls"] == 1 + def test_nous_401_refreshes_after_remint_and_retries(self, agent): self._setup_agent(agent) agent.provider = "nous"