diff --git a/tests/gateway/test_sse_agent_cancel.py b/tests/gateway/test_sse_agent_cancel.py index 2958a5b3e8c..315b373f723 100644 --- a/tests/gateway/test_sse_agent_cancel.py +++ b/tests/gateway/test_sse_agent_cancel.py @@ -276,3 +276,108 @@ class TestSSEAgentCancelOnDisconnect: assert agent_task.cancelled() or agent_task.done() asyncio.run(run()) + + +def _capturing_response(): + """Mock StreamResponse that records all written SSE bytes as text.""" + from aiohttp import web + + chunks: list = [] + resp = AsyncMock(spec=web.StreamResponse) + resp.prepare = AsyncMock() + + async def _write(data): + chunks.append(data.decode() if isinstance(data, (bytes, bytearray)) else data) + + resp.write = AsyncMock(side_effect=_write) + return resp, chunks + + +def _finish_reason(chunks: list): + """Extract the terminal finish_reason and its chunk from captured SSE.""" + import json + + sse = "".join(chunks) + finish = None + for line in sse.splitlines(): + if line.startswith("data: ") and '"finish_reason"' in line: + obj = json.loads(line[6:]) + if obj["choices"][0].get("finish_reason") is not None: + finish = obj + return (finish["choices"][0]["finish_reason"] if finish else None), finish, sse + + +class TestSSEAgentFailureFinishReason: + """gateway/platforms/api_server.py — _write_sse_chat_completion() + + A clean stream-queue termination (sentinel received) followed by an agent + failure must NOT report finish_reason: "stop". Both failure modes — an + ``agent_task`` that raises and a ``result`` dict flagged failed — surface + as finish_reason: "error", mirroring the non-streaming path. Issue #12422. + """ + + def _run(self, fake_agent, queue_items=("partial",)): + adapter = _make_adapter() + stream_q = queue.Queue() + for item in queue_items: + stream_q.put(item) + stream_q.put(None) # clean end-of-stream sentinel + + async def run(): + agent_task = asyncio.ensure_future(fake_agent()) + resp, chunks = _capturing_response() + with patch("gateway.platforms.api_server.web.StreamResponse", + return_value=resp): + await adapter._write_sse_chat_completion( + _make_request(), "cmpl-fail", "gpt-4", 1234567890, + stream_q, agent_task, + ) + return _finish_reason(chunks) + + return asyncio.run(run()) + + def test_agent_task_raises_reports_error_not_stop(self): + async def crash(): + raise RuntimeError("boom from agent") + + reason, finish, sse = self._run(crash) + assert reason == "error" + assert "error" in finish + assert "data: [DONE]" in sse + + def test_failed_result_dict_reports_error_not_stop(self): + async def failed(): + return ( + {"final_response": "", "failed": True, "completed": False, + "error": "upstream model 500"}, + {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5}, + ) + + reason, finish, _ = self._run(failed) + assert reason == "error" + assert finish.get("hermes", {}).get("failed") is True + + def test_truncated_result_reports_length(self): + async def trunc(): + return ( + {"final_response": "half", "partial": True, "completed": False, + "error": "output was truncated"}, + {"input_tokens": 5, "output_tokens": 3, "total_tokens": 8}, + ) + + reason, finish, _ = self._run(trunc) + assert reason == "length" + assert finish["hermes"]["error_code"] == "output_truncated" + + def test_successful_completion_reports_stop(self): + async def ok(): + return ( + {"final_response": "hi", "completed": True}, + {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}, + ) + + reason, finish, _ = self._run(ok) + assert reason == "stop" + # No error/hermes pollution on the happy path. + assert "error" not in finish + assert "hermes" not in finish