fix(api-server): persist incomplete snapshot on asyncio.CancelledError too

Extends PR #15171 to also cover the server-side cancellation path (aiohttp
shutdown, request-level timeout) — previously only ConnectionResetError
triggered the incomplete-snapshot write, so cancellations left the store
stuck at the in_progress snapshot written on response.created.

Factors the incomplete-snapshot build into a _persist_incomplete_if_needed()
helper called from both the ConnectionResetError and CancelledError
branches; the CancelledError handler re-raises so cooperative cancellation
semantics are preserved.

Adds two regression tests that drive _write_sse_responses directly (the
TestClient disconnect path races the server handler, which makes the
end-to-end assertion flaky).
This commit is contained in:
Teknium 2026-04-24 15:21:39 -07:00 committed by Teknium
parent a29bad2a3c
commit 36d68bcb82
2 changed files with 184 additions and 24 deletions

View file

@ -1292,6 +1292,40 @@ class APIServerAdapter(BasePlatformAdapter):
if conversation:
self._response_store.set_conversation(conversation, response_id)
def _persist_incomplete_if_needed() -> None:
"""Persist an ``incomplete`` snapshot if no terminal one was written.
Called from both the client-disconnect (``ConnectionResetError``)
and server-cancellation (``asyncio.CancelledError``) paths so
GET /v1/responses/{id} and ``previous_response_id`` chaining keep
working after abrupt stream termination.
"""
if not store or terminal_snapshot_persisted:
return
incomplete_text = "".join(final_text_parts) or final_response_text
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
if incomplete_text:
incomplete_items.append({
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": incomplete_text}],
})
incomplete_env = _envelope("incomplete")
incomplete_env["output"] = incomplete_items
incomplete_env["usage"] = {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
incomplete_history = list(conversation_history)
incomplete_history.append({"role": "user", "content": user_message})
if incomplete_text:
incomplete_history.append({"role": "assistant", "content": incomplete_text})
_persist_response_snapshot(
incomplete_env,
conversation_history_snapshot=incomplete_history,
)
try:
# response.created — initial envelope, status=in_progress
created_env = _envelope("in_progress")
@ -1598,30 +1632,7 @@ class APIServerAdapter(BasePlatformAdapter):
})
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
if store and not terminal_snapshot_persisted:
incomplete_text = "".join(final_text_parts) or final_response_text
incomplete_items: List[Dict[str, Any]] = list(emitted_items)
if incomplete_text:
incomplete_items.append({
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": incomplete_text}],
})
incomplete_env = _envelope("incomplete")
incomplete_env["output"] = incomplete_items
incomplete_env["usage"] = {
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
}
incomplete_history = list(conversation_history)
incomplete_history.append({"role": "user", "content": user_message})
if incomplete_text:
incomplete_history.append({"role": "assistant", "content": incomplete_text})
_persist_response_snapshot(
incomplete_env,
conversation_history_snapshot=incomplete_history,
)
_persist_incomplete_if_needed()
# Client disconnected — interrupt the agent so it stops
# making upstream LLM calls, then cancel the task.
agent = agent_ref[0] if agent_ref else None
@ -1637,6 +1648,22 @@ class APIServerAdapter(BasePlatformAdapter):
except (asyncio.CancelledError, Exception):
pass
logger.info("SSE client disconnected; interrupted agent task %s", response_id)
except asyncio.CancelledError:
# Server-side cancellation (e.g. shutdown, request timeout) —
# persist an incomplete snapshot so GET /v1/responses/{id} and
# previous_response_id chaining still work, then re-raise so the
# runtime's cancellation semantics are respected.
_persist_incomplete_if_needed()
agent = agent_ref[0] if agent_ref else None
if agent is not None:
try:
agent.interrupt("SSE task cancelled")
except Exception:
pass
if not agent_task.done():
agent_task.cancel()
logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id)
raise
return response

View file

@ -1374,6 +1374,139 @@ class TestResponsesStreaming:
assert data["status"] == "completed"
assert data["output"][-1]["content"][0]["text"] == "Stored response"
@pytest.mark.asyncio
async def test_stream_cancelled_persists_incomplete_snapshot(self, adapter):
"""Server-side asyncio.CancelledError (shutdown, request timeout) must
still leave an ``incomplete`` snapshot in ResponseStore so
GET /v1/responses/{id} and previous_response_id chaining keep
working. Regression for PR #15171 follow-up.
Calls _write_sse_responses directly so the test can await the
handler to completion (TestClient disconnection races the server
handler, which makes end-to-end assertion on the final stored
snapshot flaky).
"""
# Build a minimal fake request + stream queue the writer understands.
fake_request = MagicMock()
fake_request.headers = {}
written_payloads: list = []
class _FakeStreamResponse:
async def prepare(self, req):
pass
async def write(self, payload):
written_payloads.append(payload)
# Patch web.StreamResponse for the duration of the writer call.
import gateway.platforms.api_server as api_mod
import queue as _q
stream_q: _q.Queue = _q.Queue()
async def _agent_coro():
# Feed one partial delta into the stream queue...
stream_q.put("partial output")
# ...then give the drain loop a moment to pick it up before
# raising CancelledError to simulate a server-side cancel.
await asyncio.sleep(0.01)
raise asyncio.CancelledError()
agent_task = asyncio.ensure_future(_agent_coro())
response_id = f"resp_{uuid.uuid4().hex[:28]}"
with patch.object(api_mod.web, "StreamResponse", return_value=_FakeStreamResponse()):
with pytest.raises(asyncio.CancelledError):
await adapter._write_sse_responses(
request=fake_request,
response_id=response_id,
model="hermes-agent",
created_at=int(time.time()),
stream_q=stream_q,
agent_task=agent_task,
agent_ref=[None],
conversation_history=[],
user_message="will be cancelled",
instructions=None,
conversation=None,
store=True,
session_id=None,
)
# The in_progress snapshot was persisted on response.created,
# and the CancelledError handler must have updated it to
# ``incomplete`` with the partial text it saw.
stored = adapter._response_store.get(response_id)
assert stored is not None, "snapshot must be retrievable after cancellation"
assert stored["response"]["status"] == "incomplete"
# Partial text captured before cancel should be preserved.
output_text = "".join(
part.get("text", "")
for item in stored["response"].get("output", [])
if item.get("type") == "message"
for part in item.get("content", [])
)
assert "partial output" in output_text
@pytest.mark.asyncio
async def test_stream_client_disconnect_persists_incomplete_snapshot(self, adapter):
"""Client disconnect (ConnectionResetError) during streaming must
persist an ``incomplete`` snapshot in ResponseStore. Regression
for PR #15171."""
fake_request = MagicMock()
fake_request.headers = {}
write_call_count = {"n": 0}
class _DisconnectingStreamResponse:
async def prepare(self, req):
pass
async def write(self, payload):
# First two writes succeed (prepare + response.created).
# On the third write (a text delta), the "client"
# disconnects — simulate with ConnectionResetError.
write_call_count["n"] += 1
if write_call_count["n"] >= 3:
raise ConnectionResetError("simulated client disconnect")
import gateway.platforms.api_server as api_mod
import queue as _q
stream_q: _q.Queue = _q.Queue()
stream_q.put("some streamed text")
stream_q.put(None) # EOS sentinel
async def _agent_coro():
await asyncio.sleep(0.01)
return ({"final_response": "", "messages": [], "api_calls": 0},
{"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
agent_task = asyncio.ensure_future(_agent_coro())
response_id = f"resp_{uuid.uuid4().hex[:28]}"
with patch.object(api_mod.web, "StreamResponse", return_value=_DisconnectingStreamResponse()):
await adapter._write_sse_responses(
request=fake_request,
response_id=response_id,
model="hermes-agent",
created_at=int(time.time()),
stream_q=stream_q,
agent_task=agent_task,
agent_ref=[None],
conversation_history=[],
user_message="will disconnect",
instructions=None,
conversation=None,
store=True,
session_id=None,
)
stored = adapter._response_store.get(response_id)
assert stored is not None, "snapshot must survive client disconnect"
assert stored["response"]["status"] == "incomplete"
# ---------------------------------------------------------------------------
# Auth on endpoints