fix(mcp): consolidate OAuth handling, pick up external token refreshes (#11383)

* feat(mcp-oauth): scaffold MCPOAuthManager

Central manager for per-server MCP OAuth state. Provides
get_or_build_provider (cached), remove (evicts cache + deletes
disk), invalidate_if_disk_changed (mtime watch, core fix for
external-refresh workflow), and handle_401 (dedup'd recovery).

No behavior change yet — existing call sites still use
build_oauth_auth directly. Task 1 of 8 in the MCP OAuth
consolidation (fixes Cthulhu's BetterStack reliability issues).

* feat(mcp-oauth): add HermesMCPOAuthProvider with pre-flow disk watch

Subclasses the MCP SDK's OAuthClientProvider to inject a disk
mtime check before every async_auth_flow, via the central
manager. When a subclass instance is used, external token
refreshes (cron, another CLI instance) are picked up before
the next API call.

Still dead code: the manager's _build_provider still delegates
to build_oauth_auth and returns the plain OAuthClientProvider.
Task 4 wires this subclass in. Task 2 of 8.

* refactor(mcp-oauth): extract build_oauth_auth helpers

Decomposes build_oauth_auth into _configure_callback_port,
_build_client_metadata, _maybe_preregister_client, and
_parse_base_url. Public API preserved. These helpers let
MCPOAuthManager._build_provider reuse the same logic in Task 4
instead of duplicating the construction dance.

Also updates the SDK version hint in the warning from 1.10.0 to
1.26.0 (which is what we actually require for the OAuth types
used here). Task 3 of 8.

* feat(mcp-oauth): manager now builds HermesMCPOAuthProvider directly

_build_provider constructs the disk-watching subclass using the
helpers from Task 3, instead of delegating to the plain
build_oauth_auth factory. Any consumer using the manager now gets
pre-flow disk-freshness checks automatically.

build_oauth_auth is preserved as the public API for backwards
compatibility. The code path is now:

    MCPOAuthManager.get_or_build_provider  ->
      _build_provider  ->
        _configure_callback_port
        _build_client_metadata
        _maybe_preregister_client
        _parse_base_url
        HermesMCPOAuthProvider(...)

Task 4 of 8.

* feat(mcp): wire OAuth manager + add _reconnect_event

MCPServerTask gains _reconnect_event alongside _shutdown_event.
When set, _run_http / _run_stdio exit their async-with blocks
cleanly (no exception), and the outer run() loop re-enters the
transport to rebuild the MCP session with fresh credentials.
This is the recovery path for OAuth failures that the SDK's
in-place httpx.Auth cannot handle (e.g. cron externally consumed
the refresh_token, or server-side session invalidation).

_run_http now asks MCPOAuthManager for the OAuth provider
instead of calling build_oauth_auth directly. Config-time,
runtime, and reconnect paths all share one provider instance
with pre-flow disk-watch active.

shutdown() defensively sets both events so there is no race
between reconnect and shutdown signalling.

Task 5 of 8.

* feat(mcp): detect auth failures in tool handlers, trigger reconnect

All 5 MCP tool handlers (tool call, list_resources, read_resource,
list_prompts, get_prompt) now detect auth failures and route
through MCPOAuthManager.handle_401:

  1. If the manager says recovery is viable (disk has fresh tokens,
     or SDK can refresh in-place), signal MCPServerTask._reconnect_event
     to tear down and rebuild the MCP session with fresh credentials,
     then retry the tool call once.

  2. If no recovery path exists, return a structured needs_reauth
     JSON error so the model stops hallucinating manual refresh
     attempts (the 'let me curl the token endpoint' loop Cthulhu
     pasted from Discord).

_is_auth_error catches OAuthFlowError, OAuthTokenError,
OAuthNonInteractiveError, and httpx.HTTPStatusError(401). Non-auth
exceptions still surface via the generic error path unchanged.

Task 6 of 8.

* feat(mcp-cli): route add/remove through manager, add 'hermes mcp login'

cmd_mcp_add and cmd_mcp_remove now go through MCPOAuthManager
instead of calling build_oauth_auth / remove_oauth_tokens
directly. This means CLI config-time state and runtime MCP
session state are backed by the same provider cache — removing
a server evicts the live provider, adding a server populates
the same cache the MCP session will read from.

New 'hermes mcp login <name>' command:
  - Wipes both the on-disk tokens file and the in-memory
    MCPOAuthManager cache
  - Triggers a fresh OAuth browser flow via the existing probe
    path
  - Intended target for the needs_reauth error Task 6 returns
    to the model

Task 7 of 8.

* test(mcp-oauth): end-to-end integration tests

Five new tests exercising the full consolidation with real file
I/O and real imports (no transport mocks):

  1. external_refresh_picked_up_without_restart — Cthulhu's cron
     workflow. External process writes fresh tokens to disk;
     on the next auth flow the manager's mtime-watch flips
     _initialized and the SDK re-reads from storage.

  2. handle_401_deduplicates_concurrent_callers — 10 concurrent
     handlers for the same failed token fire exactly ONE recovery
     attempt (thundering-herd protection).

  3. handle_401_returns_false_when_no_provider — defensive path
     for unknown servers.

  4. invalidate_if_disk_changed_handles_missing_file — pre-auth
     state returns False cleanly.

  5. provider_is_reused_across_reconnects — cache stickiness so
     reconnects preserve the disk-watch baseline mtime.

Task 8 of 8 — consolidation complete.
This commit is contained in:
Teknium 2026-04-16 21:57:10 -07:00 committed by GitHub
parent 436a7359cd
commit 70768665a4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1566 additions and 90 deletions

View file

@ -5904,6 +5904,12 @@ Examples:
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection") mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
mcp_cfg_p.add_argument("name", help="Server name to configure") mcp_cfg_p.add_argument("name", help="Server name to configure")
mcp_login_p = mcp_sub.add_parser(
"login",
help="Force re-authentication for an OAuth-based MCP server",
)
mcp_login_p.add_argument("name", help="Server name to re-authenticate")
def cmd_mcp(args): def cmd_mcp(args):
from hermes_cli.mcp_config import mcp_command from hermes_cli.mcp_config import mcp_command
mcp_command(args) mcp_command(args)

View file

@ -279,8 +279,8 @@ def cmd_mcp_add(args):
_info(f"Starting OAuth flow for '{name}'...") _info(f"Starting OAuth flow for '{name}'...")
oauth_ok = False oauth_ok = False
try: try:
from tools.mcp_oauth import build_oauth_auth from tools.mcp_oauth_manager import get_manager
oauth_auth = build_oauth_auth(name, url) oauth_auth = get_manager().get_or_build_provider(name, url, None)
if oauth_auth: if oauth_auth:
server_config["auth"] = "oauth" server_config["auth"] = "oauth"
_success("OAuth configured (tokens will be acquired on first connection)") _success("OAuth configured (tokens will be acquired on first connection)")
@ -428,10 +428,12 @@ def cmd_mcp_remove(args):
_remove_mcp_server(name) _remove_mcp_server(name)
_success(f"Removed '{name}' from config") _success(f"Removed '{name}' from config")
# Clean up OAuth tokens if they exist # Clean up OAuth tokens if they exist — route through MCPOAuthManager so
# any provider instance cached in the current process (e.g. from an
# earlier `hermes mcp test` in the same session) is evicted too.
try: try:
from tools.mcp_oauth import remove_oauth_tokens from tools.mcp_oauth_manager import get_manager
remove_oauth_tokens(name) get_manager().remove(name)
_success("Cleaned up OAuth tokens") _success("Cleaned up OAuth tokens")
except Exception: except Exception:
pass pass
@ -577,6 +579,63 @@ def _interpolate_value(value: str) -> str:
return re.sub(r"\$\{(\w+)\}", _replace, value) return re.sub(r"\$\{(\w+)\}", _replace, value)
# ─── hermes mcp login ────────────────────────────────────────────────────────
def cmd_mcp_login(args):
"""Force re-authentication for an OAuth-based MCP server.
Deletes cached tokens (both on disk and in the running process's
MCPOAuthManager cache) and triggers a fresh OAuth flow via the
existing probe path.
Use this when:
- Tokens are stuck in a bad state (server revoked, refresh token
consumed by an external process, etc.)
- You want to re-authenticate to change scopes or account
- A tool call returned ``needs_reauth: true``
"""
name = args.name
servers = _get_mcp_servers()
if name not in servers:
_error(f"Server '{name}' not found in config.")
if servers:
_info(f"Available servers: {', '.join(servers)}")
return
server_config = servers[name]
url = server_config.get("url")
if not url:
_error(f"Server '{name}' has no URL — not an OAuth-capable server")
return
if server_config.get("auth") != "oauth":
_error(f"Server '{name}' is not configured for OAuth (auth={server_config.get('auth')})")
_info("Use `hermes mcp remove` + `hermes mcp add` to reconfigure auth.")
return
# Wipe both disk and in-memory cache so the next probe forces a fresh
# OAuth flow.
try:
from tools.mcp_oauth_manager import get_manager
mgr = get_manager()
mgr.remove(name)
except Exception as exc:
_warning(f"Could not clear existing OAuth state: {exc}")
print()
_info(f"Starting OAuth flow for '{name}'...")
# Probe triggers the OAuth flow (browser redirect + callback capture).
try:
tools = _probe_single_server(name, server_config)
if tools:
_success(f"Authenticated — {len(tools)} tool(s) available")
else:
_success("Authenticated (server reported no tools)")
except Exception as exc:
_error(f"Authentication failed: {exc}")
# ─── hermes mcp configure ──────────────────────────────────────────────────── # ─── hermes mcp configure ────────────────────────────────────────────────────
def cmd_mcp_configure(args): def cmd_mcp_configure(args):
@ -696,6 +755,7 @@ def mcp_command(args):
"test": cmd_mcp_test, "test": cmd_mcp_test,
"configure": cmd_mcp_configure, "configure": cmd_mcp_configure,
"config": cmd_mcp_configure, "config": cmd_mcp_configure,
"login": cmd_mcp_login,
} }
handler = handlers.get(action) handler = handlers.get(action)
@ -713,4 +773,5 @@ def mcp_command(args):
_info("hermes mcp list List servers") _info("hermes mcp list List servers")
_info("hermes mcp test <name> Test connection") _info("hermes mcp test <name> Test connection")
_info("hermes mcp configure <name> Toggle tools") _info("hermes mcp configure <name> Toggle tools")
_info("hermes mcp login <name> Re-authenticate OAuth")
print() print()

View file

@ -539,3 +539,64 @@ class TestDispatcher:
mcp_command(_make_args(mcp_action=None)) mcp_command(_make_args(mcp_action=None))
out = capsys.readouterr().out out = capsys.readouterr().out
assert "Commands:" in out or "No MCP servers" in out assert "Commands:" in out or "No MCP servers" in out
# ---------------------------------------------------------------------------
# Tests: Task 7 consolidation — cmd_mcp_remove evicts manager cache,
# cmd_mcp_login forces re-auth
# ---------------------------------------------------------------------------
class TestMcpRemoveEvictsManager:
def test_remove_evicts_in_memory_provider(self, tmp_path, capsys, monkeypatch):
"""After cmd_mcp_remove, the MCPOAuthManager no longer caches the provider."""
_seed_config(tmp_path, {
"oauth-srv": {"url": "https://example.com/mcp", "auth": "oauth"},
})
monkeypatch.setattr("builtins.input", lambda _: "y")
monkeypatch.setattr(
"hermes_cli.mcp_config.get_hermes_home", lambda: tmp_path
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
reset_manager_for_tests()
mgr = get_manager()
mgr.get_or_build_provider(
"oauth-srv", "https://example.com/mcp", None,
)
assert "oauth-srv" in mgr._entries
from hermes_cli.mcp_config import cmd_mcp_remove
cmd_mcp_remove(_make_args(name="oauth-srv"))
assert "oauth-srv" not in mgr._entries
class TestMcpLogin:
def test_login_rejects_unknown_server(self, tmp_path, capsys):
_seed_config(tmp_path, {})
from hermes_cli.mcp_config import cmd_mcp_login
cmd_mcp_login(_make_args(name="ghost"))
out = capsys.readouterr().out
assert "not found" in out
def test_login_rejects_non_oauth_server(self, tmp_path, capsys):
_seed_config(tmp_path, {
"srv": {"url": "https://example.com/mcp", "auth": "header"},
})
from hermes_cli.mcp_config import cmd_mcp_login
cmd_mcp_login(_make_args(name="srv"))
out = capsys.readouterr().out
assert "not configured for OAuth" in out
def test_login_rejects_stdio_server(self, tmp_path, capsys):
_seed_config(tmp_path, {
"srv": {"command": "npx", "args": ["some-server"]},
})
from hermes_cli.mcp_config import cmd_mcp_login
cmd_mcp_login(_make_args(name="srv"))
out = capsys.readouterr().out
assert "no URL" in out or "not an OAuth" in out

View file

@ -431,3 +431,71 @@ class TestBuildOAuthAuthNonInteractive:
assert auth is not None assert auth is not None
assert "no cached tokens found" not in caplog.text.lower() assert "no cached tokens found" not in caplog.text.lower()
# ---------------------------------------------------------------------------
# Extracted helper tests (Task 3 of MCP OAuth consolidation)
# ---------------------------------------------------------------------------
def test_build_client_metadata_basic():
"""_build_client_metadata returns metadata with expected defaults."""
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
cfg = {"client_name": "Test Client"}
_configure_callback_port(cfg)
md = _build_client_metadata(cfg)
assert md.client_name == "Test Client"
assert "authorization_code" in md.grant_types
assert "refresh_token" in md.grant_types
def test_build_client_metadata_without_secret_is_public():
"""Without client_secret, token endpoint auth is 'none' (public client)."""
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
cfg = {}
_configure_callback_port(cfg)
md = _build_client_metadata(cfg)
assert md.token_endpoint_auth_method == "none"
def test_build_client_metadata_with_secret_is_confidential():
"""With client_secret, token endpoint auth is 'client_secret_post'."""
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
cfg = {"client_secret": "shh"}
_configure_callback_port(cfg)
md = _build_client_metadata(cfg)
assert md.token_endpoint_auth_method == "client_secret_post"
def test_configure_callback_port_picks_free_port():
"""_configure_callback_port(0) picks a free port in the ephemeral range."""
from tools.mcp_oauth import _configure_callback_port
cfg = {"redirect_port": 0}
port = _configure_callback_port(cfg)
assert 1024 < port < 65536
assert cfg["_resolved_port"] == port
def test_configure_callback_port_uses_explicit_port():
"""An explicit redirect_port is preserved."""
from tools.mcp_oauth import _configure_callback_port
cfg = {"redirect_port": 54321}
port = _configure_callback_port(cfg)
assert port == 54321
assert cfg["_resolved_port"] == 54321
def test_parse_base_url_strips_path():
"""_parse_base_url drops path components for OAuth discovery."""
from tools.mcp_oauth import _parse_base_url
assert _parse_base_url("https://example.com/mcp/v1") == "https://example.com"
assert _parse_base_url("https://example.com") == "https://example.com"
assert _parse_base_url("https://host.example.com:8080/api") == "https://host.example.com:8080"

View file

@ -0,0 +1,193 @@
"""End-to-end integration tests for the MCP OAuth consolidation.
Exercises the full chain manager, provider subclass, disk watch, 401
dedup with real file I/O and real imports (no transport mocks, no
subprocesses). These are the tests that would catch Cthulhu's original
BetterStack bug: an external process rewrites the tokens file on disk,
and the running Hermes session picks up the new tokens on the next auth
flow without requiring a restart.
"""
import asyncio
import json
import os
import time
import pytest
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
@pytest.mark.asyncio
async def test_external_refresh_picked_up_without_restart(tmp_path, monkeypatch):
"""Simulate Cthulhu's cron workflow end-to-end.
1. A running Hermes session has OAuth tokens loaded in memory.
2. An external process (cron) writes fresh tokens to disk.
3. On the next auth flow, the manager's disk-watch invalidates the
in-memory state so the SDK re-reads from storage.
4. ``provider.context.current_tokens`` now reflects the new tokens
with no process restart required.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
token_dir = tmp_path / "mcp-tokens"
token_dir.mkdir(parents=True)
tokens_file = token_dir / "srv.json"
client_info_file = token_dir / "srv.client.json"
# Pre-seed the baseline state: valid tokens the session loaded at startup.
tokens_file.write_text(json.dumps({
"access_token": "OLD_ACCESS",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "OLD_REFRESH",
}))
client_info_file.write_text(json.dumps({
"client_id": "test-client",
"redirect_uris": ["http://127.0.0.1:12345/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}))
mgr = MCPOAuthManager()
provider = mgr.get_or_build_provider(
"srv", "https://example.com/mcp", None,
)
assert provider is not None
# The SDK's _initialize reads tokens from storage into memory. This
# is what happens on the first http request under normal operation.
await provider._initialize()
assert provider.context.current_tokens.access_token == "OLD_ACCESS"
# Now record the baseline mtime in the manager (this happens
# automatically via the HermesMCPOAuthProvider.async_auth_flow
# pre-hook on the first real request, but we exercise it directly
# here for test determinism).
await mgr.invalidate_if_disk_changed("srv")
# EXTERNAL PROCESS: cron rewrites the tokens file with fresh creds.
# The old refresh_token has been consumed by this external exchange.
future_mtime = time.time() + 1
tokens_file.write_text(json.dumps({
"access_token": "NEW_ACCESS",
"token_type": "Bearer",
"expires_in": 3600,
"refresh_token": "NEW_REFRESH",
}))
os.utime(tokens_file, (future_mtime, future_mtime))
# The next auth flow should detect the mtime change and reload.
changed = await mgr.invalidate_if_disk_changed("srv")
assert changed, "manager must detect the disk mtime change"
assert provider._initialized is False, "_initialized must flip so SDK re-reads storage"
# Simulate the next async_auth_flow: _initialize runs because _initialized=False.
await provider._initialize()
assert provider.context.current_tokens.access_token == "NEW_ACCESS"
assert provider.context.current_tokens.refresh_token == "NEW_REFRESH"
@pytest.mark.asyncio
async def test_handle_401_deduplicates_concurrent_callers(tmp_path, monkeypatch):
"""Ten concurrent 401 handlers for the same token should fire one recovery.
Mirrors Claude Code's pending401Handlers dedup pattern — prevents N MCP
tool calls hitting 401 simultaneously from all independently clearing
caches and re-reading the keychain (which thrashes the storage and
bogs down startup per CC-1096).
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
token_dir = tmp_path / "mcp-tokens"
token_dir.mkdir(parents=True)
(token_dir / "srv.json").write_text(json.dumps({
"access_token": "TOK",
"token_type": "Bearer",
"expires_in": 3600,
}))
mgr = MCPOAuthManager()
provider = mgr.get_or_build_provider(
"srv", "https://example.com/mcp", None,
)
assert provider is not None
# Count how many times invalidate_if_disk_changed is called — proxy for
# how many actual recovery attempts fire.
call_count = 0
real_invalidate = mgr.invalidate_if_disk_changed
async def counting(name):
nonlocal call_count
call_count += 1
return await real_invalidate(name)
monkeypatch.setattr(mgr, "invalidate_if_disk_changed", counting)
# Fire 10 concurrent handlers with the same failed token.
results = await asyncio.gather(*(
mgr.handle_401("srv", "SAME_FAILED_TOKEN") for _ in range(10)
))
# All callers get the same result (the shared future's resolution).
assert all(r == results[0] for r in results), "dedup must return identical result"
# Exactly ONE recovery ran — the rest awaited the same pending future.
assert call_count == 1, f"expected 1 recovery attempt, got {call_count}"
@pytest.mark.asyncio
async def test_handle_401_returns_false_when_no_provider(tmp_path, monkeypatch):
"""handle_401 for an unknown server returns False cleanly."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
mgr = MCPOAuthManager()
result = await mgr.handle_401("nonexistent", "any_token")
assert result is False
@pytest.mark.asyncio
async def test_invalidate_if_disk_changed_handles_missing_file(tmp_path, monkeypatch):
"""invalidate_if_disk_changed returns False when tokens file doesn't exist."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
mgr = MCPOAuthManager()
mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
# No tokens file exists yet — this is the pre-auth state
result = await mgr.invalidate_if_disk_changed("srv")
assert result is False
@pytest.mark.asyncio
async def test_provider_is_reused_across_reconnects(tmp_path, monkeypatch):
"""The manager caches providers; multiple reconnects reuse the same instance.
This is what makes the disk-watch stick across reconnects: tearing down
the MCP session and rebuilding it (Task 5's _reconnect_event path) must
not create a new provider, otherwise ``last_mtime_ns`` resets and the
first post-reconnect auth flow would spuriously "detect" a change.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
mgr = MCPOAuthManager()
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
# Simulate a reconnect: _run_http calls get_or_build_provider again
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert p1 is p2, "manager must cache the provider across reconnects"

View file

@ -0,0 +1,141 @@
"""Tests for the MCP OAuth manager (tools/mcp_oauth_manager.py).
The manager consolidates the eight scattered MCP-OAuth call sites into a
single object with disk-mtime watch, dedup'd 401 handling, and a provider
cache. See `tools/mcp_oauth_manager.py` for design rationale.
"""
import json
import os
import time
import pytest
pytest.importorskip(
"mcp.client.auth.oauth2",
reason="MCP SDK 1.26.0+ required for OAuth support",
)
def test_manager_is_singleton():
"""get_manager() returns the same instance across calls."""
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
reset_manager_for_tests()
m1 = get_manager()
m2 = get_manager()
assert m1 is m2
def test_manager_get_or_build_provider_caches(tmp_path, monkeypatch):
"""Calling get_or_build_provider twice with same name returns same provider."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager
mgr = MCPOAuthManager()
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert p1 is p2
def test_manager_get_or_build_rebuilds_on_url_change(tmp_path, monkeypatch):
"""Changing the URL discards the cached provider."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager
mgr = MCPOAuthManager()
p1 = mgr.get_or_build_provider("srv", "https://a.example.com/mcp", None)
p2 = mgr.get_or_build_provider("srv", "https://b.example.com/mcp", None)
assert p1 is not p2
def test_manager_remove_evicts_cache(tmp_path, monkeypatch):
"""remove(name) evicts the provider from cache AND deletes disk files."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager
# Pre-seed tokens on disk
token_dir = tmp_path / "mcp-tokens"
token_dir.mkdir(parents=True)
(token_dir / "srv.json").write_text(json.dumps({
"access_token": "TOK",
"token_type": "Bearer",
}))
mgr = MCPOAuthManager()
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert p1 is not None
assert (token_dir / "srv.json").exists()
mgr.remove("srv")
assert not (token_dir / "srv.json").exists()
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert p1 is not p2
def test_hermes_provider_subclass_exists():
"""HermesMCPOAuthProvider is defined and subclasses OAuthClientProvider."""
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS
from mcp.client.auth.oauth2 import OAuthClientProvider
assert _HERMES_PROVIDER_CLS is not None
assert issubclass(_HERMES_PROVIDER_CLS, OAuthClientProvider)
@pytest.mark.asyncio
async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch):
"""When the tokens file mtime changes, provider._initialized flips False.
This is the behaviour Claude Code ships as
invalidateOAuthCacheIfDiskChanged (CC-1096 / GH#24317) and is the core
fix for Cthulhu's external-cron refresh workflow.
"""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
reset_manager_for_tests()
token_dir = tmp_path / "mcp-tokens"
token_dir.mkdir(parents=True)
tokens_file = token_dir / "srv.json"
tokens_file.write_text(json.dumps({
"access_token": "OLD",
"token_type": "Bearer",
}))
mgr = MCPOAuthManager()
provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert provider is not None
# First call: records mtime (zero -> real) -> returns True
changed1 = await mgr.invalidate_if_disk_changed("srv")
assert changed1 is True
# No file change -> False
changed2 = await mgr.invalidate_if_disk_changed("srv")
assert changed2 is False
# Touch file with a newer mtime
future_mtime = time.time() + 10
os.utime(tokens_file, (future_mtime, future_mtime))
changed3 = await mgr.invalidate_if_disk_changed("srv")
assert changed3 is True
# _initialized flipped — next async_auth_flow will re-read from disk
assert provider._initialized is False
def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch):
"""get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider."""
from tools.mcp_oauth_manager import (
MCPOAuthManager, _HERMES_PROVIDER_CLS, reset_manager_for_tests,
)
reset_manager_for_tests()
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
mgr = MCPOAuthManager()
provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
assert _HERMES_PROVIDER_CLS is not None
assert isinstance(provider, _HERMES_PROVIDER_CLS)
assert provider._hermes_server_name == "srv"

View file

@ -0,0 +1,57 @@
"""Tests for the MCPServerTask reconnect signal.
When the OAuth layer cannot recover in-place (e.g., external refresh of a
single-use refresh_token made the SDK's in-memory refresh fail), the tool
handler signals MCPServerTask to tear down the current MCP session and
reconnect with fresh credentials. This file exercises the signal plumbing
in isolation from the full stdio/http transport machinery.
"""
import asyncio
import pytest
@pytest.mark.asyncio
async def test_reconnect_event_attribute_exists():
"""MCPServerTask has a _reconnect_event alongside _shutdown_event."""
from tools.mcp_tool import MCPServerTask
task = MCPServerTask("test")
assert hasattr(task, "_reconnect_event")
assert isinstance(task._reconnect_event, asyncio.Event)
assert not task._reconnect_event.is_set()
@pytest.mark.asyncio
async def test_wait_for_lifecycle_event_returns_reconnect():
"""When _reconnect_event fires, helper returns 'reconnect' and clears it."""
from tools.mcp_tool import MCPServerTask
task = MCPServerTask("test")
task._reconnect_event.set()
reason = await task._wait_for_lifecycle_event()
assert reason == "reconnect"
# Should have cleared so the next cycle starts fresh
assert not task._reconnect_event.is_set()
@pytest.mark.asyncio
async def test_wait_for_lifecycle_event_returns_shutdown():
"""When _shutdown_event fires, helper returns 'shutdown'."""
from tools.mcp_tool import MCPServerTask
task = MCPServerTask("test")
task._shutdown_event.set()
reason = await task._wait_for_lifecycle_event()
assert reason == "shutdown"
@pytest.mark.asyncio
async def test_wait_for_lifecycle_event_shutdown_wins_when_both_set():
"""If both events are set simultaneously, shutdown takes precedence."""
from tools.mcp_tool import MCPServerTask
task = MCPServerTask("test")
task._shutdown_event.set()
task._reconnect_event.set()
reason = await task._wait_for_lifecycle_event()
assert reason == "shutdown"

View file

@ -0,0 +1,139 @@
"""Tests for MCP tool-handler auth-failure detection.
When a tool call raises UnauthorizedError / OAuthNonInteractiveError /
httpx.HTTPStatusError(401), the handler should:
1. Ask MCPOAuthManager.handle_401 if recovery is viable.
2. If yes, trigger MCPServerTask._reconnect_event and retry once.
3. If no, return a structured needs_reauth error so the model stops
hallucinating manual refresh attempts.
"""
import json
from unittest.mock import MagicMock
import pytest
pytest.importorskip("mcp.client.auth.oauth2")
def test_is_auth_error_detects_oauth_flow_error():
from tools.mcp_tool import _is_auth_error
from mcp.client.auth import OAuthFlowError
assert _is_auth_error(OAuthFlowError("expired")) is True
def test_is_auth_error_detects_oauth_non_interactive():
from tools.mcp_tool import _is_auth_error
from tools.mcp_oauth import OAuthNonInteractiveError
assert _is_auth_error(OAuthNonInteractiveError("no browser")) is True
def test_is_auth_error_detects_httpx_401():
from tools.mcp_tool import _is_auth_error
import httpx
response = MagicMock()
response.status_code = 401
exc = httpx.HTTPStatusError("unauth", request=MagicMock(), response=response)
assert _is_auth_error(exc) is True
def test_is_auth_error_rejects_httpx_500():
from tools.mcp_tool import _is_auth_error
import httpx
response = MagicMock()
response.status_code = 500
exc = httpx.HTTPStatusError("oops", request=MagicMock(), response=response)
assert _is_auth_error(exc) is False
def test_is_auth_error_rejects_generic_exception():
from tools.mcp_tool import _is_auth_error
assert _is_auth_error(ValueError("not auth")) is False
assert _is_auth_error(RuntimeError("not auth")) is False
def test_call_tool_handler_returns_needs_reauth_on_unrecoverable_401(monkeypatch, tmp_path):
"""When session.call_tool raises 401 and handle_401 returns False,
handler returns a structured needs_reauth error (not a generic failure)."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_tool import _make_tool_handler
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
from mcp.client.auth import OAuthFlowError
reset_manager_for_tests()
# Stub server
server = MagicMock()
server.name = "srv"
session = MagicMock()
async def _call_tool_raises(*a, **kw):
raise OAuthFlowError("token expired")
session.call_tool = _call_tool_raises
server.session = session
server._reconnect_event = MagicMock()
server._ready = MagicMock()
server._ready.is_set.return_value = True
from tools import mcp_tool
mcp_tool._servers["srv"] = server
mcp_tool._server_error_counts.pop("srv", None)
# Ensure the MCP loop exists (run_on_mcp_loop needs it)
mcp_tool._ensure_mcp_loop()
# Force handle_401 to return False (no recovery available)
mgr = get_manager()
async def _h401(name, token=None):
return False
monkeypatch.setattr(mgr, "handle_401", _h401)
try:
handler = _make_tool_handler("srv", "tool1", 10.0)
result = handler({"arg": "v"})
parsed = json.loads(result)
assert parsed.get("needs_reauth") is True, f"expected needs_reauth, got: {parsed}"
assert parsed.get("server") == "srv"
assert "re-auth" in parsed.get("error", "").lower() or "reauth" in parsed.get("error", "").lower()
finally:
mcp_tool._servers.pop("srv", None)
mcp_tool._server_error_counts.pop("srv", None)
def test_call_tool_handler_non_auth_error_still_generic(monkeypatch, tmp_path):
"""Non-auth exceptions still surface via the generic error path, not needs_reauth."""
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
from tools.mcp_tool import _make_tool_handler
server = MagicMock()
server.name = "srv"
session = MagicMock()
async def _raises(*a, **kw):
raise RuntimeError("unrelated")
session.call_tool = _raises
server.session = session
from tools import mcp_tool
mcp_tool._servers["srv"] = server
mcp_tool._server_error_counts.pop("srv", None)
mcp_tool._ensure_mcp_loop()
try:
handler = _make_tool_handler("srv", "tool1", 10.0)
result = handler({"arg": "v"})
parsed = json.loads(result)
assert "needs_reauth" not in parsed
assert "MCP call failed" in parsed.get("error", "")
finally:
mcp_tool._servers.pop("srv", None)
mcp_tool._server_error_counts.pop("srv", None)

View file

@ -375,6 +375,103 @@ def remove_oauth_tokens(server_name: str) -> None:
logger.info("OAuth tokens removed for '%s'", server_name) logger.info("OAuth tokens removed for '%s'", server_name)
# ---------------------------------------------------------------------------
# Extracted helpers (Task 3 of MCP OAuth consolidation)
#
# These compose into ``build_oauth_auth`` below, and are also used by
# ``tools.mcp_oauth_manager.MCPOAuthManager._build_provider`` so the two
# construction paths share one implementation.
# ---------------------------------------------------------------------------
def _configure_callback_port(cfg: dict) -> int:
"""Pick or validate the OAuth callback port.
Stores the resolved port into ``cfg['_resolved_port']`` so sibling
helpers (and the manager) can read it from the same dict. Returns the
resolved port.
NOTE: also sets the legacy module-level ``_oauth_port`` so existing
calls to ``_wait_for_callback`` keep working. The legacy global is
the root cause of issue #5344 (port collision on concurrent OAuth
flows); replacing it with a ContextVar is out of scope for this
consolidation PR.
"""
global _oauth_port
requested = int(cfg.get("redirect_port", 0))
port = _find_free_port() if requested == 0 else requested
cfg["_resolved_port"] = port
_oauth_port = port # legacy consumer: _wait_for_callback reads this
return port
def _build_client_metadata(cfg: dict) -> "OAuthClientMetadata":
"""Build OAuthClientMetadata from the oauth config dict.
Requires ``cfg['_resolved_port']`` to have been populated by
:func:`_configure_callback_port` first.
"""
port = cfg.get("_resolved_port")
if port is None:
raise ValueError(
"_configure_callback_port() must be called before _build_client_metadata()"
)
client_name = cfg.get("client_name", "Hermes Agent")
scope = cfg.get("scope")
redirect_uri = f"http://127.0.0.1:{port}/callback"
metadata_kwargs: dict[str, Any] = {
"client_name": client_name,
"redirect_uris": [AnyUrl(redirect_uri)],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
if scope:
metadata_kwargs["scope"] = scope
if cfg.get("client_secret"):
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
return OAuthClientMetadata.model_validate(metadata_kwargs)
def _maybe_preregister_client(
storage: "HermesTokenStorage",
cfg: dict,
client_metadata: "OAuthClientMetadata",
) -> None:
"""If cfg has a pre-registered client_id, persist it to storage."""
client_id = cfg.get("client_id")
if not client_id:
return
port = cfg["_resolved_port"]
redirect_uri = f"http://127.0.0.1:{port}/callback"
info_dict: dict[str, Any] = {
"client_id": client_id,
"redirect_uris": [redirect_uri],
"grant_types": client_metadata.grant_types,
"response_types": client_metadata.response_types,
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
}
if cfg.get("client_secret"):
info_dict["client_secret"] = cfg["client_secret"]
if cfg.get("client_name"):
info_dict["client_name"] = cfg["client_name"]
if cfg.get("scope"):
info_dict["scope"] = cfg["scope"]
client_info = OAuthClientInformationFull.model_validate(info_dict)
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("Pre-registered client_id=%s for '%s'", client_id, storage._server_name)
def _parse_base_url(server_url: str) -> str:
"""Strip path component from server URL, returning the base origin."""
parsed = urlparse(server_url)
return f"{parsed.scheme}://{parsed.netloc}"
def build_oauth_auth( def build_oauth_auth(
server_name: str, server_name: str,
server_url: str, server_url: str,
@ -382,7 +479,9 @@ def build_oauth_auth(
) -> "OAuthClientProvider | None": ) -> "OAuthClientProvider | None":
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server. """Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config. Public API preserved for backwards compatibility. New code should use
:func:`tools.mcp_oauth_manager.get_manager` so OAuth state is shared
across config-time, runtime, and reconnect paths.
Args: Args:
server_name: Server key in mcp_servers config (used for storage). server_name: Server key in mcp_servers config (used for storage).
@ -396,87 +495,32 @@ def build_oauth_auth(
if not _OAUTH_AVAILABLE: if not _OAUTH_AVAILABLE:
logger.warning( logger.warning(
"MCP OAuth requested for '%s' but SDK auth types are not available. " "MCP OAuth requested for '%s' but SDK auth types are not available. "
"Install with: pip install 'mcp>=1.10.0'", "Install with: pip install 'mcp>=1.26.0'",
server_name, server_name,
) )
return None return None
global _oauth_port cfg = dict(oauth_config or {}) # copy — we mutate _resolved_port
cfg = oauth_config or {}
# --- Storage ---
storage = HermesTokenStorage(server_name) storage = HermesTokenStorage(server_name)
# --- Non-interactive warning ---
if not _is_interactive() and not storage.has_cached_tokens(): if not _is_interactive() and not storage.has_cached_tokens():
logger.warning( logger.warning(
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. " "MCP OAuth for '%s': non-interactive environment and no cached tokens "
"The OAuth flow requires browser authorization. Run interactively first " "found. The OAuth flow requires browser authorization. Run "
"to complete the initial authorization, then cached tokens will be reused.", "interactively first to complete the initial authorization, then "
"cached tokens will be reused.",
server_name, server_name,
) )
# --- Pick callback port --- _configure_callback_port(cfg)
redirect_port = int(cfg.get("redirect_port", 0)) client_metadata = _build_client_metadata(cfg)
if redirect_port == 0: _maybe_preregister_client(storage, cfg, client_metadata)
redirect_port = _find_free_port()
_oauth_port = redirect_port
# --- Client metadata --- return OAuthClientProvider(
client_name = cfg.get("client_name", "Hermes Agent") server_url=_parse_base_url(server_url),
scope = cfg.get("scope")
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
metadata_kwargs: dict[str, Any] = {
"client_name": client_name,
"redirect_uris": [AnyUrl(redirect_uri)],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "none",
}
if scope:
metadata_kwargs["scope"] = scope
client_secret = cfg.get("client_secret")
if client_secret:
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
# --- Pre-registered client ---
client_id = cfg.get("client_id")
if client_id:
info_dict: dict[str, Any] = {
"client_id": client_id,
"redirect_uris": [redirect_uri],
"grant_types": client_metadata.grant_types,
"response_types": client_metadata.response_types,
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
}
if client_secret:
info_dict["client_secret"] = client_secret
if client_name:
info_dict["client_name"] = client_name
if scope:
info_dict["scope"] = scope
client_info = OAuthClientInformationFull.model_validate(info_dict)
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
# --- Base URL for discovery ---
parsed = urlparse(server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
# --- Build provider ---
provider = OAuthClientProvider(
server_url=base_url,
client_metadata=client_metadata, client_metadata=client_metadata,
storage=storage, storage=storage,
redirect_handler=_redirect_handler, redirect_handler=_redirect_handler,
callback_handler=_wait_for_callback, callback_handler=_wait_for_callback,
timeout=float(cfg.get("timeout", 300)), timeout=float(cfg.get("timeout", 300)),
) )
return provider

413
tools/mcp_oauth_manager.py Normal file
View file

@ -0,0 +1,413 @@
#!/usr/bin/env python3
"""Central manager for per-server MCP OAuth state.
One instance shared across the process. Holds per-server OAuth provider
instances and coordinates:
- **Cross-process token reload** via mtime-based disk watch. When an external
process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
picks them up without requiring a process restart.
- **401 deduplication** via in-flight futures. When N concurrent tool calls
all hit 401 with the same access_token, only one recovery attempt fires;
the rest await the same result.
- **Reconnect signalling** for long-lived MCP sessions. The manager itself
does not drive reconnection the `MCPServerTask` in `mcp_tool.py` does
but the manager is the single source of truth that decides when reconnect
is warranted.
Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
go through `get_manager()`.
Design reference:
- Claude Code's ``invalidateOAuthCacheIfDiskChanged``
(``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
external-refresh staleness bug class.
- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
(``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
lazy refresh rather than calling refresh before every op, because one
``stat()`` per tool call is cheaper than an ``await`` + potential refresh
round-trip, and the SDK's in-memory expiry path is already correct.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from dataclasses import dataclass, field
from typing import Any, Optional
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Per-server entry
# ---------------------------------------------------------------------------
@dataclass
class _ProviderEntry:
"""Per-server OAuth state tracked by the manager.
Fields:
server_url: The MCP server URL used to build the provider. Tracked
so we can discard a cached provider if the URL changes.
oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
provider: The ``httpx.Auth``-compatible provider wrapping the MCP
SDK. None until first use.
last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
to detect external refreshes.
lock: Serialises concurrent access to this entry's state. Bound to
whichever asyncio loop first awaits it (the MCP event loop).
pending_401: In-flight 401-handler futures keyed by the failed
access_token, for deduplicating thundering-herd 401s. Mirrors
Claude Code's ``pending401Handlers`` map.
"""
server_url: str
oauth_config: Optional[dict]
provider: Optional[Any] = None
last_mtime_ns: int = 0
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
# ---------------------------------------------------------------------------
# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
# ---------------------------------------------------------------------------
def _make_hermes_provider_class() -> Optional[type]:
"""Lazy-import the SDK base class and return our subclass.
Wrapped in a function so this module imports cleanly even when the
MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
"""
try:
from mcp.client.auth.oauth2 import OAuthClientProvider
except ImportError: # pragma: no cover — SDK required in CI
return None
class HermesMCPOAuthProvider(OAuthClientProvider):
"""OAuthClientProvider with pre-flow disk-mtime reload.
Before every ``async_auth_flow`` invocation, asks the manager to
check whether the tokens file on disk has been modified externally.
If so, the manager resets ``_initialized`` so the next flow
re-reads from storage.
This makes external-process refreshes (cron, another CLI instance)
visible to the running MCP session without requiring a restart.
Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
"""
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
super().__init__(*args, **kwargs)
self._hermes_server_name = server_name
async def async_auth_flow(self, request): # type: ignore[override]
# Pre-flow hook: ask the manager to refresh from disk if needed.
# Any failure here is non-fatal — we just log and proceed with
# whatever state the SDK already has.
try:
await get_manager().invalidate_if_disk_changed(
self._hermes_server_name
)
except Exception as exc: # pragma: no cover — defensive
logger.debug(
"MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
self._hermes_server_name, exc,
)
# Delegate to the SDK's auth flow
async for item in super().async_auth_flow(request):
yield item
return HermesMCPOAuthProvider
# Cached at import time. Tested and used by :class:`MCPOAuthManager`.
_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
# ---------------------------------------------------------------------------
# Manager
# ---------------------------------------------------------------------------
class MCPOAuthManager:
"""Single source of truth for per-server MCP OAuth state.
Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
get-or-create semantics. Per-entry state is guarded by the entry's own
``asyncio.Lock`` (used from the MCP event loop thread).
"""
def __init__(self) -> None:
self._entries: dict[str, _ProviderEntry] = {}
self._entries_lock = threading.Lock()
# -- Provider construction / caching -------------------------------------
def get_or_build_provider(
self,
server_name: str,
server_url: str,
oauth_config: Optional[dict],
) -> Optional[Any]:
"""Return a cached OAuth provider for ``server_name`` or build one.
Idempotent: repeat calls with the same name return the same instance.
If ``server_url`` changes for a given name, the cached entry is
discarded and a fresh provider is built.
Returns None if the MCP SDK's OAuth support is unavailable.
"""
with self._entries_lock:
entry = self._entries.get(server_name)
if entry is not None and entry.server_url != server_url:
logger.info(
"MCP OAuth '%s': URL changed from %s to %s, discarding cache",
server_name, entry.server_url, server_url,
)
entry = None
if entry is None:
entry = _ProviderEntry(
server_url=server_url,
oauth_config=oauth_config,
)
self._entries[server_name] = entry
if entry.provider is None:
entry.provider = self._build_provider(server_name, entry)
return entry.provider
def _build_provider(
self,
server_name: str,
entry: _ProviderEntry,
) -> Optional[Any]:
"""Build the underlying OAuth provider.
Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
disk-watch hook so external token refreshes (cron, other CLI
instances) are visible to running MCP sessions.
Returns None if the MCP SDK's OAuth support is unavailable.
"""
if _HERMES_PROVIDER_CLS is None:
logger.warning(
"MCP OAuth '%s': SDK auth module unavailable", server_name,
)
return None
# Local imports avoid circular deps at module import time.
from tools.mcp_oauth import (
HermesTokenStorage,
_OAUTH_AVAILABLE,
_build_client_metadata,
_configure_callback_port,
_is_interactive,
_maybe_preregister_client,
_parse_base_url,
_redirect_handler,
_wait_for_callback,
)
if not _OAUTH_AVAILABLE:
return None
cfg = dict(entry.oauth_config or {})
storage = HermesTokenStorage(server_name)
if not _is_interactive() and not storage.has_cached_tokens():
logger.warning(
"MCP OAuth for '%s': non-interactive environment and no "
"cached tokens found. Run interactively first to complete "
"initial authorization.",
server_name,
)
_configure_callback_port(cfg)
client_metadata = _build_client_metadata(cfg)
_maybe_preregister_client(storage, cfg, client_metadata)
return _HERMES_PROVIDER_CLS(
server_name=server_name,
server_url=_parse_base_url(entry.server_url),
client_metadata=client_metadata,
storage=storage,
redirect_handler=_redirect_handler,
callback_handler=_wait_for_callback,
timeout=float(cfg.get("timeout", 300)),
)
def remove(self, server_name: str) -> None:
"""Evict the provider from cache AND delete tokens from disk.
Called by ``hermes mcp remove <name>`` and (indirectly) by
``hermes mcp login <name>`` during forced re-auth.
"""
with self._entries_lock:
self._entries.pop(server_name, None)
from tools.mcp_oauth import remove_oauth_tokens
remove_oauth_tokens(server_name)
logger.info(
"MCP OAuth '%s': evicted from cache and removed from disk",
server_name,
)
# -- Disk watch ----------------------------------------------------------
async def invalidate_if_disk_changed(self, server_name: str) -> bool:
"""If the tokens file on disk has a newer mtime than last-seen, force
the MCP SDK provider to reload its in-memory state.
Returns True if the cache was invalidated (mtime differed). This is
the core fix for the external-refresh workflow: a cron job writes
fresh tokens to disk, and on the next tool call the running MCP
session picks them up without a restart.
"""
from tools.mcp_oauth import _get_token_dir, _safe_filename
entry = self._entries.get(server_name)
if entry is None or entry.provider is None:
return False
async with entry.lock:
tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
try:
mtime_ns = tokens_path.stat().st_mtime_ns
except (FileNotFoundError, OSError):
return False
if mtime_ns != entry.last_mtime_ns:
old = entry.last_mtime_ns
entry.last_mtime_ns = mtime_ns
# Force the SDK's OAuthClientProvider to reload from storage
# on its next auth flow. `_initialized` is private API but
# stable across the MCP SDK versions we pin (>=1.26.0).
if hasattr(entry.provider, "_initialized"):
entry.provider._initialized = False # noqa: SLF001
logger.info(
"MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
"forcing reload",
server_name, old, mtime_ns,
)
return True
return False
# -- 401 handler (dedup'd) -----------------------------------------------
async def handle_401(
self,
server_name: str,
failed_access_token: Optional[str] = None,
) -> bool:
"""Handle a 401 from a tool call, deduplicated across concurrent callers.
Returns:
True if a (possibly new) access token is now available caller
should trigger a reconnect and retry the operation.
False if no recovery path exists caller should surface a
``needs_reauth`` error to the model so it stops hallucinating
manual refresh attempts.
Thundering-herd protection: if N concurrent tool calls hit 401 with
the same ``failed_access_token``, only one recovery attempt fires.
Others await the same future.
"""
entry = self._entries.get(server_name)
if entry is None or entry.provider is None:
return False
key = failed_access_token or "<unknown>"
loop = asyncio.get_running_loop()
async with entry.lock:
pending = entry.pending_401.get(key)
if pending is None:
pending = loop.create_future()
entry.pending_401[key] = pending
async def _do_handle() -> None:
try:
# Step 1: Did disk change? Picks up external refresh.
disk_changed = await self.invalidate_if_disk_changed(
server_name
)
if disk_changed:
if not pending.done():
pending.set_result(True)
return
# Step 2: No disk change — if the SDK can refresh
# in-place, let the caller retry. The SDK's httpx.Auth
# flow will issue the refresh on the next request.
provider = entry.provider
ctx = getattr(provider, "context", None)
can_refresh = False
if ctx is not None:
can_refresh_fn = getattr(ctx, "can_refresh_token", None)
if callable(can_refresh_fn):
try:
can_refresh = bool(can_refresh_fn())
except Exception:
can_refresh = False
if not pending.done():
pending.set_result(can_refresh)
except Exception as exc: # pragma: no cover — defensive
logger.warning(
"MCP OAuth '%s': 401 handler failed: %s",
server_name, exc,
)
if not pending.done():
pending.set_result(False)
finally:
entry.pending_401.pop(key, None)
asyncio.create_task(_do_handle())
try:
return await pending
except Exception as exc: # pragma: no cover — defensive
logger.warning(
"MCP OAuth '%s': awaiting 401 handler failed: %s",
server_name, exc,
)
return False
# ---------------------------------------------------------------------------
# Module-level singleton
# ---------------------------------------------------------------------------
_MANAGER: Optional[MCPOAuthManager] = None
_MANAGER_LOCK = threading.Lock()
def get_manager() -> MCPOAuthManager:
"""Return the process-wide :class:`MCPOAuthManager` singleton."""
global _MANAGER
with _MANAGER_LOCK:
if _MANAGER is None:
_MANAGER = MCPOAuthManager()
return _MANAGER
def reset_manager_for_tests() -> None:
"""Test-only helper: drop the singleton so fixtures start clean."""
global _MANAGER
with _MANAGER_LOCK:
_MANAGER = None

View file

@ -783,7 +783,8 @@ class MCPServerTask:
__slots__ = ( __slots__ = (
"name", "session", "tool_timeout", "name", "session", "tool_timeout",
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config", "_task", "_ready", "_shutdown_event", "_reconnect_event",
"_tools", "_error", "_config",
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock", "_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
) )
@ -794,6 +795,12 @@ class MCPServerTask:
self._task: Optional[asyncio.Task] = None self._task: Optional[asyncio.Task] = None
self._ready = asyncio.Event() self._ready = asyncio.Event()
self._shutdown_event = asyncio.Event() self._shutdown_event = asyncio.Event()
# Set by tool handlers on auth failure after manager.handle_401()
# confirms recovery is viable. When set, _run_http / _run_stdio
# exit their async-with blocks cleanly (no exception), and the
# outer run() loop re-enters the transport so the MCP session is
# rebuilt with fresh credentials.
self._reconnect_event = asyncio.Event()
self._tools: list = [] self._tools: list = []
self._error: Optional[Exception] = None self._error: Optional[Exception] = None
self._config: dict = {} self._config: dict = {}
@ -887,6 +894,40 @@ class MCPServerTask:
self.name, len(self._registered_tool_names), self.name, len(self._registered_tool_names),
) )
async def _wait_for_lifecycle_event(self) -> str:
"""Block until either _shutdown_event or _reconnect_event fires.
Returns:
"shutdown" if the server should exit the run loop entirely.
"reconnect" if the server should tear down the current MCP
session and re-enter the transport (fresh OAuth
tokens, new session ID, etc.). The reconnect event
is cleared before return so the next cycle starts
with a fresh signal.
Shutdown takes precedence if both events are set simultaneously.
"""
shutdown_task = asyncio.create_task(self._shutdown_event.wait())
reconnect_task = asyncio.create_task(self._reconnect_event.wait())
try:
await asyncio.wait(
{shutdown_task, reconnect_task},
return_when=asyncio.FIRST_COMPLETED,
)
finally:
for t in (shutdown_task, reconnect_task):
if not t.done():
t.cancel()
try:
await t
except (asyncio.CancelledError, Exception):
pass
if self._shutdown_event.is_set():
return "shutdown"
self._reconnect_event.clear()
return "reconnect"
async def _run_stdio(self, config: dict): async def _run_stdio(self, config: dict):
"""Run the server using stdio transport.""" """Run the server using stdio transport."""
command = config.get("command") command = config.get("command")
@ -932,7 +973,10 @@ class MCPServerTask:
self.session = session self.session = session
await self._discover_tools() await self._discover_tools()
self._ready.set() self._ready.set()
await self._shutdown_event.wait() # stdio transport does not use OAuth, but we still honor
# _reconnect_event (e.g. future manual /mcp refresh) for
# consistency with _run_http.
await self._wait_for_lifecycle_event()
# Context exited cleanly — subprocess was terminated by the SDK. # Context exited cleanly — subprocess was terminated by the SDK.
if new_pids: if new_pids:
with _lock: with _lock:
@ -951,16 +995,18 @@ class MCPServerTask:
headers = dict(config.get("headers") or {}) headers = dict(config.get("headers") or {})
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT) connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK. # OAuth 2.1 PKCE: route through the central MCPOAuthManager so the
# If OAuth setup fails (e.g. non-interactive environment without # same provider instance is reused across reconnects, pre-flow
# cached tokens), re-raise so this server is reported as failed # disk-watch is active, and config-time CLI code paths share state.
# without blocking other MCP servers from connecting. # If OAuth setup fails (e.g. non-interactive env without cached
# tokens), re-raise so this server is reported as failed without
# blocking other MCP servers from connecting.
_oauth_auth = None _oauth_auth = None
if self._auth_type == "oauth": if self._auth_type == "oauth":
try: try:
from tools.mcp_oauth import build_oauth_auth from tools.mcp_oauth_manager import get_manager
_oauth_auth = build_oauth_auth( _oauth_auth = get_manager().get_or_build_provider(
self.name, url, config.get("oauth") self.name, url, config.get("oauth"),
) )
except Exception as exc: except Exception as exc:
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc) logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
@ -995,7 +1041,12 @@ class MCPServerTask:
self.session = session self.session = session
await self._discover_tools() await self._discover_tools()
self._ready.set() self._ready.set()
await self._shutdown_event.wait() reason = await self._wait_for_lifecycle_event()
if reason == "reconnect":
logger.info(
"MCP server '%s': reconnect requested — "
"tearing down HTTP session", self.name,
)
else: else:
# Deprecated API (mcp < 1.24.0): manages httpx client internally. # Deprecated API (mcp < 1.24.0): manages httpx client internally.
_http_kwargs: dict = { _http_kwargs: dict = {
@ -1012,7 +1063,12 @@ class MCPServerTask:
self.session = session self.session = session
await self._discover_tools() await self._discover_tools()
self._ready.set() self._ready.set()
await self._shutdown_event.wait() reason = await self._wait_for_lifecycle_event()
if reason == "reconnect":
logger.info(
"MCP server '%s': reconnect requested — "
"tearing down legacy HTTP session", self.name,
)
async def _discover_tools(self): async def _discover_tools(self):
"""Discover tools from the connected session.""" """Discover tools from the connected session."""
@ -1060,8 +1116,25 @@ class MCPServerTask:
await self._run_http(config) await self._run_http(config)
else: else:
await self._run_stdio(config) await self._run_stdio(config)
# Normal exit (shutdown requested) -- break out # Transport returned cleanly. Two cases:
break # - _shutdown_event was set: exit the run loop entirely.
# - _reconnect_event was set (auth recovery): loop back and
# rebuild the MCP session with fresh credentials. Do NOT
# touch the retry counters — this is not a failure.
if self._shutdown_event.is_set():
break
logger.info(
"MCP server '%s': reconnecting (OAuth recovery or "
"manual refresh)",
self.name,
)
# Reset the session reference; _run_http/_run_stdio will
# repopulate it on successful re-entry.
self.session = None
# Keep _ready set across reconnects so tool handlers can
# still detect a transient in-flight state — it'll be
# re-set after the fresh session initializes.
continue
except Exception as exc: except Exception as exc:
self.session = None self.session = None
@ -1141,6 +1214,12 @@ class MCPServerTask:
from tools.registry import registry from tools.registry import registry
self._shutdown_event.set() self._shutdown_event.set()
# Defensive: if _wait_for_lifecycle_event is blocking, we need ANY
# event to unblock it. _shutdown_event alone is sufficient (the
# helper checks shutdown first), but setting reconnect too ensures
# there's no race where the helper misses the shutdown flag after
# returning "reconnect".
self._reconnect_event.set()
if self._task and not self._task.done(): if self._task and not self._task.done():
try: try:
await asyncio.wait_for(self._task, timeout=10) await asyncio.wait_for(self._task, timeout=10)
@ -1174,6 +1253,175 @@ _servers: Dict[str, MCPServerTask] = {}
_server_error_counts: Dict[str, int] = {} _server_error_counts: Dict[str, int] = {}
_CIRCUIT_BREAKER_THRESHOLD = 3 _CIRCUIT_BREAKER_THRESHOLD = 3
# ---------------------------------------------------------------------------
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
# ---------------------------------------------------------------------------
# Cached tuple of auth-related exception types. Lazy so this module
# imports cleanly when the MCP SDK OAuth module is missing.
_AUTH_ERROR_TYPES: tuple = ()
def _get_auth_error_types() -> tuple:
"""Return a tuple of exception types that indicate MCP OAuth failure.
Cached after first call. Includes:
- ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` raised by
the SDK's auth flow when discovery, refresh, or full re-auth fails.
- ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) kept as an
optional import for forward/backward compatibility.
- ``tools.mcp_oauth.OAuthNonInteractiveError`` raised by our callback
handler when no user is present to complete a browser flow.
- ``httpx.HTTPStatusError`` caller must additionally check
``status_code == 401`` via :func:`_is_auth_error`.
"""
global _AUTH_ERROR_TYPES
if _AUTH_ERROR_TYPES:
return _AUTH_ERROR_TYPES
types: list = []
try:
from mcp.client.auth import OAuthFlowError, OAuthTokenError
types.extend([OAuthFlowError, OAuthTokenError])
except ImportError:
pass
try:
# Older MCP SDK variants exported this
from mcp.client.auth import UnauthorizedError # type: ignore
types.append(UnauthorizedError)
except ImportError:
pass
try:
from tools.mcp_oauth import OAuthNonInteractiveError
types.append(OAuthNonInteractiveError)
except ImportError:
pass
try:
import httpx
types.append(httpx.HTTPStatusError)
except ImportError:
pass
_AUTH_ERROR_TYPES = tuple(types)
return _AUTH_ERROR_TYPES
def _is_auth_error(exc: BaseException) -> bool:
"""Return True if ``exc`` indicates an MCP OAuth failure.
``httpx.HTTPStatusError`` is only treated as auth-related when the
response status code is 401. Other HTTP errors fall through to the
generic error path in the tool handlers.
"""
types = _get_auth_error_types()
if not types or not isinstance(exc, types):
return False
try:
import httpx
if isinstance(exc, httpx.HTTPStatusError):
return getattr(exc.response, "status_code", None) == 401
except ImportError:
pass
return True
def _handle_auth_error_and_retry(
server_name: str,
exc: BaseException,
retry_call,
op_description: str,
):
"""Attempt auth recovery and one retry; return None to fall through.
Called by the 5 MCP tool handlers when ``session.<op>()`` raises an
auth-related exception. Workflow:
1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if
recovery is viable (i.e., disk has fresh tokens, or the SDK can
refresh in-place).
2. If yes, set the server's ``_reconnect_event`` so the server task
tears down the current MCP session and rebuilds it with fresh
credentials. Wait briefly for ``_ready`` to re-fire.
3. Retry the operation once. Return the retry result if it produced
a non-error JSON payload. Otherwise return the ``needs_reauth``
error dict so the model stops hallucinating manual refresh.
4. Return None if ``exc`` is not an auth error, signalling the
caller to use the generic error path.
Args:
server_name: Name of the MCP server that raised.
exc: The exception from the failed tool call.
retry_call: Zero-arg callable that re-runs the tool call, returning
the same JSON string format as the handler.
op_description: Human-readable name of the operation (for logs).
Returns:
A JSON string if auth recovery was attempted, or None to fall
through to the caller's generic error path.
"""
if not _is_auth_error(exc):
return None
from tools.mcp_oauth_manager import get_manager
manager = get_manager()
async def _recover():
return await manager.handle_401(server_name, None)
try:
recovered = _run_on_mcp_loop(_recover(), timeout=10)
except Exception as rec_exc:
logger.warning(
"MCP OAuth '%s': recovery attempt failed: %s",
server_name, rec_exc,
)
recovered = False
if recovered:
with _lock:
srv = _servers.get(server_name)
if srv is not None and hasattr(srv, "_reconnect_event"):
loop = _mcp_loop
if loop is not None and loop.is_running():
loop.call_soon_threadsafe(srv._reconnect_event.set)
# Wait briefly for the session to come back ready. Bounded
# so that a stuck reconnect falls through to the error
# path rather than hanging the caller.
deadline = time.monotonic() + 15
while time.monotonic() < deadline:
if srv.session is not None and srv._ready.is_set():
break
time.sleep(0.25)
try:
result = retry_call()
try:
parsed = json.loads(result)
if "error" not in parsed:
_server_error_counts[server_name] = 0
return result
except (json.JSONDecodeError, TypeError):
_server_error_counts[server_name] = 0
return result
except Exception as retry_exc:
logger.warning(
"MCP %s/%s retry after auth recovery failed: %s",
server_name, op_description, retry_exc,
)
# No recovery available, or retry also failed: surface a structured
# needs_reauth error. Bumps the circuit breaker so the model stops
# retrying the tool.
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
return json.dumps({
"error": (
f"MCP server '{server_name}' requires re-authentication. "
f"Run `hermes mcp login {server_name}` (or delete the tokens "
f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry "
f"this tool — ask the user to re-authenticate."
),
"needs_reauth": True,
"server": server_name,
}, ensure_ascii=False)
# Dedicated event loop running in a background daemon thread. # Dedicated event loop running in a background daemon thread.
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None _mcp_loop: Optional[asyncio.AbstractEventLoop] = None
_mcp_thread: Optional[threading.Thread] = None _mcp_thread: Optional[threading.Thread] = None
@ -1420,8 +1668,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
return json.dumps({"result": structured}, ensure_ascii=False) return json.dumps({"result": structured}, ensure_ascii=False)
return json.dumps({"result": text_result}, ensure_ascii=False) return json.dumps({"result": text_result}, ensure_ascii=False)
def _call_once():
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
try: try:
result = _run_on_mcp_loop(_call(), timeout=tool_timeout) result = _call_once()
# Check if the MCP tool itself returned an error # Check if the MCP tool itself returned an error
try: try:
parsed = json.loads(result) parsed = json.loads(result)
@ -1435,6 +1686,16 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
except InterruptedError: except InterruptedError:
return _interrupted_call_result() return _interrupted_call_result()
except Exception as exc: except Exception as exc:
# Auth-specific recovery path: consult the manager, signal
# reconnect if viable, retry once. Returns None to fall
# through for non-auth exceptions.
recovered = _handle_auth_error_and_retry(
server_name, exc, _call_once,
f"tools/call {tool_name}",
)
if recovered is not None:
return recovered
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1 _server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
logger.error( logger.error(
"MCP tool %s/%s call failed: %s", "MCP tool %s/%s call failed: %s",
@ -1476,11 +1737,19 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
resources.append(entry) resources.append(entry)
return json.dumps({"resources": resources}, ensure_ascii=False) return json.dumps({"resources": resources}, ensure_ascii=False)
try: def _call_once():
return _run_on_mcp_loop(_call(), timeout=tool_timeout) return _run_on_mcp_loop(_call(), timeout=tool_timeout)
try:
return _call_once()
except InterruptedError: except InterruptedError:
return _interrupted_call_result() return _interrupted_call_result()
except Exception as exc: except Exception as exc:
recovered = _handle_auth_error_and_retry(
server_name, exc, _call_once, "resources/list",
)
if recovered is not None:
return recovered
logger.error( logger.error(
"MCP %s/list_resources failed: %s", server_name, exc, "MCP %s/list_resources failed: %s", server_name, exc,
) )
@ -1522,11 +1791,19 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
parts.append(f"[binary data, {len(block.blob)} bytes]") parts.append(f"[binary data, {len(block.blob)} bytes]")
return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False) return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False)
try: def _call_once():
return _run_on_mcp_loop(_call(), timeout=tool_timeout) return _run_on_mcp_loop(_call(), timeout=tool_timeout)
try:
return _call_once()
except InterruptedError: except InterruptedError:
return _interrupted_call_result() return _interrupted_call_result()
except Exception as exc: except Exception as exc:
recovered = _handle_auth_error_and_retry(
server_name, exc, _call_once, "resources/read",
)
if recovered is not None:
return recovered
logger.error( logger.error(
"MCP %s/read_resource failed: %s", server_name, exc, "MCP %s/read_resource failed: %s", server_name, exc,
) )
@ -1571,11 +1848,19 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
prompts.append(entry) prompts.append(entry)
return json.dumps({"prompts": prompts}, ensure_ascii=False) return json.dumps({"prompts": prompts}, ensure_ascii=False)
try: def _call_once():
return _run_on_mcp_loop(_call(), timeout=tool_timeout) return _run_on_mcp_loop(_call(), timeout=tool_timeout)
try:
return _call_once()
except InterruptedError: except InterruptedError:
return _interrupted_call_result() return _interrupted_call_result()
except Exception as exc: except Exception as exc:
recovered = _handle_auth_error_and_retry(
server_name, exc, _call_once, "prompts/list",
)
if recovered is not None:
return recovered
logger.error( logger.error(
"MCP %s/list_prompts failed: %s", server_name, exc, "MCP %s/list_prompts failed: %s", server_name, exc,
) )
@ -1628,11 +1913,19 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
resp["description"] = result.description resp["description"] = result.description
return json.dumps(resp, ensure_ascii=False) return json.dumps(resp, ensure_ascii=False)
try: def _call_once():
return _run_on_mcp_loop(_call(), timeout=tool_timeout) return _run_on_mcp_loop(_call(), timeout=tool_timeout)
try:
return _call_once()
except InterruptedError: except InterruptedError:
return _interrupted_call_result() return _interrupted_call_result()
except Exception as exc: except Exception as exc:
recovered = _handle_auth_error_and_retry(
server_name, exc, _call_once, "prompts/get",
)
if recovered is not None:
return recovered
logger.error( logger.error(
"MCP %s/get_prompt failed: %s", server_name, exc, "MCP %s/get_prompt failed: %s", server_name, exc,
) )