fix(api-server): cancel orphaned agent + true interrupt on SSE disconnect (salvage #3399) (#3427)

Salvage of #3399 by @binhnt92 with true agent interruption added on top.

When a streaming /v1/chat/completions client disconnects mid-stream, the agent is now interrupted via agent.interrupt() so it stops making LLM API calls, and the asyncio task wrapper is cancelled.

Closes #3399.
This commit is contained in:
Teknium 2026-03-27 11:33:19 -07:00 committed by GitHub
parent 5127567d5d
commit f57ebf52e9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 377 additions and 62 deletions

View file

@ -495,17 +495,21 @@ class APIServerAdapter(BasePlatformAdapter):
if delta is not None:
_stream_q.put(delta)
# Start agent in background
# Start agent in background. agent_ref is a mutable container
# so the SSE writer can interrupt the agent on client disconnect.
agent_ref = [None]
agent_task = asyncio.ensure_future(self._run_agent(
user_message=user_message,
conversation_history=history,
ephemeral_system_prompt=system_prompt,
session_id=session_id,
stream_delta_callback=_on_delta,
agent_ref=agent_ref,
))
return await self._write_sse_chat_completion(
request, completion_id, model_name, created, _stream_q, agent_task
request, completion_id, model_name, created, _stream_q,
agent_task, agent_ref,
)
# Non-streaming: run the agent (with optional Idempotency-Key)
@ -568,9 +572,14 @@ class APIServerAdapter(BasePlatformAdapter):
async def _write_sse_chat_completion(
self, request: "web.Request", completion_id: str, model: str,
created: int, stream_q, agent_task,
created: int, stream_q, agent_task, agent_ref=None,
) -> "web.StreamResponse":
"""Write real streaming SSE from agent's stream_delta_callback queue."""
"""Write real streaming SSE from agent's stream_delta_callback queue.
If the client disconnects mid-stream (network drop, browser tab close),
the agent is interrupted via ``agent.interrupt()`` so it stops making
LLM API calls, and the asyncio task wrapper is cancelled.
"""
import queue as _q
response = web.StreamResponse(
@ -579,69 +588,87 @@ class APIServerAdapter(BasePlatformAdapter):
)
await response.prepare(request)
# Role chunk
role_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
# Stream content chunks as they arrive from the agent
loop = asyncio.get_event_loop()
while True:
try:
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
except _q.Empty:
if agent_task.done():
# Drain any remaining items
while True:
try:
delta = stream_q.get_nowait()
if delta is None:
break
content_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
except _q.Empty:
break
break
continue
if delta is None: # End of stream sentinel
break
content_chunk = {
try:
# Role chunk
role_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
# Get usage from completed agent
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
try:
result, agent_usage = await agent_task
usage = agent_usage or usage
except Exception:
pass
# Stream content chunks as they arrive from the agent
loop = asyncio.get_event_loop()
while True:
try:
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
except _q.Empty:
if agent_task.done():
# Drain any remaining items
while True:
try:
delta = stream_q.get_nowait()
if delta is None:
break
content_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
except _q.Empty:
break
break
continue
# Finish chunk
finish_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
if delta is None: # End of stream sentinel
break
content_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
# Get usage from completed agent
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
try:
result, agent_usage = await agent_task
usage = agent_usage or usage
except Exception:
pass
# Finish chunk
finish_chunk = {
"id": completion_id, "object": "chat.completion.chunk",
"created": created, "model": model,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": usage.get("input_tokens", 0),
"completion_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
}
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
# Client disconnected mid-stream. Interrupt the agent so it
# stops making LLM API calls at the next loop iteration, then
# cancel the asyncio task wrapper.
agent = agent_ref[0] if agent_ref else None
if agent is not None:
try:
agent.interrupt("SSE client disconnected")
except Exception:
pass
if not agent_task.done():
agent_task.cancel()
try:
await agent_task
except (asyncio.CancelledError, Exception):
pass
logger.info("SSE client disconnected; interrupted agent task %s", completion_id)
return response
@ -1144,12 +1171,18 @@ class APIServerAdapter(BasePlatformAdapter):
ephemeral_system_prompt: Optional[str] = None,
session_id: Optional[str] = None,
stream_delta_callback=None,
agent_ref: Optional[list] = None,
) -> tuple:
"""
Create an agent and run a conversation in a thread executor.
Returns ``(result_dict, usage_dict)`` where *usage_dict* contains
``input_tokens``, ``output_tokens`` and ``total_tokens``.
If *agent_ref* is a one-element list, the AIAgent instance is stored
at ``agent_ref[0]`` before ``run_conversation`` begins. This allows
callers (e.g. the SSE writer) to call ``agent.interrupt()`` from
another thread to stop in-progress LLM calls.
"""
loop = asyncio.get_event_loop()
@ -1159,6 +1192,8 @@ class APIServerAdapter(BasePlatformAdapter):
session_id=session_id,
stream_delta_callback=stream_delta_callback,
)
if agent_ref is not None:
agent_ref[0] = agent
result = agent.run_conversation(
user_message=user_message,
conversation_history=conversation_history,

View file

@ -0,0 +1,280 @@
"""Tests for SSE client disconnect → agent task cancellation.
When a streaming /v1/chat/completions client disconnects mid-stream
(network drop, browser tab close), the agent is interrupted via
agent.interrupt() so it stops making LLM API calls, and the asyncio
task wrapper is cancelled.
"""
import asyncio
import json
import queue
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Build a minimal APIServerAdapter with mocked internals."""
from gateway.platforms.api_server import APIServerAdapter
from gateway.config import PlatformConfig
config = PlatformConfig(enabled=True, token="test-key")
adapter = APIServerAdapter(config)
return adapter
def _make_request():
"""Build a mock aiohttp request."""
req = MagicMock()
req.headers = {}
return req
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSSEAgentCancelOnDisconnect:
"""gateway/platforms/api_server.py — _write_sse_chat_completion()"""
def test_agent_task_cancelled_on_client_disconnect(self):
"""When response.write raises ConnectionResetError (client dropped),
the agent task must be cancelled."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello ") # Some data already queued
# Agent task that runs forever (simulates a long LLM call)
agent_done = asyncio.Event()
async def fake_agent():
await agent_done.wait()
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
# Mock response that raises ConnectionResetError on second write
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("client disconnected")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch.object(type(adapter), '_write_sse_chat_completion',
adapter._write_sse_chat_completion):
# Patch StreamResponse creation
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-123", "gpt-4", 1234567890,
stream_q, agent_task,
)
# The critical assertion: agent_task must be cancelled
assert agent_task.cancelled() or agent_task.done()
# Clean up
agent_done.set()
asyncio.run(run())
def test_agent_task_not_cancelled_on_normal_completion(self):
"""On normal stream completion, agent task should NOT be cancelled."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello")
stream_q.put(None) # End-of-stream sentinel
async def fake_agent():
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
await asyncio.sleep(0) # Let agent complete
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock()
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-456", "gpt-4", 1234567890,
stream_q, agent_task,
)
# Agent should have completed normally, not been cancelled
assert agent_task.done()
assert not agent_task.cancelled()
asyncio.run(run())
def test_broken_pipe_also_cancels_agent(self):
"""BrokenPipeError (another disconnect variant) also cancels the task."""
adapter = _make_adapter()
stream_q = queue.Queue()
async def fake_agent():
await asyncio.sleep(999) # Never completes
return {}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken"))
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-789", "gpt-4", 1234567890,
stream_q, agent_task,
)
assert agent_task.cancelled() or agent_task.done()
asyncio.run(run())
def test_already_done_task_not_cancelled_on_disconnect(self):
"""If agent already finished before disconnect, don't try to cancel."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("data")
async def fake_agent():
return {"final_response": "done"}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
await asyncio.sleep(0) # Let agent complete
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("late disconnect")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-done", "gpt-4", 1234567890,
stream_q, agent_task,
)
# Task was already done — should not be cancelled
assert agent_task.done()
assert not agent_task.cancelled()
asyncio.run(run())
def test_agent_interrupt_called_on_disconnect(self):
"""When the client disconnects, agent.interrupt() must be called
so the agent thread stops making LLM API calls."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello ")
agent_done = asyncio.Event()
async def fake_agent():
await agent_done.wait()
return {"final_response": "done"}, {}
# Mock agent with an interrupt method
mock_agent = MagicMock()
mock_agent.interrupt = MagicMock()
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
agent_ref = [mock_agent]
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("client disconnected")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-int", "gpt-4", 1234567890,
stream_q, agent_task, agent_ref,
)
# agent.interrupt() must have been called
mock_agent.interrupt.assert_called_once_with("SSE client disconnected")
# Clean up
agent_done.set()
asyncio.run(run())
def test_agent_ref_none_still_cancels_task(self):
"""When agent_ref is not provided (None), the task is still cancelled
on disconnect just without the interrupt() call."""
adapter = _make_adapter()
stream_q = queue.Queue()
async def fake_agent():
await asyncio.sleep(999)
return {}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone"))
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
# No agent_ref passed — should still handle disconnect cleanly
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-noref", "gpt-4", 1234567890,
stream_q, agent_task,
)
assert agent_task.cancelled() or agent_task.done()
asyncio.run(run())