diff --git a/run_agent.py b/run_agent.py index fb03ee5c4f..2dcfcd682b 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3824,10 +3824,8 @@ class AIAgent: ) if _is_timeout or _is_conn_err: - # Transient network / timeout error. Retry the - # streaming request with a fresh connection rather - # than falling back to non-streaming (which would - # hang for up to 15 min on the same dead server). + # Transient network / timeout error. Retry the + # streaming request with a fresh connection first. if _stream_attempt < _max_stream_retries: logger.info( "Streaming attempt %s/%s failed (%s: %s), " @@ -3845,30 +3843,34 @@ class AIAgent: ) request_client_holder["client"] = None continue - # Exhausted retries — propagate to outer loop logger.warning( - "Streaming exhausted %s retries on transient error: %s", + "Streaming exhausted %s retries on transient error, " + "falling back to non-streaming: %s", _max_stream_retries + 1, e, ) - result["error"] = e - return - - # Non-transient error (e.g. "streaming not supported", - # auth error, 4xx). Fall back to non-streaming once. - err_msg = str(e).lower() - if "stream" in err_msg and "not supported" in err_msg: - logger.info( - "Streaming not supported, falling back to non-streaming: %s", e + else: + _err_lower = str(e).lower() + _is_stream_unsupported = ( + "stream" in _err_lower + and "not supported" in _err_lower + ) + if _is_stream_unsupported: + self._safe_print( + "\n⚠ Streaming is not supported for this " + "model/provider. Falling back to non-streaming.\n" + " To avoid this delay, set display.streaming: false " + "in config.yaml\n" + ) + logger.info( + "Streaming failed before delivery, falling back to non-streaming: %s", + e, ) - try: - result["response"] = self._interruptible_api_call(api_kwargs) - except Exception as fallback_err: - result["error"] = fallback_err - return - # Unknown error — propagate to outer retry loop - result["error"] = e + try: + result["response"] = self._interruptible_api_call(api_kwargs) + except Exception as fallback_err: + result["error"] = fallback_err return finally: request_client = request_client_holder.get("client") diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 9d3ed6f320..88e3aa9e87 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -487,6 +487,51 @@ class TestStreamingFallback: with pytest.raises(Exception, match="Rate limit exceeded"): agent._interruptible_streaming_api_call({}) + @patch("run_agent.AIAgent._interruptible_api_call") + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_exhausted_transient_stream_error_falls_back(self, mock_close, mock_create, mock_non_stream): + """Transient stream errors retry first, then fall back after retries are exhausted.""" + from run_agent import AIAgent + import httpx + + mock_client = MagicMock() + mock_client.chat.completions.create.side_effect = httpx.ConnectError("socket closed") + mock_create.return_value = mock_client + + fallback_response = SimpleNamespace( + id="fallback", + model="test", + choices=[SimpleNamespace( + index=0, + message=SimpleNamespace( + role="assistant", + content="fallback after retries exhausted", + tool_calls=None, + reasoning_content=None, + ), + finish_reason="stop", + )], + usage=None, + ) + mock_non_stream.return_value = fallback_response + + agent = AIAgent( + model="test/model", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "fallback after retries exhausted" + assert mock_client.chat.completions.create.call_count == 3 + mock_non_stream.assert_called_once() + assert mock_close.call_count >= 1 + # ── Test: Reasoning Streaming ────────────────────────────────────────────