mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix(gateway): enqueue SSE EOS sentinel on task completion
This commit is contained in:
parent
4fa5f7b765
commit
4bb0a82a2b
2 changed files with 102 additions and 1 deletions
|
|
@ -1168,6 +1168,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||||
agent_ref=agent_ref,
|
agent_ref=agent_ref,
|
||||||
gateway_session_key=gateway_session_key,
|
gateway_session_key=gateway_session_key,
|
||||||
))
|
))
|
||||||
|
# Ensure SSE drain loops can terminate without relying on polling
|
||||||
|
# agent_task.done(), which can race with queue timeout checks.
|
||||||
|
agent_task.add_done_callback(lambda _fut: _stream_q.put(None))
|
||||||
|
|
||||||
return await self._write_sse_chat_completion(
|
return await self._write_sse_chat_completion(
|
||||||
request, completion_id, model_name, created, _stream_q,
|
request, completion_id, model_name, created, _stream_q,
|
||||||
|
|
@ -2197,6 +2200,9 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||||
agent_ref=agent_ref,
|
agent_ref=agent_ref,
|
||||||
gateway_session_key=gateway_session_key,
|
gateway_session_key=gateway_session_key,
|
||||||
))
|
))
|
||||||
|
# Ensure SSE drain loops can terminate without relying on polling
|
||||||
|
# agent_task.done(), which can race with queue timeout checks.
|
||||||
|
agent_task.add_done_callback(lambda _fut: _stream_q.put(None))
|
||||||
|
|
||||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||||
model_name = body.get("model", self._model_name)
|
model_name = body.get("model", self._model_name)
|
||||||
|
|
|
||||||
|
|
@ -681,6 +681,56 @@ class TestChatCompletionsEndpoint:
|
||||||
assert "[DONE]" in body
|
assert "[DONE]" in body
|
||||||
assert "Hello!" in body
|
assert "Hello!" in body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_task_done_callback_enqueues_eos_for_chat_completions(self, adapter):
|
||||||
|
"""Regression guard for #24451: completion callback must signal SSE EOS."""
|
||||||
|
app = _create_app(adapter)
|
||||||
|
async with TestClient(TestServer(app)) as cli:
|
||||||
|
class _FakeTask:
|
||||||
|
def __init__(self):
|
||||||
|
self.callbacks = []
|
||||||
|
|
||||||
|
def add_done_callback(self, cb):
|
||||||
|
self.callbacks.append(cb)
|
||||||
|
|
||||||
|
fake_task = _FakeTask()
|
||||||
|
|
||||||
|
def _fake_ensure_future(coro):
|
||||||
|
# We short-circuit task scheduling in this unit test.
|
||||||
|
coro.close()
|
||||||
|
return fake_task
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
adapter,
|
||||||
|
"_run_agent",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value=(
|
||||||
|
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||||
|
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
|
||||||
|
patch.object(adapter, "_write_sse_chat_completion", new_callable=AsyncMock) as mock_write_sse,
|
||||||
|
):
|
||||||
|
mock_write_sse.return_value = web.Response(status=200, text="ok")
|
||||||
|
resp = await cli.post(
|
||||||
|
"/v1/chat/completions",
|
||||||
|
json={
|
||||||
|
"model": "test",
|
||||||
|
"messages": [{"role": "user", "content": "hi"}],
|
||||||
|
"stream": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
assert len(fake_task.callbacks) == 1
|
||||||
|
stream_q = mock_write_sse.call_args.args[4]
|
||||||
|
assert stream_q.empty()
|
||||||
|
fake_task.callbacks[0](fake_task)
|
||||||
|
assert stream_q.get_nowait() is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||||
|
|
@ -1676,6 +1726,52 @@ class TestResponsesStreaming:
|
||||||
assert "Hello" in body
|
assert "Hello" in body
|
||||||
assert " world" in body
|
assert " world" in body
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_task_done_callback_enqueues_eos_for_responses(self, adapter):
|
||||||
|
"""Regression guard for #24451 on /v1/responses streaming path."""
|
||||||
|
app = _create_app(adapter)
|
||||||
|
async with TestClient(TestServer(app)) as cli:
|
||||||
|
class _FakeTask:
|
||||||
|
def __init__(self):
|
||||||
|
self.callbacks = []
|
||||||
|
|
||||||
|
def add_done_callback(self, cb):
|
||||||
|
self.callbacks.append(cb)
|
||||||
|
|
||||||
|
fake_task = _FakeTask()
|
||||||
|
|
||||||
|
def _fake_ensure_future(coro):
|
||||||
|
# We short-circuit task scheduling in this unit test.
|
||||||
|
coro.close()
|
||||||
|
return fake_task
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
adapter,
|
||||||
|
"_run_agent",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value=(
|
||||||
|
{"final_response": "ok", "messages": [], "api_calls": 1},
|
||||||
|
{"input_tokens": 1, "output_tokens": 1, "total_tokens": 2},
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch("gateway.platforms.api_server.asyncio.ensure_future", side_effect=_fake_ensure_future),
|
||||||
|
patch.object(adapter, "_write_sse_responses", new_callable=AsyncMock) as mock_write_sse,
|
||||||
|
):
|
||||||
|
mock_write_sse.return_value = web.Response(status=200, text="ok")
|
||||||
|
resp = await cli.post(
|
||||||
|
"/v1/responses",
|
||||||
|
json={"model": "hermes-agent", "input": "hi", "stream": True},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
assert len(fake_task.callbacks) == 1
|
||||||
|
stream_q = mock_write_sse.call_args.kwargs["stream_q"]
|
||||||
|
assert stream_q.empty()
|
||||||
|
fake_task.callbacks[0](fake_task)
|
||||||
|
assert stream_q.get_nowait() is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_emits_function_call_and_output_items(self, adapter):
|
async def test_stream_emits_function_call_and_output_items(self, adapter):
|
||||||
app = _create_app(adapter)
|
app = _create_app(adapter)
|
||||||
|
|
@ -3061,4 +3157,3 @@ class TestSessionKeyHeader:
|
||||||
assert resp.status == 200
|
assert resp.status == 200
|
||||||
data = await resp.json()
|
data = await resp.json()
|
||||||
assert data["features"]["session_key_header"] == "X-Hermes-Session-Key"
|
assert data["features"]["session_key_header"] == "X-Hermes-Session-Key"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue