fix(gateway): enqueue SSE EOS sentinel on task completion

This commit is contained in:
Ahmet Oşrak 2026-05-13 00:56:32 +03:00 committed by Teknium
parent 4fa5f7b765
commit 4bb0a82a2b
2 changed files with 102 additions and 1 deletions

View file

@ -681,6 +681,56 @@ class TestChatCompletionsEndpoint:
assert "[DONE]" in body
assert "Hello!" in body
@pytest.mark.asyncio
async def test_stream_task_done_callback_enqueues_eos_for_chat_completions(self, adapter):
"""Regression guard for #24451: completion callback must signal SSE EOS."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
class _FakeTask:
def __init__(self):
self.callbacks = []
def add_done_callback(self, cb):
self.callbacks.append(cb)
fake_task = _FakeTask()
def _fake_ensure_future(coro):
# We short-circuit task scheduling in this unit test.
coro.close()
return fake_task
with (
patch.object(
adapter,
"_run_agent",
new=AsyncMock(
return_value=(
{"final_response": "ok", "messages": [], "api_calls": 1},
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
)
),
),
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
patch.object(adapter, "_write_sse_chat_completion", new_callable=AsyncMock) as mock_write_sse,
):
mock_write_sse.return_value = web.Response(status=200, text="ok")
resp = await cli.post(
"/v1/chat/completions",
json={
"model": "test",
"messages": [{"role": "user", "content": "hi"}],
"stream": True,
},
)
assert resp.status == 200
assert len(fake_task.callbacks) == 1
stream_q = mock_write_sse.call_args.args[4]
assert stream_q.empty()
fake_task.callbacks[0](fake_task)
assert stream_q.get_nowait() is None
@pytest.mark.asyncio
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
"""Idle SSE streams should send keepalive comments while tools run silently."""
@ -1676,6 +1726,52 @@ class TestResponsesStreaming:
assert "Hello" in body
assert " world" in body
@pytest.mark.asyncio
async def test_stream_task_done_callback_enqueues_eos_for_responses(self, adapter):
"""Regression guard for #24451 on /v1/responses streaming path."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
class _FakeTask:
def __init__(self):
self.callbacks = []
def add_done_callback(self, cb):
self.callbacks.append(cb)
fake_task = _FakeTask()
def _fake_ensure_future(coro):
# We short-circuit task scheduling in this unit test.
coro.close()
return fake_task
with (
patch.object(
adapter,
"_run_agent",
new=AsyncMock(
return_value=(
{"final_response": "ok", "messages": [], "api_calls": 1},
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
)
),
),
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
patch.object(adapter, "_write_sse_responses", new_callable=AsyncMock) as mock_write_sse,
):
mock_write_sse.return_value = web.Response(status=200, text="ok")
resp = await cli.post(
"/v1/responses",
json={"model": "hermes-agent", "input": "hi", "stream": True},
)
assert resp.status == 200
assert len(fake_task.callbacks) == 1
stream_q = mock_write_sse.call_args.kwargs["stream_q"]
assert stream_q.empty()
fake_task.callbacks[0](fake_task)
assert stream_q.get_nowait() is None
@pytest.mark.asyncio
async def test_stream_emits_function_call_and_output_items(self, adapter):
app = _create_app(adapter)
@ -3061,4 +3157,3 @@ class TestSessionKeyHeader:
assert resp.status == 200
data = await resp.json()
assert data["features"]["session_key_header"] == "X-Hermes-Session-Key"