mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Salvage of #3399 by @binhnt92 with true agent interruption added on top. When a streaming /v1/chat/completions client disconnects mid-stream, the agent is now interrupted via agent.interrupt() so it stops making LLM API calls, and the asyncio task wrapper is cancelled. Closes #3399.
This commit is contained in:
parent
5127567d5d
commit
f57ebf52e9
2 changed files with 377 additions and 62 deletions
|
|
@ -495,17 +495,21 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
if delta is not None:
|
||||
_stream_q.put(delta)
|
||||
|
||||
# Start agent in background
|
||||
# Start agent in background. agent_ref is a mutable container
|
||||
# so the SSE writer can interrupt the agent on client disconnect.
|
||||
agent_ref = [None]
|
||||
agent_task = asyncio.ensure_future(self._run_agent(
|
||||
user_message=user_message,
|
||||
conversation_history=history,
|
||||
ephemeral_system_prompt=system_prompt,
|
||||
session_id=session_id,
|
||||
stream_delta_callback=_on_delta,
|
||||
agent_ref=agent_ref,
|
||||
))
|
||||
|
||||
return await self._write_sse_chat_completion(
|
||||
request, completion_id, model_name, created, _stream_q, agent_task
|
||||
request, completion_id, model_name, created, _stream_q,
|
||||
agent_task, agent_ref,
|
||||
)
|
||||
|
||||
# Non-streaming: run the agent (with optional Idempotency-Key)
|
||||
|
|
@ -568,9 +572,14 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
|
||||
async def _write_sse_chat_completion(
|
||||
self, request: "web.Request", completion_id: str, model: str,
|
||||
created: int, stream_q, agent_task,
|
||||
created: int, stream_q, agent_task, agent_ref=None,
|
||||
) -> "web.StreamResponse":
|
||||
"""Write real streaming SSE from agent's stream_delta_callback queue."""
|
||||
"""Write real streaming SSE from agent's stream_delta_callback queue.
|
||||
|
||||
If the client disconnects mid-stream (network drop, browser tab close),
|
||||
the agent is interrupted via ``agent.interrupt()`` so it stops making
|
||||
LLM API calls, and the asyncio task wrapper is cancelled.
|
||||
"""
|
||||
import queue as _q
|
||||
|
||||
response = web.StreamResponse(
|
||||
|
|
@ -579,69 +588,87 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
)
|
||||
await response.prepare(request)
|
||||
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
|
||||
except _q.Empty:
|
||||
if agent_task.done():
|
||||
# Drain any remaining items
|
||||
while True:
|
||||
try:
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
content_chunk = {
|
||||
try:
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
try:
|
||||
result, agent_usage = await agent_task
|
||||
usage = agent_usage or usage
|
||||
except Exception:
|
||||
pass
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
try:
|
||||
delta = await loop.run_in_executor(None, lambda: stream_q.get(timeout=0.5))
|
||||
except _q.Empty:
|
||||
if agent_task.done():
|
||||
# Drain any remaining items
|
||||
while True:
|
||||
try:
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
continue
|
||||
|
||||
# Finish chunk
|
||||
finish_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
content_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {"content": delta}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
try:
|
||||
result, agent_usage = await agent_task
|
||||
usage = agent_usage or usage
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Finish chunk
|
||||
finish_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
"created": created, "model": model,
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
"usage": {
|
||||
"prompt_tokens": usage.get("input_tokens", 0),
|
||||
"completion_tokens": usage.get("output_tokens", 0),
|
||||
"total_tokens": usage.get("total_tokens", 0),
|
||||
},
|
||||
}
|
||||
await response.write(f"data: {json.dumps(finish_chunk)}\n\n".encode())
|
||||
await response.write(b"data: [DONE]\n\n")
|
||||
except (ConnectionResetError, ConnectionAbortedError, BrokenPipeError, OSError):
|
||||
# Client disconnected mid-stream. Interrupt the agent so it
|
||||
# stops making LLM API calls at the next loop iteration, then
|
||||
# cancel the asyncio task wrapper.
|
||||
agent = agent_ref[0] if agent_ref else None
|
||||
if agent is not None:
|
||||
try:
|
||||
agent.interrupt("SSE client disconnected")
|
||||
except Exception:
|
||||
pass
|
||||
if not agent_task.done():
|
||||
agent_task.cancel()
|
||||
try:
|
||||
await agent_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
logger.info("SSE client disconnected; interrupted agent task %s", completion_id)
|
||||
|
||||
return response
|
||||
|
||||
|
|
@ -1144,12 +1171,18 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
ephemeral_system_prompt: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
stream_delta_callback=None,
|
||||
agent_ref: Optional[list] = None,
|
||||
) -> tuple:
|
||||
"""
|
||||
Create an agent and run a conversation in a thread executor.
|
||||
|
||||
Returns ``(result_dict, usage_dict)`` where *usage_dict* contains
|
||||
``input_tokens``, ``output_tokens`` and ``total_tokens``.
|
||||
|
||||
If *agent_ref* is a one-element list, the AIAgent instance is stored
|
||||
at ``agent_ref[0]`` before ``run_conversation`` begins. This allows
|
||||
callers (e.g. the SSE writer) to call ``agent.interrupt()`` from
|
||||
another thread to stop in-progress LLM calls.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
|
|
@ -1159,6 +1192,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||
session_id=session_id,
|
||||
stream_delta_callback=stream_delta_callback,
|
||||
)
|
||||
if agent_ref is not None:
|
||||
agent_ref[0] = agent
|
||||
result = agent.run_conversation(
|
||||
user_message=user_message,
|
||||
conversation_history=conversation_history,
|
||||
|
|
|
|||
280
tests/gateway/test_sse_agent_cancel.py
Normal file
280
tests/gateway/test_sse_agent_cancel.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""Tests for SSE client disconnect → agent task cancellation.
|
||||
|
||||
When a streaming /v1/chat/completions client disconnects mid-stream
|
||||
(network drop, browser tab close), the agent is interrupted via
|
||||
agent.interrupt() so it stops making LLM API calls, and the asyncio
|
||||
task wrapper is cancelled.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_adapter():
|
||||
"""Build a minimal APIServerAdapter with mocked internals."""
|
||||
from gateway.platforms.api_server import APIServerAdapter
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
config = PlatformConfig(enabled=True, token="test-key")
|
||||
adapter = APIServerAdapter(config)
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_request():
|
||||
"""Build a mock aiohttp request."""
|
||||
req = MagicMock()
|
||||
req.headers = {}
|
||||
return req
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSSEAgentCancelOnDisconnect:
|
||||
"""gateway/platforms/api_server.py — _write_sse_chat_completion()"""
|
||||
|
||||
def test_agent_task_cancelled_on_client_disconnect(self):
|
||||
"""When response.write raises ConnectionResetError (client dropped),
|
||||
the agent task must be cancelled."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello ") # Some data already queued
|
||||
|
||||
# Agent task that runs forever (simulates a long LLM call)
|
||||
agent_done = asyncio.Event()
|
||||
|
||||
async def fake_agent():
|
||||
await agent_done.wait()
|
||||
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
# Mock response that raises ConnectionResetError on second write
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("client disconnected")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch.object(type(adapter), '_write_sse_chat_completion',
|
||||
adapter._write_sse_chat_completion):
|
||||
# Patch StreamResponse creation
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-123", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# The critical assertion: agent_task must be cancelled
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
# Clean up
|
||||
agent_done.set()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_task_not_cancelled_on_normal_completion(self):
|
||||
"""On normal stream completion, agent task should NOT be cancelled."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello")
|
||||
stream_q.put(None) # End-of-stream sentinel
|
||||
|
||||
async def fake_agent():
|
||||
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
await asyncio.sleep(0) # Let agent complete
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock()
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-456", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# Agent should have completed normally, not been cancelled
|
||||
assert agent_task.done()
|
||||
assert not agent_task.cancelled()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_broken_pipe_also_cancels_agent(self):
|
||||
"""BrokenPipeError (another disconnect variant) also cancels the task."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
|
||||
async def fake_agent():
|
||||
await asyncio.sleep(999) # Never completes
|
||||
return {}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken"))
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-789", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_already_done_task_not_cancelled_on_disconnect(self):
|
||||
"""If agent already finished before disconnect, don't try to cancel."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("data")
|
||||
|
||||
async def fake_agent():
|
||||
return {"final_response": "done"}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
await asyncio.sleep(0) # Let agent complete
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("late disconnect")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-done", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
# Task was already done — should not be cancelled
|
||||
assert agent_task.done()
|
||||
assert not agent_task.cancelled()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_interrupt_called_on_disconnect(self):
|
||||
"""When the client disconnects, agent.interrupt() must be called
|
||||
so the agent thread stops making LLM API calls."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
stream_q.put("hello ")
|
||||
|
||||
agent_done = asyncio.Event()
|
||||
|
||||
async def fake_agent():
|
||||
await agent_done.wait()
|
||||
return {"final_response": "done"}, {}
|
||||
|
||||
# Mock agent with an interrupt method
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.interrupt = MagicMock()
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
agent_ref = [mock_agent]
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
call_count = 0
|
||||
|
||||
async def write_side_effect(data):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count >= 2:
|
||||
raise ConnectionResetError("client disconnected")
|
||||
|
||||
mock_response.write = AsyncMock(side_effect=write_side_effect)
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-int", "gpt-4", 1234567890,
|
||||
stream_q, agent_task, agent_ref,
|
||||
)
|
||||
|
||||
# agent.interrupt() must have been called
|
||||
mock_agent.interrupt.assert_called_once_with("SSE client disconnected")
|
||||
# Clean up
|
||||
agent_done.set()
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
def test_agent_ref_none_still_cancels_task(self):
|
||||
"""When agent_ref is not provided (None), the task is still cancelled
|
||||
on disconnect — just without the interrupt() call."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
stream_q = queue.Queue()
|
||||
|
||||
async def fake_agent():
|
||||
await asyncio.sleep(999)
|
||||
return {}, {}
|
||||
|
||||
async def run():
|
||||
from aiohttp import web
|
||||
|
||||
agent_task = asyncio.ensure_future(fake_agent())
|
||||
|
||||
mock_response = AsyncMock(spec=web.StreamResponse)
|
||||
mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone"))
|
||||
mock_response.prepare = AsyncMock()
|
||||
|
||||
with patch("gateway.platforms.api_server.web.StreamResponse",
|
||||
return_value=mock_response):
|
||||
# No agent_ref passed — should still handle disconnect cleanly
|
||||
await adapter._write_sse_chat_completion(
|
||||
_make_request(), "cmpl-noref", "gpt-4", 1234567890,
|
||||
stream_q, agent_task,
|
||||
)
|
||||
|
||||
assert agent_task.cancelled() or agent_task.done()
|
||||
|
||||
asyncio.run(run())
|
||||
Loading…
Add table
Add a link
Reference in a new issue