From 4bb0a82a2b8dc4d4fd952d977a81ae2ccbc52fbc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ahmet=20O=C5=9Frak?= Date: Wed, 13 May 2026 00:56:32 +0300 Subject: [PATCH] fix(gateway): enqueue SSE EOS sentinel on task completion --- gateway/platforms/api_server.py | 6 ++ tests/gateway/test_api_server.py | 97 +++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 497adbd19c6..8b53db3a99f 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -1168,6 +1168,9 @@ class APIServerAdapter(BasePlatformAdapter): agent_ref=agent_ref, gateway_session_key=gateway_session_key, )) + # Ensure SSE drain loops can terminate without relying on polling + # agent_task.done(), which can race with queue timeout checks. + agent_task.add_done_callback(lambda _fut: _stream_q.put(None)) return await self._write_sse_chat_completion( request, completion_id, model_name, created, _stream_q, @@ -2197,6 +2200,9 @@ class APIServerAdapter(BasePlatformAdapter): agent_ref=agent_ref, gateway_session_key=gateway_session_key, )) + # Ensure SSE drain loops can terminate without relying on polling + # agent_task.done(), which can race with queue timeout checks. + agent_task.add_done_callback(lambda _fut: _stream_q.put(None)) response_id = f"resp_{uuid.uuid4().hex[:28]}" model_name = body.get("model", self._model_name) diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 9e00a375871..66b304fff51 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -681,6 +681,56 @@ class TestChatCompletionsEndpoint: assert "[DONE]" in body assert "Hello!" in body + @pytest.mark.asyncio + async def test_stream_task_done_callback_enqueues_eos_for_chat_completions(self, adapter): + """Regression guard for #24451: completion callback must signal SSE EOS.""" + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + class _FakeTask: + def __init__(self): + self.callbacks = [] + + def add_done_callback(self, cb): + self.callbacks.append(cb) + + fake_task = _FakeTask() + + def _fake_ensure_future(coro): + # We short-circuit task scheduling in this unit test. + coro.close() + return fake_task + + with ( + patch.object( + adapter, + "_run_agent", + new=AsyncMock( + return_value=( + {"final_response": "ok", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + ), + ), + patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future), + patch.object(adapter, "_write_sse_chat_completion", new_callable=AsyncMock) as mock_write_sse, + ): + mock_write_sse.return_value = web.Response(status=200, text="ok") + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "hi"}], + "stream": True, + }, + ) + assert resp.status == 200 + + assert len(fake_task.callbacks) == 1 + stream_q = mock_write_sse.call_args.args[4] + assert stream_q.empty() + fake_task.callbacks[0](fake_task) + assert stream_q.get_nowait() is None + @pytest.mark.asyncio async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter): """Idle SSE streams should send keepalive comments while tools run silently.""" @@ -1676,6 +1726,52 @@ class TestResponsesStreaming: assert "Hello" in body assert " world" in body + @pytest.mark.asyncio + async def test_stream_task_done_callback_enqueues_eos_for_responses(self, adapter): + """Regression guard for #24451 on /v1/responses streaming path.""" + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + class _FakeTask: + def __init__(self): + self.callbacks = [] + + def add_done_callback(self, cb): + self.callbacks.append(cb) + + fake_task = _FakeTask() + + def _fake_ensure_future(coro): + # We short-circuit task scheduling in this unit test. + coro.close() + return fake_task + + with ( + patch.object( + adapter, + "_run_agent", + new=AsyncMock( + return_value=( + {"final_response": "ok", "messages": [], "api_calls": 1}, + {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + ) + ), + ), + patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future), + patch.object(adapter, "_write_sse_responses", new_callable=AsyncMock) as mock_write_sse, + ): + mock_write_sse.return_value = web.Response(status=200, text="ok") + resp = await cli.post( + "/v1/responses", + json={"model": "hermes-agent", "input": "hi", "stream": True}, + ) + assert resp.status == 200 + + assert len(fake_task.callbacks) == 1 + stream_q = mock_write_sse.call_args.kwargs["stream_q"] + assert stream_q.empty() + fake_task.callbacks[0](fake_task) + assert stream_q.get_nowait() is None + @pytest.mark.asyncio async def test_stream_emits_function_call_and_output_items(self, adapter): app = _create_app(adapter) @@ -3061,4 +3157,3 @@ class TestSessionKeyHeader: assert resp.status == 200 data = await resp.json() assert data["features"]["session_key_header"] == "X-Hermes-Session-Key" -