diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index db3304a096..659197f696 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -9,6 +9,7 @@ Exposes an HTTP server with endpoints: - GET /v1/models — lists hermes-agent as an available model - POST /v1/runs — start a run, returns run_id immediately (202) - GET /v1/runs/{run_id}/events — SSE stream of structured lifecycle events +- POST /v1/runs/{run_id}/stop — interrupt a running agent - GET /health — health check - GET /health/detailed — rich status for cross-container dashboard probing @@ -586,6 +587,9 @@ class APIServerAdapter(BasePlatformAdapter): self._run_streams: Dict[str, "asyncio.Queue[Optional[Dict]]"] = {} # Creation timestamps for orphaned-run TTL sweep self._run_streams_created: Dict[str, float] = {} + # Active run agent/task references for stop support + self._active_run_agents: Dict[str, Any] = {} + self._active_run_tasks: Dict[str, "asyncio.Task"] = {} self._session_db: Optional[Any] = None # Lazy-init SessionDB for session continuity @staticmethod @@ -2441,6 +2445,7 @@ class APIServerAdapter(BasePlatformAdapter): stream_delta_callback=_text_cb, tool_progress_callback=event_cb, ) + self._active_run_agents[run_id] = agent def _run_sync(): r = agent.run_conversation( user_message=user_message, @@ -2480,8 +2485,11 @@ class APIServerAdapter(BasePlatformAdapter): q.put_nowait(None) except Exception: pass + self._active_run_agents.pop(run_id, None) + self._active_run_tasks.pop(run_id, None) task = asyncio.create_task(_run_and_close()) + self._active_run_tasks[run_id] = task try: self._background_tasks.add(task) except TypeError: @@ -2540,6 +2548,34 @@ class APIServerAdapter(BasePlatformAdapter): return response + async def _handle_stop_run(self, request: "web.Request") -> "web.Response": + """POST /v1/runs/{run_id}/stop — interrupt a running agent.""" + auth_err = self._check_auth(request) + if auth_err: + return auth_err + + run_id = request.match_info["run_id"] + agent = self._active_run_agents.get(run_id) + task = self._active_run_tasks.get(run_id) + + if agent is None and task is None: + return web.json_response(_openai_error(f"Run not found: {run_id}", code="run_not_found"), status=404) + + if agent is not None: + try: + agent.interrupt("Stop requested via API") + except Exception: + pass + + if task is not None and not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + return web.json_response({"run_id": run_id, "status": "stopping"}) + async def _sweep_orphaned_runs(self) -> None: """Periodically clean up run streams that were never consumed.""" while True: @@ -2554,6 +2590,8 @@ class APIServerAdapter(BasePlatformAdapter): logger.debug("[api_server] sweeping orphaned run %s", run_id) self._run_streams.pop(run_id, None) self._run_streams_created.pop(run_id, None) + self._active_run_agents.pop(run_id, None) + self._active_run_tasks.pop(run_id, None) # ------------------------------------------------------------------ # BasePlatformAdapter interface @@ -2589,6 +2627,7 @@ class APIServerAdapter(BasePlatformAdapter): # Structured event streaming self._app.router.add_post("/v1/runs", self._handle_runs) self._app.router.add_get("/v1/runs/{run_id}/events", self._handle_run_events) + self._app.router.add_post("/v1/runs/{run_id}/stop", self._handle_stop_run) # Start background sweep to clean up orphaned (unconsumed) run streams sweep_task = asyncio.create_task(self._sweep_orphaned_runs()) try: diff --git a/tests/gateway/test_api_server_runs.py b/tests/gateway/test_api_server_runs.py new file mode 100644 index 0000000000..e485bad5ce --- /dev/null +++ b/tests/gateway/test_api_server_runs.py @@ -0,0 +1,365 @@ +"""Tests for /v1/runs endpoints: start, events, and stop. + +Covers: +- POST /v1/runs — start a run (202) +- GET /v1/runs/{run_id}/events — SSE event stream +- POST /v1/runs/{run_id}/stop — interrupt a running agent +- Auth, error handling, and cleanup +""" + +import asyncio +import json +import threading +import time as _time +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import web +from aiohttp.test_utils import TestClient, TestServer + +from gateway.config import PlatformConfig +from gateway.platforms.api_server import ( + APIServerAdapter, + cors_middleware, + security_headers_middleware, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_adapter(api_key: str = "") -> APIServerAdapter: + """Create an adapter with optional API key.""" + extra = {} + if api_key: + extra["key"] = api_key + config = PlatformConfig(enabled=True, extra=extra) + adapter = APIServerAdapter(config) + return adapter + + +def _create_runs_app(adapter: APIServerAdapter) -> web.Application: + """Create an aiohttp app with /v1/runs routes registered.""" + mws = [mw for mw in (cors_middleware, security_headers_middleware) if mw is not None] + app = web.Application(middlewares=mws) + app["api_server_adapter"] = adapter + app.router.add_post("/v1/runs", adapter._handle_runs) + app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events) + app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run) + return app + + +def _make_slow_agent(**kwargs): + """Create a mock agent that blocks in run_conversation until interrupted. + + Returns (mock_agent, agent_ready_event, interrupt_event) where + agent_ready_event is set once run_conversation starts, and + interrupt_event is set when interrupt() is called. + """ + ready = threading.Event() + interrupted = threading.Event() + + mock_agent = MagicMock() + + def _do_interrupt(message=None): + interrupted.set() + + mock_agent.interrupt = MagicMock(side_effect=_do_interrupt) + + def _slow_run(user_message=None, conversation_history=None, task_id=None): + ready.set() + # Block until interrupt() is called + interrupted.wait(timeout=10) + return {"final_response": "interrupted"} + + mock_agent.run_conversation.side_effect = _slow_run + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + + return mock_agent, ready, interrupted + + +@pytest.fixture +def adapter(): + return _make_adapter() + + +@pytest.fixture +def auth_adapter(): + return _make_adapter(api_key="sk-secret") + + +# --------------------------------------------------------------------------- +# POST /v1/runs — start a run +# --------------------------------------------------------------------------- + + +class TestStartRun: + @pytest.mark.asyncio + async def test_start_returns_202(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "done"} + mock_agent.session_prompt_tokens = 10 + mock_agent.session_completion_tokens = 5 + mock_agent.session_total_tokens = 15 + mock_create.return_value = mock_agent + + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + assert data["status"] == "started" + assert data["run_id"].startswith("run_") + + @pytest.mark.asyncio + async def test_start_invalid_json_returns_400(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post( + "/v1/runs", + data="not json", + headers={"Content-Type": "application/json"}, + ) + assert resp.status == 400 + + @pytest.mark.asyncio + async def test_start_missing_input_returns_400(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/v1/runs", json={"model": "test"}) + assert resp.status == 400 + data = await resp.json() + assert "input" in data["error"]["message"] + + @pytest.mark.asyncio + async def test_start_empty_input_returns_400(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/v1/runs", json={"input": ""}) + assert resp.status == 400 + + @pytest.mark.asyncio + async def test_start_requires_auth(self, auth_adapter): + app = _create_runs_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 401 + + @pytest.mark.asyncio + async def test_start_with_valid_auth(self, auth_adapter): + app = _create_runs_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(auth_adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "ok"} + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + mock_create.return_value = mock_agent + + resp = await cli.post( + "/v1/runs", + json={"input": "hello"}, + headers={"Authorization": "Bearer sk-secret"}, + ) + assert resp.status == 202 + + +# --------------------------------------------------------------------------- +# GET /v1/runs/{run_id}/events — SSE event stream +# --------------------------------------------------------------------------- + + +class TestRunEvents: + @pytest.mark.asyncio + async def test_events_stream_returns_completed(self, adapter): + """Events stream should receive run.completed when agent finishes.""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "Hello!"} + mock_agent.session_prompt_tokens = 10 + mock_agent.session_completion_tokens = 5 + mock_agent.session_total_tokens = 15 + mock_create.return_value = mock_agent + + # Start run + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + # Subscribe to events + events_resp = await cli.get(f"/v1/runs/{run_id}/events") + assert events_resp.status == 200 + body = await events_resp.text() + + # Should contain run.completed + assert "run.completed" in body + assert "Hello!" in body + + @pytest.mark.asyncio + async def test_events_not_found_returns_404(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/runs/run_nonexistent/events") + assert resp.status == 404 + + @pytest.mark.asyncio + async def test_events_requires_auth(self, auth_adapter): + app = _create_runs_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.get("/v1/runs/run_any/events") + assert resp.status == 401 + + +# --------------------------------------------------------------------------- +# POST /v1/runs/{run_id}/stop — interrupt a running agent +# --------------------------------------------------------------------------- + + +class TestStopRun: + @pytest.mark.asyncio + async def test_stop_running_agent(self, adapter): + """Stop should interrupt the agent and cancel the task.""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent, agent_ready, _ = _make_slow_agent() + mock_create.return_value = mock_agent + + # Start run + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + # Wait for agent to start running in the thread + agent_ready.wait(timeout=3.0) + await asyncio.sleep(0.1) + + # Verify agent ref is stored + assert run_id in adapter._active_run_agents + + # Stop the run + stop_resp = await cli.post(f"/v1/runs/{run_id}/stop") + assert stop_resp.status == 200 + stop_data = await stop_resp.json() + assert stop_data["run_id"] == run_id + assert stop_data["status"] == "stopping" + + # Agent interrupt should have been called + mock_agent.interrupt.assert_called_once_with("Stop requested via API") + + # Refs should be cleaned up + await asyncio.sleep(0.5) + assert run_id not in adapter._active_run_agents + assert run_id not in adapter._active_run_tasks + + @pytest.mark.asyncio + async def test_stop_nonexistent_run_returns_404(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/v1/runs/run_nonexistent/stop") + assert resp.status == 404 + + @pytest.mark.asyncio + async def test_stop_requires_auth(self, auth_adapter): + app = _create_runs_app(auth_adapter) + async with TestClient(TestServer(app)) as cli: + resp = await cli.post("/v1/runs/run_any/stop") + assert resp.status == 401 + + @pytest.mark.asyncio + async def test_stop_already_completed_run_returns_404(self, adapter): + """Stopping a run that already finished should return 404 (refs cleaned up).""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "done"} + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + mock_create.return_value = mock_agent + + # Start and wait for completion + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + await asyncio.sleep(0.3) + + # Run should be done, refs cleaned up + assert run_id not in adapter._active_run_agents + + # Stop should return 404 + stop_resp = await cli.post(f"/v1/runs/{run_id}/stop") + assert stop_resp.status == 404 + + @pytest.mark.asyncio + async def test_stop_interrupt_exception_does_not_crash(self, adapter): + """If agent.interrupt() raises, stop should still succeed.""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent, agent_ready, _ = _make_slow_agent() + # Override the interrupt side_effect to raise + mock_agent.interrupt = MagicMock(side_effect=RuntimeError("interrupt failed")) + mock_create.return_value = mock_agent + + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + agent_ready.wait(timeout=3.0) + await asyncio.sleep(0.1) + + stop_resp = await cli.post(f"/v1/runs/{run_id}/stop") + assert stop_resp.status == 200 + stop_data = await stop_resp.json() + assert stop_data["status"] == "stopping" + + @pytest.mark.asyncio + async def test_stop_sends_sentinel_to_events_stream(self, adapter): + """After stop, the events stream should close.""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent, agent_ready, _ = _make_slow_agent() + mock_create.return_value = mock_agent + + # Start run + resp = await cli.post("/v1/runs", json={"input": "hello"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + agent_ready.wait(timeout=3.0) + await asyncio.sleep(0.1) + + # Subscribe to events in background + events_task = asyncio.ensure_future( + cli.get(f"/v1/runs/{run_id}/events") + ) + + await asyncio.sleep(0.1) + + # Stop the run + stop_resp = await cli.post(f"/v1/runs/{run_id}/stop") + assert stop_resp.status == 200 + + # Events stream should close + events_resp = await asyncio.wait_for(events_task, timeout=5.0) + assert events_resp.status == 200 + body = await events_resp.text() + # Stream should have received run.failed and closed + assert "run.failed" in body or "stream closed" in body