diff --git a/tests/tools/test_mcp_elicitation.py b/tests/tools/test_mcp_elicitation.py new file mode 100644 index 00000000000..35321eb35ea --- /dev/null +++ b/tests/tools/test_mcp_elicitation.py @@ -0,0 +1,296 @@ +"""Tests for the MCP elicitation handler in tools.mcp_tool. + +These tests exercise ElicitationHandler in isolation -- the underlying +approval system and the MCP transport layer are mocked, so no real MCP +server or user input is required. + +Tests skip cleanly if the optional `mcp` SDK is not installed (it is an +optional dependency under the `[mcp]` extra). +""" + +import asyncio +from unittest.mock import patch + +import pytest + + +pytest.importorskip("mcp.types") + +from mcp.types import ElicitResult # noqa: E402 -- after importorskip + +from tools.mcp_tool import ( # noqa: E402 + ElicitationHandler, + _format_elicitation_schema_summary, +) + + +def _form_params(message="please confirm", schema=None): + """Build a stand-in for ElicitRequestFormParams. + + We use a plain object (not the SDK type directly) so the test doesn't + couple to optional Pydantic validation -- the handler reads fields via + getattr() and tolerates duck-typed inputs. + """ + from types import SimpleNamespace + return SimpleNamespace( + mode="form", + message=message, + requested_schema=schema or {}, + ) + + +def _url_params(message="open this url", url="https://example.com/auth", elicitation_id="e1"): + from types import SimpleNamespace + return SimpleNamespace( + mode="url", + message=message, + url=url, + elicitation_id=elicitation_id, + ) + + +class TestSchemaSummary: + def test_empty_schema_falls_back_to_generic_message(self): + out = _format_elicitation_schema_summary({}, "pay") + assert "pay" in out + assert "Approval requested" in out + + def test_properties_render_with_type_and_description(self): + schema = { + "type": "object", + "properties": { + "amount": {"type": "string", "description": "USD amount"}, + "recipient": {"type": "string"}, + }, + } + out = _format_elicitation_schema_summary(schema, "pay") + assert "amount (string): USD amount" in out + assert "recipient (string)" in out + + +class TestElicitationHandlerFormMode: + def test_user_accepts_once_returns_accept(self): + handler = ElicitationHandler("pay", {"timeout": 5}) + params = _form_params( + "authorize a payment of $0.50", + {"properties": {"approved": {"type": "boolean"}}}, + ) + + with patch("tools.approval.request_elicitation_consent", return_value="accept"): + result = asyncio.run(handler(context=None, params=params)) + + assert isinstance(result, ElicitResult) + assert result.action == "accept" + assert result.content == {} + assert handler.metrics["accepted"] == 1 + assert handler.metrics["declined"] == 0 + + def test_user_denies_returns_decline(self): + handler = ElicitationHandler("pay", {"timeout": 5}) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", return_value="decline"): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "decline" + assert handler.metrics["declined"] == 1 + assert handler.metrics["accepted"] == 0 + + def test_cancel_propagates_through(self): + """request_elicitation_consent returns 'cancel' when the gateway + wait times out (resolved=False). The handler should propagate + that as ElicitResult(action='cancel') so the server can + distinguish 'no answer' from 'no'.""" + handler = ElicitationHandler("pay", {"timeout": 5}) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", return_value="cancel"): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "cancel" + assert handler.metrics["errors"] == 1 + + +class TestElicitationHandlerFailureModes: + def test_url_mode_is_declined_without_prompting(self): + handler = ElicitationHandler("pay", {"timeout": 5}) + params = _url_params() + + # If the handler tried to prompt, this would raise AssertionError + # because the side_effect treats the call as a test failure. + with patch( + "tools.approval.request_elicitation_consent", + side_effect=AssertionError("URL mode must not prompt"), + ): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "decline" + assert handler.metrics["declined"] == 1 + + def test_exception_in_approval_fails_closed_to_decline(self): + handler = ElicitationHandler("pay", {"timeout": 5}) + params = _form_params() + + with patch( + "tools.approval.request_elicitation_consent", + side_effect=RuntimeError("approval system blew up"), + ): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "decline" + assert handler.metrics["errors"] == 1 + + def test_timeout_returns_cancel(self, monkeypatch): + # Shrink the outer grace window so the test budget is just the + # handler timeout. Default grace is 5s, which makes stall durations + # tight and the test flaky. + monkeypatch.setattr( + ElicitationHandler, "_OUTER_TIMEOUT_GRACE_SECONDS", 0 + ) + # _safe_numeric clamps `timeout` to a minimum of 1s, so the + # effective wait_for budget is 1s here. Stall longer than that + # so the wait_for reliably fires TimeoutError. + handler = ElicitationHandler("pay", {"timeout": 0.05}) + params = _form_params() + + def stall(*_args, **_kwargs): + import time as _t + _t.sleep(2) + return "accept" + + with patch("tools.approval.request_elicitation_consent", side_effect=stall): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "cancel" + assert handler.metrics["errors"] == 1 + + +class TestElicitationHandlerWiring: + def test_session_kwargs_returns_callback(self): + handler = ElicitationHandler("pay", {}) + kwargs = handler.session_kwargs() + assert kwargs == {"elicitation_callback": handler} + + def test_default_timeout_is_300_seconds(self): + handler = ElicitationHandler("pay", {}) + assert handler.timeout == 300 + + def test_disabled_config_does_not_construct_handler(self): + """The server task initializer checks ``elicitation.enabled`` -- + an explicit ``False`` should suppress handler creation. The unit + of that decision lives in MCPServerTask, but the handler itself + must remain harmless to instantiate with arbitrary config.""" + handler = ElicitationHandler("pay", {"enabled": False, "timeout": 10}) + # Just confirm it instantiates and reads timeout; the gate lives + # at the higher layer. + assert handler.timeout == 10 + + +class TestElicitationHandlerContextBridge: + """The MCP recv-loop task that fires elicitation callbacks does NOT + inherit the agent's contextvars (HERMES_SESSION_PLATFORM etc.). The + handler reads ``owner._pending_call_context`` -- a snapshot captured + by the MCP tool wrapper around ``session.call_tool`` -- and replays + it before invoking the approval router so gateway-session detection + survives the task hop. Regression tests for that bridge.""" + + def test_captured_context_is_replayed_in_consent_call(self): + """The captured context's contextvar values must be observable + when ``request_elicitation_consent`` runs -- otherwise the + gateway-platform detection in approval.py sees an empty platform + string and falls back to the CLI path (the bug this fixes).""" + import contextvars + from types import SimpleNamespace + + probe: contextvars.ContextVar[str] = contextvars.ContextVar( + "elicitation_test_probe", default="" + ) + seen: list[str] = [] + + def fake_consent(*_args, **_kwargs): + seen.append(probe.get()) + return "accept" + + token = probe.set("gateway:telegram") + try: + captured = contextvars.copy_context() + finally: + probe.reset(token) + assert probe.get() == "", ( + "Sanity check: the probe must be empty outside the captured " + "context, otherwise the test would pass even without replay." + ) + + owner = SimpleNamespace(_pending_call_context=captured) + handler = ElicitationHandler("pay", {"timeout": 5}, owner=owner) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", side_effect=fake_consent): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "accept" + assert seen == ["gateway:telegram"], ( + f"Expected the captured contextvar to be visible inside the " + f"consent call; got {seen!r}" + ) + + def test_missing_captured_context_falls_back_to_direct_call(self): + """Without an owner (or with an owner that hasn't entered a tool + call) the handler must still invoke the consent router -- just + without the contextvar replay. Otherwise CLI/TUI sessions, which + don't set HERMES_SESSION_PLATFORM, would break.""" + handler = ElicitationHandler("pay", {"timeout": 5}, owner=None) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", return_value="accept") as m: + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "accept" + assert m.call_count == 1 + + def test_captured_context_can_be_replayed_multiple_times(self): + """A single tool call may trigger more than one elicitation + (e.g. the agent retries an MCP call within the same wrapper). + ``Context.run`` raises if a context is re-entered, so the handler + must ``.copy()`` before each run.""" + import contextvars + from types import SimpleNamespace + + probe: contextvars.ContextVar[str] = contextvars.ContextVar( + "elicitation_test_probe_multi", default="" + ) + seen: list[str] = [] + + def fake_consent(*_args, **_kwargs): + seen.append(probe.get()) + return "accept" + + token = probe.set("gateway:slack") + try: + captured = contextvars.copy_context() + finally: + probe.reset(token) + + owner = SimpleNamespace(_pending_call_context=captured) + handler = ElicitationHandler("pay", {"timeout": 5}, owner=owner) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", side_effect=fake_consent): + for _ in range(3): + asyncio.run(handler(context=None, params=params)) + + assert seen == ["gateway:slack"] * 3 + + def test_pending_call_context_none_does_not_crash(self): + """``owner._pending_call_context`` is set to None between tool + calls. An elicitation arriving in that window must not crash.""" + from types import SimpleNamespace + + owner = SimpleNamespace(_pending_call_context=None) + handler = ElicitationHandler("pay", {"timeout": 5}, owner=owner) + params = _form_params() + + with patch("tools.approval.request_elicitation_consent", return_value="decline"): + result = asyncio.run(handler(context=None, params=params)) + + assert result.action == "decline" diff --git a/tools/approval.py b/tools/approval.py index 6e4cca276b8..4d619d435d7 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -1852,5 +1852,92 @@ def check_execute_code_guard(code: str, env_type: str) -> dict: "user_approved": True, "description": description} +# ========================================================================= +# MCP elicitation entry point +# ========================================================================= + +def request_elicitation_consent( + message: str, + description: str, + *, + timeout_seconds: int | None = None, + surface: str = "mcp-elicitation", +) -> str: + """Route an MCP elicitation request to whichever approval surface owns + the active session and return a normalized result. + + Gateway sessions (Telegram, Slack, Discord, etc.) go through + ``_await_gateway_decision`` so the notify_cb posts a message and the + agent thread blocks until the user responds via the platform UI. + CLI/TUI sessions go through ``prompt_dangerous_approval``. + + Always fails closed: missing notify_cb in a gateway session, timeouts, + and exceptions all map to ``"decline"`` so a server treats them as + "user did not approve" rather than retrying or hanging. + + Returns one of ``"accept" | "decline" | "cancel"``. + """ + try: + session_key = get_current_session_key() + except Exception as exc: # pragma: no cover -- defensive + logger.warning("Elicitation consent: session lookup failed: %s", exc) + return "decline" + + if _is_gateway_approval_context(): + with _lock: + notify_cb = _gateway_notify_cbs.get(session_key) + if notify_cb is None: + logger.warning( + "Elicitation requested in gateway session %s but no " + "notify_cb is registered — failing closed", + session_key, + ) + return "decline" + + approval_data = { + "command": message, + "description": description, + "pattern_key": "mcp_elicitation", + "pattern_keys": ["mcp_elicitation"], + } + try: + decision = _await_gateway_decision( + session_key, notify_cb, approval_data, surface=surface, + ) + except Exception as exc: + logger.error( + "Elicitation gateway dispatch failed: %s", exc, exc_info=True, + ) + return "decline" + + if decision.get("notify_failed"): + return "decline" + if not decision.get("resolved"): + return "cancel" + choice = decision.get("choice") + if choice in ("once", "session", "always"): + return "accept" + return "decline" + + # CLI / TUI path. allow_permanent=False because elicitation is a + # per-call confirmation — there is no pattern to remember. + try: + choice = prompt_dangerous_approval( + message, + description, + timeout_seconds=timeout_seconds, + allow_permanent=False, + ) + except Exception as exc: + logger.error( + "Elicitation CLI prompt failed: %s", exc, exc_info=True, + ) + return "decline" + + if choice in ("once", "session", "always"): + return "accept" + return "decline" + + # Load permanent allowlist from config on module import load_permanent_allowlist() diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 2c5a1be5975..c7f0b4eb732 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -78,6 +78,7 @@ Thread safety: """ import asyncio +import contextvars import concurrent.futures import inspect import json @@ -176,6 +177,7 @@ _MCP_AVAILABLE = False _MCP_HTTP_AVAILABLE = False _MCP_SAMPLING_TYPES = False _MCP_NOTIFICATION_TYPES = False +_MCP_ELICITATION_TYPES = False _MCP_MESSAGE_HANDLER_SUPPORTED = False # Conservative fallback for SDK builds that don't export LATEST_PROTOCOL_VERSION. # Streamable HTTP was introduced by 2025-03-26, so this remains valid for the @@ -221,6 +223,16 @@ try: _MCP_SAMPLING_TYPES = True except ImportError: logger.debug("MCP sampling types not available -- sampling disabled") + # Elicitation types -- gated separately for the same reason as sampling. + # Added in mcp Python SDK 1.11.0 (Jul 2025); servers use elicitation to + # ask the client for structured input mid-tool-call (e.g. payment + # authorization). Missing types just disable the feature; everything + # else keeps working. + try: + from mcp.types import ElicitRequestParams, ElicitResult + _MCP_ELICITATION_TYPES = True + except ImportError: + logger.debug("MCP elicitation types not available -- elicitation disabled") # Notification types for dynamic tool discovery (tools/list_changed) try: from mcp.types import ( @@ -1141,6 +1153,193 @@ class SamplingHandler: return self._build_text_result(choice, response) +# --------------------------------------------------------------------------- +# Elicitation handler +# --------------------------------------------------------------------------- + +def _format_elicitation_schema_summary(schema: dict, server_name: str) -> str: + """Render a JSON-schema-ish requested_schema to a human-readable field list. + + Elicitation schemas are restricted to a flat object with named top-level + properties. We surface field names, types, and descriptions so the user + can tell what the server is asking for before approving. + """ + props = schema.get("properties") if isinstance(schema, dict) else None + if not isinstance(props, dict) or not props: + return f"Approval requested by MCP server '{server_name}'." + + lines = [f"Fields requested by MCP server '{server_name}':"] + for field_name, field_spec in props.items(): + field_type = "" + field_desc = "" + if isinstance(field_spec, dict): + field_type = str(field_spec.get("type", "") or "") + field_desc = str(field_spec.get("description", "") or "") + suffix = f" ({field_type})" if field_type else "" + if field_desc: + lines.append(f" - {field_name}{suffix}: {field_desc}") + else: + lines.append(f" - {field_name}{suffix}") + return "\n".join(lines) + + +class ElicitationHandler: + """Handles ``elicitation/create`` requests for a single MCP server. + + Each ``MCPServerTask`` that has elicitation enabled creates one handler. + The handler is callable and passed directly to ``ClientSession`` as the + ``elicitation_callback`` (added in mcp Python SDK 1.11.0). + + Elicitation lets a server ask the client to collect structured input from + the user mid-tool-call (e.g. payment authorization, OAuth confirmation). + Form-mode elicitations are routed through Hermes' existing approval + system (``tools.approval.prompt_dangerous_approval``), which surfaces + the prompt on whichever surface the active session uses -- CLI, TUI, + Telegram, Slack, etc. URL-mode elicitations are declined as unsupported. + + Failure modes are fail-closed: any timeout, exception, or unexpected + state returns ``decline``/``cancel`` rather than silently accepting. + The server treats this as the user not approving. + """ + + # Outer cap for the approval await. ``prompt_dangerous_approval`` runs + # its own input() timeout via the approval-config value; this is an + # asyncio-side safety net so the MCP event loop never blocks + # indefinitely if the inner timeout machinery is bypassed. + _OUTER_TIMEOUT_GRACE_SECONDS = 5 + + def __init__(self, server_name: str, config: dict, owner: Optional["MCPServerTask"] = None): + self.server_name = server_name + # Per-elicitation timeout. Default 5 min mirrors the gateway approval + # default so users on async surfaces (Telegram, Slack) have time to + # respond before the server gives up. + self.timeout = _safe_numeric(config.get("timeout", 300), 300, float) + # Back-reference to the MCPServerTask so we can read the agent's + # captured contextvars snapshot at elicitation time. Optional so + # the handler stays unit-testable in isolation. + self.owner = owner + self.metrics = { + "requests": 0, + "accepted": 0, + "declined": 0, + "errors": 0, + } + + def session_kwargs(self) -> dict: + """Return kwargs to pass to ClientSession for elicitation support.""" + return {"elicitation_callback": self} + + async def __call__(self, context, params): + """Elicitation callback invoked by the MCP SDK. + + Conforms to ``ElicitationFnT`` protocol. Returns ``ElicitResult`` + or ``ErrorData``. + """ + self.metrics["requests"] += 1 + + # URL-mode elicitations point the user to an external URL for + # sensitive out-of-band flows (OAuth, payment processing). Honouring + # them requires opening a browser to that URL and waiting for the + # server's notifications/elicitation/complete -- out of scope for + # the initial implementation. Decline cleanly so the server does + # not hang. + mode = getattr(params, "mode", "form") + if mode == "url": + logger.info( + "MCP server '%s' requested URL-mode elicitation; " + "declining (URL-mode elicitation not implemented)", + self.server_name, + ) + self.metrics["declined"] += 1 + return ElicitResult(action="decline") + + message = getattr(params, "message", "") or ( + f"MCP server '{self.server_name}' is requesting your approval" + ) + schema = getattr(params, "requested_schema", {}) or {} + description = _format_elicitation_schema_summary(schema, self.server_name) + + logger.info( + "MCP server '%s' elicitation request: %s", + self.server_name, _sanitize_error(message)[:200], + ) + + # Lazy import: tools.approval is imported very early during process + # bootstrap; matching the lazy pattern used by _fire_approval_hook + # avoids any chance of import-order coupling. + try: + from tools.approval import request_elicitation_consent + except Exception as exc: # pragma: no cover -- defensive + logger.error( + "MCP server '%s' elicitation: approval system unavailable: %s", + self.server_name, exc, + ) + self.metrics["errors"] += 1 + return ElicitResult(action="decline") + + # Offload the sync consent flow to a worker thread. Running it + # inline would freeze the MCP background event loop, blocking every + # other RPC on this session. request_elicitation_consent() routes + # itself to the right surface (gateway notify_cb for Telegram / + # Slack / etc., prompt_dangerous_approval for CLI / TUI) and + # normalizes the answer to one of accept / decline / cancel. + # + # The recv-loop task that fires this callback does NOT inherit + # the agent's contextvars (HERMES_SESSION_PLATFORM etc.). When + # the MCP tool wrapper captured the agent's context onto + # owner._pending_call_context we replay it here via + # contextvars.Context.run so the gateway-platform detection in + # request_elicitation_consent picks up the right session. + captured = getattr(self.owner, "_pending_call_context", None) if self.owner else None + + def _invoke_consent() -> str: + if captured is None: + return request_elicitation_consent( + message, + description, + timeout_seconds=int(self.timeout), + surface=f"mcp-elicitation/{self.server_name}", + ) + # Context.run can only execute a context once — copy to allow + # multiple elicitations within a single tool call. + return captured.copy().run( + request_elicitation_consent, + message, + description, + timeout_seconds=int(self.timeout), + surface=f"mcp-elicitation/{self.server_name}", + ) + + try: + answer = await asyncio.wait_for( + asyncio.to_thread(_invoke_consent), + timeout=self.timeout + self._OUTER_TIMEOUT_GRACE_SECONDS, + ) + except asyncio.TimeoutError: + logger.warning( + "MCP server '%s' elicitation timed out after %ds", + self.server_name, int(self.timeout), + ) + self.metrics["errors"] += 1 + return ElicitResult(action="cancel") + except Exception as exc: + logger.error( + "MCP server '%s' elicitation failed: %s", + self.server_name, exc, exc_info=True, + ) + self.metrics["errors"] += 1 + return ElicitResult(action="decline") + + if answer == "accept": + self.metrics["accepted"] += 1 + return ElicitResult(action="accept", content={}) + if answer == "cancel": + self.metrics["errors"] += 1 + return ElicitResult(action="cancel") + self.metrics["declined"] += 1 + return ElicitResult(action="decline") + + # --------------------------------------------------------------------------- # Server task -- each MCP server lives in one long-lived asyncio Task # --------------------------------------------------------------------------- @@ -1159,8 +1358,10 @@ class MCPServerTask: "name", "session", "tool_timeout", "_task", "_ready", "_shutdown_event", "_reconnect_event", "_tools", "_error", "_config", - "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", + "_sampling", "_elicitation", + "_registered_tool_names", "_auth_type", "_refresh_lock", "_rpc_lock", "_pending_refresh_tasks", + "_pending_call_context", "initialize_result", ) @@ -1181,6 +1382,7 @@ class MCPServerTask: self._error: Optional[Exception] = None self._config: dict = {} self._sampling: Optional[SamplingHandler] = None + self._elicitation: Optional[ElicitationHandler] = None self._registered_tool_names: list[str] = [] self._auth_type: str = "" self._refresh_lock = asyncio.Lock() @@ -1192,6 +1394,16 @@ class MCPServerTask: # transports for conservative per-server ordering. self._rpc_lock = asyncio.Lock() self._pending_refresh_tasks: set[asyncio.Task] = set() + # contextvars snapshot of the agent task that's currently in + # session.call_tool(). The MCP recv loop dispatches incoming + # elicitation/create requests on a SEPARATE asyncio task whose + # context doesn't inherit HERMES_SESSION_PLATFORM, so the + # elicitation handler has no way to detect the gateway session + # that triggered the call. Capturing the agent's context here + # and replaying it inside the elicitation callback restores + # gateway-platform attribution and routes the approval prompt + # to the right surface (Telegram, Slack, etc.). + self._pending_call_context: Optional[contextvars.Context] = None # Captures the ``InitializeResult`` returned by # ``await session.initialize()`` so downstream code can inspect the # server's real advertised capabilities (``.capabilities.resources``, @@ -1463,6 +1675,8 @@ class MCPServerTask: ) sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} + if self._elicitation: + sampling_kwargs.update(self._elicitation.session_kwargs()) if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: sampling_kwargs["message_handler"] = self._make_message_handler() @@ -1664,6 +1878,8 @@ class MCPServerTask: raise sampling_kwargs = self._sampling.session_kwargs() if self._sampling else {} + if self._elicitation: + sampling_kwargs.update(self._elicitation.session_kwargs()) if _MCP_NOTIFICATION_TYPES and _MCP_MESSAGE_HANDLER_SUPPORTED: sampling_kwargs["message_handler"] = self._make_message_handler() @@ -1859,6 +2075,16 @@ class MCPServerTask: else: self._sampling = None + # Set up elicitation handler if enabled and SDK types are available. + # Servers use elicitation/create to ask the client for structured + # input mid-tool-call (e.g. payment authorization). The handler + # routes those requests through Hermes' approval system. + elicitation_config = config.get("elicitation", {}) + if elicitation_config.get("enabled", True) and _MCP_ELICITATION_TYPES: + self._elicitation = ElicitationHandler(self.name, elicitation_config, owner=self) + else: + self._elicitation = None + # Validate: warn if both url and command are present if "url" in config and "command" in config: logger.warning( @@ -2817,7 +3043,15 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): async def _call(): async with server._rpc_lock: - result = await server.session.call_tool(tool_name, arguments=args) + # Snapshot the agent's context so an elicitation callback + # triggered during this call (fired on the MCP recv loop + # task, which doesn't inherit our contextvars) can replay + # it and detect the gateway platform / session for routing. + server._pending_call_context = contextvars.copy_context() + try: + result = await server.session.call_tool(tool_name, arguments=args) + finally: + server._pending_call_context = None # MCP CallToolResult has .content (list of content blocks) and .isError if result.isError: error_text = ""