mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(api-server): persist incomplete snapshot on asyncio.CancelledError too
Extends PR #15171 to also cover the server-side cancellation path (aiohttp shutdown, request-level timeout) — previously only ConnectionResetError triggered the incomplete-snapshot write, so cancellations left the store stuck at the in_progress snapshot written on response.created. Factors the incomplete-snapshot build into a _persist_incomplete_if_needed() helper called from both the ConnectionResetError and CancelledError branches; the CancelledError handler re-raises so cooperative cancellation semantics are preserved. Adds two regression tests that drive _write_sse_responses directly (the TestClient disconnect path races the server handler, which makes the end-to-end assertion flaky).
This commit is contained in:
parent
a29bad2a3c
commit
36d68bcb82
2 changed files with 184 additions and 24 deletions
|
|
@ -1292,6 +1292,40 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||||
if conversation:
|
if conversation:
|
||||||
self._response_store.set_conversation(conversation, response_id)
|
self._response_store.set_conversation(conversation, response_id)
|
||||||
|
|
||||||
|
def _persist_incomplete_if_needed() -> None:
|
||||||
|
"""Persist an ``incomplete`` snapshot if no terminal one was written.
|
||||||
|
|
||||||
|
Called from both the client-disconnect (``ConnectionResetError``)
|
||||||
|
and server-cancellation (``asyncio.CancelledError``) paths so
|
||||||
|
GET /v1/responses/{id} and ``previous_response_id`` chaining keep
|
||||||
|
working after abrupt stream termination.
|
||||||
|
"""
|
||||||
|
if not store or terminal_snapshot_persisted:
|
||||||
|
return
|
||||||
|
incomplete_text = "".join(final_text_parts) or final_response_text
|
||||||
|
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
|
||||||
|
if incomplete_text:
|
||||||
|
incomplete_items.append({
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "output_text", "text": incomplete_text}],
|
||||||
|
})
|
||||||
|
incomplete_env = _envelope("incomplete")
|
||||||
|
incomplete_env["output"] = incomplete_items
|
||||||
|
incomplete_env["usage"] = {
|
||||||
|
"input_tokens": usage.get("input_tokens", 0),
|
||||||
|
"output_tokens": usage.get("output_tokens", 0),
|
||||||
|
"total_tokens": usage.get("total_tokens", 0),
|
||||||
|
}
|
||||||
|
incomplete_history = list(conversation_history)
|
||||||
|
incomplete_history.append({"role": "user", "content": user_message})
|
||||||
|
if incomplete_text:
|
||||||
|
incomplete_history.append({"role": "assistant", "content": incomplete_text})
|
||||||
|
_persist_response_snapshot(
|
||||||
|
incomplete_env,
|
||||||
|
conversation_history_snapshot=incomplete_history,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# response.created — initial envelope, status=in_progress
|
# response.created — initial envelope, status=in_progress
|
||||||
created_env = _envelope("in_progress")
|
created_env = _envelope("in_progress")
|
||||||
|
|
@ -1598,30 +1632,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||||
})
|
})
|
||||||
|
|
||||||
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
||||||
if store and not terminal_snapshot_persisted:
|
_persist_incomplete_if_needed()
|
||||||
incomplete_text = "".join(final_text_parts) or final_response_text
|
|
||||||
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
|
|
||||||
if incomplete_text:
|
|
||||||
incomplete_items.append({
|
|
||||||
"type": "message",
|
|
||||||
"role": "assistant",
|
|
||||||
"content": [{"type": "output_text", "text": incomplete_text}],
|
|
||||||
})
|
|
||||||
incomplete_env = _envelope("incomplete")
|
|
||||||
incomplete_env["output"] = incomplete_items
|
|
||||||
incomplete_env["usage"] = {
|
|
||||||
"input_tokens": usage.get("input_tokens", 0),
|
|
||||||
"output_tokens": usage.get("output_tokens", 0),
|
|
||||||
"total_tokens": usage.get("total_tokens", 0),
|
|
||||||
}
|
|
||||||
incomplete_history = list(conversation_history)
|
|
||||||
incomplete_history.append({"role": "user", "content": user_message})
|
|
||||||
if incomplete_text:
|
|
||||||
incomplete_history.append({"role": "assistant", "content": incomplete_text})
|
|
||||||
_persist_response_snapshot(
|
|
||||||
incomplete_env,
|
|
||||||
conversation_history_snapshot=incomplete_history,
|
|
||||||
)
|
|
||||||
# Client disconnected — interrupt the agent so it stops
|
# Client disconnected — interrupt the agent so it stops
|
||||||
# making upstream LLM calls, then cancel the task.
|
# making upstream LLM calls, then cancel the task.
|
||||||
agent = agent_ref[0] if agent_ref else None
|
agent = agent_ref[0] if agent_ref else None
|
||||||
|
|
@ -1637,6 +1648,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||||
except (asyncio.CancelledError, Exception):
|
except (asyncio.CancelledError, Exception):
|
||||||
pass
|
pass
|
||||||
logger.info("SSE client disconnected; interrupted agent task %s", response_id)
|
logger.info("SSE client disconnected; interrupted agent task %s", response_id)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Server-side cancellation (e.g. shutdown, request timeout) —
|
||||||
|
# persist an incomplete snapshot so GET /v1/responses/{id} and
|
||||||
|
# previous_response_id chaining still work, then re-raise so the
|
||||||
|
# runtime's cancellation semantics are respected.
|
||||||
|
_persist_incomplete_if_needed()
|
||||||
|
agent = agent_ref[0] if agent_ref else None
|
||||||
|
if agent is not None:
|
||||||
|
try:
|
||||||
|
agent.interrupt("SSE task cancelled")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not agent_task.done():
|
||||||
|
agent_task.cancel()
|
||||||
|
logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id)
|
||||||
|
raise
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1374,6 +1374,139 @@ class TestResponsesStreaming:
|
||||||
assert data["status"] == "completed"
|
assert data["status"] == "completed"
|
||||||
assert data["output"][-1]["content"][0]["text"] == "Stored response"
|
assert data["output"][-1]["content"][0]["text"] == "Stored response"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_cancelled_persists_incomplete_snapshot(self, adapter):
|
||||||
|
"""Server-side asyncio.CancelledError (shutdown, request timeout) must
|
||||||
|
still leave an ``incomplete`` snapshot in ResponseStore so
|
||||||
|
GET /v1/responses/{id} and previous_response_id chaining keep
|
||||||
|
working. Regression for PR #15171 follow-up.
|
||||||
|
|
||||||
|
Calls _write_sse_responses directly so the test can await the
|
||||||
|
handler to completion (TestClient disconnection races the server
|
||||||
|
handler, which makes end-to-end assertion on the final stored
|
||||||
|
snapshot flaky).
|
||||||
|
"""
|
||||||
|
# Build a minimal fake request + stream queue the writer understands.
|
||||||
|
fake_request = MagicMock()
|
||||||
|
fake_request.headers = {}
|
||||||
|
|
||||||
|
written_payloads: list = []
|
||||||
|
|
||||||
|
class _FakeStreamResponse:
|
||||||
|
async def prepare(self, req):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def write(self, payload):
|
||||||
|
written_payloads.append(payload)
|
||||||
|
|
||||||
|
# Patch web.StreamResponse for the duration of the writer call.
|
||||||
|
import gateway.platforms.api_server as api_mod
|
||||||
|
import queue as _q
|
||||||
|
|
||||||
|
stream_q: _q.Queue = _q.Queue()
|
||||||
|
|
||||||
|
async def _agent_coro():
|
||||||
|
# Feed one partial delta into the stream queue...
|
||||||
|
stream_q.put("partial output")
|
||||||
|
# ...then give the drain loop a moment to pick it up before
|
||||||
|
# raising CancelledError to simulate a server-side cancel.
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
agent_task = asyncio.ensure_future(_agent_coro())
|
||||||
|
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||||
|
|
||||||
|
with patch.object(api_mod.web, "StreamResponse", return_value=_FakeStreamResponse()):
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await adapter._write_sse_responses(
|
||||||
|
request=fake_request,
|
||||||
|
response_id=response_id,
|
||||||
|
model="hermes-agent",
|
||||||
|
created_at=int(time.time()),
|
||||||
|
stream_q=stream_q,
|
||||||
|
agent_task=agent_task,
|
||||||
|
agent_ref=[None],
|
||||||
|
conversation_history=[],
|
||||||
|
user_message="will be cancelled",
|
||||||
|
instructions=None,
|
||||||
|
conversation=None,
|
||||||
|
store=True,
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# The in_progress snapshot was persisted on response.created,
|
||||||
|
# and the CancelledError handler must have updated it to
|
||||||
|
# ``incomplete`` with the partial text it saw.
|
||||||
|
stored = adapter._response_store.get(response_id)
|
||||||
|
assert stored is not None, "snapshot must be retrievable after cancellation"
|
||||||
|
assert stored["response"]["status"] == "incomplete"
|
||||||
|
# Partial text captured before cancel should be preserved.
|
||||||
|
output_text = "".join(
|
||||||
|
part.get("text", "")
|
||||||
|
for item in stored["response"].get("output", [])
|
||||||
|
if item.get("type") == "message"
|
||||||
|
for part in item.get("content", [])
|
||||||
|
)
|
||||||
|
assert "partial output" in output_text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stream_client_disconnect_persists_incomplete_snapshot(self, adapter):
|
||||||
|
"""Client disconnect (ConnectionResetError) during streaming must
|
||||||
|
persist an ``incomplete`` snapshot in ResponseStore. Regression
|
||||||
|
for PR #15171."""
|
||||||
|
fake_request = MagicMock()
|
||||||
|
fake_request.headers = {}
|
||||||
|
|
||||||
|
write_call_count = {"n": 0}
|
||||||
|
|
||||||
|
class _DisconnectingStreamResponse:
|
||||||
|
async def prepare(self, req):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def write(self, payload):
|
||||||
|
# First two writes succeed (prepare + response.created).
|
||||||
|
# On the third write (a text delta), the "client"
|
||||||
|
# disconnects — simulate with ConnectionResetError.
|
||||||
|
write_call_count["n"] += 1
|
||||||
|
if write_call_count["n"] >= 3:
|
||||||
|
raise ConnectionResetError("simulated client disconnect")
|
||||||
|
|
||||||
|
import gateway.platforms.api_server as api_mod
|
||||||
|
import queue as _q
|
||||||
|
|
||||||
|
stream_q: _q.Queue = _q.Queue()
|
||||||
|
stream_q.put("some streamed text")
|
||||||
|
stream_q.put(None) # EOS sentinel
|
||||||
|
|
||||||
|
async def _agent_coro():
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
return ({"final_response": "", "messages": [], "api_calls": 0},
|
||||||
|
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
||||||
|
|
||||||
|
agent_task = asyncio.ensure_future(_agent_coro())
|
||||||
|
response_id = f"resp_{uuid.uuid4().hex[:28]}"
|
||||||
|
|
||||||
|
with patch.object(api_mod.web, "StreamResponse", return_value=_DisconnectingStreamResponse()):
|
||||||
|
await adapter._write_sse_responses(
|
||||||
|
request=fake_request,
|
||||||
|
response_id=response_id,
|
||||||
|
model="hermes-agent",
|
||||||
|
created_at=int(time.time()),
|
||||||
|
stream_q=stream_q,
|
||||||
|
agent_task=agent_task,
|
||||||
|
agent_ref=[None],
|
||||||
|
conversation_history=[],
|
||||||
|
user_message="will disconnect",
|
||||||
|
instructions=None,
|
||||||
|
conversation=None,
|
||||||
|
store=True,
|
||||||
|
session_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
stored = adapter._response_store.get(response_id)
|
||||||
|
assert stored is not None, "snapshot must survive client disconnect"
|
||||||
|
assert stored["response"]["status"] == "incomplete"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Auth on endpoints
|
# Auth on endpoints
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue