mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gemini): tighten native routing and streaming replay
- only use the native adapter for the canonical Gemini native endpoint - keep custom and /openai base URLs on the OpenAI-compatible path - preserve Hermes keepalive transport injection for native Gemini clients - stabilize streaming tool-call replay across repeated SSE events - add follow-up tests for base_url precedence, async streaming, and duplicate tool-call chunks
This commit is contained in:
parent
3dea497b20
commit
d393104bad
7 changed files with 225 additions and 56 deletions
|
|
@ -815,8 +815,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||||
continue # skip provider if we don't know a valid aux model
|
continue # skip provider if we don't know a valid aux model
|
||||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||||
if provider_id == "gemini":
|
if provider_id == "gemini":
|
||||||
from agent.gemini_native_adapter import GeminiNativeClient
|
from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url
|
||||||
|
|
||||||
|
if is_native_gemini_base_url(base_url):
|
||||||
return GeminiNativeClient(api_key=api_key, base_url=base_url), model
|
return GeminiNativeClient(api_key=api_key, base_url=base_url), model
|
||||||
extra = {}
|
extra = {}
|
||||||
if "api.kimi.com" in base_url.lower():
|
if "api.kimi.com" in base_url.lower():
|
||||||
|
|
@ -840,8 +841,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||||
continue # skip provider if we don't know a valid aux model
|
continue # skip provider if we don't know a valid aux model
|
||||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||||
if provider_id == "gemini":
|
if provider_id == "gemini":
|
||||||
from agent.gemini_native_adapter import GeminiNativeClient
|
from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url
|
||||||
|
|
||||||
|
if is_native_gemini_base_url(base_url):
|
||||||
return GeminiNativeClient(api_key=api_key, base_url=base_url), model
|
return GeminiNativeClient(api_key=api_key, base_url=base_url), model
|
||||||
extra = {}
|
extra = {}
|
||||||
if "api.kimi.com" in base_url.lower():
|
if "api.kimi.com" in base_url.lower():
|
||||||
|
|
@ -1703,8 +1705,9 @@ def resolve_provider_client(
|
||||||
final_model = _normalize_resolved_model(model or default_model, provider)
|
final_model = _normalize_resolved_model(model or default_model, provider)
|
||||||
|
|
||||||
if provider == "gemini":
|
if provider == "gemini":
|
||||||
from agent.gemini_native_adapter import GeminiNativeClient
|
from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url
|
||||||
|
|
||||||
|
if is_native_gemini_base_url(base_url):
|
||||||
client = GeminiNativeClient(api_key=api_key, base_url=base_url)
|
client = GeminiNativeClient(api_key=api_key, base_url=base_url)
|
||||||
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
logger.debug("resolve_provider_client: %s (%s)", provider, final_model)
|
||||||
return (_to_async_client(client, final_model) if async_mode
|
return (_to_async_client(client, final_model) if async_mode
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,16 @@ logger = logging.getLogger(__name__)
|
||||||
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
|
||||||
|
|
||||||
|
|
||||||
|
def is_native_gemini_base_url(base_url: str) -> bool:
|
||||||
|
"""Return True when the endpoint speaks Gemini's native REST API."""
|
||||||
|
normalized = str(base_url or "").strip().rstrip("/").lower()
|
||||||
|
if not normalized:
|
||||||
|
return False
|
||||||
|
if "generativelanguage.googleapis.com" not in normalized:
|
||||||
|
return False
|
||||||
|
return not normalized.endswith("/openai")
|
||||||
|
|
||||||
|
|
||||||
class GeminiAPIError(Exception):
|
class GeminiAPIError(Exception):
|
||||||
"""Error shape compatible with Hermes retry/error classification."""
|
"""Error shape compatible with Hermes retry/error classification."""
|
||||||
|
|
||||||
|
|
@ -520,7 +530,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
|
||||||
parts = ((cand.get("content") or {}).get("parts") or []) if isinstance(cand, dict) else []
|
parts = ((cand.get("content") or {}).get("parts") or []) if isinstance(cand, dict) else []
|
||||||
chunks: List[_GeminiStreamChunk] = []
|
chunks: List[_GeminiStreamChunk] = []
|
||||||
|
|
||||||
for part in parts:
|
for part_index, part in enumerate(parts):
|
||||||
if not isinstance(part, dict):
|
if not isinstance(part, dict):
|
||||||
continue
|
continue
|
||||||
if part.get("thought") is True and isinstance(part.get("text"), str):
|
if part.get("thought") is True and isinstance(part.get("text"), str):
|
||||||
|
|
@ -536,14 +546,30 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
args_str = "{}"
|
args_str = "{}"
|
||||||
thought_signature = part.get("thoughtSignature") if isinstance(part.get("thoughtSignature"), str) else ""
|
thought_signature = part.get("thoughtSignature") if isinstance(part.get("thoughtSignature"), str) else ""
|
||||||
call_key = json.dumps({"name": name, "args": args_str, "thought_signature": thought_signature}, sort_keys=True)
|
call_key = json.dumps(
|
||||||
|
{
|
||||||
|
"part_index": part_index,
|
||||||
|
"name": name,
|
||||||
|
"thought_signature": thought_signature,
|
||||||
|
},
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
slot = tool_call_indices.get(call_key)
|
slot = tool_call_indices.get(call_key)
|
||||||
if slot is None:
|
if slot is None:
|
||||||
slot = {
|
slot = {
|
||||||
"index": len(tool_call_indices),
|
"index": len(tool_call_indices),
|
||||||
"id": f"call_{uuid.uuid4().hex[:12]}",
|
"id": f"call_{uuid.uuid4().hex[:12]}",
|
||||||
|
"last_arguments": "",
|
||||||
}
|
}
|
||||||
tool_call_indices[call_key] = slot
|
tool_call_indices[call_key] = slot
|
||||||
|
emitted_arguments = args_str
|
||||||
|
last_arguments = str(slot.get("last_arguments") or "")
|
||||||
|
if last_arguments:
|
||||||
|
if args_str == last_arguments:
|
||||||
|
emitted_arguments = ""
|
||||||
|
elif args_str.startswith(last_arguments):
|
||||||
|
emitted_arguments = args_str[len(last_arguments):]
|
||||||
|
slot["last_arguments"] = args_str
|
||||||
chunks.append(
|
chunks.append(
|
||||||
_make_stream_chunk(
|
_make_stream_chunk(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -551,7 +577,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
|
||||||
"index": slot["index"],
|
"index": slot["index"],
|
||||||
"id": slot["id"],
|
"id": slot["id"],
|
||||||
"name": name,
|
"name": name,
|
||||||
"arguments": args_str,
|
"arguments": emitted_arguments,
|
||||||
"extra_content": _tool_call_extra_from_part(part),
|
"extra_content": _tool_call_extra_from_part(part),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
@ -672,6 +698,7 @@ class GeminiNativeClient:
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
default_headers: Optional[Dict[str, str]] = None,
|
default_headers: Optional[Dict[str, str]] = None,
|
||||||
timeout: Any = None,
|
timeout: Any = None,
|
||||||
|
http_client: Optional[httpx.Client] = None,
|
||||||
**_: Any,
|
**_: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
|
@ -682,7 +709,9 @@ class GeminiNativeClient:
|
||||||
self._default_headers = dict(default_headers or {})
|
self._default_headers = dict(default_headers or {})
|
||||||
self.chat = _GeminiChatNamespace(self)
|
self.chat = _GeminiChatNamespace(self)
|
||||||
self.is_closed = False
|
self.is_closed = False
|
||||||
self._http = httpx.Client(timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0))
|
self._http = http_client or httpx.Client(
|
||||||
|
timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0)
|
||||||
|
)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
self.is_closed = True
|
self.is_closed = True
|
||||||
|
|
@ -707,6 +736,13 @@ class GeminiNativeClient:
|
||||||
headers.update(self._default_headers)
|
headers.update(self._default_headers)
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _advance_stream_iterator(iterator: Iterator[_GeminiStreamChunk]) -> tuple[bool, Optional[_GeminiStreamChunk]]:
|
||||||
|
try:
|
||||||
|
return False, next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
return True, None
|
||||||
|
|
||||||
def _create_chat_completion(
|
def _create_chat_completion(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
|
|
@ -767,7 +803,7 @@ class GeminiNativeClient:
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
response.read()
|
response.read()
|
||||||
raise gemini_http_error(response)
|
raise gemini_http_error(response)
|
||||||
tool_call_indices: Dict[str, int] = {}
|
tool_call_indices: Dict[str, Dict[str, Any]] = {}
|
||||||
for event in _iter_sse_events(response):
|
for event in _iter_sse_events(response):
|
||||||
for chunk in translate_stream_event(event, model, tool_call_indices):
|
for chunk in translate_stream_event(event, model, tool_call_indices):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
@ -790,7 +826,19 @@ class AsyncGeminiNativeClient:
|
||||||
self.chat = _AsyncGeminiChatNamespace(self)
|
self.chat = _AsyncGeminiChatNamespace(self)
|
||||||
|
|
||||||
async def _create_chat_completion(self, **kwargs: Any) -> Any:
|
async def _create_chat_completion(self, **kwargs: Any) -> Any:
|
||||||
return await asyncio.to_thread(self._sync.chat.completions.create, **kwargs)
|
stream = bool(kwargs.get("stream"))
|
||||||
|
result = await asyncio.to_thread(self._sync.chat.completions.create, **kwargs)
|
||||||
|
if not stream:
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _async_stream() -> Any:
|
||||||
|
while True:
|
||||||
|
done, chunk = await asyncio.to_thread(self._sync._advance_stream_iterator, result)
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return _async_stream()
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
await asyncio.to_thread(self._sync.close)
|
await asyncio.to_thread(self._sync.close)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
Add a first-class `gemini` provider that authenticates via Google OAuth, using the standard Gemini API (not Cloud Code Assist). Users who have a Google AI subscription or Gemini API access can authenticate through the browser without needing to manually copy API keys.
|
Add a first-class `gemini` provider that authenticates via Google OAuth, using the standard Gemini API (not Cloud Code Assist). Users who have a Google AI subscription or Gemini API access can authenticate through the browser without needing to manually copy API keys.
|
||||||
|
|
||||||
## Architecture Decision
|
## Architecture Decision
|
||||||
- **Path A (chosen):** Standard Gemini API at `generativelanguage.googleapis.com/v1beta/openai/`
|
- **Path A (chosen):** Standard Gemini API at `generativelanguage.googleapis.com/v1beta`
|
||||||
- **NOT Path B:** Cloud Code Assist (`cloudcode-pa.googleapis.com`) — rate-limited free tier, internal API, account ban risk
|
- **NOT Path B:** Cloud Code Assist (`cloudcode-pa.googleapis.com`) — rate-limited free tier, internal API, account ban risk
|
||||||
- Standard `chat_completions` api_mode via OpenAI SDK — no new api_mode needed
|
- Standard `chat_completions` api_mode via OpenAI SDK — no new api_mode needed
|
||||||
- Our own OAuth credentials — NOT sharing tokens with Gemini CLI
|
- Our own OAuth credentials — NOT sharing tokens with Gemini CLI
|
||||||
|
|
@ -32,9 +32,9 @@ Add a first-class `gemini` provider that authenticates via Google OAuth, using t
|
||||||
- File locking for concurrent access (multiple agent sessions)
|
- File locking for concurrent access (multiple agent sessions)
|
||||||
|
|
||||||
## API Integration
|
## API Integration
|
||||||
- Base URL: `https://generativelanguage.googleapis.com/v1beta/openai/`
|
- Base URL: `https://generativelanguage.googleapis.com/v1beta`
|
||||||
- Auth: `Authorization: Bearer <access_token>` (passed as `api_key` to OpenAI SDK)
|
- Auth: native Gemini API authentication handled by the provider adapter
|
||||||
- api_mode: `chat_completions` (standard)
|
- api_mode: `chat_completions` (standard facade over native transport)
|
||||||
- Models: gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash, etc.
|
- Models: gemini-2.5-pro, gemini-2.5-flash, gemini-2.0-flash, etc.
|
||||||
|
|
||||||
## Files to Create/Modify
|
## Files to Create/Modify
|
||||||
|
|
|
||||||
59
run_agent.py
59
run_agent.py
|
|
@ -4705,6 +4705,30 @@ class AIAgent:
|
||||||
return bool(getattr(http_client, "is_closed", False))
|
return bool(getattr(http_client, "is_closed", False))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_keepalive_http_client() -> Any:
|
||||||
|
try:
|
||||||
|
import httpx as _httpx
|
||||||
|
import socket as _socket
|
||||||
|
|
||||||
|
_sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)]
|
||||||
|
if hasattr(_socket, "TCP_KEEPIDLE"):
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30))
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10))
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3))
|
||||||
|
elif hasattr(_socket, "TCP_KEEPALIVE"):
|
||||||
|
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30))
|
||||||
|
# When a custom transport is provided, httpx won't auto-read proxy
|
||||||
|
# from env vars (allow_env_proxies = trust_env and transport is None).
|
||||||
|
# Explicitly read proxy settings to ensure HTTP_PROXY/HTTPS_PROXY work.
|
||||||
|
_proxy = _get_proxy_from_env()
|
||||||
|
return _httpx.Client(
|
||||||
|
transport=_httpx.HTTPTransport(socket_options=_sock_opts),
|
||||||
|
proxy=_proxy,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def _create_openai_client(self, client_kwargs: dict, *, reason: str, shared: bool) -> Any:
|
def _create_openai_client(self, client_kwargs: dict, *, reason: str, shared: bool) -> Any:
|
||||||
from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls
|
from agent.auxiliary_client import _validate_base_url, _validate_proxy_env_urls
|
||||||
# Treat client_kwargs as read-only. Callers pass self._client_kwargs (or shallow
|
# Treat client_kwargs as read-only. Callers pass self._client_kwargs (or shallow
|
||||||
|
|
@ -4746,12 +4770,18 @@ class AIAgent:
|
||||||
)
|
)
|
||||||
return client
|
return client
|
||||||
if self.provider == "gemini":
|
if self.provider == "gemini":
|
||||||
from agent.gemini_native_adapter import GeminiNativeClient
|
from agent.gemini_native_adapter import GeminiNativeClient, is_native_gemini_base_url
|
||||||
|
|
||||||
|
base_url = str(client_kwargs.get("base_url", "") or "")
|
||||||
|
if is_native_gemini_base_url(base_url):
|
||||||
safe_kwargs = {
|
safe_kwargs = {
|
||||||
k: v for k, v in client_kwargs.items()
|
k: v for k, v in client_kwargs.items()
|
||||||
if k in {"api_key", "base_url", "default_headers", "timeout"}
|
if k in {"api_key", "base_url", "default_headers", "timeout", "http_client"}
|
||||||
}
|
}
|
||||||
|
if "http_client" not in safe_kwargs:
|
||||||
|
keepalive_http = self._build_keepalive_http_client()
|
||||||
|
if keepalive_http is not None:
|
||||||
|
safe_kwargs["http_client"] = keepalive_http
|
||||||
client = GeminiNativeClient(**safe_kwargs)
|
client = GeminiNativeClient(**safe_kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Gemini native client created (%s, shared=%s) %s",
|
"Gemini native client created (%s, shared=%s) %s",
|
||||||
|
|
@ -4778,28 +4808,9 @@ class AIAgent:
|
||||||
# Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and
|
# Tests in ``tests/run_agent/test_create_openai_client_reuse.py`` and
|
||||||
# ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant.
|
# ``tests/run_agent/test_sequential_chats_live.py`` pin this invariant.
|
||||||
if "http_client" not in client_kwargs:
|
if "http_client" not in client_kwargs:
|
||||||
try:
|
keepalive_http = self._build_keepalive_http_client()
|
||||||
import httpx as _httpx
|
if keepalive_http is not None:
|
||||||
import socket as _socket
|
client_kwargs["http_client"] = keepalive_http
|
||||||
_sock_opts = [(_socket.SOL_SOCKET, _socket.SO_KEEPALIVE, 1)]
|
|
||||||
if hasattr(_socket, "TCP_KEEPIDLE"):
|
|
||||||
# Linux
|
|
||||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPIDLE, 30))
|
|
||||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPINTVL, 10))
|
|
||||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPCNT, 3))
|
|
||||||
elif hasattr(_socket, "TCP_KEEPALIVE"):
|
|
||||||
# macOS (uses TCP_KEEPALIVE instead of TCP_KEEPIDLE)
|
|
||||||
_sock_opts.append((_socket.IPPROTO_TCP, _socket.TCP_KEEPALIVE, 30))
|
|
||||||
# When a custom transport is provided, httpx won't auto-read proxy
|
|
||||||
# from env vars (allow_env_proxies = trust_env and transport is None).
|
|
||||||
# Explicitly read proxy settings to ensure HTTP_PROXY/HTTPS_PROXY work.
|
|
||||||
_proxy = _get_proxy_from_env()
|
|
||||||
client_kwargs["http_client"] = _httpx.Client(
|
|
||||||
transport=_httpx.HTTPTransport(socket_options=_sock_opts),
|
|
||||||
proxy=_proxy,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Fall through to default transport if socket opts fail
|
|
||||||
client = OpenAI(**client_kwargs)
|
client = OpenAI(**client_kwargs)
|
||||||
logger.info(
|
logger.info(
|
||||||
"OpenAI client created (%s, shared=%s) %s",
|
"OpenAI client created (%s, shared=%s) %s",
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,43 @@ def test_native_http_error_keeps_status_and_retry_after():
|
||||||
assert "quota exhausted" in str(err)
|
assert "quota exhausted" in str(err)
|
||||||
|
|
||||||
|
|
||||||
|
def test_native_client_accepts_injected_http_client():
|
||||||
|
from agent.gemini_native_adapter import GeminiNativeClient
|
||||||
|
|
||||||
|
injected = SimpleNamespace(close=lambda: None)
|
||||||
|
client = GeminiNativeClient(api_key="AIza-test", http_client=injected)
|
||||||
|
assert client._http is injected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_native_client_streams_without_requiring_async_iterator_from_sync_client():
|
||||||
|
from agent.gemini_native_adapter import AsyncGeminiNativeClient
|
||||||
|
|
||||||
|
chunk = SimpleNamespace(choices=[SimpleNamespace(delta=SimpleNamespace(content="hi"), finish_reason=None)])
|
||||||
|
sync_stream = iter([chunk])
|
||||||
|
|
||||||
|
def _advance(iterator):
|
||||||
|
try:
|
||||||
|
return False, next(iterator)
|
||||||
|
except StopIteration:
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
sync_client = SimpleNamespace(
|
||||||
|
api_key="AIza-test",
|
||||||
|
base_url="https://generativelanguage.googleapis.com/v1beta",
|
||||||
|
chat=SimpleNamespace(completions=SimpleNamespace(create=lambda **kwargs: sync_stream)),
|
||||||
|
_advance_stream_iterator=_advance,
|
||||||
|
close=lambda: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_client = AsyncGeminiNativeClient(sync_client)
|
||||||
|
stream = await async_client.chat.completions.create(stream=True)
|
||||||
|
collected = []
|
||||||
|
async for item in stream:
|
||||||
|
collected.append(item)
|
||||||
|
assert collected == [chunk]
|
||||||
|
|
||||||
|
|
||||||
def test_stream_event_translation_emits_tool_call_delta_with_stable_index():
|
def test_stream_event_translation_emits_tool_call_delta_with_stable_index():
|
||||||
from agent.gemini_native_adapter import translate_stream_event
|
from agent.gemini_native_adapter import translate_stream_event
|
||||||
|
|
||||||
|
|
@ -209,4 +246,30 @@ def test_stream_event_translation_emits_tool_call_delta_with_stable_index():
|
||||||
assert first[0].choices[0].delta.tool_calls[0].index == 0
|
assert first[0].choices[0].delta.tool_calls[0].index == 0
|
||||||
assert second[0].choices[0].delta.tool_calls[0].index == 0
|
assert second[0].choices[0].delta.tool_calls[0].index == 0
|
||||||
assert first[0].choices[0].delta.tool_calls[0].id == second[0].choices[0].delta.tool_calls[0].id
|
assert first[0].choices[0].delta.tool_calls[0].id == second[0].choices[0].delta.tool_calls[0].id
|
||||||
|
assert first[0].choices[0].delta.tool_calls[0].function.arguments == '{"q": "abc"}'
|
||||||
|
assert second[0].choices[0].delta.tool_calls[0].function.arguments == ""
|
||||||
assert first[-1].choices[0].finish_reason == "tool_calls"
|
assert first[-1].choices[0].finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_event_translation_keeps_identical_calls_in_distinct_parts():
|
||||||
|
from agent.gemini_native_adapter import translate_stream_event
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
|
||||||
|
{"functionCall": {"name": "search", "args": {"q": "abc"}}},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks = translate_stream_event(event, model="gemini-2.5-flash", tool_call_indices={})
|
||||||
|
tool_chunks = [chunk for chunk in chunks if chunk.choices[0].delta.tool_calls]
|
||||||
|
assert tool_chunks[0].choices[0].delta.tool_calls[0].index == 0
|
||||||
|
assert tool_chunks[1].choices[0].delta.tool_calls[0].index == 1
|
||||||
|
assert tool_chunks[0].choices[0].delta.tool_calls[0].id != tool_chunks[1].choices[0].delta.tool_calls[0].id
|
||||||
|
|
|
||||||
|
|
@ -13,7 +13,7 @@ class TestCustomProvidersValidation:
|
||||||
issues = validate_config_structure({
|
issues = validate_config_structure({
|
||||||
"custom_providers": {
|
"custom_providers": {
|
||||||
"name": "Generativelanguage.googleapis.com",
|
"name": "Generativelanguage.googleapis.com",
|
||||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
"base_url": "https://generativelanguage.googleapis.com/v1beta",
|
||||||
"api_key": "xxx",
|
"api_key": "xxx",
|
||||||
"model": "models/gemini-2.5-flash",
|
"model": "models/gemini-2.5-flash",
|
||||||
"rate_limit_delay": 2.0,
|
"rate_limit_delay": 2.0,
|
||||||
|
|
|
||||||
|
|
@ -210,8 +210,10 @@ class TestGeminiAgentInit:
|
||||||
def test_gemini_agent_uses_native_client(self, monkeypatch):
|
def test_gemini_agent_uses_native_client(self, monkeypatch):
|
||||||
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
|
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
|
||||||
with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \
|
with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \
|
||||||
patch("run_agent.OpenAI") as mock_openai:
|
patch("run_agent.OpenAI") as mock_openai, \
|
||||||
|
patch("run_agent.ContextCompressor") as mock_compressor:
|
||||||
mock_client.return_value = MagicMock()
|
mock_client.return_value = MagicMock()
|
||||||
|
mock_compressor.return_value = MagicMock(context_length=1048576, threshold_tokens=524288)
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
AIAgent(
|
AIAgent(
|
||||||
model="gemini-2.5-flash",
|
model="gemini-2.5-flash",
|
||||||
|
|
@ -222,6 +224,38 @@ class TestGeminiAgentInit:
|
||||||
assert mock_client.called
|
assert mock_client.called
|
||||||
mock_openai.assert_not_called()
|
mock_openai.assert_not_called()
|
||||||
|
|
||||||
|
def test_gemini_custom_base_url_keeps_openai_client(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
|
||||||
|
with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \
|
||||||
|
patch("run_agent.OpenAI") as mock_openai, \
|
||||||
|
patch("run_agent.ContextCompressor") as mock_compressor:
|
||||||
|
mock_openai.return_value = MagicMock()
|
||||||
|
mock_compressor.return_value = MagicMock(context_length=128000, threshold_tokens=64000)
|
||||||
|
from run_agent import AIAgent
|
||||||
|
AIAgent(
|
||||||
|
model="gemini-2.5-flash",
|
||||||
|
provider="gemini",
|
||||||
|
api_key="AIzaSy_REAL_KEY",
|
||||||
|
base_url="https://proxy.example.com/v1",
|
||||||
|
)
|
||||||
|
mock_openai.assert_called_once()
|
||||||
|
|
||||||
|
def test_gemini_openai_compat_base_url_keeps_openai_client(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_REAL_KEY")
|
||||||
|
with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \
|
||||||
|
patch("run_agent.OpenAI") as mock_openai, \
|
||||||
|
patch("run_agent.ContextCompressor") as mock_compressor:
|
||||||
|
mock_openai.return_value = MagicMock()
|
||||||
|
mock_compressor.return_value = MagicMock(context_length=1048576, threshold_tokens=524288)
|
||||||
|
from run_agent import AIAgent
|
||||||
|
AIAgent(
|
||||||
|
model="gemini-2.5-flash",
|
||||||
|
provider="gemini",
|
||||||
|
api_key="AIzaSy_REAL_KEY",
|
||||||
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||||
|
)
|
||||||
|
mock_openai.assert_called_once()
|
||||||
|
|
||||||
def test_gemini_resolve_provider_client_uses_native_client(self, monkeypatch):
|
def test_gemini_resolve_provider_client_uses_native_client(self, monkeypatch):
|
||||||
"""resolve_provider_client('gemini') should build GeminiNativeClient."""
|
"""resolve_provider_client('gemini') should build GeminiNativeClient."""
|
||||||
monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy_TEST_KEY")
|
monkeypatch.setenv("GEMINI_API_KEY", "AIzaSy_TEST_KEY")
|
||||||
|
|
@ -233,6 +267,16 @@ class TestGeminiAgentInit:
|
||||||
assert mock_client.called
|
assert mock_client.called
|
||||||
mock_openai.assert_not_called()
|
mock_openai.assert_not_called()
|
||||||
|
|
||||||
|
def test_gemini_resolve_provider_client_keeps_openai_for_non_native_base_url(self, monkeypatch):
|
||||||
|
monkeypatch.setenv("GOOGLE_API_KEY", "AIzaSy_TEST_KEY")
|
||||||
|
monkeypatch.setenv("GEMINI_BASE_URL", "https://proxy.example.com/v1")
|
||||||
|
with patch("agent.gemini_native_adapter.GeminiNativeClient") as mock_client, \
|
||||||
|
patch("agent.auxiliary_client.OpenAI") as mock_openai:
|
||||||
|
mock_openai.return_value = MagicMock()
|
||||||
|
from agent.auxiliary_client import resolve_provider_client
|
||||||
|
resolve_provider_client("gemini")
|
||||||
|
mock_openai.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
# ── models.dev Integration ──
|
# ── models.dev Integration ──
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue