mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix(gateway): enqueue SSE EOS sentinel on task completion
This commit is contained in:
parent
4fa5f7b765
commit
4bb0a82a2b
2 changed files with 102 additions and 1 deletions
|
|
@ -1168,6 +1168,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
agent_ref=agent_ref,
|
||||
gateway_session_key=gateway_session_key,
|
||||
))
|
||||
# Ensure SSE drain loops can terminate without relying on polling
|
||||
# agent_task.done(), which can race with queue timeout checks.
|
||||
agent_task.add_done_callback(lambda _fut: _stream_q.put(None))
|
||||
|
||||
return await self._write_sse_chat_completion(
|
||||
request, completion_id, model_name, created, _stream_q,
|
||||
|
|
@ -2197,6 +2200,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
agent_ref=agent_ref,
|
||||
gateway_session_key=gateway_session_key,
|
||||
))
|
||||
# Ensure SSE drain loops can terminate without relying on polling
|
||||
# agent_task.done(), which can race with queue timeout checks.
|
||||
agent_task.add_done_callback(lambda _fut: _stream_q.put(None))
|
||||
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
model_name = body.get("model", self._model_name)
|
||||
|
|
|
|||
|
|
@ -681,6 +681,56 @@ class TestChatCompletionsEndpoint:
|
|||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_task_done_callback_enqueues_eos_for_chat_completions(self, adapter):
|
||||
"""Regression guard for #24451: completion callback must signal SSE EOS."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
class _FakeTask:
|
||||
def __init__(self):
|
||||
self.callbacks = []
|
||||
|
||||
def add_done_callback(self, cb):
|
||||
self.callbacks.append(cb)
|
||||
|
||||
fake_task = _FakeTask()
|
||||
|
||||
def _fake_ensure_future(coro):
|
||||
# We short-circuit task scheduling in this unit test.
|
||||
coro.close()
|
||||
return fake_task
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter,
|
||||
"_run_agent",
|
||||
new=AsyncMock(
|
||||
return_value=(
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
),
|
||||
),
|
||||
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
|
||||
patch.object(adapter, "_write_sse_chat_completion", new_callable=AsyncMock) as mock_write_sse,
|
||||
):
|
||||
mock_write_sse.return_value = web.Response(status=200, text="ok")
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
assert len(fake_task.callbacks) == 1
|
||||
stream_q = mock_write_sse.call_args.args[4]
|
||||
assert stream_q.empty()
|
||||
fake_task.callbacks[0](fake_task)
|
||||
assert stream_q.get_nowait() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||
|
|
@ -1676,6 +1726,52 @@ class TestResponsesStreaming:
|
|||
assert "Hello" in body
|
||||
assert " world" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_task_done_callback_enqueues_eos_for_responses(self, adapter):
|
||||
"""Regression guard for #24451 on /v1/responses streaming path."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
class _FakeTask:
|
||||
def __init__(self):
|
||||
self.callbacks = []
|
||||
|
||||
def add_done_callback(self, cb):
|
||||
self.callbacks.append(cb)
|
||||
|
||||
fake_task = _FakeTask()
|
||||
|
||||
def _fake_ensure_future(coro):
|
||||
# We short-circuit task scheduling in this unit test.
|
||||
coro.close()
|
||||
return fake_task
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
adapter,
|
||||
"_run_agent",
|
||||
new=AsyncMock(
|
||||
return_value=(
|
||||
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||
)
|
||||
),
|
||||
),
|
||||
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
|
||||
patch.object(adapter, "_write_sse_responses", new_callable=AsyncMock) as mock_write_sse,
|
||||
):
|
||||
mock_write_sse.return_value = web.Response(status=200, text="ok")
|
||||
resp = await cli.post(
|
||||
"/v1/responses",
|
||||
json={"model": "hermes-agent", "input": "hi", "stream": True},
|
||||
)
|
||||
assert resp.status == 200
|
||||
|
||||
assert len(fake_task.callbacks) == 1
|
||||
stream_q = mock_write_sse.call_args.kwargs["stream_q"]
|
||||
assert stream_q.empty()
|
||||
fake_task.callbacks[0](fake_task)
|
||||
assert stream_q.get_nowait() is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_emits_function_call_and_output_items(self, adapter):
|
||||
app = _create_app(adapter)
|
||||
|
|
@ -3061,4 +3157,3 @@ class TestSessionKeyHeader:
|
|||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert data["features"]["session_key_header"] == "X-Hermes-Session-Key"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue