mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(api-server): persist incomplete snapshot on asyncio.CancelledError too
Extends PR #15171 to also cover the server-side cancellation path (aiohttp shutdown, request-level timeout) — previously only ConnectionResetError triggered the incomplete-snapshot write, so cancellations left the store stuck at the in_progress snapshot written on response.created. Factors the incomplete-snapshot build into a _persist_incomplete_if_needed() helper called from both the ConnectionResetError and CancelledError branches; the CancelledError handler re-raises so cooperative cancellation semantics are preserved. Adds two regression tests that drive _write_sse_responses directly (the TestClient disconnect path races the server handler, which makes the end-to-end assertion flaky).
This commit is contained in:
parent
a29bad2a3c
commit
36d68bcb82
2 changed files with 184 additions and 24 deletions
|
|
@ -1292,6 +1292,40 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
if conversation:
|
||||
self._response_store.set_conversation(conversation, response_id)
|
||||
|
||||
def _persist_incomplete_if_needed() -> None:
|
||||
"""Persist an ``incomplete`` snapshot if no terminal one was written.
|
||||
|
||||
Called from both the client-disconnect (``ConnectionResetError``)
|
||||
and server-cancellation (``asyncio.CancelledError``) paths so
|
||||
GET /v1/responses/{id} and ``previous_response_id`` chaining keep
|
||||
working after abrupt stream termination.
|
||||
"""
|
||||
if not store or terminal_snapshot_persisted:
|
||||
return
|
||||
incomplete_text = "".join(final_text_parts) or final_response_text
|
||||
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
|
||||
if incomplete_text:
|
||||
incomplete_items.append({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": incomplete_text}],
|
||||
})
|
||||
incomplete_env = _envelope("incomplete")
|
||||
incomplete_env["output"] = incomplete_items
|
||||
incomplete_env["usage"] = {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
incomplete_history = list(conversation_history)
|
||||
incomplete_history.append({"role": "user", "content": user_message})
|
||||
if incomplete_text:
|
||||
incomplete_history.append({"role": "assistant", "content": incomplete_text})
|
||||
_persist_response_snapshot(
|
||||
incomplete_env,
|
||||
conversation_history_snapshot=incomplete_history,
|
||||
)
|
||||
|
||||
try:
|
||||
# response.created — initial envelope, status=in_progress
|
||||
created_env = _envelope("in_progress")
|
||||
|
|
@ -1598,30 +1632,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
})
|
||||
|
||||
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
||||
if store and not terminal_snapshot_persisted:
|
||||
incomplete_text = "".join(final_text_parts) or final_response_text
|
||||
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
|
||||
if incomplete_text:
|
||||
incomplete_items.append({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": incomplete_text}],
|
||||
})
|
||||
incomplete_env = _envelope("incomplete")
|
||||
incomplete_env["output"] = incomplete_items
|
||||
incomplete_env["usage"] = {
|
||||
"input_tokens": usage.get("input_tokens", 0),
|
||||
"output_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
}
|
||||
incomplete_history = list(conversation_history)
|
||||
incomplete_history.append({"role": "user", "content": user_message})
|
||||
if incomplete_text:
|
||||
incomplete_history.append({"role": "assistant", "content": incomplete_text})
|
||||
_persist_response_snapshot(
|
||||
incomplete_env,
|
||||
conversation_history_snapshot=incomplete_history,
|
||||
)
|
||||
_persist_incomplete_if_needed()
|
||||
# Client disconnected — interrupt the agent so it stops
|
||||
# making upstream LLM calls, then cancel the task.
|
||||
agent = agent_ref[0] if agent_ref else None
|
||||
|
|
@ -1637,6 +1648,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
logger.info("SSE client disconnected; interrupted agent task %s", response_id)
|
||||
except asyncio.CancelledError:
|
||||
# Server-side cancellation (e.g. shutdown, request timeout) —
|
||||
# persist an incomplete snapshot so GET /v1/responses/{id} and
|
||||
# previous_response_id chaining still work, then re-raise so the
|
||||
# runtime's cancellation semantics are respected.
|
||||
_persist_incomplete_if_needed()
|
||||
agent = agent_ref[0] if agent_ref else None
|
||||
if agent is not None:
|
||||
try:
|
||||
agent.interrupt("SSE task cancelled")
|
||||
except Exception:
|
||||
pass
|
||||
if not agent_task.done():
|
||||
agent_task.cancel()
|
||||
logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -1374,6 +1374,139 @@ class TestResponsesStreaming:
|
|||
assert data["status"] == "completed"
|
||||
assert data["output"][-1]["content"][0]["text"] == "Stored response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_cancelled_persists_incomplete_snapshot(self, adapter):
|
||||
"""Server-side asyncio.CancelledError (shutdown, request timeout) must
|
||||
still leave an ``incomplete`` snapshot in ResponseStore so
|
||||
GET /v1/responses/{id} and previous_response_id chaining keep
|
||||
working. Regression for PR #15171 follow-up.
|
||||
|
||||
Calls _write_sse_responses directly so the test can await the
|
||||
handler to completion (TestClient disconnection races the server
|
||||
handler, which makes end-to-end assertion on the final stored
|
||||
snapshot flaky).
|
||||
"""
|
||||
# Build a minimal fake request + stream queue the writer understands.
|
||||
fake_request = MagicMock()
|
||||
fake_request.headers = {}
|
||||
|
||||
written_payloads: list = []
|
||||
|
||||
class _FakeStreamResponse:
|
||||
async def prepare(self, req):
|
||||
pass
|
||||
|
||||
async def write(self, payload):
|
||||
written_payloads.append(payload)
|
||||
|
||||
# Patch web.StreamResponse for the duration of the writer call.
|
||||
import gateway.platforms.api_server as api_mod
|
||||
import queue as _q
|
||||
|
||||
stream_q: _q.Queue = _q.Queue()
|
||||
|
||||
async def _agent_coro():
|
||||
# Feed one partial delta into the stream queue...
|
||||
stream_q.put("partial output")
|
||||
# ...then give the drain loop a moment to pick it up before
|
||||
# raising CancelledError to simulate a server-side cancel.
|
||||
await asyncio.sleep(0.01)
|
||||
raise asyncio.CancelledError()
|
||||
|
||||
agent_task = asyncio.ensure_future(_agent_coro())
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
|
||||
with patch.object(api_mod.web, "StreamResponse", return_value=_FakeStreamResponse()):
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await adapter._write_sse_responses(
|
||||
request=fake_request,
|
||||
response_id=response_id,
|
||||
model="hermes-agent",
|
||||
created_at=int(time.time()),
|
||||
stream_q=stream_q,
|
||||
agent_task=agent_task,
|
||||
agent_ref=[None],
|
||||
conversation_history=[],
|
||||
user_message="will be cancelled",
|
||||
instructions=None,
|
||||
conversation=None,
|
||||
store=True,
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
# The in_progress snapshot was persisted on response.created,
|
||||
# and the CancelledError handler must have updated it to
|
||||
# ``incomplete`` with the partial text it saw.
|
||||
stored = adapter._response_store.get(response_id)
|
||||
assert stored is not None, "snapshot must be retrievable after cancellation"
|
||||
assert stored["response"]["status"] == "incomplete"
|
||||
# Partial text captured before cancel should be preserved.
|
||||
output_text = "".join(
|
||||
part.get("text", "")
|
||||
for item in stored["response"].get("output", [])
|
||||
if item.get("type") == "message"
|
||||
for part in item.get("content", [])
|
||||
)
|
||||
assert "partial output" in output_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_client_disconnect_persists_incomplete_snapshot(self, adapter):
|
||||
"""Client disconnect (ConnectionResetError) during streaming must
|
||||
persist an ``incomplete`` snapshot in ResponseStore. Regression
|
||||
for PR #15171."""
|
||||
fake_request = MagicMock()
|
||||
fake_request.headers = {}
|
||||
|
||||
write_call_count = {"n": 0}
|
||||
|
||||
class _DisconnectingStreamResponse:
|
||||
async def prepare(self, req):
|
||||
pass
|
||||
|
||||
async def write(self, payload):
|
||||
# First two writes succeed (prepare + response.created).
|
||||
# On the third write (a text delta), the "client"
|
||||
# disconnects — simulate with ConnectionResetError.
|
||||
write_call_count["n"] += 1
|
||||
if write_call_count["n"] >= 3:
|
||||
raise ConnectionResetError("simulated client disconnect")
|
||||
|
||||
import gateway.platforms.api_server as api_mod
|
||||
import queue as _q
|
||||
|
||||
stream_q: _q.Queue = _q.Queue()
|
||||
stream_q.put("some streamed text")
|
||||
stream_q.put(None) # EOS sentinel
|
||||
|
||||
async def _agent_coro():
|
||||
await asyncio.sleep(0.01)
|
||||
return ({"final_response": "", "messages": [], "api_calls": 0},
|
||||
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||
|
||||
agent_task = asyncio.ensure_future(_agent_coro())
|
||||
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||
|
||||
with patch.object(api_mod.web, "StreamResponse", return_value=_DisconnectingStreamResponse()):
|
||||
await adapter._write_sse_responses(
|
||||
request=fake_request,
|
||||
response_id=response_id,
|
||||
model="hermes-agent",
|
||||
created_at=int(time.time()),
|
||||
stream_q=stream_q,
|
||||
agent_task=agent_task,
|
||||
agent_ref=[None],
|
||||
conversation_history=[],
|
||||
user_message="will disconnect",
|
||||
instructions=None,
|
||||
conversation=None,
|
||||
store=True,
|
||||
session_id=None,
|
||||
)
|
||||
|
||||
stored = adapter._response_store.get(response_id)
|
||||
assert stored is not None, "snapshot must survive client disconnect"
|
||||
assert stored["response"]["status"] == "incomplete"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Auth on endpoints
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue