diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 243bad599..9cbaf90e1 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -5904,6 +5904,12 @@ Examples: mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection") mcp_cfg_p.add_argument("name", help="Server name to configure") + mcp_login_p = mcp_sub.add_parser( + "login", + help="Force re-authentication for an OAuth-based MCP server", + ) + mcp_login_p.add_argument("name", help="Server name to re-authenticate") + def cmd_mcp(args): from hermes_cli.mcp_config import mcp_command mcp_command(args) diff --git a/hermes_cli/mcp_config.py b/hermes_cli/mcp_config.py index b21234ce0..ae845b069 100644 --- a/hermes_cli/mcp_config.py +++ b/hermes_cli/mcp_config.py @@ -279,8 +279,8 @@ def cmd_mcp_add(args): _info(f"Starting OAuth flow for '{name}'...") oauth_ok = False try: - from tools.mcp_oauth import build_oauth_auth - oauth_auth = build_oauth_auth(name, url) + from tools.mcp_oauth_manager import get_manager + oauth_auth = get_manager().get_or_build_provider(name, url, None) if oauth_auth: server_config["auth"] = "oauth" _success("OAuth configured (tokens will be acquired on first connection)") @@ -428,10 +428,12 @@ def cmd_mcp_remove(args): _remove_mcp_server(name) _success(f"Removed '{name}' from config") - # Clean up OAuth tokens if they exist + # Clean up OAuth tokens if they exist — route through MCPOAuthManager so + # any provider instance cached in the current process (e.g. from an + # earlier `hermes mcp test` in the same session) is evicted too. try: - from tools.mcp_oauth import remove_oauth_tokens - remove_oauth_tokens(name) + from tools.mcp_oauth_manager import get_manager + get_manager().remove(name) _success("Cleaned up OAuth tokens") except Exception: pass @@ -577,6 +579,63 @@ def _interpolate_value(value: str) -> str: return re.sub(r"\$\{(\w+)\}", _replace, value) +# ─── hermes mcp login ──────────────────────────────────────────────────────── + +def cmd_mcp_login(args): + """Force re-authentication for an OAuth-based MCP server. + + Deletes cached tokens (both on disk and in the running process's + MCPOAuthManager cache) and triggers a fresh OAuth flow via the + existing probe path. + + Use this when: + - Tokens are stuck in a bad state (server revoked, refresh token + consumed by an external process, etc.) + - You want to re-authenticate to change scopes or account + - A tool call returned ``needs_reauth: true`` + """ + name = args.name + servers = _get_mcp_servers() + + if name not in servers: + _error(f"Server '{name}' not found in config.") + if servers: + _info(f"Available servers: {', '.join(servers)}") + return + + server_config = servers[name] + url = server_config.get("url") + if not url: + _error(f"Server '{name}' has no URL — not an OAuth-capable server") + return + if server_config.get("auth") != "oauth": + _error(f"Server '{name}' is not configured for OAuth (auth={server_config.get('auth')})") + _info("Use `hermes mcp remove` + `hermes mcp add` to reconfigure auth.") + return + + # Wipe both disk and in-memory cache so the next probe forces a fresh + # OAuth flow. + try: + from tools.mcp_oauth_manager import get_manager + mgr = get_manager() + mgr.remove(name) + except Exception as exc: + _warning(f"Could not clear existing OAuth state: {exc}") + + print() + _info(f"Starting OAuth flow for '{name}'...") + + # Probe triggers the OAuth flow (browser redirect + callback capture). + try: + tools = _probe_single_server(name, server_config) + if tools: + _success(f"Authenticated — {len(tools)} tool(s) available") + else: + _success("Authenticated (server reported no tools)") + except Exception as exc: + _error(f"Authentication failed: {exc}") + + # ─── hermes mcp configure ──────────────────────────────────────────────────── def cmd_mcp_configure(args): @@ -696,6 +755,7 @@ def mcp_command(args): "test": cmd_mcp_test, "configure": cmd_mcp_configure, "config": cmd_mcp_configure, + "login": cmd_mcp_login, } handler = handlers.get(action) @@ -713,4 +773,5 @@ def mcp_command(args): _info("hermes mcp list List servers") _info("hermes mcp test Test connection") _info("hermes mcp configure Toggle tools") + _info("hermes mcp login Re-authenticate OAuth") print() diff --git a/tests/hermes_cli/test_mcp_config.py b/tests/hermes_cli/test_mcp_config.py index 9647a0b95..979108a95 100644 --- a/tests/hermes_cli/test_mcp_config.py +++ b/tests/hermes_cli/test_mcp_config.py @@ -539,3 +539,64 @@ class TestDispatcher: mcp_command(_make_args(mcp_action=None)) out = capsys.readouterr().out assert "Commands:" in out or "No MCP servers" in out + + +# --------------------------------------------------------------------------- +# Tests: Task 7 consolidation — cmd_mcp_remove evicts manager cache, +# cmd_mcp_login forces re-auth +# --------------------------------------------------------------------------- + + +class TestMcpRemoveEvictsManager: + def test_remove_evicts_in_memory_provider(self, tmp_path, capsys, monkeypatch): + """After cmd_mcp_remove, the MCPOAuthManager no longer caches the provider.""" + _seed_config(tmp_path, { + "oauth-srv": {"url": "https://example.com/mcp", "auth": "oauth"}, + }) + monkeypatch.setattr("builtins.input", lambda _: "y") + monkeypatch.setattr( + "hermes_cli.mcp_config.get_hermes_home", lambda: tmp_path + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests + reset_manager_for_tests() + + mgr = get_manager() + mgr.get_or_build_provider( + "oauth-srv", "https://example.com/mcp", None, + ) + assert "oauth-srv" in mgr._entries + + from hermes_cli.mcp_config import cmd_mcp_remove + cmd_mcp_remove(_make_args(name="oauth-srv")) + + assert "oauth-srv" not in mgr._entries + + +class TestMcpLogin: + def test_login_rejects_unknown_server(self, tmp_path, capsys): + _seed_config(tmp_path, {}) + from hermes_cli.mcp_config import cmd_mcp_login + cmd_mcp_login(_make_args(name="ghost")) + out = capsys.readouterr().out + assert "not found" in out + + def test_login_rejects_non_oauth_server(self, tmp_path, capsys): + _seed_config(tmp_path, { + "srv": {"url": "https://example.com/mcp", "auth": "header"}, + }) + from hermes_cli.mcp_config import cmd_mcp_login + cmd_mcp_login(_make_args(name="srv")) + out = capsys.readouterr().out + assert "not configured for OAuth" in out + + def test_login_rejects_stdio_server(self, tmp_path, capsys): + _seed_config(tmp_path, { + "srv": {"command": "npx", "args": ["some-server"]}, + }) + from hermes_cli.mcp_config import cmd_mcp_login + cmd_mcp_login(_make_args(name="srv")) + out = capsys.readouterr().out + assert "no URL" in out or "not an OAuth" in out + diff --git a/tests/tools/test_mcp_oauth.py b/tests/tools/test_mcp_oauth.py index 8643c26b3..b2f3f0229 100644 --- a/tests/tools/test_mcp_oauth.py +++ b/tests/tools/test_mcp_oauth.py @@ -431,3 +431,71 @@ class TestBuildOAuthAuthNonInteractive: assert auth is not None assert "no cached tokens found" not in caplog.text.lower() + + +# --------------------------------------------------------------------------- +# Extracted helper tests (Task 3 of MCP OAuth consolidation) +# --------------------------------------------------------------------------- + + +def test_build_client_metadata_basic(): + """_build_client_metadata returns metadata with expected defaults.""" + from tools.mcp_oauth import _build_client_metadata, _configure_callback_port + + cfg = {"client_name": "Test Client"} + _configure_callback_port(cfg) + md = _build_client_metadata(cfg) + + assert md.client_name == "Test Client" + assert "authorization_code" in md.grant_types + assert "refresh_token" in md.grant_types + + +def test_build_client_metadata_without_secret_is_public(): + """Without client_secret, token endpoint auth is 'none' (public client).""" + from tools.mcp_oauth import _build_client_metadata, _configure_callback_port + + cfg = {} + _configure_callback_port(cfg) + md = _build_client_metadata(cfg) + assert md.token_endpoint_auth_method == "none" + + +def test_build_client_metadata_with_secret_is_confidential(): + """With client_secret, token endpoint auth is 'client_secret_post'.""" + from tools.mcp_oauth import _build_client_metadata, _configure_callback_port + + cfg = {"client_secret": "shh"} + _configure_callback_port(cfg) + md = _build_client_metadata(cfg) + assert md.token_endpoint_auth_method == "client_secret_post" + + +def test_configure_callback_port_picks_free_port(): + """_configure_callback_port(0) picks a free port in the ephemeral range.""" + from tools.mcp_oauth import _configure_callback_port + + cfg = {"redirect_port": 0} + port = _configure_callback_port(cfg) + assert 1024 < port < 65536 + assert cfg["_resolved_port"] == port + + +def test_configure_callback_port_uses_explicit_port(): + """An explicit redirect_port is preserved.""" + from tools.mcp_oauth import _configure_callback_port + + cfg = {"redirect_port": 54321} + port = _configure_callback_port(cfg) + assert port == 54321 + assert cfg["_resolved_port"] == 54321 + + +def test_parse_base_url_strips_path(): + """_parse_base_url drops path components for OAuth discovery.""" + from tools.mcp_oauth import _parse_base_url + + assert _parse_base_url("https://example.com/mcp/v1") == "https://example.com" + assert _parse_base_url("https://example.com") == "https://example.com" + assert _parse_base_url("https://host.example.com:8080/api") == "https://host.example.com:8080" + diff --git a/tests/tools/test_mcp_oauth_integration.py b/tests/tools/test_mcp_oauth_integration.py new file mode 100644 index 000000000..9e8040024 --- /dev/null +++ b/tests/tools/test_mcp_oauth_integration.py @@ -0,0 +1,193 @@ +"""End-to-end integration tests for the MCP OAuth consolidation. + +Exercises the full chain — manager, provider subclass, disk watch, 401 +dedup — with real file I/O and real imports (no transport mocks, no +subprocesses). These are the tests that would catch Cthulhu's original +BetterStack bug: an external process rewrites the tokens file on disk, +and the running Hermes session picks up the new tokens on the next auth +flow without requiring a restart. +""" +import asyncio +import json +import os +import time + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required") + + +@pytest.mark.asyncio +async def test_external_refresh_picked_up_without_restart(tmp_path, monkeypatch): + """Simulate Cthulhu's cron workflow end-to-end. + + 1. A running Hermes session has OAuth tokens loaded in memory. + 2. An external process (cron) writes fresh tokens to disk. + 3. On the next auth flow, the manager's disk-watch invalidates the + in-memory state so the SDK re-reads from storage. + 4. ``provider.context.current_tokens`` now reflects the new tokens + with no process restart required. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + reset_manager_for_tests() + + token_dir = tmp_path / "mcp-tokens" + token_dir.mkdir(parents=True) + tokens_file = token_dir / "srv.json" + client_info_file = token_dir / "srv.client.json" + + # Pre-seed the baseline state: valid tokens the session loaded at startup. + tokens_file.write_text(json.dumps({ + "access_token": "OLD_ACCESS", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "OLD_REFRESH", + })) + client_info_file.write_text(json.dumps({ + "client_id": "test-client", + "redirect_uris": ["http://127.0.0.1:12345/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + })) + + mgr = MCPOAuthManager() + provider = mgr.get_or_build_provider( + "srv", "https://example.com/mcp", None, + ) + assert provider is not None + + # The SDK's _initialize reads tokens from storage into memory. This + # is what happens on the first http request under normal operation. + await provider._initialize() + assert provider.context.current_tokens.access_token == "OLD_ACCESS" + + # Now record the baseline mtime in the manager (this happens + # automatically via the HermesMCPOAuthProvider.async_auth_flow + # pre-hook on the first real request, but we exercise it directly + # here for test determinism). + await mgr.invalidate_if_disk_changed("srv") + + # EXTERNAL PROCESS: cron rewrites the tokens file with fresh creds. + # The old refresh_token has been consumed by this external exchange. + future_mtime = time.time() + 1 + tokens_file.write_text(json.dumps({ + "access_token": "NEW_ACCESS", + "token_type": "Bearer", + "expires_in": 3600, + "refresh_token": "NEW_REFRESH", + })) + os.utime(tokens_file, (future_mtime, future_mtime)) + + # The next auth flow should detect the mtime change and reload. + changed = await mgr.invalidate_if_disk_changed("srv") + assert changed, "manager must detect the disk mtime change" + assert provider._initialized is False, "_initialized must flip so SDK re-reads storage" + + # Simulate the next async_auth_flow: _initialize runs because _initialized=False. + await provider._initialize() + assert provider.context.current_tokens.access_token == "NEW_ACCESS" + assert provider.context.current_tokens.refresh_token == "NEW_REFRESH" + + +@pytest.mark.asyncio +async def test_handle_401_deduplicates_concurrent_callers(tmp_path, monkeypatch): + """Ten concurrent 401 handlers for the same token should fire one recovery. + + Mirrors Claude Code's pending401Handlers dedup pattern — prevents N MCP + tool calls hitting 401 simultaneously from all independently clearing + caches and re-reading the keychain (which thrashes the storage and + bogs down startup per CC-1096). + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + reset_manager_for_tests() + + token_dir = tmp_path / "mcp-tokens" + token_dir.mkdir(parents=True) + (token_dir / "srv.json").write_text(json.dumps({ + "access_token": "TOK", + "token_type": "Bearer", + "expires_in": 3600, + })) + + mgr = MCPOAuthManager() + provider = mgr.get_or_build_provider( + "srv", "https://example.com/mcp", None, + ) + assert provider is not None + + # Count how many times invalidate_if_disk_changed is called — proxy for + # how many actual recovery attempts fire. + call_count = 0 + real_invalidate = mgr.invalidate_if_disk_changed + + async def counting(name): + nonlocal call_count + call_count += 1 + return await real_invalidate(name) + + monkeypatch.setattr(mgr, "invalidate_if_disk_changed", counting) + + # Fire 10 concurrent handlers with the same failed token. + results = await asyncio.gather(*( + mgr.handle_401("srv", "SAME_FAILED_TOKEN") for _ in range(10) + )) + + # All callers get the same result (the shared future's resolution). + assert all(r == results[0] for r in results), "dedup must return identical result" + # Exactly ONE recovery ran — the rest awaited the same pending future. + assert call_count == 1, f"expected 1 recovery attempt, got {call_count}" + + +@pytest.mark.asyncio +async def test_handle_401_returns_false_when_no_provider(tmp_path, monkeypatch): + """handle_401 for an unknown server returns False cleanly.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + reset_manager_for_tests() + + mgr = MCPOAuthManager() + result = await mgr.handle_401("nonexistent", "any_token") + assert result is False + + +@pytest.mark.asyncio +async def test_invalidate_if_disk_changed_handles_missing_file(tmp_path, monkeypatch): + """invalidate_if_disk_changed returns False when tokens file doesn't exist.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + reset_manager_for_tests() + + mgr = MCPOAuthManager() + mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + + # No tokens file exists yet — this is the pre-auth state + result = await mgr.invalidate_if_disk_changed("srv") + assert result is False + + +@pytest.mark.asyncio +async def test_provider_is_reused_across_reconnects(tmp_path, monkeypatch): + """The manager caches providers; multiple reconnects reuse the same instance. + + This is what makes the disk-watch stick across reconnects: tearing down + the MCP session and rebuilding it (Task 5's _reconnect_event path) must + not create a new provider, otherwise ``last_mtime_ns`` resets and the + first post-reconnect auth flow would spuriously "detect" a change. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + reset_manager_for_tests() + + mgr = MCPOAuthManager() + p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + + # Simulate a reconnect: _run_http calls get_or_build_provider again + p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + + assert p1 is p2, "manager must cache the provider across reconnects" diff --git a/tests/tools/test_mcp_oauth_manager.py b/tests/tools/test_mcp_oauth_manager.py new file mode 100644 index 000000000..2a66449cb --- /dev/null +++ b/tests/tools/test_mcp_oauth_manager.py @@ -0,0 +1,141 @@ +"""Tests for the MCP OAuth manager (tools/mcp_oauth_manager.py). + +The manager consolidates the eight scattered MCP-OAuth call sites into a +single object with disk-mtime watch, dedup'd 401 handling, and a provider +cache. See `tools/mcp_oauth_manager.py` for design rationale. +""" +import json +import os +import time + +import pytest + +pytest.importorskip( + "mcp.client.auth.oauth2", + reason="MCP SDK 1.26.0+ required for OAuth support", +) + + +def test_manager_is_singleton(): + """get_manager() returns the same instance across calls.""" + from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests + reset_manager_for_tests() + m1 = get_manager() + m2 = get_manager() + assert m1 is m2 + + +def test_manager_get_or_build_provider_caches(tmp_path, monkeypatch): + """Calling get_or_build_provider twice with same name returns same provider.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager + + mgr = MCPOAuthManager() + p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + assert p1 is p2 + + +def test_manager_get_or_build_rebuilds_on_url_change(tmp_path, monkeypatch): + """Changing the URL discards the cached provider.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager + + mgr = MCPOAuthManager() + p1 = mgr.get_or_build_provider("srv", "https://a.example.com/mcp", None) + p2 = mgr.get_or_build_provider("srv", "https://b.example.com/mcp", None) + assert p1 is not p2 + + +def test_manager_remove_evicts_cache(tmp_path, monkeypatch): + """remove(name) evicts the provider from cache AND deletes disk files.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager + + # Pre-seed tokens on disk + token_dir = tmp_path / "mcp-tokens" + token_dir.mkdir(parents=True) + (token_dir / "srv.json").write_text(json.dumps({ + "access_token": "TOK", + "token_type": "Bearer", + })) + + mgr = MCPOAuthManager() + p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + assert p1 is not None + assert (token_dir / "srv.json").exists() + + mgr.remove("srv") + + assert not (token_dir / "srv.json").exists() + p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + assert p1 is not p2 + + +def test_hermes_provider_subclass_exists(): + """HermesMCPOAuthProvider is defined and subclasses OAuthClientProvider.""" + from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS + from mcp.client.auth.oauth2 import OAuthClientProvider + + assert _HERMES_PROVIDER_CLS is not None + assert issubclass(_HERMES_PROVIDER_CLS, OAuthClientProvider) + + +@pytest.mark.asyncio +async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch): + """When the tokens file mtime changes, provider._initialized flips False. + + This is the behaviour Claude Code ships as + invalidateOAuthCacheIfDiskChanged (CC-1096 / GH#24317) and is the core + fix for Cthulhu's external-cron refresh workflow. + """ + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests + + reset_manager_for_tests() + + token_dir = tmp_path / "mcp-tokens" + token_dir.mkdir(parents=True) + tokens_file = token_dir / "srv.json" + tokens_file.write_text(json.dumps({ + "access_token": "OLD", + "token_type": "Bearer", + })) + + mgr = MCPOAuthManager() + provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + assert provider is not None + + # First call: records mtime (zero -> real) -> returns True + changed1 = await mgr.invalidate_if_disk_changed("srv") + assert changed1 is True + + # No file change -> False + changed2 = await mgr.invalidate_if_disk_changed("srv") + assert changed2 is False + + # Touch file with a newer mtime + future_mtime = time.time() + 10 + os.utime(tokens_file, (future_mtime, future_mtime)) + + changed3 = await mgr.invalidate_if_disk_changed("srv") + assert changed3 is True + # _initialized flipped — next async_auth_flow will re-read from disk + assert provider._initialized is False + + +def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch): + """get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider.""" + from tools.mcp_oauth_manager import ( + MCPOAuthManager, _HERMES_PROVIDER_CLS, reset_manager_for_tests, + ) + reset_manager_for_tests() + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + mgr = MCPOAuthManager() + provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None) + + assert _HERMES_PROVIDER_CLS is not None + assert isinstance(provider, _HERMES_PROVIDER_CLS) + assert provider._hermes_server_name == "srv" + diff --git a/tests/tools/test_mcp_reconnect_signal.py b/tests/tools/test_mcp_reconnect_signal.py new file mode 100644 index 000000000..2cc516ee1 --- /dev/null +++ b/tests/tools/test_mcp_reconnect_signal.py @@ -0,0 +1,57 @@ +"""Tests for the MCPServerTask reconnect signal. + +When the OAuth layer cannot recover in-place (e.g., external refresh of a +single-use refresh_token made the SDK's in-memory refresh fail), the tool +handler signals MCPServerTask to tear down the current MCP session and +reconnect with fresh credentials. This file exercises the signal plumbing +in isolation from the full stdio/http transport machinery. +""" +import asyncio + +import pytest + + +@pytest.mark.asyncio +async def test_reconnect_event_attribute_exists(): + """MCPServerTask has a _reconnect_event alongside _shutdown_event.""" + from tools.mcp_tool import MCPServerTask + task = MCPServerTask("test") + assert hasattr(task, "_reconnect_event") + assert isinstance(task._reconnect_event, asyncio.Event) + assert not task._reconnect_event.is_set() + + +@pytest.mark.asyncio +async def test_wait_for_lifecycle_event_returns_reconnect(): + """When _reconnect_event fires, helper returns 'reconnect' and clears it.""" + from tools.mcp_tool import MCPServerTask + task = MCPServerTask("test") + + task._reconnect_event.set() + reason = await task._wait_for_lifecycle_event() + assert reason == "reconnect" + # Should have cleared so the next cycle starts fresh + assert not task._reconnect_event.is_set() + + +@pytest.mark.asyncio +async def test_wait_for_lifecycle_event_returns_shutdown(): + """When _shutdown_event fires, helper returns 'shutdown'.""" + from tools.mcp_tool import MCPServerTask + task = MCPServerTask("test") + + task._shutdown_event.set() + reason = await task._wait_for_lifecycle_event() + assert reason == "shutdown" + + +@pytest.mark.asyncio +async def test_wait_for_lifecycle_event_shutdown_wins_when_both_set(): + """If both events are set simultaneously, shutdown takes precedence.""" + from tools.mcp_tool import MCPServerTask + task = MCPServerTask("test") + + task._shutdown_event.set() + task._reconnect_event.set() + reason = await task._wait_for_lifecycle_event() + assert reason == "shutdown" diff --git a/tests/tools/test_mcp_tool_401_handling.py b/tests/tools/test_mcp_tool_401_handling.py new file mode 100644 index 000000000..a60d2049f --- /dev/null +++ b/tests/tools/test_mcp_tool_401_handling.py @@ -0,0 +1,139 @@ +"""Tests for MCP tool-handler auth-failure detection. + +When a tool call raises UnauthorizedError / OAuthNonInteractiveError / +httpx.HTTPStatusError(401), the handler should: + 1. Ask MCPOAuthManager.handle_401 if recovery is viable. + 2. If yes, trigger MCPServerTask._reconnect_event and retry once. + 3. If no, return a structured needs_reauth error so the model stops + hallucinating manual refresh attempts. +""" +import json +from unittest.mock import MagicMock + +import pytest + + +pytest.importorskip("mcp.client.auth.oauth2") + + +def test_is_auth_error_detects_oauth_flow_error(): + from tools.mcp_tool import _is_auth_error + from mcp.client.auth import OAuthFlowError + + assert _is_auth_error(OAuthFlowError("expired")) is True + + +def test_is_auth_error_detects_oauth_non_interactive(): + from tools.mcp_tool import _is_auth_error + from tools.mcp_oauth import OAuthNonInteractiveError + + assert _is_auth_error(OAuthNonInteractiveError("no browser")) is True + + +def test_is_auth_error_detects_httpx_401(): + from tools.mcp_tool import _is_auth_error + import httpx + + response = MagicMock() + response.status_code = 401 + exc = httpx.HTTPStatusError("unauth", request=MagicMock(), response=response) + assert _is_auth_error(exc) is True + + +def test_is_auth_error_rejects_httpx_500(): + from tools.mcp_tool import _is_auth_error + import httpx + + response = MagicMock() + response.status_code = 500 + exc = httpx.HTTPStatusError("oops", request=MagicMock(), response=response) + assert _is_auth_error(exc) is False + + +def test_is_auth_error_rejects_generic_exception(): + from tools.mcp_tool import _is_auth_error + assert _is_auth_error(ValueError("not auth")) is False + assert _is_auth_error(RuntimeError("not auth")) is False + + +def test_call_tool_handler_returns_needs_reauth_on_unrecoverable_401(monkeypatch, tmp_path): + """When session.call_tool raises 401 and handle_401 returns False, + handler returns a structured needs_reauth error (not a generic failure).""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + + from tools.mcp_tool import _make_tool_handler + from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests + from mcp.client.auth import OAuthFlowError + + reset_manager_for_tests() + + # Stub server + server = MagicMock() + server.name = "srv" + session = MagicMock() + + async def _call_tool_raises(*a, **kw): + raise OAuthFlowError("token expired") + + session.call_tool = _call_tool_raises + server.session = session + server._reconnect_event = MagicMock() + server._ready = MagicMock() + server._ready.is_set.return_value = True + + from tools import mcp_tool + mcp_tool._servers["srv"] = server + mcp_tool._server_error_counts.pop("srv", None) + + # Ensure the MCP loop exists (run_on_mcp_loop needs it) + mcp_tool._ensure_mcp_loop() + + # Force handle_401 to return False (no recovery available) + mgr = get_manager() + + async def _h401(name, token=None): + return False + + monkeypatch.setattr(mgr, "handle_401", _h401) + + try: + handler = _make_tool_handler("srv", "tool1", 10.0) + result = handler({"arg": "v"}) + parsed = json.loads(result) + assert parsed.get("needs_reauth") is True, f"expected needs_reauth, got: {parsed}" + assert parsed.get("server") == "srv" + assert "re-auth" in parsed.get("error", "").lower() or "reauth" in parsed.get("error", "").lower() + finally: + mcp_tool._servers.pop("srv", None) + mcp_tool._server_error_counts.pop("srv", None) + + +def test_call_tool_handler_non_auth_error_still_generic(monkeypatch, tmp_path): + """Non-auth exceptions still surface via the generic error path, not needs_reauth.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + from tools.mcp_tool import _make_tool_handler + + server = MagicMock() + server.name = "srv" + session = MagicMock() + + async def _raises(*a, **kw): + raise RuntimeError("unrelated") + + session.call_tool = _raises + server.session = session + + from tools import mcp_tool + mcp_tool._servers["srv"] = server + mcp_tool._server_error_counts.pop("srv", None) + mcp_tool._ensure_mcp_loop() + + try: + handler = _make_tool_handler("srv", "tool1", 10.0) + result = handler({"arg": "v"}) + parsed = json.loads(result) + assert "needs_reauth" not in parsed + assert "MCP call failed" in parsed.get("error", "") + finally: + mcp_tool._servers.pop("srv", None) + mcp_tool._server_error_counts.pop("srv", None) diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index 6b0ef12f2..6e1d7f5fb 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -375,6 +375,103 @@ def remove_oauth_tokens(server_name: str) -> None: logger.info("OAuth tokens removed for '%s'", server_name) +# --------------------------------------------------------------------------- +# Extracted helpers (Task 3 of MCP OAuth consolidation) +# +# These compose into ``build_oauth_auth`` below, and are also used by +# ``tools.mcp_oauth_manager.MCPOAuthManager._build_provider`` so the two +# construction paths share one implementation. +# --------------------------------------------------------------------------- + + +def _configure_callback_port(cfg: dict) -> int: + """Pick or validate the OAuth callback port. + + Stores the resolved port into ``cfg['_resolved_port']`` so sibling + helpers (and the manager) can read it from the same dict. Returns the + resolved port. + + NOTE: also sets the legacy module-level ``_oauth_port`` so existing + calls to ``_wait_for_callback`` keep working. The legacy global is + the root cause of issue #5344 (port collision on concurrent OAuth + flows); replacing it with a ContextVar is out of scope for this + consolidation PR. + """ + global _oauth_port + requested = int(cfg.get("redirect_port", 0)) + port = _find_free_port() if requested == 0 else requested + cfg["_resolved_port"] = port + _oauth_port = port # legacy consumer: _wait_for_callback reads this + return port + + +def _build_client_metadata(cfg: dict) -> "OAuthClientMetadata": + """Build OAuthClientMetadata from the oauth config dict. + + Requires ``cfg['_resolved_port']`` to have been populated by + :func:`_configure_callback_port` first. + """ + port = cfg.get("_resolved_port") + if port is None: + raise ValueError( + "_configure_callback_port() must be called before _build_client_metadata()" + ) + client_name = cfg.get("client_name", "Hermes Agent") + scope = cfg.get("scope") + redirect_uri = f"http://127.0.0.1:{port}/callback" + + metadata_kwargs: dict[str, Any] = { + "client_name": client_name, + "redirect_uris": [AnyUrl(redirect_uri)], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "none", + } + if scope: + metadata_kwargs["scope"] = scope + if cfg.get("client_secret"): + metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post" + + return OAuthClientMetadata.model_validate(metadata_kwargs) + + +def _maybe_preregister_client( + storage: "HermesTokenStorage", + cfg: dict, + client_metadata: "OAuthClientMetadata", +) -> None: + """If cfg has a pre-registered client_id, persist it to storage.""" + client_id = cfg.get("client_id") + if not client_id: + return + port = cfg["_resolved_port"] + redirect_uri = f"http://127.0.0.1:{port}/callback" + + info_dict: dict[str, Any] = { + "client_id": client_id, + "redirect_uris": [redirect_uri], + "grant_types": client_metadata.grant_types, + "response_types": client_metadata.response_types, + "token_endpoint_auth_method": client_metadata.token_endpoint_auth_method, + } + if cfg.get("client_secret"): + info_dict["client_secret"] = cfg["client_secret"] + if cfg.get("client_name"): + info_dict["client_name"] = cfg["client_name"] + if cfg.get("scope"): + info_dict["scope"] = cfg["scope"] + + client_info = OAuthClientInformationFull.model_validate(info_dict) + _write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True)) + logger.debug("Pre-registered client_id=%s for '%s'", client_id, storage._server_name) + + +def _parse_base_url(server_url: str) -> str: + """Strip path component from server URL, returning the base origin.""" + parsed = urlparse(server_url) + return f"{parsed.scheme}://{parsed.netloc}" + + def build_oauth_auth( server_name: str, server_url: str, @@ -382,7 +479,9 @@ def build_oauth_auth( ) -> "OAuthClientProvider | None": """Build an ``httpx.Auth``-compatible OAuth handler for an MCP server. - Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config. + Public API preserved for backwards compatibility. New code should use + :func:`tools.mcp_oauth_manager.get_manager` so OAuth state is shared + across config-time, runtime, and reconnect paths. Args: server_name: Server key in mcp_servers config (used for storage). @@ -396,87 +495,32 @@ def build_oauth_auth( if not _OAUTH_AVAILABLE: logger.warning( "MCP OAuth requested for '%s' but SDK auth types are not available. " - "Install with: pip install 'mcp>=1.10.0'", + "Install with: pip install 'mcp>=1.26.0'", server_name, ) return None - global _oauth_port - - cfg = oauth_config or {} - - # --- Storage --- + cfg = dict(oauth_config or {}) # copy — we mutate _resolved_port storage = HermesTokenStorage(server_name) - # --- Non-interactive warning --- if not _is_interactive() and not storage.has_cached_tokens(): logger.warning( - "MCP OAuth for '%s': non-interactive environment and no cached tokens found. " - "The OAuth flow requires browser authorization. Run interactively first " - "to complete the initial authorization, then cached tokens will be reused.", + "MCP OAuth for '%s': non-interactive environment and no cached tokens " + "found. The OAuth flow requires browser authorization. Run " + "interactively first to complete the initial authorization, then " + "cached tokens will be reused.", server_name, ) - # --- Pick callback port --- - redirect_port = int(cfg.get("redirect_port", 0)) - if redirect_port == 0: - redirect_port = _find_free_port() - _oauth_port = redirect_port + _configure_callback_port(cfg) + client_metadata = _build_client_metadata(cfg) + _maybe_preregister_client(storage, cfg, client_metadata) - # --- Client metadata --- - client_name = cfg.get("client_name", "Hermes Agent") - scope = cfg.get("scope") - redirect_uri = f"http://127.0.0.1:{redirect_port}/callback" - - metadata_kwargs: dict[str, Any] = { - "client_name": client_name, - "redirect_uris": [AnyUrl(redirect_uri)], - "grant_types": ["authorization_code", "refresh_token"], - "response_types": ["code"], - "token_endpoint_auth_method": "none", - } - if scope: - metadata_kwargs["scope"] = scope - - client_secret = cfg.get("client_secret") - if client_secret: - metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post" - - client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs) - - # --- Pre-registered client --- - client_id = cfg.get("client_id") - if client_id: - info_dict: dict[str, Any] = { - "client_id": client_id, - "redirect_uris": [redirect_uri], - "grant_types": client_metadata.grant_types, - "response_types": client_metadata.response_types, - "token_endpoint_auth_method": client_metadata.token_endpoint_auth_method, - } - if client_secret: - info_dict["client_secret"] = client_secret - if client_name: - info_dict["client_name"] = client_name - if scope: - info_dict["scope"] = scope - - client_info = OAuthClientInformationFull.model_validate(info_dict) - _write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True)) - logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name) - - # --- Base URL for discovery --- - parsed = urlparse(server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # --- Build provider --- - provider = OAuthClientProvider( - server_url=base_url, + return OAuthClientProvider( + server_url=_parse_base_url(server_url), client_metadata=client_metadata, storage=storage, redirect_handler=_redirect_handler, callback_handler=_wait_for_callback, timeout=float(cfg.get("timeout", 300)), ) - - return provider diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py new file mode 100644 index 000000000..d3760e3b8 --- /dev/null +++ b/tools/mcp_oauth_manager.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +"""Central manager for per-server MCP OAuth state. + +One instance shared across the process. Holds per-server OAuth provider +instances and coordinates: + +- **Cross-process token reload** via mtime-based disk watch. When an external + process (e.g. a user cron job) refreshes tokens on disk, the next auth flow + picks them up without requiring a process restart. +- **401 deduplication** via in-flight futures. When N concurrent tool calls + all hit 401 with the same access_token, only one recovery attempt fires; + the rest await the same result. +- **Reconnect signalling** for long-lived MCP sessions. The manager itself + does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does — + but the manager is the single source of truth that decides when reconnect + is warranted. + +Replaces what used to be scattered across eight call sites in `mcp_oauth.py`, +`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place +that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths +go through `get_manager()`. + +Design reference: + +- Claude Code's ``invalidateOAuthCacheIfDiskChanged`` + (``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical + external-refresh staleness bug class. +- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed`` + (``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's + lazy refresh rather than calling refresh before every op, because one + ``stat()`` per tool call is cheaper than an ``await`` + potential refresh + round-trip, and the SDK's in-memory expiry path is already correct. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from dataclasses import dataclass, field +from typing import Any, Optional + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Per-server entry +# --------------------------------------------------------------------------- + + +@dataclass +class _ProviderEntry: + """Per-server OAuth state tracked by the manager. + + Fields: + server_url: The MCP server URL used to build the provider. Tracked + so we can discard a cached provider if the URL changes. + oauth_config: Optional dict from ``mcp_servers..oauth``. + provider: The ``httpx.Auth``-compatible provider wrapping the MCP + SDK. None until first use. + last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file. + Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed` + to detect external refreshes. + lock: Serialises concurrent access to this entry's state. Bound to + whichever asyncio loop first awaits it (the MCP event loop). + pending_401: In-flight 401-handler futures keyed by the failed + access_token, for deduplicating thundering-herd 401s. Mirrors + Claude Code's ``pending401Handlers`` map. + """ + + server_url: str + oauth_config: Optional[dict] + provider: Optional[Any] = None + last_mtime_ns: int = 0 + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict) + + +# --------------------------------------------------------------------------- +# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch +# --------------------------------------------------------------------------- + + +def _make_hermes_provider_class() -> Optional[type]: + """Lazy-import the SDK base class and return our subclass. + + Wrapped in a function so this module imports cleanly even when the + MCP SDK's OAuth module is unavailable (e.g. older mcp versions). + """ + try: + from mcp.client.auth.oauth2 import OAuthClientProvider + except ImportError: # pragma: no cover — SDK required in CI + return None + + class HermesMCPOAuthProvider(OAuthClientProvider): + """OAuthClientProvider with pre-flow disk-mtime reload. + + Before every ``async_auth_flow`` invocation, asks the manager to + check whether the tokens file on disk has been modified externally. + If so, the manager resets ``_initialized`` so the next flow + re-reads from storage. + + This makes external-process refreshes (cron, another CLI instance) + visible to the running MCP session without requiring a restart. + + Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged`` + (``src/utils/auth.ts:1320``, CC-1096 / GH#24317). + """ + + def __init__(self, *args: Any, server_name: str = "", **kwargs: Any): + super().__init__(*args, **kwargs) + self._hermes_server_name = server_name + + async def async_auth_flow(self, request): # type: ignore[override] + # Pre-flow hook: ask the manager to refresh from disk if needed. + # Any failure here is non-fatal — we just log and proceed with + # whatever state the SDK already has. + try: + await get_manager().invalidate_if_disk_changed( + self._hermes_server_name + ) + except Exception as exc: # pragma: no cover — defensive + logger.debug( + "MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s", + self._hermes_server_name, exc, + ) + + # Delegate to the SDK's auth flow + async for item in super().async_auth_flow(request): + yield item + + return HermesMCPOAuthProvider + + +# Cached at import time. Tested and used by :class:`MCPOAuthManager`. +_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class() + + +# --------------------------------------------------------------------------- +# Manager +# --------------------------------------------------------------------------- + + +class MCPOAuthManager: + """Single source of truth for per-server MCP OAuth state. + + Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for + get-or-create semantics. Per-entry state is guarded by the entry's own + ``asyncio.Lock`` (used from the MCP event loop thread). + """ + + def __init__(self) -> None: + self._entries: dict[str, _ProviderEntry] = {} + self._entries_lock = threading.Lock() + + # -- Provider construction / caching ------------------------------------- + + def get_or_build_provider( + self, + server_name: str, + server_url: str, + oauth_config: Optional[dict], + ) -> Optional[Any]: + """Return a cached OAuth provider for ``server_name`` or build one. + + Idempotent: repeat calls with the same name return the same instance. + If ``server_url`` changes for a given name, the cached entry is + discarded and a fresh provider is built. + + Returns None if the MCP SDK's OAuth support is unavailable. + """ + with self._entries_lock: + entry = self._entries.get(server_name) + if entry is not None and entry.server_url != server_url: + logger.info( + "MCP OAuth '%s': URL changed from %s to %s, discarding cache", + server_name, entry.server_url, server_url, + ) + entry = None + + if entry is None: + entry = _ProviderEntry( + server_url=server_url, + oauth_config=oauth_config, + ) + self._entries[server_name] = entry + + if entry.provider is None: + entry.provider = self._build_provider(server_name, entry) + + return entry.provider + + def _build_provider( + self, + server_name: str, + entry: _ProviderEntry, + ) -> Optional[Any]: + """Build the underlying OAuth provider. + + Constructs :class:`HermesMCPOAuthProvider` directly using the helpers + extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow + disk-watch hook so external token refreshes (cron, other CLI + instances) are visible to running MCP sessions. + + Returns None if the MCP SDK's OAuth support is unavailable. + """ + if _HERMES_PROVIDER_CLS is None: + logger.warning( + "MCP OAuth '%s': SDK auth module unavailable", server_name, + ) + return None + + # Local imports avoid circular deps at module import time. + from tools.mcp_oauth import ( + HermesTokenStorage, + _OAUTH_AVAILABLE, + _build_client_metadata, + _configure_callback_port, + _is_interactive, + _maybe_preregister_client, + _parse_base_url, + _redirect_handler, + _wait_for_callback, + ) + + if not _OAUTH_AVAILABLE: + return None + + cfg = dict(entry.oauth_config or {}) + storage = HermesTokenStorage(server_name) + + if not _is_interactive() and not storage.has_cached_tokens(): + logger.warning( + "MCP OAuth for '%s': non-interactive environment and no " + "cached tokens found. Run interactively first to complete " + "initial authorization.", + server_name, + ) + + _configure_callback_port(cfg) + client_metadata = _build_client_metadata(cfg) + _maybe_preregister_client(storage, cfg, client_metadata) + + return _HERMES_PROVIDER_CLS( + server_name=server_name, + server_url=_parse_base_url(entry.server_url), + client_metadata=client_metadata, + storage=storage, + redirect_handler=_redirect_handler, + callback_handler=_wait_for_callback, + timeout=float(cfg.get("timeout", 300)), + ) + + def remove(self, server_name: str) -> None: + """Evict the provider from cache AND delete tokens from disk. + + Called by ``hermes mcp remove `` and (indirectly) by + ``hermes mcp login `` during forced re-auth. + """ + with self._entries_lock: + self._entries.pop(server_name, None) + + from tools.mcp_oauth import remove_oauth_tokens + remove_oauth_tokens(server_name) + logger.info( + "MCP OAuth '%s': evicted from cache and removed from disk", + server_name, + ) + + # -- Disk watch ---------------------------------------------------------- + + async def invalidate_if_disk_changed(self, server_name: str) -> bool: + """If the tokens file on disk has a newer mtime than last-seen, force + the MCP SDK provider to reload its in-memory state. + + Returns True if the cache was invalidated (mtime differed). This is + the core fix for the external-refresh workflow: a cron job writes + fresh tokens to disk, and on the next tool call the running MCP + session picks them up without a restart. + """ + from tools.mcp_oauth import _get_token_dir, _safe_filename + + entry = self._entries.get(server_name) + if entry is None or entry.provider is None: + return False + + async with entry.lock: + tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json" + try: + mtime_ns = tokens_path.stat().st_mtime_ns + except (FileNotFoundError, OSError): + return False + + if mtime_ns != entry.last_mtime_ns: + old = entry.last_mtime_ns + entry.last_mtime_ns = mtime_ns + # Force the SDK's OAuthClientProvider to reload from storage + # on its next auth flow. `_initialized` is private API but + # stable across the MCP SDK versions we pin (>=1.26.0). + if hasattr(entry.provider, "_initialized"): + entry.provider._initialized = False # noqa: SLF001 + logger.info( + "MCP OAuth '%s': tokens file changed (mtime %d -> %d), " + "forcing reload", + server_name, old, mtime_ns, + ) + return True + return False + + # -- 401 handler (dedup'd) ----------------------------------------------- + + async def handle_401( + self, + server_name: str, + failed_access_token: Optional[str] = None, + ) -> bool: + """Handle a 401 from a tool call, deduplicated across concurrent callers. + + Returns: + True if a (possibly new) access token is now available — caller + should trigger a reconnect and retry the operation. + False if no recovery path exists — caller should surface a + ``needs_reauth`` error to the model so it stops hallucinating + manual refresh attempts. + + Thundering-herd protection: if N concurrent tool calls hit 401 with + the same ``failed_access_token``, only one recovery attempt fires. + Others await the same future. + """ + entry = self._entries.get(server_name) + if entry is None or entry.provider is None: + return False + + key = failed_access_token or "" + loop = asyncio.get_running_loop() + + async with entry.lock: + pending = entry.pending_401.get(key) + if pending is None: + pending = loop.create_future() + entry.pending_401[key] = pending + + async def _do_handle() -> None: + try: + # Step 1: Did disk change? Picks up external refresh. + disk_changed = await self.invalidate_if_disk_changed( + server_name + ) + if disk_changed: + if not pending.done(): + pending.set_result(True) + return + + # Step 2: No disk change — if the SDK can refresh + # in-place, let the caller retry. The SDK's httpx.Auth + # flow will issue the refresh on the next request. + provider = entry.provider + ctx = getattr(provider, "context", None) + can_refresh = False + if ctx is not None: + can_refresh_fn = getattr(ctx, "can_refresh_token", None) + if callable(can_refresh_fn): + try: + can_refresh = bool(can_refresh_fn()) + except Exception: + can_refresh = False + if not pending.done(): + pending.set_result(can_refresh) + except Exception as exc: # pragma: no cover — defensive + logger.warning( + "MCP OAuth '%s': 401 handler failed: %s", + server_name, exc, + ) + if not pending.done(): + pending.set_result(False) + finally: + entry.pending_401.pop(key, None) + + asyncio.create_task(_do_handle()) + + try: + return await pending + except Exception as exc: # pragma: no cover — defensive + logger.warning( + "MCP OAuth '%s': awaiting 401 handler failed: %s", + server_name, exc, + ) + return False + + +# --------------------------------------------------------------------------- +# Module-level singleton +# --------------------------------------------------------------------------- + + +_MANAGER: Optional[MCPOAuthManager] = None +_MANAGER_LOCK = threading.Lock() + + +def get_manager() -> MCPOAuthManager: + """Return the process-wide :class:`MCPOAuthManager` singleton.""" + global _MANAGER + with _MANAGER_LOCK: + if _MANAGER is None: + _MANAGER = MCPOAuthManager() + return _MANAGER + + +def reset_manager_for_tests() -> None: + """Test-only helper: drop the singleton so fixtures start clean.""" + global _MANAGER + with _MANAGER_LOCK: + _MANAGER = None diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index a73aa4381..e5e856d0b 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -783,7 +783,8 @@ class MCPServerTask: __slots__ = ( "name", "session", "tool_timeout", - "_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", + "_task", "_ready", "_shutdown_event", "_reconnect_event", + "_tools", "_error", "_config", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", ) @@ -794,6 +795,12 @@ class MCPServerTask: self._task: Optional[asyncio.Task] = None self._ready = asyncio.Event() self._shutdown_event = asyncio.Event() + # Set by tool handlers on auth failure after manager.handle_401() + # confirms recovery is viable. When set, _run_http / _run_stdio + # exit their async-with blocks cleanly (no exception), and the + # outer run() loop re-enters the transport so the MCP session is + # rebuilt with fresh credentials. + self._reconnect_event = asyncio.Event() self._tools: list = [] self._error: Optional[Exception] = None self._config: dict = {} @@ -887,6 +894,40 @@ class MCPServerTask: self.name, len(self._registered_tool_names), ) + async def _wait_for_lifecycle_event(self) -> str: + """Block until either _shutdown_event or _reconnect_event fires. + + Returns: + "shutdown" if the server should exit the run loop entirely. + "reconnect" if the server should tear down the current MCP + session and re-enter the transport (fresh OAuth + tokens, new session ID, etc.). The reconnect event + is cleared before return so the next cycle starts + with a fresh signal. + + Shutdown takes precedence if both events are set simultaneously. + """ + shutdown_task = asyncio.create_task(self._shutdown_event.wait()) + reconnect_task = asyncio.create_task(self._reconnect_event.wait()) + try: + await asyncio.wait( + {shutdown_task, reconnect_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + finally: + for t in (shutdown_task, reconnect_task): + if not t.done(): + t.cancel() + try: + await t + except (asyncio.CancelledError, Exception): + pass + + if self._shutdown_event.is_set(): + return "shutdown" + self._reconnect_event.clear() + return "reconnect" + async def _run_stdio(self, config: dict): """Run the server using stdio transport.""" command = config.get("command") @@ -932,7 +973,10 @@ class MCPServerTask: self.session = session await self._discover_tools() self._ready.set() - await self._shutdown_event.wait() + # stdio transport does not use OAuth, but we still honor + # _reconnect_event (e.g. future manual /mcp refresh) for + # consistency with _run_http. + await self._wait_for_lifecycle_event() # Context exited cleanly — subprocess was terminated by the SDK. if new_pids: with _lock: @@ -951,16 +995,18 @@ class MCPServerTask: headers = dict(config.get("headers") or {}) connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) - # OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK. - # If OAuth setup fails (e.g. non-interactive environment without - # cached tokens), re-raise so this server is reported as failed - # without blocking other MCP servers from connecting. + # OAuth 2.1 PKCE: route through the central MCPOAuthManager so the + # same provider instance is reused across reconnects, pre-flow + # disk-watch is active, and config-time CLI code paths share state. + # If OAuth setup fails (e.g. non-interactive env without cached + # tokens), re-raise so this server is reported as failed without + # blocking other MCP servers from connecting. _oauth_auth = None if self._auth_type == "oauth": try: - from tools.mcp_oauth import build_oauth_auth - _oauth_auth = build_oauth_auth( - self.name, url, config.get("oauth") + from tools.mcp_oauth_manager import get_manager + _oauth_auth = get_manager().get_or_build_provider( + self.name, url, config.get("oauth"), ) except Exception as exc: logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) @@ -995,7 +1041,12 @@ class MCPServerTask: self.session = session await self._discover_tools() self._ready.set() - await self._shutdown_event.wait() + reason = await self._wait_for_lifecycle_event() + if reason == "reconnect": + logger.info( + "MCP server '%s': reconnect requested — " + "tearing down HTTP session", self.name, + ) else: # Deprecated API (mcp < 1.24.0): manages httpx client internally. _http_kwargs: dict = { @@ -1012,7 +1063,12 @@ class MCPServerTask: self.session = session await self._discover_tools() self._ready.set() - await self._shutdown_event.wait() + reason = await self._wait_for_lifecycle_event() + if reason == "reconnect": + logger.info( + "MCP server '%s': reconnect requested — " + "tearing down legacy HTTP session", self.name, + ) async def _discover_tools(self): """Discover tools from the connected session.""" @@ -1060,8 +1116,25 @@ class MCPServerTask: await self._run_http(config) else: await self._run_stdio(config) - # Normal exit (shutdown requested) -- break out - break + # Transport returned cleanly. Two cases: + # - _shutdown_event was set: exit the run loop entirely. + # - _reconnect_event was set (auth recovery): loop back and + # rebuild the MCP session with fresh credentials. Do NOT + # touch the retry counters — this is not a failure. + if self._shutdown_event.is_set(): + break + logger.info( + "MCP server '%s': reconnecting (OAuth recovery or " + "manual refresh)", + self.name, + ) + # Reset the session reference; _run_http/_run_stdio will + # repopulate it on successful re-entry. + self.session = None + # Keep _ready set across reconnects so tool handlers can + # still detect a transient in-flight state — it'll be + # re-set after the fresh session initializes. + continue except Exception as exc: self.session = None @@ -1141,6 +1214,12 @@ class MCPServerTask: from tools.registry import registry self._shutdown_event.set() + # Defensive: if _wait_for_lifecycle_event is blocking, we need ANY + # event to unblock it. _shutdown_event alone is sufficient (the + # helper checks shutdown first), but setting reconnect too ensures + # there's no race where the helper misses the shutdown flag after + # returning "reconnect". + self._reconnect_event.set() if self._task and not self._task.done(): try: await asyncio.wait_for(self._task, timeout=10) @@ -1174,6 +1253,175 @@ _servers: Dict[str, MCPServerTask] = {} _server_error_counts: Dict[str, int] = {} _CIRCUIT_BREAKER_THRESHOLD = 3 +# --------------------------------------------------------------------------- +# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation) +# --------------------------------------------------------------------------- + +# Cached tuple of auth-related exception types. Lazy so this module +# imports cleanly when the MCP SDK OAuth module is missing. +_AUTH_ERROR_TYPES: tuple = () + + +def _get_auth_error_types() -> tuple: + """Return a tuple of exception types that indicate MCP OAuth failure. + + Cached after first call. Includes: + - ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` — raised by + the SDK's auth flow when discovery, refresh, or full re-auth fails. + - ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) — kept as an + optional import for forward/backward compatibility. + - ``tools.mcp_oauth.OAuthNonInteractiveError`` — raised by our callback + handler when no user is present to complete a browser flow. + - ``httpx.HTTPStatusError`` — caller must additionally check + ``status_code == 401`` via :func:`_is_auth_error`. + """ + global _AUTH_ERROR_TYPES + if _AUTH_ERROR_TYPES: + return _AUTH_ERROR_TYPES + types: list = [] + try: + from mcp.client.auth import OAuthFlowError, OAuthTokenError + types.extend([OAuthFlowError, OAuthTokenError]) + except ImportError: + pass + try: + # Older MCP SDK variants exported this + from mcp.client.auth import UnauthorizedError # type: ignore + types.append(UnauthorizedError) + except ImportError: + pass + try: + from tools.mcp_oauth import OAuthNonInteractiveError + types.append(OAuthNonInteractiveError) + except ImportError: + pass + try: + import httpx + types.append(httpx.HTTPStatusError) + except ImportError: + pass + _AUTH_ERROR_TYPES = tuple(types) + return _AUTH_ERROR_TYPES + + +def _is_auth_error(exc: BaseException) -> bool: + """Return True if ``exc`` indicates an MCP OAuth failure. + + ``httpx.HTTPStatusError`` is only treated as auth-related when the + response status code is 401. Other HTTP errors fall through to the + generic error path in the tool handlers. + """ + types = _get_auth_error_types() + if not types or not isinstance(exc, types): + return False + try: + import httpx + if isinstance(exc, httpx.HTTPStatusError): + return getattr(exc.response, "status_code", None) == 401 + except ImportError: + pass + return True + + +def _handle_auth_error_and_retry( + server_name: str, + exc: BaseException, + retry_call, + op_description: str, +): + """Attempt auth recovery and one retry; return None to fall through. + + Called by the 5 MCP tool handlers when ``session.()`` raises an + auth-related exception. Workflow: + + 1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if + recovery is viable (i.e., disk has fresh tokens, or the SDK can + refresh in-place). + 2. If yes, set the server's ``_reconnect_event`` so the server task + tears down the current MCP session and rebuilds it with fresh + credentials. Wait briefly for ``_ready`` to re-fire. + 3. Retry the operation once. Return the retry result if it produced + a non-error JSON payload. Otherwise return the ``needs_reauth`` + error dict so the model stops hallucinating manual refresh. + 4. Return None if ``exc`` is not an auth error, signalling the + caller to use the generic error path. + + Args: + server_name: Name of the MCP server that raised. + exc: The exception from the failed tool call. + retry_call: Zero-arg callable that re-runs the tool call, returning + the same JSON string format as the handler. + op_description: Human-readable name of the operation (for logs). + + Returns: + A JSON string if auth recovery was attempted, or None to fall + through to the caller's generic error path. + """ + if not _is_auth_error(exc): + return None + + from tools.mcp_oauth_manager import get_manager + manager = get_manager() + + async def _recover(): + return await manager.handle_401(server_name, None) + + try: + recovered = _run_on_mcp_loop(_recover(), timeout=10) + except Exception as rec_exc: + logger.warning( + "MCP OAuth '%s': recovery attempt failed: %s", + server_name, rec_exc, + ) + recovered = False + + if recovered: + with _lock: + srv = _servers.get(server_name) + if srv is not None and hasattr(srv, "_reconnect_event"): + loop = _mcp_loop + if loop is not None and loop.is_running(): + loop.call_soon_threadsafe(srv._reconnect_event.set) + # Wait briefly for the session to come back ready. Bounded + # so that a stuck reconnect falls through to the error + # path rather than hanging the caller. + deadline = time.monotonic() + 15 + while time.monotonic() < deadline: + if srv.session is not None and srv._ready.is_set(): + break + time.sleep(0.25) + + try: + result = retry_call() + try: + parsed = json.loads(result) + if "error" not in parsed: + _server_error_counts[server_name] = 0 + return result + except (json.JSONDecodeError, TypeError): + _server_error_counts[server_name] = 0 + return result + except Exception as retry_exc: + logger.warning( + "MCP %s/%s retry after auth recovery failed: %s", + server_name, op_description, retry_exc, + ) + + # No recovery available, or retry also failed: surface a structured + # needs_reauth error. Bumps the circuit breaker so the model stops + # retrying the tool. + _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 + return json.dumps({ + "error": ( + f"MCP server '{server_name}' requires re-authentication. " + f"Run `hermes mcp login {server_name}` (or delete the tokens " + f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry " + f"this tool — ask the user to re-authenticate." + ), + "needs_reauth": True, + "server": server_name, + }, ensure_ascii=False) + # Dedicated event loop running in a background daemon thread. _mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_thread: Optional[threading.Thread] = None @@ -1420,8 +1668,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): return json.dumps({"result": structured}, ensure_ascii=False) return json.dumps({"result": text_result}, ensure_ascii=False) + def _call_once(): + return _run_on_mcp_loop(_call(), timeout=tool_timeout) + try: - result = _run_on_mcp_loop(_call(), timeout=tool_timeout) + result = _call_once() # Check if the MCP tool itself returned an error try: parsed = json.loads(result) @@ -1435,6 +1686,16 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float): except InterruptedError: return _interrupted_call_result() except Exception as exc: + # Auth-specific recovery path: consult the manager, signal + # reconnect if viable, retry once. Returns None to fall + # through for non-auth exceptions. + recovered = _handle_auth_error_and_retry( + server_name, exc, _call_once, + f"tools/call {tool_name}", + ) + if recovered is not None: + return recovered + _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 logger.error( "MCP tool %s/%s call failed: %s", @@ -1476,11 +1737,19 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float): resources.append(entry) return json.dumps({"resources": resources}, ensure_ascii=False) - try: + def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) + + try: + return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: + recovered = _handle_auth_error_and_retry( + server_name, exc, _call_once, "resources/list", + ) + if recovered is not None: + return recovered logger.error( "MCP %s/list_resources failed: %s", server_name, exc, ) @@ -1522,11 +1791,19 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float): parts.append(f"[binary data, {len(block.blob)} bytes]") return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False) - try: + def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) + + try: + return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: + recovered = _handle_auth_error_and_retry( + server_name, exc, _call_once, "resources/read", + ) + if recovered is not None: + return recovered logger.error( "MCP %s/read_resource failed: %s", server_name, exc, ) @@ -1571,11 +1848,19 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float): prompts.append(entry) return json.dumps({"prompts": prompts}, ensure_ascii=False) - try: + def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) + + try: + return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: + recovered = _handle_auth_error_and_retry( + server_name, exc, _call_once, "prompts/list", + ) + if recovered is not None: + return recovered logger.error( "MCP %s/list_prompts failed: %s", server_name, exc, ) @@ -1628,11 +1913,19 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float): resp["description"] = result.description return json.dumps(resp, ensure_ascii=False) - try: + def _call_once(): return _run_on_mcp_loop(_call(), timeout=tool_timeout) + + try: + return _call_once() except InterruptedError: return _interrupted_call_result() except Exception as exc: + recovered = _handle_auth_error_and_retry( + server_name, exc, _call_once, "prompts/get", + ) + if recovered is not None: + return recovered logger.error( "MCP %s/get_prompt failed: %s", server_name, exc, )