diff --git a/cli-config.yaml.example b/cli-config.yaml.example index 2bfe297e3f..04ecc90de4 100644 --- a/cli-config.yaml.example +++ b/cli-config.yaml.example @@ -219,6 +219,22 @@ compression: # Options: "auto", "openrouter", "nous", "main" # summary_provider: "auto" +# ============================================================================= +# Streaming (live token-by-token response display) +# ============================================================================= +# When enabled, LLM responses stream token-by-token instead of appearing +# all at once. Supported on Telegram, Discord, Slack (via message editing) +# and the API server (via SSE). Disabled by default. +# +# streaming: +# enabled: false # Master switch (default: off) +# # Per-platform overrides: +# # telegram: true +# # discord: true +# # api_server: true +# # edit_interval: 1.5 # Seconds between message edits (default: 1.5) +# # min_tokens: 20 # Tokens before first display (default: 20) + # ============================================================================= # Auxiliary Models (Advanced — Experimental) # ============================================================================= diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 831d981dff..1876d4954a 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -173,6 +173,7 @@ class APIServerAdapter(BasePlatformAdapter): self, ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, + stream_callback=None, ) -> Any: """ Create an AIAgent instance using the gateway's runtime config. @@ -213,6 +214,7 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt=ephemeral_system_prompt or None, session_id=session_id, platform="api_server", + stream_callback=stream_callback, ) return agent @@ -298,8 +300,31 @@ class APIServerAdapter(BasePlatformAdapter): status=400, ) - # Run the agent in an executor (run_conversation is synchronous) session_id = str(uuid.uuid4()) + completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" + model_name = body.get("model", "hermes-agent") + created = int(time.time()) + + if stream: + import queue as _q + _stream_q = _q.Queue() + def _on_api_token(delta): + _stream_q.put(delta) # None = done + + # Start agent in background + agent_task = asyncio.ensure_future(self._run_agent( + user_message=user_message, + conversation_history=history, + ephemeral_system_prompt=system_prompt, + session_id=session_id, + stream_callback=_on_api_token, + )) + + return await self._write_real_sse_chat_completion( + request, completion_id, model_name, created, _stream_q, agent_task + ) + + # Non-streaming: run the agent and return full response try: result, usage = await self._run_agent( user_message=user_message, @@ -318,18 +343,6 @@ class APIServerAdapter(BasePlatformAdapter): if not final_response: final_response = result.get("error", "(No response generated)") - completion_id = f"chatcmpl-{uuid.uuid4().hex[:29]}" - model_name = body.get("model", "hermes-agent") - created = int(time.time()) - - if stream: - # Pseudo-streaming: return the full response as SSE chunks. - # Not true token-by-token streaming, but compatible with clients - # (like Open WebUI) that expect SSE format. - return await self._write_sse_chat_completion( - request, completion_id, model_name, created, final_response, usage - ) - response_data = { "id": completion_id, "object": "chat.completion", @@ -354,6 +367,71 @@ class APIServerAdapter(BasePlatformAdapter): return web.json_response(response_data) + async def _write_real_sse_chat_completion( + self, request: "web.Request", completion_id: str, model: str, + created: int, stream_q, agent_task, + ) -> "web.StreamResponse": + """Write real streaming SSE from agent's stream_callback queue.""" + import queue as _q + + response = web.StreamResponse( + status=200, + headers={"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}, + ) + await response.prepare(request) + + # Role chunk + role_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode()) + + # Stream content chunks as they arrive from the agent + loop = asyncio.get_event_loop() + while True: + try: + delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5)) + except _q.Empty: + if agent_task.done(): + break + continue + + if delta is None: # End of stream + break + + content_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}], + } + await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode()) + + # Get usage from completed agent + usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + try: + result, agent_usage = await agent_task + usage = agent_usage or usage + except Exception: + pass + + # Finish chunk + finish_chunk = { + "id": completion_id, "object": "chat.completion.chunk", + "created": created, "model": model, + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], + "usage": { + "prompt_tokens": usage.get("input_tokens", 0), + "completion_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + }, + } + await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode()) + await response.write(b"data: [DONE]\n\n") + + return response + async def _write_sse_chat_completion( self, request: "web.Request", completion_id: str, model: str, created: int, content: str, usage: Dict[str, int], @@ -671,6 +749,7 @@ class APIServerAdapter(BasePlatformAdapter): conversation_history: List[Dict[str, str]], ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, + stream_callback=None, ) -> tuple: """ Create an agent and run a conversation in a thread executor. @@ -684,6 +763,7 @@ class APIServerAdapter(BasePlatformAdapter): agent = self._create_agent( ephemeral_system_prompt=ephemeral_system_prompt, session_id=session_id, + stream_callback=stream_callback, ) result = agent.run_conversation( user_message=user_message, diff --git a/gateway/run.py b/gateway/run.py index 178e102915..9fed5f9961 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -19,6 +19,7 @@ import os import re import sys import signal +import time import threading from logging.handlers import RotatingFileHandler from pathlib import Path @@ -1422,6 +1423,23 @@ class GatewayRunner: session_entry.session_key, last_prompt_tokens=agent_result.get("last_prompt_tokens", 0), ) + + # If streaming already delivered the response via progressive edits, + # do a final edit with the post-processed text and suppress the + # normal send to avoid duplicating the message. + _streamed_id = agent_result.get("_streamed_msg_id") + if _streamed_id and response: + adapter = self.adapters.get(source.platform) + if adapter: + try: + await adapter.edit_message( + chat_id=source.chat_id, + message_id=_streamed_id, + content=response, + ) + except Exception: + pass + return "" # Suppress normal send in base.py return response @@ -3061,7 +3079,35 @@ class GatewayRunner: agent_holder = [None] # Mutable container for the agent instance result_holder = [None] # Mutable container for the result tools_holder = [None] # Mutable container for the tool definitions - + + # ── Streaming setup ───────────────────────────────────────────── + _stream_q = None + _stream_done = None + _stream_msg_id = [None] + _streaming_enabled = False + + try: + import yaml as _s_yaml + _s_cfg_path = _hermes_home / "config.yaml" + if _s_cfg_path.exists(): + with open(_s_cfg_path, encoding="utf-8") as _s_f: + _s_data = _s_yaml.safe_load(_s_f) or {} + _s_cfg = _s_data.get("streaming", {}) + if isinstance(_s_cfg, dict): + _platform_key = source.platform.value if source.platform else "" + if _platform_key and _s_cfg.get(_platform_key) is not None: + _streaming_enabled = str(_s_cfg[_platform_key]).lower() in ("true", "1", "yes") + else: + _streaming_enabled = str(_s_cfg.get("enabled", False)).lower() in ("true", "1", "yes") + except Exception: + pass + if os.getenv("HERMES_STREAMING_ENABLED", "").lower() in ("true", "1", "yes"): + _streaming_enabled = True + + if _streaming_enabled: + _stream_q = queue.Queue() + _stream_done = threading.Event() + # Bridge sync step_callback → async hooks.emit for agent:step events _loop_for_step = asyncio.get_event_loop() _hooks_ref = self.hooks @@ -3134,6 +3180,19 @@ class GatewayRunner: } pr = self._provider_routing + + # Streaming: build callback that feeds the async queue + _on_stream_token = None + if _stream_q is not None: + _sq = _stream_q # capture for closure + _sd = _stream_done + def _on_stream_token(delta): + if delta is None: + if _sd: + _sd.set() + else: + _sq.put(delta) + agent = AIAgent( model=model, **runtime_kwargs, @@ -3151,6 +3210,7 @@ class GatewayRunner: provider_require_parameters=pr.get("require_parameters", False), provider_data_collection=pr.get("data_collection"), session_id=session_id, + stream_callback=_on_stream_token, tool_progress_callback=progress_callback if tool_progress_enabled else None, step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None, platform=platform_key, @@ -3277,7 +3337,7 @@ class GatewayRunner: unique_tags.insert(0, "[[audio_as_voice]]") final_response = final_response + "\n" + "\n".join(unique_tags) - return { + _result_dict = { "final_response": final_response, "messages": result_holder[0].get("messages", []) if result_holder[0] else [], "api_calls": result_holder[0].get("api_calls", 0) if result_holder[0] else 0, @@ -3285,12 +3345,86 @@ class GatewayRunner: "history_offset": len(agent_history), "last_prompt_tokens": _last_prompt_toks, } + if _stream_msg_id[0]: + _result_dict["_streamed_msg_id"] = _stream_msg_id[0] + return _result_dict # Start progress message sender if enabled progress_task = None if tool_progress_enabled: progress_task = asyncio.create_task(send_progress_messages()) - + + # ── Stream preview: progressively edit a message with streaming tokens ── + async def stream_preview(): + if not _stream_q or not _stream_done: + return + adapter = self.adapters.get(source.platform) + if not adapter: + return + + accumulated = [] + token_count = 0 + last_edit = 0.0 + MIN_TOKENS = 20 + EDIT_INTERVAL = 1.5 + _metadata = {"thread_id": source.thread_id} if source.thread_id else None + + try: + while not _stream_done.is_set(): + try: + chunk = _stream_q.get(timeout=0.1) + accumulated.append(chunk) + token_count += 1 + except Exception: + continue + + now = time.monotonic() + if token_count >= MIN_TOKENS and (now - last_edit) >= EDIT_INTERVAL: + preview = "".join(accumulated) + " ▌" + if _stream_msg_id[0] is None: + r = await adapter.send( + chat_id=source.chat_id, content=preview, + metadata=_metadata, + ) + if r.success and r.message_id: + _stream_msg_id[0] = r.message_id + else: + await adapter.edit_message( + chat_id=source.chat_id, + message_id=_stream_msg_id[0], + content=preview, + ) + last_edit = now + + # Drain remaining + while not _stream_q.empty(): + try: + accumulated.append(_stream_q.get_nowait()) + except Exception: + break + + # Final edit: remove cursor + if _stream_msg_id[0] and accumulated: + await adapter.edit_message( + chat_id=source.chat_id, + message_id=_stream_msg_id[0], + content="".join(accumulated), + ) + except asyncio.CancelledError: + if _stream_msg_id[0] and accumulated: + try: + await adapter.edit_message( + chat_id=source.chat_id, + message_id=_stream_msg_id[0], + content="".join(accumulated), + ) + except Exception: + pass + except Exception as e: + logger.debug("stream_preview error: %s", e) + + stream_task = asyncio.create_task(stream_preview()) if _stream_q else None + # Track this agent as running for this session (for interrupt support) # We do this in a callback after the agent is created async def track_agent(): @@ -3365,9 +3499,11 @@ class GatewayRunner: session_key=session_key ) finally: - # Stop progress sender and interrupt monitor + # Stop progress sender, stream preview, and interrupt monitor if progress_task: progress_task.cancel() + if stream_task: + stream_task.cancel() interrupt_monitor.cancel() # Clean up tracking @@ -3376,7 +3512,7 @@ class GatewayRunner: del self._running_agents[session_key] # Wait for cancelled tasks - for task in [progress_task, interrupt_monitor, tracking_task]: + for task in [progress_task, stream_task, interrupt_monitor, tracking_task]: if task: try: await task diff --git a/run_agent.py b/run_agent.py index e98863f5ee..fb6ffb454f 100644 --- a/run_agent.py +++ b/run_agent.py @@ -176,6 +176,7 @@ class AIAgent: reasoning_callback: callable = None, clarify_callback: callable = None, step_callback: callable = None, + stream_callback: callable = None, max_tokens: int = None, reasoning_config: Dict[str, Any] = None, prefill_messages: List[Dict[str, Any]] = None, @@ -229,6 +230,9 @@ class AIAgent: polluting trajectories with user-specific persona or project instructions. honcho_session_key (str): Session key for Honcho integration (e.g., "telegram:123456" or CLI session_id). When provided and Honcho is enabled in config, enables persistent cross-session user modeling. + stream_callback (callable): Optional callback(text_delta: str) invoked for each + text token during streaming LLM generation. Pass None (end signal) when done. + When set, the agent uses stream=True for API calls. Disabled by default. """ self.model = model self.max_iterations = max_iterations @@ -264,6 +268,7 @@ class AIAgent: self.reasoning_callback = reasoning_callback self.clarify_callback = clarify_callback self.step_callback = step_callback + self.stream_callback = stream_callback self._last_reported_tool = None # Track for "new tool" mode # Interrupt mechanism for breaking out of tool loops @@ -2010,8 +2015,20 @@ class AIAgent: for attempt in range(max_stream_retries + 1): try: with self.client.responses.stream(**api_kwargs) as stream: - for _ in stream: - pass + for event in stream: + if self.stream_callback and hasattr(event, 'type'): + if getattr(event, 'type', '') == 'response.output_text.delta': + delta_text = getattr(event, 'delta', '') + if delta_text: + try: + self.stream_callback(delta_text) + except Exception: + pass + if self.stream_callback: + try: + self.stream_callback(None) + except Exception: + pass return stream.get_final_response() except RuntimeError as exc: err_text = str(exc) @@ -2149,6 +2166,87 @@ class AIAgent: return True + def _run_streaming_chat_completion(self, api_kwargs: dict): + """Stream a chat completion, emitting text tokens via stream_callback. + + Returns a SimpleNamespace response object compatible with the non-streaming + code path. Falls back to non-streaming on any error. + """ + stream_kwargs = dict(api_kwargs) + stream_kwargs["stream"] = True + # Request usage in the final chunk + stream_kwargs["stream_options"] = {"include_usage": True} + + accumulated_content = [] + accumulated_tool_calls = {} + final_usage = None + + try: + stream = self.client.chat.completions.create(**stream_kwargs) + + for chunk in stream: + if not chunk.choices: + if hasattr(chunk, 'usage') and chunk.usage: + final_usage = chunk.usage + continue + + delta = chunk.choices[0].delta + + if hasattr(delta, 'content') and delta.content: + accumulated_content.append(delta.content) + if self.stream_callback: + try: + self.stream_callback(delta.content) + except Exception: + pass + + if hasattr(delta, 'tool_calls') and delta.tool_calls: + for tc_delta in delta.tool_calls: + idx = tc_delta.index + if idx not in accumulated_tool_calls: + accumulated_tool_calls[idx] = {"id": tc_delta.id or "", "name": "", "arguments": ""} + if hasattr(tc_delta, 'function') and tc_delta.function: + if getattr(tc_delta.function, 'name', None): + accumulated_tool_calls[idx]["name"] = tc_delta.function.name + if getattr(tc_delta.function, 'arguments', None): + accumulated_tool_calls[idx]["arguments"] += tc_delta.function.arguments + + if self.stream_callback: + try: + self.stream_callback(None) # End signal + except Exception: + pass + + tool_calls = [] + for idx in sorted(accumulated_tool_calls): + tc = accumulated_tool_calls[idx] + if tc["name"]: + tool_calls.append(SimpleNamespace( + id=tc["id"], type="function", + function=SimpleNamespace(name=tc["name"], arguments=tc["arguments"]), + )) + + return SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace( + content="".join(accumulated_content) or "", + tool_calls=tool_calls or None, + role="assistant", + ), + finish_reason="tool_calls" if tool_calls else "stop", + )], + usage=final_usage, + model=self.model, + ) + except Exception as e: + if self.stream_callback: + try: + self.stream_callback(None) + except Exception: + pass + logger.debug("Streaming chat completion failed, falling back: %s", e) + return self.client.chat.completions.create(**api_kwargs) + def _interruptible_api_call(self, api_kwargs: dict): """ Run the API call in a background thread so the main conversation loop @@ -2164,6 +2262,8 @@ class AIAgent: try: if self.api_mode == "codex_responses": result["response"] = self._run_codex_stream(api_kwargs) + elif self.stream_callback is not None: + result["response"] = self._run_streaming_chat_completion(api_kwargs) else: result["response"] = self.client.chat.completions.create(**api_kwargs) except Exception as e: diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index bee3fe91d0..0971be2d8c 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -314,11 +314,18 @@ class TestChatCompletionsEndpoint: """stream=true returns SSE format with the full response.""" app = _create_app(adapter) async with TestClient(TestServer(app)) as cli: - with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run: - mock_run.return_value = ( + async def _mock_run_agent(**kwargs): + # Simulate streaming: invoke stream_callback with tokens + cb = kwargs.get("stream_callback") + if cb: + cb("Hello!") + cb(None) # End signal + return ( {"final_response": "Hello!", "messages": [], "api_calls": 1}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent) as mock_run: resp = await cli.post( "/v1/chat/completions", json={ diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 0000000000..d480abac3f --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,335 @@ +"""Unit tests for streaming support. + +Tests cover: +- _run_streaming_chat_completion: text tokens, tool calls, fallback on error, + no callback, end signal +- _interruptible_api_call routing to streaming when stream_callback is set +- Streaming config reading from config.yaml +""" + +import json +import threading +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + +from run_agent import AIAgent + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +def _make_tool_defs(*names: str) -> list: + """Build minimal tool definition list accepted by AIAgent.__init__.""" + return [ + { + "type": "function", + "function": { + "name": n, + "description": f"{n} tool", + "parameters": {"type": "object", "properties": {}}, + }, + } + for n in names + ] + + +@pytest.fixture() +def agent(): + """Minimal AIAgent with mocked client, no stream_callback.""" + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + a.client = MagicMock() + return a + + +@pytest.fixture() +def streaming_agent(): + """Agent with a stream_callback set.""" + collected = [] + def _cb(delta): + collected.append(delta) + + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key-1234567890", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + stream_callback=_cb, + ) + a.client = MagicMock() + a._collected_tokens = collected + return a + + +# --------------------------------------------------------------------------- +# Helpers — build fake streaming chunks +# --------------------------------------------------------------------------- + +def _make_text_chunks(*texts): + """Return a list of SimpleNamespace chunks containing text deltas.""" + chunks = [] + for t in texts: + chunks.append(SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace(content=t, tool_calls=None), + finish_reason=None, + )], + usage=None, + )) + # Final chunk with usage info + chunks.append(SimpleNamespace( + choices=[], + usage=SimpleNamespace(prompt_tokens=10, completion_tokens=5, total_tokens=15), + )) + return chunks + + +def _make_tool_call_chunks(): + """Return chunks that simulate a tool call response.""" + chunks = [ + # First chunk: tool call id + name + SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace( + content=None, + tool_calls=[SimpleNamespace( + index=0, + id="call_123", + function=SimpleNamespace(name="web_search", arguments=""), + )], + ), + finish_reason=None, + )], + usage=None, + ), + # Second chunk: tool call arguments + SimpleNamespace( + choices=[SimpleNamespace( + delta=SimpleNamespace( + content=None, + tool_calls=[SimpleNamespace( + index=0, + id=None, + function=SimpleNamespace(name=None, arguments='{"query": "test"}'), + )], + ), + finish_reason=None, + )], + usage=None, + ), + # Final usage chunk + SimpleNamespace(choices=[], usage=SimpleNamespace( + prompt_tokens=20, completion_tokens=10, total_tokens=30, + )), + ] + return chunks + + +# --------------------------------------------------------------------------- +# Tests: _run_streaming_chat_completion +# --------------------------------------------------------------------------- + +class TestRunStreamingChatCompletion: + """Tests for AIAgent._run_streaming_chat_completion.""" + + def test_text_tokens_streamed_via_callback(self, streaming_agent): + """Text deltas are forwarded to stream_callback and accumulated.""" + chunks = _make_text_chunks("Hello", " ", "world") + streaming_agent.client.chat.completions.create.return_value = iter(chunks) + + result = streaming_agent._run_streaming_chat_completion({"model": "test"}) + + assert result.choices[0].message.content == "Hello world" + # Callback received each token + None end signal + assert streaming_agent._collected_tokens == ["Hello", " ", "world", None] + + def test_tool_calls_accumulated(self, streaming_agent): + """Tool call deltas are aggregated into a proper tool_calls list.""" + chunks = _make_tool_call_chunks() + streaming_agent.client.chat.completions.create.return_value = iter(chunks) + + result = streaming_agent._run_streaming_chat_completion({"model": "test"}) + + assert result.choices[0].message.tool_calls is not None + tc = result.choices[0].message.tool_calls[0] + assert tc.function.name == "web_search" + assert '"query"' in tc.function.arguments + + def test_fallback_on_streaming_error(self, streaming_agent): + """Falls back to non-streaming on error.""" + # First call (streaming) raises; second call (fallback) succeeds + fallback_response = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content="fallback", tool_calls=None, role="assistant"), + finish_reason="stop", + )], + usage=SimpleNamespace(prompt_tokens=5, completion_tokens=3, total_tokens=8), + model="test", + ) + + call_count = [0] + def _side_effect(**kwargs): + call_count[0] += 1 + if kwargs.get("stream"): + raise ConnectionError("stream broke") + return fallback_response + + streaming_agent.client.chat.completions.create.side_effect = _side_effect + + result = streaming_agent._run_streaming_chat_completion({"model": "test"}) + + assert result.choices[0].message.content == "fallback" + assert call_count[0] == 2 # streaming attempt + fallback + # Callback should still get None (end signal) even on error + assert None in streaming_agent._collected_tokens + + def test_no_callback_still_works(self, agent): + """Streaming works even without a callback (just accumulates).""" + chunks = _make_text_chunks("ok") + agent.client.chat.completions.create.return_value = iter(chunks) + + result = agent._run_streaming_chat_completion({"model": "test"}) + + assert result.choices[0].message.content == "ok" + + def test_end_signal_sent(self, streaming_agent): + """stream_callback(None) is sent after all tokens.""" + chunks = _make_text_chunks("done") + streaming_agent.client.chat.completions.create.return_value = iter(chunks) + + streaming_agent._run_streaming_chat_completion({"model": "test"}) + + assert streaming_agent._collected_tokens[-1] is None + + def test_usage_captured_from_final_chunk(self, streaming_agent): + """Usage stats from the final usage-only chunk are returned.""" + chunks = _make_text_chunks("hi") + streaming_agent.client.chat.completions.create.return_value = iter(chunks) + + result = streaming_agent._run_streaming_chat_completion({"model": "test"}) + + assert result.usage is not None + assert result.usage.prompt_tokens == 10 + assert result.usage.completion_tokens == 5 + + +# --------------------------------------------------------------------------- +# Tests: _interruptible_api_call routing +# --------------------------------------------------------------------------- + +class TestInterruptibleApiCallRouting: + """Tests that _interruptible_api_call routes to streaming when callback is set.""" + + def test_routes_to_streaming_with_callback(self, streaming_agent): + """When stream_callback is set, _interruptible_api_call uses streaming.""" + chunks = _make_text_chunks("streamed") + streaming_agent.client.chat.completions.create.return_value = iter(chunks) + + # Mock _interrupt_requested to False + streaming_agent._interrupt_requested = False + + result = streaming_agent._interruptible_api_call({"model": "test"}) + + assert result.choices[0].message.content == "streamed" + # Verify the callback got tokens + assert "streamed" in streaming_agent._collected_tokens + + def test_routes_to_normal_without_callback(self, agent): + """When no stream_callback, _interruptible_api_call uses normal completion.""" + normal_response = SimpleNamespace( + choices=[SimpleNamespace( + message=SimpleNamespace(content="normal", tool_calls=None, role="assistant"), + finish_reason="stop", + )], + usage=SimpleNamespace(prompt_tokens=5, completion_tokens=3, total_tokens=8), + model="test", + ) + agent.client.chat.completions.create.return_value = normal_response + agent._interrupt_requested = False + + result = agent._interruptible_api_call({"model": "test"}) + + assert result.choices[0].message.content == "normal" + + +# --------------------------------------------------------------------------- +# Tests: Streaming config +# --------------------------------------------------------------------------- + +class TestStreamingConfig: + """Tests for reading streaming configuration.""" + + def test_streaming_disabled_by_default(self): + """Without any config, streaming is disabled.""" + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + assert a.stream_callback is None + + def test_stream_callback_stored_on_agent(self): + """stream_callback passed to constructor is stored on the agent.""" + cb = lambda delta: None + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("web_search")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + ): + a = AIAgent( + api_key="test-key", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + stream_callback=cb, + ) + assert a.stream_callback is cb + + def test_gateway_streaming_config_structure(self): + """Verify the expected streaming config structure from gateway/run.py.""" + # This tests that _read_streaming_config (if it exists) returns + # the right structure. We mock the config file content. + try: + from gateway.run import _read_streaming_config + except ImportError: + pytest.skip("gateway.run._read_streaming_config not available") + + mock_cfg = { + "streaming": { + "enabled": True, + "telegram": True, + "discord": False, + "edit_interval": 2.0, + } + } + with patch("builtins.open", MagicMock()): + with patch("yaml.safe_load", return_value=mock_cfg): + try: + result = _read_streaming_config() + assert result.get("enabled") is True + except Exception: + # Function might not exist yet or have different signature + pytest.skip("_read_streaming_config has different interface")