mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-26 01:01:40 +00:00
fix(api-server): persist incomplete snapshot on asyncio.CancelledError too
Extends PR #15171 to also cover the server-side cancellation path (aiohttp shutdown, request-level timeout) — previously only ConnectionResetError triggered the incomplete-snapshot write, so cancellations left the store stuck at the in_progress snapshot written on response.created. Factors the incomplete-snapshot build into a _persist_incomplete_if_needed() helper called from both the ConnectionResetError and CancelledError branches; the CancelledError handler re-raises so cooperative cancellation semantics are preserved. Adds two regression tests that drive _write_sse_responses directly (the TestClient disconnect path races the server handler, which makes the end-to-end assertion flaky).
This commit is contained in:
parent
a29bad2a3c
commit
36d68bcb82
2 changed files with 184 additions and 24 deletions
|
|
@ -1374,6 +1374,139 @@ class TestResponsesStreaming:
|
|||
assert data["status"] == "completed"
|
||||
assert data["output"][-1]["content"][0]["text"] == "Stored response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_cancelled_persists_incomplete_snapshot(self, adapter):
|
||||
"""Server-side asyncio.CancelledError (shutdown, request timeout) must
|
||||
still leave an ``incomplete`` snapshot in ResponseStore so
|
||||
GET /v1/responses/{id} and previous_response_id chaining keep
|
||||
working. Regression for PR #15171 follow-up.
|
||||
|
||||
Calls _write_sse_responses directly so the test can await the
|
||||
handler to completion (TestClient disconnection races the server
|
||||
handler, which makes end-to-end assertion on the final stored
|
||||
snapshot flaky).
|
||||
"""
|
||||
# Build a minimal fake request + stream queue the writer understands.
|
||||
fake_request = MagicMock()
|
||||
fake_request.headers = {}
|
||||
|
||||
written_payloads: list = []
|
||||
|
||||
class _FakeStreamResponse:
|
||||
async def prepare(self, req):
|
||||
pass
|
||||
|
||||
async def write(self, payload):
|
||||
written_payloads.append(payload)
|
||||
|
||||
# Patch web.StreamResponse for the duration of the writer call.
|
||||
import gateway.platforms.api_server as api_mod
|
||||
import queue as _q
|
||||
|
||||
stream_q: _q.Queue = _q.Queue()
|
||||
|
||||
async def _agent_coro():
|
||||
# Feed one partial delta into the stream queue...
|
||||
stream_q.put("partial output")
|
||||
# ...then give the drain loop a moment to pick it up before
|
||||
# raising CancelledError to simulate a server-side cancel.
|
||||
await asyncio.sleep(0.01)
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
agent_task = asyncio.ensure_future(_agent_coro())
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
|
||||
with patch.object(api_mod.web, "StreamResponse", return_value=_FakeStreamResponse()):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await adapter._write_sse_responses(
|
||||
request=fake_request,
|
||||
response_id=response_id,
|
||||
model="hermes-agent",
|
||||
created_at=int(time.time()),
|
||||
stream_q=stream_q,
|
||||
agent_task=agent_task,
|
||||
agent_ref=[None],
|
||||
conversation_history=[],
|
||||
user_message="will be cancelled",
|
||||
instructions=None,
|
||||
conversation=None,
|
||||
store=True,
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
# The in_progress snapshot was persisted on response.created,
|
||||
# and the CancelledError handler must have updated it to
|
||||
# ``incomplete`` with the partial text it saw.
|
||||
stored = adapter._response_store.get(response_id)
|
||||
assert stored is not None, "snapshot must be retrievable after cancellation"
|
||||
assert stored["response"]["status"] == "incomplete"
|
||||
# Partial text captured before cancel should be preserved.
|
||||
output_text = "".join(
|
||||
part.get("text", "")
|
||||
for item in stored["response"].get("output", [])
|
||||
if item.get("type") == "message"
|
||||
for part in item.get("content", [])
|
||||
)
|
||||
assert "partial output" in output_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_client_disconnect_persists_incomplete_snapshot(self, adapter):
|
||||
"""Client disconnect (ConnectionResetError) during streaming must
|
||||
persist an ``incomplete`` snapshot in ResponseStore. Regression
|
||||
for PR #15171."""
|
||||
fake_request = MagicMock()
|
||||
fake_request.headers = {}
|
||||
|
||||
write_call_count = {"n": 0}
|
||||
|
||||
class _DisconnectingStreamResponse:
|
||||
async def prepare(self, req):
|
||||
pass
|
||||
|
||||
async def write(self, payload):
|
||||
# First two writes succeed (prepare + response.created).
|
||||
# On the third write (a text delta), the "client"
|
||||
# disconnects — simulate with ConnectionResetError.
|
||||
write_call_count["n"] += 1
|
||||
if write_call_count["n"] >= 3:
|
||||
raise ConnectionResetError("simulated client disconnect")
|
||||
|
||||
import gateway.platforms.api_server as api_mod
|
||||
import queue as _q
|
||||
|
||||
stream_q: _q.Queue = _q.Queue()
|
||||
stream_q.put("some streamed text")
|
||||
stream_q.put(None) # EOS sentinel
|
||||
|
||||
async def _agent_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
return ({"final_response": "", "messages": [], "api_calls": 0},
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
|
||||
agent_task = asyncio.ensure_future(_agent_coro())
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
|
||||
with patch.object(api_mod.web, "StreamResponse", return_value=_DisconnectingStreamResponse()):
|
||||
await adapter._write_sse_responses(
|
||||
request=fake_request,
|
||||
response_id=response_id,
|
||||
model="hermes-agent",
|
||||
created_at=int(time.time()),
|
||||
stream_q=stream_q,
|
||||
agent_task=agent_task,
|
||||
agent_ref=[None],
|
||||
conversation_history=[],
|
||||
user_message="will disconnect",
|
||||
instructions=None,
|
||||
conversation=None,
|
||||
store=True,
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
stored = adapter._response_store.get(response_id)
|
||||
assert stored is not None, "snapshot must survive client disconnect"
|
||||
assert stored["response"]["status"] == "incomplete"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth on endpoints
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue