mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gemini): assign unique stream indices to parallel tool calls
The streaming translator in agent/gemini_cloudcode_adapter.py keyed OpenAI tool-call indices by function name, so when the model emitted multiple parallel functionCall parts with the same name in a single turn (e.g. three read_file calls in one response), they all collapsed onto index 0. Downstream aggregators that key chunks by index would overwrite or drop all but the first call. Replace the name-keyed dict with a per-stream counter that persists across SSE events. Each functionCall part now gets a fresh, unique index, matching the non-streaming path which already uses enumerate(parts). Add TestTranslateStreamEvent covering parallel-same-name calls, index persistence across events, and finish-reason promotion to tool_calls.
This commit is contained in:
parent
d990fa52ed
commit
49282b6e04
2 changed files with 77 additions and 6 deletions
|
|
@ -505,9 +505,16 @@ def _iter_sse_events(response: httpx.Response) -> Iterator[Dict[str, Any]]:
|
||||||
def _translate_stream_event(
|
def _translate_stream_event(
|
||||||
event: Dict[str, Any],
|
event: Dict[str, Any],
|
||||||
model: str,
|
model: str,
|
||||||
tool_call_indices: Dict[str, int],
|
tool_call_counter: List[int],
|
||||||
) -> List[_GeminiStreamChunk]:
|
) -> List[_GeminiStreamChunk]:
|
||||||
"""Unwrap Code Assist envelope and emit OpenAI-shaped chunk(s)."""
|
"""Unwrap Code Assist envelope and emit OpenAI-shaped chunk(s).
|
||||||
|
|
||||||
|
``tool_call_counter`` is a single-element list used as a mutable counter
|
||||||
|
across events in the same stream. Each ``functionCall`` part gets a
|
||||||
|
fresh, unique OpenAI ``index`` — keying by function name would collide
|
||||||
|
whenever the model issues parallel calls to the same tool (e.g. reading
|
||||||
|
three files in one turn).
|
||||||
|
"""
|
||||||
inner = event.get("response") if isinstance(event.get("response"), dict) else event
|
inner = event.get("response") if isinstance(event.get("response"), dict) else event
|
||||||
candidates = inner.get("candidates") or []
|
candidates = inner.get("candidates") or []
|
||||||
if not candidates:
|
if not candidates:
|
||||||
|
|
@ -533,7 +540,8 @@ def _translate_stream_event(
|
||||||
fc = part.get("functionCall")
|
fc = part.get("functionCall")
|
||||||
if isinstance(fc, dict) and fc.get("name"):
|
if isinstance(fc, dict) and fc.get("name"):
|
||||||
name = str(fc["name"])
|
name = str(fc["name"])
|
||||||
idx = tool_call_indices.setdefault(name, len(tool_call_indices))
|
idx = tool_call_counter[0]
|
||||||
|
tool_call_counter[0] += 1
|
||||||
try:
|
try:
|
||||||
args_str = json.dumps(fc.get("args") or {}, ensure_ascii=False)
|
args_str = json.dumps(fc.get("args") or {}, ensure_ascii=False)
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
|
|
@ -550,7 +558,7 @@ def _translate_stream_event(
|
||||||
finish_reason_raw = str(cand.get("finishReason") or "")
|
finish_reason_raw = str(cand.get("finishReason") or "")
|
||||||
if finish_reason_raw:
|
if finish_reason_raw:
|
||||||
mapped = _map_gemini_finish_reason(finish_reason_raw)
|
mapped = _map_gemini_finish_reason(finish_reason_raw)
|
||||||
if tool_call_indices:
|
if tool_call_counter[0] > 0:
|
||||||
mapped = "tool_calls"
|
mapped = "tool_calls"
|
||||||
chunks.append(_make_stream_chunk(model=model, finish_reason=mapped))
|
chunks.append(_make_stream_chunk(model=model, finish_reason=mapped))
|
||||||
return chunks
|
return chunks
|
||||||
|
|
@ -734,9 +742,9 @@ class GeminiCloudCodeClient:
|
||||||
# Materialize error body for better diagnostics
|
# Materialize error body for better diagnostics
|
||||||
response.read()
|
response.read()
|
||||||
raise _gemini_http_error(response)
|
raise _gemini_http_error(response)
|
||||||
tool_call_indices: Dict[str, int] = {}
|
tool_call_counter: List[int] = [0]
|
||||||
for event in _iter_sse_events(response):
|
for event in _iter_sse_events(response):
|
||||||
for chunk in _translate_stream_event(event, model, tool_call_indices):
|
for chunk in _translate_stream_event(event, model, tool_call_counter):
|
||||||
yield chunk
|
yield chunk
|
||||||
except httpx.HTTPError as exc:
|
except httpx.HTTPError as exc:
|
||||||
raise CodeAssistError(
|
raise CodeAssistError(
|
||||||
|
|
|
||||||
|
|
@ -850,6 +850,69 @@ class TestTranslateGeminiResponse:
|
||||||
assert _map_gemini_finish_reason("RECITATION") == "content_filter"
|
assert _map_gemini_finish_reason("RECITATION") == "content_filter"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranslateStreamEvent:
|
||||||
|
def test_parallel_calls_to_same_tool_get_unique_indices(self):
|
||||||
|
"""Gemini may emit several functionCall parts with the same name in a
|
||||||
|
single turn (e.g. parallel file reads). Each must get its own OpenAI
|
||||||
|
``index`` — otherwise downstream aggregators collapse them into one.
|
||||||
|
"""
|
||||||
|
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||||
|
|
||||||
|
event = {
|
||||||
|
"response": {
|
||||||
|
"candidates": [{
|
||||||
|
"content": {"parts": [
|
||||||
|
{"functionCall": {"name": "read_file", "args": {"path": "a"}}},
|
||||||
|
{"functionCall": {"name": "read_file", "args": {"path": "b"}}},
|
||||||
|
{"functionCall": {"name": "read_file", "args": {"path": "c"}}},
|
||||||
|
]},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
counter = [0]
|
||||||
|
chunks = _translate_stream_event(event, model="gemini-2.5-flash",
|
||||||
|
tool_call_counter=counter)
|
||||||
|
indices = [c.choices[0].delta.tool_calls[0].index for c in chunks]
|
||||||
|
assert indices == [0, 1, 2]
|
||||||
|
assert counter[0] == 3
|
||||||
|
|
||||||
|
def test_counter_persists_across_events(self):
|
||||||
|
"""Index assignment must continue across SSE events in the same stream."""
|
||||||
|
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||||
|
|
||||||
|
def _event(name):
|
||||||
|
return {"response": {"candidates": [{
|
||||||
|
"content": {"parts": [{"functionCall": {"name": name, "args": {}}}]},
|
||||||
|
}]}}
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
chunks_a = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
|
||||||
|
chunks_b = _translate_stream_event(_event("bar"), model="m", tool_call_counter=counter)
|
||||||
|
chunks_c = _translate_stream_event(_event("foo"), model="m", tool_call_counter=counter)
|
||||||
|
|
||||||
|
assert chunks_a[0].choices[0].delta.tool_calls[0].index == 0
|
||||||
|
assert chunks_b[0].choices[0].delta.tool_calls[0].index == 1
|
||||||
|
assert chunks_c[0].choices[0].delta.tool_calls[0].index == 2
|
||||||
|
|
||||||
|
def test_finish_reason_switches_to_tool_calls_when_any_seen(self):
|
||||||
|
from agent.gemini_cloudcode_adapter import _translate_stream_event
|
||||||
|
|
||||||
|
counter = [0]
|
||||||
|
# First event emits one tool call.
|
||||||
|
_translate_stream_event(
|
||||||
|
{"response": {"candidates": [{
|
||||||
|
"content": {"parts": [{"functionCall": {"name": "x", "args": {}}}]},
|
||||||
|
}]}},
|
||||||
|
model="m", tool_call_counter=counter,
|
||||||
|
)
|
||||||
|
# Second event carries only the terminal finishReason.
|
||||||
|
chunks = _translate_stream_event(
|
||||||
|
{"response": {"candidates": [{"finishReason": "STOP"}]}},
|
||||||
|
model="m", tool_call_counter=counter,
|
||||||
|
)
|
||||||
|
assert chunks[-1].choices[0].finish_reason == "tool_calls"
|
||||||
|
|
||||||
|
|
||||||
class TestGeminiCloudCodeClient:
|
class TestGeminiCloudCodeClient:
|
||||||
def test_client_exposes_openai_interface(self):
|
def test_client_exposes_openai_interface(self):
|
||||||
from agent.gemini_cloudcode_adapter import GeminiCloudCodeClient
|
from agent.gemini_cloudcode_adapter import GeminiCloudCodeClient
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue