diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 1b2e2e156b..8c46cc6157 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -981,39 +981,62 @@ class APIServerAdapter(BasePlatformAdapter): if delta is not None: _stream_q.put(delta) - def _on_tool_progress(event_type, name, preview, args, **kwargs): - """Send tool progress as a separate SSE event. + # Track which tool_call_ids we've emitted a "running" lifecycle + # event for, so a "completed" event without a matching "running" + # (e.g. internal/filtered tools) is silently dropped instead of + # producing an orphaned event clients can't correlate. + _started_tool_call_ids: set[str] = set() - Previously, progress markers like ``⏰ list`` were injected - directly into ``delta.content``. OpenAI-compatible frontends - (Open WebUI, LobeChat, …) store ``delta.content`` verbatim as - the assistant message and send it back on subsequent requests. - After enough turns the model learns to *emit* the markers as - plain text instead of issuing real tool calls — silently - hallucinating tool results. See #6972. + def _on_tool_start(tool_call_id, function_name, function_args): + """Emit ``hermes.tool.progress`` with ``status: running``. - The fix: push a tagged tuple ``("__tool_progress__", payload)`` - onto the stream queue. The SSE writer emits it as a custom - ``event: hermes.tool.progress`` line that compliant frontends - can render for UX but will *not* persist into conversation - history. Clients that don't understand the custom event type - silently ignore it per the SSE specification. + Replaces the old ``tool_progress_callback("tool.started", + ...)`` emit so SSE consumers receive a single event per + tool start, carrying both the legacy ``tool``/``emoji``/ + ``label`` payload (for #6972 frontends) and the new + ``toolCallId``/``status`` correlation fields (#16588). + + Skips tools whose names start with ``_`` so internal + events (``_thinking``, …) stay off the wire — matching + the prior ``_on_tool_progress`` filter exactly. """ - if event_type != "tool.started": + if not tool_call_id or function_name.startswith("_"): return - if name.startswith("_"): - return - from agent.display import get_tool_emoji - emoji = get_tool_emoji(name) - label = preview or name + _started_tool_call_ids.add(tool_call_id) + from agent.display import build_tool_preview, get_tool_emoji + label = build_tool_preview(function_name, function_args) or function_name _stream_q.put(("__tool_progress__", { - "tool": name, - "emoji": emoji, + "tool": function_name, + "emoji": get_tool_emoji(function_name), "label": label, + "toolCallId": tool_call_id, + "status": "running", + })) + + def _on_tool_complete(tool_call_id, function_name, function_args, function_result): + """Emit the matching ``status: completed`` event. + + Dropped if the start was filtered (internal tool, missing + id, or never seen) so clients never get an orphaned + ``completed`` they can't correlate to a prior ``running``. + """ + if not tool_call_id or tool_call_id not in _started_tool_call_ids: + return + _started_tool_call_ids.discard(tool_call_id) + _stream_q.put(("__tool_progress__", { + "tool": function_name, + "toolCallId": tool_call_id, + "status": "completed", })) # Start agent in background. agent_ref is a mutable container # so the SSE writer can interrupt the agent on client disconnect. + # + # ``tool_progress_callback`` is intentionally not wired here: + # it would duplicate every emit because ``run_agent`` fires it + # side-by-side with ``tool_start_callback``/``tool_complete_callback``. + # The structured callbacks are strictly richer (they carry the + # tool_call id), so they own the chat-completions SSE channel. agent_ref = [None] agent_task = asyncio.ensure_future(self._run_agent( user_message=user_message, @@ -1021,7 +1044,8 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt=system_prompt, session_id=session_id, stream_delta_callback=_on_delta, - tool_progress_callback=_on_tool_progress, + tool_start_callback=_on_tool_start, + tool_complete_callback=_on_tool_complete, agent_ref=agent_ref, )) @@ -1136,7 +1160,8 @@ class APIServerAdapter(BasePlatformAdapter): Tagged tuples ``("__tool_progress__", payload)`` are sent as a custom ``event: hermes.tool.progress`` SSE event so frontends can display them without storing the markers in - conversation history. See #6972. + conversation history. See #6972 for the original event, + #16588 for the ``toolCallId``/``status`` lifecycle fields. """ if isinstance(item, tuple) and len(item) == 2 and item[0] == "__tool_progress__": event_data = json.dumps(item[1]) diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 75386097c8..2ebb48bcf4 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -688,17 +688,17 @@ class TestChatCompletionsEndpoint: @pytest.mark.asyncio async def test_stream_includes_tool_progress(self, adapter): - """tool_progress_callback fires → progress appears as custom SSE event, not in delta.content.""" + """tool_start_callback fires → progress appears as custom SSE event, not in delta.content.""" import asyncio app = _create_app(adapter) async with TestClient(TestServer(app)) as cli: async def _mock_run_agent(**kwargs): cb = kwargs.get("stream_delta_callback") - tp_cb = kwargs.get("tool_progress_callback") - # Simulate tool progress before streaming content - if tp_cb: - tp_cb("tool.started", "terminal", "ls -la", {"command": "ls -la"}) + ts_cb = kwargs.get("tool_start_callback") + # Simulate the structured tool start the gateway now consumes. + if ts_cb: + ts_cb("call_terminal_1", "terminal", {"command": "ls -la"}) if cb: await asyncio.sleep(0.05) cb("Here are the files.") @@ -724,7 +724,10 @@ class TestChatCompletionsEndpoint: # markers instead of calling tools (#6972). assert "event: hermes.tool.progress" in body assert '"tool": "terminal"' in body - assert '"label": "ls -la"' in body + # ``label`` is now derived by ``build_tool_preview`` from the + # tool args rather than passed by the caller, so we assert + # only that *some* label exists rather than a literal value. + assert '"label":' in body # The progress marker must NOT appear inside any # chat.completion.chunk delta.content field. import json as _json @@ -744,17 +747,17 @@ class TestChatCompletionsEndpoint: @pytest.mark.asyncio async def test_stream_tool_progress_skips_internal_events(self, adapter): - """Internal events (name starting with _) are not streamed.""" + """Internal tool calls (name starting with ``_``) are not streamed.""" import asyncio app = _create_app(adapter) async with TestClient(TestServer(app)) as cli: async def _mock_run_agent(**kwargs): cb = kwargs.get("stream_delta_callback") - tp_cb = kwargs.get("tool_progress_callback") - if tp_cb: - tp_cb("tool.started", "_thinking", "some internal state", {}) - tp_cb("tool.started", "web_search", "Python docs", {"query": "Python docs"}) + ts_cb = kwargs.get("tool_start_callback") + if ts_cb: + ts_cb("call_internal_1", "_thinking", {"text": "some internal state"}) + ts_cb("call_search_1", "web_search", {"query": "Python docs"}) if cb: await asyncio.sleep(0.05) cb("Found it.") @@ -776,10 +779,142 @@ class TestChatCompletionsEndpoint: body = await resp.text() # Internal _thinking event should NOT appear anywhere assert "some internal state" not in body + assert "call_internal_1" not in body # Real tool progress should appear as custom SSE event assert "event: hermes.tool.progress" in body assert '"tool": "web_search"' in body - assert '"label": "Python docs"' in body + # Label is derived from the args dict by build_tool_preview; + # asserting on the structural fact (label exists, call id + # is correlated) rather than a literal preview string keeps + # the test robust against preview-formatter tweaks. + assert '"label":' in body + assert '"toolCallId": "call_search_1"' in body + + @pytest.mark.asyncio + async def test_stream_emits_tool_lifecycle_with_call_id(self, adapter): + """Regression for #16588. + + ``/v1/chat/completions`` streaming previously emitted only a + ``tool.started``-style ``hermes.tool.progress`` event; clients + rendering tool lifecycle UI had no way to mark a tool as finished + because no matching ``status: completed`` event was emitted, and + no ``toolCallId`` was carried for correlation. + + The fix adds ``tool_start_callback`` / ``tool_complete_callback`` + to the chat completions agent invocation and writes both halves + of the lifecycle pair on the same ``event: hermes.tool.progress`` + SSE line, with stable ``toolCallId`` and ``status``. + """ + import asyncio + import json as _json + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + ts_cb = kwargs.get("tool_start_callback") + tc_cb = kwargs.get("tool_complete_callback") + # The structured callbacks own the chat-completions SSE + # channel now; ``tool_progress_callback`` is intentionally + # not wired so each tool start emits exactly one event. + if ts_cb: + ts_cb("call_terminal_1", "terminal", {"command": "ls -la"}) + if tc_cb: + tc_cb("call_terminal_1", "terminal", {"command": "ls -la"}, "ok") + if cb: + await asyncio.sleep(0.05) + cb("done.") + return ( + {"final_response": "done.", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "list"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + + # Walk the SSE body and collect *(status, toolCallId)* pairs + # per event so the assertions verify per-event correlation — + # an event missing ``toolCallId`` would not pass even if a + # different event happens to carry the right id. + pairs: list[tuple[str | None, str | None]] = [] + lines = body.splitlines() + for i, line in enumerate(lines): + if line.strip() != "event: hermes.tool.progress": + continue + for follow in lines[i + 1: i + 4]: + if follow.startswith("data: "): + try: + payload = _json.loads(follow[len("data: "):]) + except _json.JSONDecodeError: + break + pairs.append((payload.get("status"), payload.get("toolCallId"))) + break + + # Each tool start must emit exactly one event (no duplicate + # legacy + new emit), and each lifecycle pair must carry the + # same toolCallId on every event — not just somewhere in the + # aggregate. + assert len(pairs) == 2, f"expected 2 events (running+completed), got {pairs}" + assert pairs[0] == ("running", "call_terminal_1"), pairs + assert pairs[1] == ("completed", "call_terminal_1"), pairs + + @pytest.mark.asyncio + async def test_stream_tool_lifecycle_skips_internal_and_orphan_completes(self, adapter): + """Internal tools (``_thinking``-style) and ``completed`` events + without a prior matching ``running`` must produce no lifecycle + events on the wire — otherwise clients would see orphaned + ``status: completed`` updates they cannot correlate.""" + import asyncio + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + ts_cb = kwargs.get("tool_start_callback") + tc_cb = kwargs.get("tool_complete_callback") + # Internal tool — must be filtered. + if ts_cb: + ts_cb("call_internal_1", "_thinking", {}) + if tc_cb: + tc_cb("call_internal_1", "_thinking", {}, "") + # Completion without start — orphan, must be dropped. + if tc_cb: + tc_cb("call_orphan_1", "web_search", {}, "ok") + if cb: + await asyncio.sleep(0.05) + cb("ok.") + return ( + {"final_response": "ok.", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "ok"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + + # Neither the internal call_id nor the orphan call_id should + # surface as a lifecycle payload on the wire. + assert "call_internal_1" not in body + assert "call_orphan_1" not in body + assert '"status": "running"' not in body + assert '"status": "completed"' not in body @pytest.mark.asyncio async def test_no_user_message_returns_400(self, adapter):