diff --git a/agent/chat_completion_helpers.py b/agent/chat_completion_helpers.py index 8bab29cae47..7a5e7534723 100644 --- a/agent/chat_completion_helpers.py +++ b/agent/chat_completion_helpers.py @@ -28,6 +28,7 @@ from typing import Any, Dict, Optional from hermes_cli.timeouts import get_provider_request_timeout, get_provider_stale_timeout from hermes_constants import PARTIAL_STREAM_STUB_ID, FINISH_REASON_LENGTH from agent.error_classifier import FailoverReason +from agent.gemini_native_adapter import is_native_gemini_base_url from agent.model_metadata import is_local_endpoint from agent.message_sanitization import ( _sanitize_surrogates, @@ -1911,7 +1912,6 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= stream_kwargs = { **api_kwargs, "stream": True, - "stream_options": {"include_usage": True}, "timeout": _httpx.Timeout( connect=_conn_cap, read=_stream_read_timeout, @@ -1919,6 +1919,14 @@ def interruptible_streaming_api_call(agent, api_kwargs: dict, *, on_first_delta= pool=_conn_cap, ), } + # OpenAI's `stream_options={"include_usage": True}` drives usage + # accounting on OpenAI-compatible endpoints (incl. the Gemini OpenAI + # compat shim and aggregators like OpenRouter). Google's *native* + # Gemini REST endpoint rejects the keyword outright + # (`Completions.create() got an unexpected keyword argument + # 'stream_options'`), so omit it only for that endpoint. + if not is_native_gemini_base_url(agent.base_url): + stream_kwargs["stream_options"] = {"include_usage": True} request_client = _set_request_client( agent._create_request_openai_client( reason="chat_completion_stream_request", diff --git a/tests/run_agent/test_streaming.py b/tests/run_agent/test_streaming.py index 11f8e72d632..134f789388c 100644 --- a/tests/run_agent/test_streaming.py +++ b/tests/run_agent/test_streaming.py @@ -95,6 +95,99 @@ class TestStreamingAccumulator: assert response.usage is not None assert response.usage.completion_tokens == 3 + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_native_gemini_endpoint_omits_stream_options(self, mock_close, mock_create): + """Google's native Gemini REST endpoint rejects OpenAI-only stream_options.""" + from run_agent import AIAgent + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter([ + _make_stream_chunk(content="Paris", finish_reason="stop", model="gemini"), + ]) + mock_create.return_value = mock_client + + agent = AIAgent( + api_key="test-key", + base_url="https://generativelanguage.googleapis.com/v1beta", + model="gemini-3-flash-preview", + provider="gemini", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "Paris" + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert call_kwargs["stream"] is True + assert "stream_options" not in call_kwargs + + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_gemini_openai_compat_shim_keeps_stream_options(self, mock_close, mock_create): + """The Gemini OpenAI-compat shim (.../openai) accepts stream_options.""" + from run_agent import AIAgent + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter([ + _make_stream_chunk(content="ok", finish_reason="stop", model="gemini"), + _make_empty_chunk(usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1)), + ]) + mock_create.return_value = mock_client + + agent = AIAgent( + api_key="test-key", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + model="gemini-3-flash-preview", + provider="gemini", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "ok" + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert call_kwargs["stream_options"] == {"include_usage": True} + + @patch("run_agent.AIAgent._create_request_openai_client") + @patch("run_agent.AIAgent._close_request_openai_client") + def test_openai_compatible_streaming_keeps_stream_options(self, mock_close, mock_create): + """OpenAI-compatible aggregators still request final usage chunks.""" + from run_agent import AIAgent + + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = iter([ + _make_stream_chunk(content="ok", finish_reason="stop", model="test-model"), + _make_empty_chunk(usage=SimpleNamespace(prompt_tokens=2, completion_tokens=1)), + ]) + mock_create.return_value = mock_client + + agent = AIAgent( + api_key="test-key", + base_url="https://openrouter.ai/api/v1", + model="test/model", + provider="openrouter", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + agent.api_mode = "chat_completions" + agent._interrupt_requested = False + + response = agent._interruptible_streaming_api_call({}) + + assert response.choices[0].message.content == "ok" + call_kwargs = mock_client.chat.completions.create.call_args.kwargs + assert call_kwargs["stream_options"] == {"include_usage": True} + @patch("run_agent.AIAgent._create_request_openai_client") @patch("run_agent.AIAgent._close_request_openai_client") def test_tool_call_response(self, mock_close, mock_create):