fix(gemini): tighten native routing and streaming replay

- only use the native adapter for the canonical Gemini native endpoint
- keep custom and /openai base URLs on the OpenAI-compatible path
- preserve Hermes keepalive transport injection for native Gemini clients
- stabilize streaming tool-call replay across repeated SSE events
- add follow-up tests for base_url precedence, async streaming, and duplicate tool-call chunks
This commit is contained in:
kshitijk4poor 2026-04-20 00:41:20 +05:30 committed by Teknium
parent 3dea497b20
commit d393104bad
7 changed files with 225 additions and 56 deletions

View file

@ -32,6 +32,16 @@ logger = logging.getLogger(__name__)
DEFAULT_GEMINI_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
def is_native_gemini_base_url(base_url: str) -> bool:
"""Return True when the endpoint speaks Gemini's native REST API."""
normalized = str(base_url or "").strip().rstrip("/").lower()
if not normalized:
return False
if "generativelanguage.googleapis.com" not in normalized:
return False
return not normalized.endswith("/openai")
class GeminiAPIError(Exception):
"""Error shape compatible with Hermes retry/error classification."""
@ -520,7 +530,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
parts = ((cand.get("content") or {}).get("parts") or []) if isinstance(cand, dict) else []
chunks: List[_GeminiStreamChunk] = []
for part in parts:
for part_index, part in enumerate(parts):
if not isinstance(part, dict):
continue
if part.get("thought") is True and isinstance(part.get("text"), str):
@ -536,14 +546,30 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
except (TypeError, ValueError):
args_str = "{}"
thought_signature = part.get("thoughtSignature") if isinstance(part.get("thoughtSignature"), str) else ""
call_key = json.dumps({"name": name, "args": args_str, "thought_signature": thought_signature}, sort_keys=True)
call_key = json.dumps(
{
"part_index": part_index,
"name": name,
"thought_signature": thought_signature,
},
sort_keys=True,
)
slot = tool_call_indices.get(call_key)
if slot is None:
slot = {
"index": len(tool_call_indices),
"id": f"call_{uuid.uuid4().hex[:12]}",
"last_arguments": "",
}
tool_call_indices[call_key] = slot
emitted_arguments = args_str
last_arguments = str(slot.get("last_arguments") or "")
if last_arguments:
if args_str == last_arguments:
emitted_arguments = ""
elif args_str.startswith(last_arguments):
emitted_arguments = args_str[len(last_arguments):]
slot["last_arguments"] = args_str
chunks.append(
_make_stream_chunk(
model=model,
@ -551,7 +577,7 @@ def translate_stream_event(event: Dict[str, Any], model: str, tool_call_indices:
"index": slot["index"],
"id": slot["id"],
"name": name,
"arguments": args_str,
"arguments": emitted_arguments,
"extra_content": _tool_call_extra_from_part(part),
},
)
@ -672,6 +698,7 @@ class GeminiNativeClient:
base_url: Optional[str] = None,
default_headers: Optional[Dict[str, str]] = None,
timeout: Any = None,
http_client: Optional[httpx.Client] = None,
**_: Any,
) -> None:
self.api_key = api_key
@ -682,7 +709,9 @@ class GeminiNativeClient:
self._default_headers = dict(default_headers or {})
self.chat = _GeminiChatNamespace(self)
self.is_closed = False
self._http = httpx.Client(timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0))
self._http = http_client or httpx.Client(
timeout=timeout or httpx.Timeout(connect=15.0, read=600.0, write=30.0, pool=30.0)
)
def close(self) -> None:
self.is_closed = True
@ -707,6 +736,13 @@ class GeminiNativeClient:
headers.update(self._default_headers)
return headers
@staticmethod
def _advance_stream_iterator(iterator: Iterator[_GeminiStreamChunk]) -> tuple[bool, Optional[_GeminiStreamChunk]]:
try:
return False, next(iterator)
except StopIteration:
return True, None
def _create_chat_completion(
self,
*,
@ -767,7 +803,7 @@ class GeminiNativeClient:
if response.status_code != 200:
response.read()
raise gemini_http_error(response)
tool_call_indices: Dict[str, int] = {}
tool_call_indices: Dict[str, Dict[str, Any]] = {}
for event in _iter_sse_events(response):
for chunk in translate_stream_event(event, model, tool_call_indices):
yield chunk
@ -790,7 +826,19 @@ class AsyncGeminiNativeClient:
self.chat = _AsyncGeminiChatNamespace(self)
async def _create_chat_completion(self, **kwargs: Any) -> Any:
return await asyncio.to_thread(self._sync.chat.completions.create, **kwargs)
stream = bool(kwargs.get("stream"))
result = await asyncio.to_thread(self._sync.chat.completions.create, **kwargs)
if not stream:
return result
async def _async_stream() -> Any:
while True:
done, chunk = await asyncio.to_thread(self._sync._advance_stream_iterator, result)
if done:
break
yield chunk
return _async_stream()
async def close(self) -> None:
await asyncio.to_thread(self._sync.close)