From 526c0e018a2087303cf31b25b949a64a029d0718 Mon Sep 17 00:00:00 2001 From: Zhicheng Han Date: Tue, 5 May 2026 18:34:58 +0200 Subject: [PATCH] feat(api-server): expose run approval events --- gateway/platforms/api_server.py | 185 +++++++++++++++++++++++++- tests/gateway/test_api_server_runs.py | 93 +++++++++++++ tools/approval.py | 26 +++- 3 files changed, 295 insertions(+), 9 deletions(-) diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 3b0375ff03..cde7813623 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -11,7 +11,8 @@ Exposes an HTTP server with endpoints: - POST /v1/runs — start a run, returns run_id immediately (202) - GET /v1/runs/{run_id} — retrieve current run status - GET /v1/runs/{run_id}/events — SSE stream of structured lifecycle events -- POST /v1/runs/{run_id}/stop — interrupt a running agent +- POST /v1/runs/{run_id}/approval — resolve a pending run approval +- POST /v1/runs/{run_id}/stop — interrupt a running agent - GET /health — health check - GET /health/detailed — rich status for cross-container dashboard probing @@ -605,6 +606,10 @@ class APIServerAdapter(BasePlatformAdapter): self._active_run_tasks: Dict[str, "asyncio.Task"] = {} # Pollable run status for dashboards and external control-plane UIs. self._run_statuses: Dict[str, Dict[str, Any]] = {} + # Active approval session key for each run_id. The approval core + # resolves requests by session key, while API clients address the + # in-flight run by run_id. + self._run_approval_sessions: Dict[str, str] = {} self._session_db: Optional[Any] = None # Lazy-init SessionDB for session continuity @staticmethod @@ -936,7 +941,9 @@ class APIServerAdapter(BasePlatformAdapter): "run_status": True, "run_events_sse": True, "run_stop": True, + "run_approval_response": True, "tool_progress_events": True, + "approval_events": True, "session_continuity_header": "X-Hermes-Session-Id", "session_key_header": "X-Hermes-Session-Key", "cors": bool(self._cors_origins), @@ -950,6 +957,7 @@ class APIServerAdapter(BasePlatformAdapter): "runs": {"method": "POST", "path": "/v1/runs"}, "run_status": {"method": "GET", "path": "/v1/runs/{run_id}"}, "run_events": {"method": "GET", "path": "/v1/runs/{run_id}/events"}, + "run_approval": {"method": "POST", "path": "/v1/runs/{run_id}/approval"}, "run_stop": {"method": "POST", "path": "/v1/runs/{run_id}/stop"}, }, }) @@ -2821,12 +2829,14 @@ class APIServerAdapter(BasePlatformAdapter): run_id = f"run_{uuid.uuid4().hex}" session_id = body.get("session_id") or stored_session_id or run_id + approval_session_key = gateway_session_key or session_id or run_id ephemeral_system_prompt = instructions loop = asyncio.get_running_loop() q: "asyncio.Queue[Optional[Dict]]" = asyncio.Queue() created_at = time.time() self._run_streams[run_id] = q self._run_streams_created[run_id] = created_at + self._run_approval_sessions[run_id] = approval_session_key event_cb = self._make_run_event_callback(run_id, loop) @@ -2863,13 +2873,66 @@ class APIServerAdapter(BasePlatformAdapter): gateway_session_key=gateway_session_key, ) self._active_run_agents[run_id] = agent - def _run_sync(): - effective_task_id = session_id or run_id - r = agent.run_conversation( - user_message=user_message, - conversation_history=conversation_history, - task_id=effective_task_id, + + def _approval_notify(approval_data: Dict[str, Any]) -> None: + event = dict(approval_data or {}) + event.update({ + "event": "approval.request", + "run_id": run_id, + "timestamp": time.time(), + "choices": ["once", "session", "always", "deny"], + }) + self._set_run_status( + run_id, + "waiting_for_approval", + last_event="approval.request", ) + try: + loop.call_soon_threadsafe(q.put_nowait, event) + except Exception: + pass + + def _run_sync(): + from gateway.session_context import clear_session_vars, set_session_vars + from tools.approval import ( + register_gateway_notify, + reset_current_session_key, + set_current_session_key, + unregister_gateway_notify, + ) + + effective_task_id = session_id or run_id + approval_token = None + session_tokens = [] + try: + # Bind approval/session identity for this API run via + # contextvars so concurrent runs do not share process + # environment state. + approval_token = set_current_session_key(approval_session_key) + session_tokens = set_session_vars( + platform="api_server", + session_key=approval_session_key, + ) + register_gateway_notify(approval_session_key, _approval_notify) + r = agent.run_conversation( + user_message=user_message, + conversation_history=conversation_history, + task_id=effective_task_id, + ) + finally: + try: + unregister_gateway_notify(approval_session_key) + finally: + if approval_token is not None: + try: + reset_current_session_key(approval_token) + except Exception: + pass + if session_tokens: + try: + clear_session_vars(session_tokens) + except Exception: + pass u = { "input_tokens": getattr(agent, "session_prompt_tokens", 0) or 0, "output_tokens": getattr(agent, "session_completion_tokens", 0) or 0, @@ -2944,6 +3007,17 @@ class APIServerAdapter(BasePlatformAdapter): except Exception: pass finally: + # If the asyncio wrapper is cancelled (for example via + # /stop), the executor thread can still be blocked waiting + # on an approval Event. Unregistering here releases those + # waits immediately; the in-thread unregister is harmlessly + # idempotent on normal completion. + try: + from tools.approval import unregister_gateway_notify + + unregister_gateway_notify(approval_session_key) + except Exception: + pass # Sentinel: signal SSE stream to close try: q.put_nowait(None) @@ -2951,6 +3025,7 @@ class APIServerAdapter(BasePlatformAdapter): pass self._active_run_agents.pop(run_id, None) self._active_run_tasks.pop(run_id, None) + self._run_approval_sessions.pop(run_id, None) task = asyncio.create_task(_run_and_close()) self._active_run_tasks[run_id] = task @@ -3034,6 +3109,92 @@ class APIServerAdapter(BasePlatformAdapter): return response + + async def _handle_run_approval(self, request: "web.Request") -> "web.Response": + """POST /v1/runs/{run_id}/approval — resolve a pending run approval.""" + auth_err = self._check_auth(request) + if auth_err: + return auth_err + + run_id = request.match_info["run_id"] + status = self._run_statuses.get(run_id) + if status is None: + return web.json_response( + _openai_error(f"Run not found: {run_id}", code="run_not_found"), + status=404, + ) + + try: + body = await request.json() + except Exception: + return web.json_response(_openai_error("Invalid JSON"), status=400) + + raw_choice = str(body.get("choice", "")).strip().lower() + aliases = {"approve": "once", "approved": "once", "allow": "once"} + choice = aliases.get(raw_choice, raw_choice) + allowed = {"once", "session", "always", "deny"} + if choice not in allowed: + return web.json_response( + _openai_error( + "Invalid approval choice; expected one of: once, session, always, deny", + code="invalid_approval_choice", + ), + status=400, + ) + + approval_session_key = self._run_approval_sessions.get(run_id) + if not approval_session_key: + return web.json_response( + _openai_error( + f"Run has no active approval session: {run_id}", + code="approval_not_active", + ), + status=409, + ) + + resolve_all = bool(body.get("all") or body.get("resolve_all")) + try: + from tools.approval import resolve_gateway_approval + + resolved = resolve_gateway_approval( + approval_session_key, + choice, + resolve_all=resolve_all, + ) + except Exception as exc: + logger.exception("[api_server] approval resolution failed for run %s", run_id) + return web.json_response(_openai_error(str(exc)), status=500) + + if resolved <= 0: + return web.json_response( + _openai_error( + f"Run has no pending approval: {run_id}", + code="approval_not_pending", + ), + status=409, + ) + + self._set_run_status(run_id, "running", last_event="approval.responded") + q = self._run_streams.get(run_id) + if q is not None: + try: + q.put_nowait({ + "event": "approval.responded", + "run_id": run_id, + "timestamp": time.time(), + "choice": choice, + "resolved": resolved, + }) + except Exception: + pass + + return web.json_response({ + "object": "hermes.run.approval_response", + "run_id": run_id, + "choice": choice, + "resolved": resolved, + }) + async def _handle_stop_run(self, request: "web.Request") -> "web.Response": """POST /v1/runs/{run_id}/stop — interrupt a running agent.""" auth_err = self._check_auth(request) @@ -3086,10 +3247,19 @@ class APIServerAdapter(BasePlatformAdapter): ] for run_id in stale: logger.debug("[api_server] sweeping orphaned run %s", run_id) + try: + from tools.approval import unregister_gateway_notify + + approval_session_key = self._run_approval_sessions.get(run_id) + if approval_session_key: + unregister_gateway_notify(approval_session_key) + except Exception: + pass self._run_streams.pop(run_id, None) self._run_streams_created.pop(run_id, None) self._active_run_agents.pop(run_id, None) self._active_run_tasks.pop(run_id, None) + self._run_approval_sessions.pop(run_id, None) stale_statuses = [ run_id @@ -3136,6 +3306,7 @@ class APIServerAdapter(BasePlatformAdapter): self._app.router.add_post("/v1/runs", self._handle_runs) self._app.router.add_get("/v1/runs/{run_id}", self._handle_get_run) self._app.router.add_get("/v1/runs/{run_id}/events", self._handle_run_events) + self._app.router.add_post("/v1/runs/{run_id}/approval", self._handle_run_approval) self._app.router.add_post("/v1/runs/{run_id}/stop", self._handle_stop_run) # Start background sweep to clean up orphaned (unconsumed) run streams sweep_task = asyncio.create_task(self._sweep_orphaned_runs()) diff --git a/tests/gateway/test_api_server_runs.py b/tests/gateway/test_api_server_runs.py index 6ce67db923..f47060d068 100644 --- a/tests/gateway/test_api_server_runs.py +++ b/tests/gateway/test_api_server_runs.py @@ -49,6 +49,7 @@ def _create_runs_app(adapter: APIServerAdapter) -> web.Application: app.router.add_post("/v1/runs", adapter._handle_runs) app.router.add_get("/v1/runs/{run_id}", adapter._handle_get_run) app.router.add_get("/v1/runs/{run_id}/events", adapter._handle_run_events) + app.router.add_post("/v1/runs/{run_id}/approval", adapter._handle_run_approval) app.router.add_post("/v1/runs/{run_id}/stop", adapter._handle_stop_run) return app @@ -305,6 +306,98 @@ class TestRunEvents: assert "run.completed" in body assert "Hello!" in body + + @pytest.mark.asyncio + async def test_approval_request_event_and_response_unblock_run(self, adapter): + """Dangerous-command approvals should surface on the run SSE stream.""" + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + guard_result = {} + + mock_agent = MagicMock() + + def _run_with_approval(user_message=None, conversation_history=None, task_id=None): + from tools.approval import check_all_command_guards + + result = check_all_command_guards("git reset --hard HEAD", "local") + guard_result.update(result) + return {"final_response": "approved" if result.get("approved") else "blocked"} + + mock_agent.run_conversation.side_effect = _run_with_approval + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + mock_create.return_value = mock_agent + + resp = await cli.post("/v1/runs", json={"input": "needs approval"}) + assert resp.status == 202 + data = await resp.json() + run_id = data["run_id"] + + events_resp = await cli.get(f"/v1/runs/{run_id}/events") + assert events_resp.status == 200 + + approval_event = None + for _ in range(20): + line = await asyncio.wait_for(events_resp.content.readline(), timeout=3.0) + text = line.decode() + if not text.startswith("data: "): + continue + event = json.loads(text[len("data: "):]) + if event.get("event") == "approval.request": + approval_event = event + break + + assert approval_event is not None + assert approval_event["run_id"] == run_id + assert approval_event["command"] == "git reset --hard HEAD" + assert approval_event["pattern_key"] + assert "pattern_keys" in approval_event + assert approval_event["choices"] == ["once", "session", "always", "deny"] + + approval_resp = await cli.post( + f"/v1/runs/{run_id}/approval", + json={"choice": "once"}, + ) + assert approval_resp.status == 200 + approval_data = await approval_resp.json() + assert approval_data["resolved"] == 1 + assert approval_data["choice"] == "once" + + body = await events_resp.text() + assert "approval.responded" in body + assert "run.completed" in body + + assert guard_result.get("approved") is True + + @pytest.mark.asyncio + async def test_approval_response_without_pending_returns_409(self, adapter): + app = _create_runs_app(adapter) + async with TestClient(TestServer(app)) as cli: + with patch.object(adapter, "_create_agent") as mock_create: + mock_agent = MagicMock() + mock_agent.run_conversation.return_value = {"final_response": "done"} + mock_agent.session_prompt_tokens = 0 + mock_agent.session_completion_tokens = 0 + mock_agent.session_total_tokens = 0 + mock_create.return_value = mock_agent + + resp = await cli.post("/v1/runs", json={"input": "hello"}) + data = await resp.json() + run_id = data["run_id"] + + approval_resp = await cli.post( + f"/v1/runs/{run_id}/approval", + json={"choice": "once"}, + ) + assert approval_resp.status == 409 + approval_data = await approval_resp.json() + assert approval_data["error"]["code"] in { + "approval_not_active", + "approval_not_pending", + } + @pytest.mark.asyncio async def test_events_not_found_returns_404(self, adapter): app = _create_runs_app(adapter) diff --git a/tools/approval.py b/tools/approval.py index a7faaff21f..1322098ebc 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -83,6 +83,28 @@ def get_current_session_key(default: str = "default") -> str: from gateway.session_context import get_session_env return get_session_env("HERMES_SESSION_KEY", default) + +def _get_session_platform() -> str: + """Return the current gateway platform from contextvars/env fallback.""" + try: + from gateway.session_context import get_session_env + + return get_session_env("HERMES_SESSION_PLATFORM", "") or "" + except Exception: + return os.getenv("HERMES_SESSION_PLATFORM", "") or "" + + +def _is_gateway_approval_context() -> bool: + """True when this call is inside a gateway/API session. + + Legacy gateway integrations set HERMES_GATEWAY_SESSION in process env. + Newer concurrent gateway paths bind HERMES_SESSION_PLATFORM via + contextvars so approval mode does not depend on process-global flags. + """ + if os.getenv("HERMES_GATEWAY_SESSION"): + return True + return bool(_get_session_platform()) + # Sensitive write targets that should trigger approval even when referenced # via shell expansions like $HOME or $HERMES_HOME. _SSH_SENSITIVE_PATH = r'(?:~|\$home|\$\{home\})/\.ssh(?:/|$)' @@ -829,7 +851,7 @@ def check_dangerous_command(command: str, env_type: str, return {"approved": True, "message": None} is_cli = os.getenv("HERMES_INTERACTIVE") - is_gateway = os.getenv("HERMES_GATEWAY_SESSION") + is_gateway = _is_gateway_approval_context() if not is_cli and not is_gateway: # Cron sessions: respect cron_mode config @@ -946,7 +968,7 @@ def check_all_command_guards(command: str, env_type: str, return {"approved": True, "message": None} is_cli = os.getenv("HERMES_INTERACTIVE") - is_gateway = os.getenv("HERMES_GATEWAY_SESSION") + is_gateway = _is_gateway_approval_context() is_ask = os.getenv("HERMES_EXEC_ASK") # Preserve the existing non-interactive behavior: outside CLI/gateway/ask