diff --git a/cli.py b/cli.py index 6292993e2b..b7203e9063 100755 --- a/cli.py +++ b/cli.py @@ -1474,9 +1474,10 @@ class HermesCLI: self._in_reasoning_block = False after = self._stream_prefilt[idx + len(tag):] self._stream_prefilt = "" - # Process remaining text after close tag + # Process remaining text after close tag through full + # filtering (it could contain another open tag) if after: - self._emit_stream_text(after) + self._stream_delta(after) return # Still inside reasoning block — keep only the tail that could # be a partial close tag prefix (save memory on long blocks). diff --git a/gateway/run.py b/gateway/run.py index 8bc860c3a0..71f453d889 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4371,10 +4371,19 @@ class GatewayRunner: if tool_progress_enabled: progress_task = asyncio.create_task(send_progress_messages()) - # Start stream consumer task if configured + # Start stream consumer task — polls for consumer creation since it + # happens inside run_sync (thread pool) after the agent is constructed. stream_task = None - if stream_consumer_holder[0] is not None: - stream_task = asyncio.create_task(stream_consumer_holder[0].run()) + + async def _start_stream_consumer(): + """Wait for the stream consumer to be created, then run it.""" + for _ in range(200): # Up to 10s wait + if stream_consumer_holder[0] is not None: + await stream_consumer_holder[0].run() + return + await asyncio.sleep(0.05) + + stream_task = asyncio.create_task(_start_stream_consumer()) # Track this agent as running for this session (for interrupt support) # We do this in a callback after the agent is created diff --git a/run_agent.py b/run_agent.py index 459b8d0ef1..a9088732a8 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2604,7 +2604,7 @@ class AIAgent: def _close_request_openai_client(self, client: Any, *, reason: str) -> None: self._close_openai_client(client, reason=reason, shared=False) - def _run_codex_stream(self, api_kwargs: dict, client: Any = None): + def _run_codex_stream(self, api_kwargs: dict, client: Any = None, on_first_delta: callable = None): """Execute one streaming Responses API request and return the final response.""" active_client = client or self._ensure_primary_openai_client(reason="codex_stream_direct") max_stream_retries = 1 @@ -2623,6 +2623,11 @@ class AIAgent: if delta_text and not has_tool_calls: if not first_delta_fired: first_delta_fired = True + if on_first_delta: + try: + on_first_delta() + except Exception: + pass self._fire_stream_delta(delta_text) # Track tool calls to suppress text streaming elif "function_call" in event_type: @@ -2812,6 +2817,7 @@ class AIAgent: result["response"] = self._run_codex_stream( api_kwargs, client=request_client_holder["client"], + on_first_delta=getattr(self, "_codex_on_first_delta", None), ) elif self.api_mode == "anthropic_messages": result["response"] = self._anthropic_messages_create(api_kwargs) @@ -2853,146 +2859,6 @@ class AIAgent: raise result["error"] return result["response"] - def _streaming_api_call(self, api_kwargs: dict, stream_callback): - """Streaming variant of _interruptible_api_call for voice TTS pipeline. - - Uses ``stream=True`` and forwards content deltas to *stream_callback* - in real-time. Returns a ``SimpleNamespace`` that mimics a normal - ``ChatCompletion`` so the rest of the agent loop works unchanged. - - This method is separate from ``_interruptible_api_call`` to keep the - core agent loop untouched for non-voice users. - """ - result = {"response": None, "error": None} - request_client_holder = {"client": None} - - def _call(): - try: - stream_kwargs = {**api_kwargs, "stream": True} - request_client_holder["client"] = self._create_request_openai_client( - reason="chat_completion_stream_request" - ) - stream = request_client_holder["client"].chat.completions.create(**stream_kwargs) - - content_parts: list[str] = [] - tool_calls_acc: dict[int, dict] = {} - finish_reason = None - model_name = None - role = "assistant" - - for chunk in stream: - if not chunk.choices: - if hasattr(chunk, "model") and chunk.model: - model_name = chunk.model - continue - - delta = chunk.choices[0].delta - if hasattr(chunk, "model") and chunk.model: - model_name = chunk.model - - if delta and delta.content: - content_parts.append(delta.content) - try: - stream_callback(delta.content) - except Exception: - pass - - if delta and delta.tool_calls: - for tc_delta in delta.tool_calls: - idx = tc_delta.index if tc_delta.index is not None else 0 - if idx in tool_calls_acc and tc_delta.id and tc_delta.id != tool_calls_acc[idx]["id"]: - matched = False - for eidx, eentry in tool_calls_acc.items(): - if eentry["id"] == tc_delta.id: - idx = eidx - matched = True - break - if not matched: - idx = (max(k for k in tool_calls_acc if isinstance(k, int)) + 1) if tool_calls_acc else 0 - if idx not in tool_calls_acc: - tool_calls_acc[idx] = { - "id": tc_delta.id or "", - "type": "function", - "function": {"name": "", "arguments": ""}, - } - entry = tool_calls_acc[idx] - if tc_delta.id: - entry["id"] = tc_delta.id - if tc_delta.function: - if tc_delta.function.name: - entry["function"]["name"] += tc_delta.function.name - if tc_delta.function.arguments: - entry["function"]["arguments"] += tc_delta.function.arguments - - if chunk.choices[0].finish_reason: - finish_reason = chunk.choices[0].finish_reason - - full_content = "".join(content_parts) or None - mock_tool_calls = None - if tool_calls_acc: - mock_tool_calls = [] - for idx in sorted(tool_calls_acc): - tc = tool_calls_acc[idx] - mock_tool_calls.append(SimpleNamespace( - id=tc["id"], - type=tc["type"], - function=SimpleNamespace( - name=tc["function"]["name"], - arguments=tc["function"]["arguments"], - ), - )) - - mock_message = SimpleNamespace( - role=role, - content=full_content, - tool_calls=mock_tool_calls, - reasoning_content=None, - ) - mock_choice = SimpleNamespace( - index=0, - message=mock_message, - finish_reason=finish_reason or "stop", - ) - mock_response = SimpleNamespace( - id="stream-" + str(uuid.uuid4()), - model=model_name, - choices=[mock_choice], - usage=None, - ) - result["response"] = mock_response - - except Exception as e: - result["error"] = e - finally: - request_client = request_client_holder.get("client") - if request_client is not None: - self._close_request_openai_client(request_client, reason="stream_request_complete") - - t = threading.Thread(target=_call, daemon=True) - t.start() - while t.is_alive(): - t.join(timeout=0.3) - if self._interrupt_requested: - try: - if self.api_mode == "anthropic_messages": - from agent.anthropic_adapter import build_anthropic_client - - self._anthropic_client.close() - self._anthropic_client = build_anthropic_client( - self._anthropic_api_key, - getattr(self, "_anthropic_base_url", None), - ) - else: - request_client = request_client_holder.get("client") - if request_client is not None: - self._close_request_openai_client(request_client, reason="stream_interrupt_abort") - except Exception: - pass - raise InterruptedError("Agent interrupted during API call") - if result["error"] is not None: - raise result["error"] - return result["response"] - # ── Unified streaming API call ───────────────────────────────────────── def _fire_stream_delta(self, text: str) -> None: @@ -3039,12 +2905,20 @@ class AIAgent: streaming is not supported. """ if self.api_mode == "codex_responses": - # Codex already streams internally; we just need to pass callbacks - return self._interruptible_api_call(api_kwargs) + # Codex streams internally via _run_codex_stream. The main dispatch + # in _interruptible_api_call already calls it; we just need to + # ensure on_first_delta reaches it. Store it on the instance + # temporarily so _run_codex_stream can pick it up. + self._codex_on_first_delta = on_first_delta + try: + return self._interruptible_api_call(api_kwargs) + finally: + self._codex_on_first_delta = None result = {"response": None, "error": None} request_client_holder = {"client": None} first_delta_fired = {"done": False} + deltas_were_sent = {"yes": False} # Track if any deltas were fired (for fallback) def _fire_first_delta(): if not first_delta_fired["done"] and on_first_delta: @@ -3098,6 +2972,7 @@ class AIAgent: if not tool_calls_acc: _fire_first_delta() self._fire_stream_delta(delta.content) + deltas_were_sent["yes"] = True # Accumulate tool call deltas (silently, no callback) if delta and delta.tool_calls: @@ -3208,17 +3083,22 @@ class AIAgent: else: result["response"] = _call_chat_completions() except Exception as e: - # Always fall back to non-streaming on ANY streaming error. - # Many third-party/extrinsic providers have partial or broken - # streaming support — rejecting stream=True, crashing on - # stream_options, dropping connections mid-stream, etc. - # A clean fallback to the standard request path ensures the - # agent still works even if streaming doesn't. - logger.info("Streaming failed, falling back to non-streaming: %s", e) - try: - result["response"] = self._interruptible_api_call(api_kwargs) - except Exception as fallback_err: - result["error"] = fallback_err + if deltas_were_sent["yes"]: + # Streaming failed AFTER some tokens were already delivered + # to consumers. Don't fall back — that would cause + # double-delivery (partial streamed + full non-streamed). + # Let the error propagate; the partial content already + # reached the user via the stream. + logger.warning("Streaming failed after partial delivery, not falling back: %s", e) + result["error"] = e + else: + # Streaming failed before any tokens reached consumers. + # Safe to fall back to the standard non-streaming path. + logger.info("Streaming failed before delivery, falling back to non-streaming: %s", e) + try: + result["response"] = self._interruptible_api_call(api_kwargs) + except Exception as fallback_err: + result["error"] = fallback_err finally: request_client = request_client_holder.get("client") if request_client is not None: diff --git a/tests/test_openai_client_lifecycle.py b/tests/test_openai_client_lifecycle.py index 695737895d..72d92fd15e 100644 --- a/tests/test_openai_client_lifecycle.py +++ b/tests/test_openai_client_lifecycle.py @@ -59,8 +59,11 @@ def _build_agent(shared_client=None): agent._interrupt_requested = False agent._interrupt_message = None agent._client_lock = threading.RLock() - agent._client_kwargs = {"api_key": "test-key", "base_url": agent.base_url} + agent._client_kwargs = {"api_key": "***", "base_url": agent.base_url} agent.client = shared_client or FakeSharedClient(lambda **kwargs: {"shared": True}) + agent.stream_delta_callback = None + agent._stream_callback = None + agent.reasoning_callback = None return agent @@ -173,7 +176,11 @@ def test_streaming_call_recreates_closed_shared_client_before_request(monkeypatc monkeypatch.setattr(run_agent, "OpenAI", factory) agent = _build_agent(shared_client=stale_shared) - response = agent._streaming_api_call({"model": agent.model, "messages": []}, lambda _delta: None) + agent.stream_delta_callback = lambda _delta: None + # Force chat_completions mode so the streaming path uses + # chat.completions.create(stream=True) instead of Codex responses.stream() + agent.api_mode = "chat_completions" + response = agent._interruptible_streaming_api_call({"model": agent.model, "messages": []}) assert response.choices[0].message.content == "Hello world" assert agent.client is replacement_shared diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 2cc37fc51c..cfe8bab208 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -2329,8 +2329,9 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) callback = MagicMock() + agent.stream_delta_callback = callback - resp = agent._streaming_api_call({"messages": []}, callback) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.choices[0].message.content == "Hello World" assert resp.choices[0].finish_reason == "stop" @@ -2347,7 +2348,7 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) tc = resp.choices[0].message.tool_calls assert len(tc) == 1 @@ -2363,7 +2364,7 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) tc = resp.choices[0].message.tool_calls assert len(tc) == 2 @@ -2378,7 +2379,7 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.choices[0].message.content == "I'll search" assert len(resp.choices[0].message.tool_calls) == 1 @@ -2387,7 +2388,7 @@ class TestStreamingApiCall: chunks = [_make_chunk(finish_reason="stop")] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.choices[0].message.content is None assert resp.choices[0].message.tool_calls is None @@ -2399,9 +2400,9 @@ class TestStreamingApiCall: _make_chunk(finish_reason="stop"), ] agent.client.chat.completions.create.return_value = iter(chunks) - callback = MagicMock(side_effect=ValueError("boom")) + agent.stream_delta_callback = MagicMock(side_effect=ValueError("boom")) - resp = agent._streaming_api_call({"messages": []}, callback) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.choices[0].message.content == "Hello World" @@ -2412,7 +2413,7 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.model == "gpt-4o" @@ -2420,22 +2421,23 @@ class TestStreamingApiCall: chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")] agent.client.chat.completions.create.return_value = iter(chunks) - agent._streaming_api_call({"messages": [], "model": "test"}, MagicMock()) + agent._interruptible_streaming_api_call({"messages": [], "model": "test"}) call_kwargs = agent.client.chat.completions.create.call_args assert call_kwargs[1].get("stream") is True or call_kwargs.kwargs.get("stream") is True - def test_api_exception_propagated(self, agent): + def test_api_exception_falls_back_to_non_streaming(self, agent): + """When streaming fails before any deltas, fallback to non-streaming is attempted.""" agent.client.chat.completions.create.side_effect = ConnectionError("fail") - + # The fallback also uses the same client, so it'll fail too with pytest.raises(ConnectionError, match="fail"): - agent._streaming_api_call({"messages": []}, MagicMock()) + agent._interruptible_streaming_api_call({"messages": []}) def test_response_has_uuid_id(self, agent): chunks = [_make_chunk(content="x"), _make_chunk(finish_reason="stop")] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.id.startswith("stream-") assert len(resp.id) > len("stream-") @@ -2449,7 +2451,7 @@ class TestStreamingApiCall: ] agent.client.chat.completions.create.return_value = iter(chunks) - resp = agent._streaming_api_call({"messages": []}, MagicMock()) + resp = agent._interruptible_streaming_api_call({"messages": []}) assert resp.choices[0].message.content == "Hello" assert resp.model == "gpt-4" @@ -2505,7 +2507,7 @@ class TestAnthropicInterruptHandler: def test_streaming_has_anthropic_branch(self): """_streaming_api_call must also handle Anthropic interrupt.""" import inspect - source = inspect.getsource(AIAgent._streaming_api_call) + source = inspect.getsource(AIAgent._interruptible_streaming_api_call) assert "anthropic_messages" in source, \ "_streaming_api_call must handle Anthropic interrupt"