diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 80d2033b7..9156eaa26 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -815,9 +815,10 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model) if provider_id == "gemini": - from agent.gemini_native_adapter import GeminiNativeClient + from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url - return GeminiNativeClient(api_key=api_key, base_url=base_url), model + if is_native_gemini_base_url(base_url): + return GeminiNativeClient(api_key=api_key, base_url=base_url), model extra = {} if "api.kimi.com" in base_url.lower(): extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} @@ -840,9 +841,10 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]: continue # skip provider if we don't know a valid aux model logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model) if provider_id == "gemini": - from agent.gemini_native_adapter import GeminiNativeClient + from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url - return GeminiNativeClient(api_key=api_key, base_url=base_url), model + if is_native_gemini_base_url(base_url): + return GeminiNativeClient(api_key=api_key, base_url=base_url), model extra = {} if "api.kimi.com" in base_url.lower(): extra["default_headers"] = {"User-Agent": "KimiCLI/1.30.0"} @@ -1703,12 +1705,13 @@ def resolve_provider_client( final_model = _normalize_resolved_model(model or default_model, provider) if provider == "gemini": - from agent.gemini_native_adapter import GeminiNativeClient + from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url - client = GeminiNativeClient(api_key=api_key, base_url=base_url) - logger.debug("resolve_provider_client: %s (%s)", provider, final_model) - return (_to_async_client(client, final_model) if async_mode - else (client, final_model)) + if is_native_gemini_base_url(base_url): + client = GeminiNativeClient(api_key=api_key, base_url=base_url) + logger.debug("resolve_provider_client: %s (%s)", provider, final_model) + return (_to_async_client(client, final_model) if async_mode + else (client, final_model)) # Provider-specific headers headers = {} diff --git a/agent/gemini_native_adapter.py b/agent/gemini_native_adapter.py index a495137a8..72fba8f29 100644 --- a/agent/gemini_native_adapter.py +++ b/agent/gemini_native_adapter.py @@ -32,6 +32,16 @@ logger = logging.getLogger(__name__) DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta" +def is_native_gemini_base_url(base_url: str) -> bool: + """Return True when the endpoint speaks Gemini's native REST API.""" + normalized = str(base_url or "").strip().rstrip("/").lower() + if not normalized: + return False + if "generativelanguage.googleapis.com" not in normalized: + return False + return not normalized.endswith("/openai") + + class GeminiAPIError(Exception): """Error shape compatible with Hermes retry/error classification.""" @@ -520,7 +530,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices: parts = ((cand.get("content") or {}).get("parts") or []) if isinstance(cand, dict) else [] chunks: List[_GeminiStreamChunk] = [] - for part in parts: + for part_index, part in enumerate(parts): if not isinstance(part, dict): continue if part.get("thought") is True and isinstance(part.get("text"), str): @@ -536,14 +546,30 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices: except (TypeError, ValueError): args_str = "{}" thought_signature = part.get("thoughtSignature") if isinstance(part.get("thoughtSignature"), str) else "" - call_key = json.dumps({"name": name, "args": args_str, "thought_signature": thought_signature}, sort_keys=True) + call_key = json.dumps( + { + "part_index": part_index, + "name": name, + "thought_signature": thought_signature, + }, + sort_keys=True, + ) slot = tool_call_indices.get(call_key) if slot is None: slot = { "index": len(tool_call_indices), "id": f"call_{uuid.uuid4().hex[:12]}", + "last_arguments": "", } tool_call_indices[call_key] = slot + emitted_arguments = args_str + last_arguments = str(slot.get("last_arguments") or "") + if last_arguments: + if args_str == last_arguments: + emitted_arguments = "" + elif args_str.startswith(last_arguments): + emitted_arguments = args_str[len(last_arguments):] + slot["last_arguments"] = args_str chunks.append( _make_stream_chunk( model=model, @@ -551,7 +577,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices: "index": slot["index"], "id": slot["id"], "name": name, - "arguments": args_str, + "arguments": emitted_arguments, "extra_content": _tool_call_extra_from_part(part), }, ) @@ -672,6 +698,7 @@ class GeminiNativeClient: base_url: Optional[str] = None, default_headers: Optional[Dict[str, str]] = None, timeout: Any = None, + http_client: Optional[httpx.Client] = None, **_: Any, ) -> None: self.api_key = api_key @@ -682,7 +709,9 @@ class GeminiNativeClient: self._default_headers = dict(default_headers or {}) self.chat = _GeminiChatNamespace(self) self.is_closed = False - self._http = httpx.Client(timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0)) + self._http = http_client or httpx.Client( + timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0) + ) def close(self) -> None: self.is_closed = True @@ -707,6 +736,13 @@ class GeminiNativeClient: headers.update(self._default_headers) return headers + @staticmethod + def _advance_stream_iterator(iterator: Iterator[_GeminiStreamChunk]) -> tuple[bool, Optional[_GeminiStreamChunk]]: + try: + return False, next(iterator) + except StopIteration: + return True, None + def _create_chat_completion( self, *, @@ -767,7 +803,7 @@ class GeminiNativeClient: if response.status_code != 200: response.read() raise gemini_http_error(response) - tool_call_indices: Dict[str, int] = {} + tool_call_indices: Dict[str, Dict[str, Any]] = {} for event in _iter_sse_events(response): for chunk in translate_stream_event(event, model, tool_call_indices): yield chunk @@ -790,7 +826,19 @@ class AsyncGeminiNativeClient: self.chat = _AsyncGeminiChatNamespace(self) async def _create_chat_completion(self, **kwargs: Any) -> Any: - return await asyncio.to_thread(self._sync.chat.completions.create, **kwargs) + stream = bool(kwargs.get("stream")) + result = await asyncio.to_thread(self._sync.chat.completions.create, **kwargs) + if not stream: + return result + + async def _async_stream() -> Any: + while True: + done, chunk = await asyncio.to_thread(self._sync._advance_stream_iterator, result) + if done: + break + yield chunk + + return _async_stream() async def close(self) -> None: await asyncio.to_thread(self._sync.close) diff --git a/plans/gemini-oauth-provider.md b/plans/gemini-oauth-provider.md index 9953d0eca..a466183e8 100644 --- a/plans/gemini-oauth-provider.md +++ b/plans/gemini-oauth-provider.md @@ -4,7 +4,7 @@ Add a first-class `gemini` provider that authenticates via Google OAuth, using the standard Gemini API (not Cloud Code Assist). Users who have a Google AI subscription or Gemini API access can authenticate through the browser without needing to manually copy API keys. ## Architecture Decision -- **Path A (chosen):** Standard Gemini API at `generativelanguage.googleapis.com/v1beta/openai/` +- **Path A (chosen):** Standard Gemini API at `generativelanguage.googleapis.com/v1beta` - **NOT Path B:** Cloud Code Assist (`cloudcode-pa.googleapis.com`) — rate-limited free tier, internal API, account ban risk - Standard `chat_completions` api_mode via OpenAI SDK — no new api_mode needed - Our own OAuth credentials — NOT sharing tokens with Gemini CLI @@ -32,9 +32,9 @@ Add a first-class `gemini` provider that authenticates via Google OAuth, using t - File locking for concurrent access (multiple agent sessions) ## API Integration -- Base URL: `https://generativelanguage.googleapis.com/v1beta/openai/` -- Auth: `Authorization: Bearer ` (passed as `api_key` to OpenAI SDK) -- api_mode: `chat_completions` (standard) +- Base URL: `https://generativelanguage.googleapis.com/v1beta` +- Auth: native Gemini API authentication handled by the provider adapter +- api_mode: `chat_completions` (standard facade over native transport) - Models: gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash, etc. ## Files to Create/Modify diff --git a/run_agent.py b/run_agent.py index a6831d586..85eaad1b3 100644 --- a/run_agent.py +++ b/run_agent.py @@ -4705,6 +4705,30 @@ class AIAgent: return bool(getattr(http_client, "is_closed", False)) return False + @staticmethod + def _build_keepalive_http_client() -> Any: + try: + import httpx as _httpx + import socket as _socket + + _sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)] + if hasattr(_socket, "TCP_KEEPIDLE"): + _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30)) + _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10)) + _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3)) + elif hasattr(_socket, "TCP_KEEPALIVE"): + _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30)) + # When a custom transport is provided, httpx won't auto-read proxy + # from env vars (allow_env_proxies = trust_env and transport is None). + # Explicitly read proxy settings to ensure HTTP_PROXY/HTTPS_PROXY work. + _proxy = _get_proxy_from_env() + return _httpx.Client( + transport=_httpx.HTTPTransport(socket_options=_sock_opts), + proxy=_proxy, + ) + except Exception: + return None + def _create_openai_client(self, client_kwargs: dict, *, reason: str, shared: bool) -> Any: from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls # Treat client_kwargs as read-only. Callers pass self._client_kwargs (or shallow @@ -4746,20 +4770,26 @@ class AIAgent: ) return client if self.provider == "gemini": - from agent.gemini_native_adapter import GeminiNativeClient + from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url - safe_kwargs = { - k: v for k, v in client_kwargs.items() - if k in {"api_key", "base_url", "default_headers", "timeout"} - } - client = GeminiNativeClient(**safe_kwargs) - logger.info( - "Gemini native client created (%s, shared=%s) %s", - reason, - shared, - self._client_log_context(), - ) - return client + base_url = str(client_kwargs.get("base_url", "") or "") + if is_native_gemini_base_url(base_url): + safe_kwargs = { + k: v for k, v in client_kwargs.items() + if k in {"api_key", "base_url", "default_headers", "timeout", "http_client"} + } + if "http_client" not in safe_kwargs: + keepalive_http = self._build_keepalive_http_client() + if keepalive_http is not None: + safe_kwargs["http_client"] = keepalive_http + client = GeminiNativeClient(**safe_kwargs) + logger.info( + "Gemini native client created (%s, shared=%s) %s", + reason, + shared, + self._client_log_context(), + ) + return client # Inject TCP keepalives so the kernel detects dead provider connections # instead of letting them sit silently in CLOSE-WAIT (#10324). Without # this, a peer that drops mid-stream leaves the socket in a state where @@ -4778,28 +4808,9 @@ class AIAgent: # Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and # ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant. if "http_client" not in client_kwargs: - try: - import httpx as _httpx - import socket as _socket - _sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)] - if hasattr(_socket, "TCP_KEEPIDLE"): - # Linux - _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30)) - _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10)) - _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3)) - elif hasattr(_socket, "TCP_KEEPALIVE"): - # macOS (uses TCP_KEEPALIVE instead of TCP_KEEPIDLE) - _sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30)) - # When a custom transport is provided, httpx won't auto-read proxy - # from env vars (allow_env_proxies = trust_env and transport is None). - # Explicitly read proxy settings to ensure HTTP_PROXY/HTTPS_PROXY work. - _proxy = _get_proxy_from_env() - client_kwargs["http_client"] = _httpx.Client( - transport=_httpx.HTTPTransport(socket_options=_sock_opts), - proxy=_proxy, - ) - except Exception: - pass # Fall through to default transport if socket opts fail + keepalive_http = self._build_keepalive_http_client() + if keepalive_http is not None: + client_kwargs["http_client"] = keepalive_http client = OpenAI(**client_kwargs) logger.info( "OpenAI client created (%s, shared=%s) %s", diff --git a/tests/agent/test_gemini_native_adapter.py b/tests/agent/test_gemini_native_adapter.py index daa825522..0141c7410 100644 --- a/tests/agent/test_gemini_native_adapter.py +++ b/tests/agent/test_gemini_native_adapter.py @@ -186,6 +186,43 @@ def test_native_http_error_keeps_status_and_retry_after(): assert "quota exhausted" in str(err) +def test_native_client_accepts_injected_http_client(): + from agent.gemini_native_adapter import GeminiNativeClient + + injected = SimpleNamespace(close=lambda: None) + client = GeminiNativeClient(api_key="AIza-test", http_client=injected) + assert client._http is injected + + +@pytest.mark.asyncio +async def test_async_native_client_streams_without_requiring_async_iterator_from_sync_client(): + from agent.gemini_native_adapter import AsyncGeminiNativeClient + + chunk = SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="hi"), finish_reason=None)]) + sync_stream = iter([chunk]) + + def _advance(iterator): + try: + return False, next(iterator) + except StopIteration: + return True, None + + sync_client = SimpleNamespace( + api_key="AIza-test", + base_url="https://generativelanguage.googleapis.com/v1beta", + chat=SimpleNamespace(completions=SimpleNamespace(create=lambda **kwargs: sync_stream)), + _advance_stream_iterator=_advance, + close=lambda: None, + ) + + async_client = AsyncGeminiNativeClient(sync_client) + stream = await async_client.chat.completions.create(stream=True) + collected = [] + async for item in stream: + collected.append(item) + assert collected == [chunk] + + def test_stream_event_translation_emits_tool_call_delta_with_stable_index(): from agent.gemini_native_adapter import translate_stream_event @@ -209,4 +246,30 @@ def test_stream_event_translation_emits_tool_call_delta_with_stable_index(): assert first[0].choices[0].delta.tool_calls[0].index == 0 assert second[0].choices[0].delta.tool_calls[0].index == 0 assert first[0].choices[0].delta.tool_calls[0].id == second[0].choices[0].delta.tool_calls[0].id + assert first[0].choices[0].delta.tool_calls[0].function.arguments == '{"q": "abc"}' + assert second[0].choices[0].delta.tool_calls[0].function.arguments == "" assert first[-1].choices[0].finish_reason == "tool_calls" + + +def test_stream_event_translation_keeps_identical_calls_in_distinct_parts(): + from agent.gemini_native_adapter import translate_stream_event + + event = { + "candidates": [ + { + "content": { + "parts": [ + {"functionCall": {"name": "search", "args": {"q": "abc"}}}, + {"functionCall": {"name": "search", "args": {"q": "abc"}}}, + ] + }, + "finishReason": "STOP", + } + ] + } + + chunks = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices={}) + tool_chunks = [chunk for chunk in chunks if chunk.choices[0].delta.tool_calls] + assert tool_chunks[0].choices[0].delta.tool_calls[0].index == 0 + assert tool_chunks[1].choices[0].delta.tool_calls[0].index == 1 + assert tool_chunks[0].choices[0].delta.tool_calls[0].id != tool_chunks[1].choices[0].delta.tool_calls[0].id diff --git a/tests/hermes_cli/test_config_validation.py b/tests/hermes_cli/test_config_validation.py index 39a3eca72..c18afc911 100644 --- a/tests/hermes_cli/test_config_validation.py +++ b/tests/hermes_cli/test_config_validation.py @@ -13,7 +13,7 @@ class TestCustomProvidersValidation: issues = validate_config_structure({ "custom_providers": { "name": "Generativelanguage.googleapis.com", - "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", + "base_url": "https://generativelanguage.googleapis.com/v1beta", "api_key": "xxx", "model": "models/gemini-2.5-flash", "rate_limit_delay": 2.0, diff --git a/tests/hermes_cli/test_gemini_provider.py b/tests/hermes_cli/test_gemini_provider.py index d0fcad7a4..7f9348be4 100644 --- a/tests/hermes_cli/test_gemini_provider.py +++ b/tests/hermes_cli/test_gemini_provider.py @@ -210,8 +210,10 @@ class TestGeminiAgentInit: def test_gemini_agent_uses_native_client(self, monkeypatch): monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY") with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \ - patch("run_agent.OpenAI") as mock_openai: + patch("run_agent.OpenAI") as mock_openai, \ + patch("run_agent.ContextCompressor") as mock_compressor: mock_client.return_value = MagicMock() + mock_compressor.return_value = MagicMock(context_length=1048576, threshold_tokens=524288) from run_agent import AIAgent AIAgent( model="gemini-2.5-flash", @@ -222,6 +224,38 @@ class TestGeminiAgentInit: assert mock_client.called mock_openai.assert_not_called() + def test_gemini_custom_base_url_keeps_openai_client(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY") + with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \ + patch("run_agent.OpenAI") as mock_openai, \ + patch("run_agent.ContextCompressor") as mock_compressor: + mock_openai.return_value = MagicMock() + mock_compressor.return_value = MagicMock(context_length=128000, threshold_tokens=64000) + from run_agent import AIAgent + AIAgent( + model="gemini-2.5-flash", + provider="gemini", + api_key="AIzaSy_REAL_KEY", + base_url="https://proxy.example.com/v1", + ) + mock_openai.assert_called_once() + + def test_gemini_openai_compat_base_url_keeps_openai_client(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY") + with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \ + patch("run_agent.OpenAI") as mock_openai, \ + patch("run_agent.ContextCompressor") as mock_compressor: + mock_openai.return_value = MagicMock() + mock_compressor.return_value = MagicMock(context_length=1048576, threshold_tokens=524288) + from run_agent import AIAgent + AIAgent( + model="gemini-2.5-flash", + provider="gemini", + api_key="AIzaSy_REAL_KEY", + base_url="https://generativelanguage.googleapis.com/v1beta/openai", + ) + mock_openai.assert_called_once() + def test_gemini_resolve_provider_client_uses_native_client(self, monkeypatch): """resolve_provider_client('gemini') should build GeminiNativeClient.""" monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy_TEST_KEY") @@ -233,6 +267,16 @@ class TestGeminiAgentInit: assert mock_client.called mock_openai.assert_not_called() + def test_gemini_resolve_provider_client_keeps_openai_for_non_native_base_url(self, monkeypatch): + monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_TEST_KEY") + monkeypatch.setenv("GEMINI_BASE_URL", "https://proxy.example.com/v1") + with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \ + patch("agent.auxiliary_client.OpenAI") as mock_openai: + mock_openai.return_value = MagicMock() + from agent.auxiliary_client import resolve_provider_client + resolve_provider_client("gemini") + mock_openai.assert_called_once() + # ── models.dev Integration ──