diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 37b7121a5..db3304a09 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -1292,6 +1292,40 @@ class APIServerAdapter(BasePlatformAdapter): if conversation: self._response_store.set_conversation(conversation, response_id) + def _persist_incomplete_if_needed() -> None: + """Persist an ``incomplete`` snapshot if no terminal one was written. + + Called from both the client-disconnect (``ConnectionResetError``) + and server-cancellation (``asyncio.CancelledError``) paths so + GET /v1/responses/{id} and ``previous_response_id`` chaining keep + working after abrupt stream termination. + """ + if not store or terminal_snapshot_persisted: + return + incomplete_text = "".join(final_text_parts) or final_response_text + incomplete_items: List[Dict[str, Any]] = list(emitted_items) + if incomplete_text: + incomplete_items.append({ + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": incomplete_text}], + }) + incomplete_env = _envelope("incomplete") + incomplete_env["output"] = incomplete_items + incomplete_env["usage"] = { + "input_tokens": usage.get("input_tokens", 0), + "output_tokens": usage.get("output_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + } + incomplete_history = list(conversation_history) + incomplete_history.append({"role": "user", "content": user_message}) + if incomplete_text: + incomplete_history.append({"role": "assistant", "content": incomplete_text}) + _persist_response_snapshot( + incomplete_env, + conversation_history_snapshot=incomplete_history, + ) + try: # response.created — initial envelope, status=in_progress created_env = _envelope("in_progress") @@ -1598,30 +1632,7 @@ class APIServerAdapter(BasePlatformAdapter): }) except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError): - if store and not terminal_snapshot_persisted: - incomplete_text = "".join(final_text_parts) or final_response_text - incomplete_items: List[Dict[str, Any]] = list(emitted_items) - if incomplete_text: - incomplete_items.append({ - "type": "message", - "role": "assistant", - "content": [{"type": "output_text", "text": incomplete_text}], - }) - incomplete_env = _envelope("incomplete") - incomplete_env["output"] = incomplete_items - incomplete_env["usage"] = { - "input_tokens": usage.get("input_tokens", 0), - "output_tokens": usage.get("output_tokens", 0), - "total_tokens": usage.get("total_tokens", 0), - } - incomplete_history = list(conversation_history) - incomplete_history.append({"role": "user", "content": user_message}) - if incomplete_text: - incomplete_history.append({"role": "assistant", "content": incomplete_text}) - _persist_response_snapshot( - incomplete_env, - conversation_history_snapshot=incomplete_history, - ) + _persist_incomplete_if_needed() # Client disconnected — interrupt the agent so it stops # making upstream LLM calls, then cancel the task. agent = agent_ref[0] if agent_ref else None @@ -1637,6 +1648,22 @@ class APIServerAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass logger.info("SSE client disconnected; interrupted agent task %s", response_id) + except asyncio.CancelledError: + # Server-side cancellation (e.g. shutdown, request timeout) — + # persist an incomplete snapshot so GET /v1/responses/{id} and + # previous_response_id chaining still work, then re-raise so the + # runtime's cancellation semantics are respected. + _persist_incomplete_if_needed() + agent = agent_ref[0] if agent_ref else None + if agent is not None: + try: + agent.interrupt("SSE task cancelled") + except Exception: + pass + if not agent_task.done(): + agent_task.cancel() + logger.info("SSE task cancelled; persisted incomplete snapshot for %s", response_id) + raise return response diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index ca229f26f..828585106 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -1374,6 +1374,139 @@ class TestResponsesStreaming: assert data["status"] == "completed" assert data["output"][-1]["content"][0]["text"] == "Stored response" + @pytest.mark.asyncio + async def test_stream_cancelled_persists_incomplete_snapshot(self, adapter): + """Server-side asyncio.CancelledError (shutdown, request timeout) must + still leave an ``incomplete`` snapshot in ResponseStore so + GET /v1/responses/{id} and previous_response_id chaining keep + working. Regression for PR #15171 follow-up. + + Calls _write_sse_responses directly so the test can await the + handler to completion (TestClient disconnection races the server + handler, which makes end-to-end assertion on the final stored + snapshot flaky). + """ + # Build a minimal fake request + stream queue the writer understands. + fake_request = MagicMock() + fake_request.headers = {} + + written_payloads: list = [] + + class _FakeStreamResponse: + async def prepare(self, req): + pass + + async def write(self, payload): + written_payloads.append(payload) + + # Patch web.StreamResponse for the duration of the writer call. + import gateway.platforms.api_server as api_mod + import queue as _q + + stream_q: _q.Queue = _q.Queue() + + async def _agent_coro(): + # Feed one partial delta into the stream queue... + stream_q.put("partial output") + # ...then give the drain loop a moment to pick it up before + # raising CancelledError to simulate a server-side cancel. + await asyncio.sleep(0.01) + raise asyncio.CancelledError() + + agent_task = asyncio.ensure_future(_agent_coro()) + response_id = f"resp_{uuid.uuid4().hex[:28]}" + + with patch.object(api_mod.web, "StreamResponse", return_value=_FakeStreamResponse()): + with pytest.raises(asyncio.CancelledError): + await adapter._write_sse_responses( + request=fake_request, + response_id=response_id, + model="hermes-agent", + created_at=int(time.time()), + stream_q=stream_q, + agent_task=agent_task, + agent_ref=[None], + conversation_history=[], + user_message="will be cancelled", + instructions=None, + conversation=None, + store=True, + session_id=None, + ) + + # The in_progress snapshot was persisted on response.created, + # and the CancelledError handler must have updated it to + # ``incomplete`` with the partial text it saw. + stored = adapter._response_store.get(response_id) + assert stored is not None, "snapshot must be retrievable after cancellation" + assert stored["response"]["status"] == "incomplete" + # Partial text captured before cancel should be preserved. + output_text = "".join( + part.get("text", "") + for item in stored["response"].get("output", []) + if item.get("type") == "message" + for part in item.get("content", []) + ) + assert "partial output" in output_text + + @pytest.mark.asyncio + async def test_stream_client_disconnect_persists_incomplete_snapshot(self, adapter): + """Client disconnect (ConnectionResetError) during streaming must + persist an ``incomplete`` snapshot in ResponseStore. Regression + for PR #15171.""" + fake_request = MagicMock() + fake_request.headers = {} + + write_call_count = {"n": 0} + + class _DisconnectingStreamResponse: + async def prepare(self, req): + pass + + async def write(self, payload): + # First two writes succeed (prepare + response.created). + # On the third write (a text delta), the "client" + # disconnects — simulate with ConnectionResetError. + write_call_count["n"] += 1 + if write_call_count["n"] >= 3: + raise ConnectionResetError("simulated client disconnect") + + import gateway.platforms.api_server as api_mod + import queue as _q + + stream_q: _q.Queue = _q.Queue() + stream_q.put("some streamed text") + stream_q.put(None) # EOS sentinel + + async def _agent_coro(): + await asyncio.sleep(0.01) + return ({"final_response": "", "messages": [], "api_calls": 0}, + {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}) + + agent_task = asyncio.ensure_future(_agent_coro()) + response_id = f"resp_{uuid.uuid4().hex[:28]}" + + with patch.object(api_mod.web, "StreamResponse", return_value=_DisconnectingStreamResponse()): + await adapter._write_sse_responses( + request=fake_request, + response_id=response_id, + model="hermes-agent", + created_at=int(time.time()), + stream_q=stream_q, + agent_task=agent_task, + agent_ref=[None], + conversation_history=[], + user_message="will disconnect", + instructions=None, + conversation=None, + store=True, + session_id=None, + ) + + stored = adapter._response_store.get(response_id) + assert stored is not None, "snapshot must survive client disconnect" + assert stored["response"]["status"] == "incomplete" + # --------------------------------------------------------------------------- # Auth on endpoints