mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(mcp): consolidate OAuth handling, pick up external token refreshes (#11383)
* feat(mcp-oauth): scaffold MCPOAuthManager
Central manager for per-server MCP OAuth state. Provides
get_or_build_provider (cached), remove (evicts cache + deletes
disk), invalidate_if_disk_changed (mtime watch, core fix for
external-refresh workflow), and handle_401 (dedup'd recovery).
No behavior change yet — existing call sites still use
build_oauth_auth directly. Task 1 of 8 in the MCP OAuth
consolidation (fixes Cthulhu's BetterStack reliability issues).
* feat(mcp-oauth): add HermesMCPOAuthProvider with pre-flow disk watch
Subclasses the MCP SDK's OAuthClientProvider to inject a disk
mtime check before every async_auth_flow, via the central
manager. When a subclass instance is used, external token
refreshes (cron, another CLI instance) are picked up before
the next API call.
Still dead code: the manager's _build_provider still delegates
to build_oauth_auth and returns the plain OAuthClientProvider.
Task 4 wires this subclass in. Task 2 of 8.
* refactor(mcp-oauth): extract build_oauth_auth helpers
Decomposes build_oauth_auth into _configure_callback_port,
_build_client_metadata, _maybe_preregister_client, and
_parse_base_url. Public API preserved. These helpers let
MCPOAuthManager._build_provider reuse the same logic in Task 4
instead of duplicating the construction dance.
Also updates the SDK version hint in the warning from 1.10.0 to
1.26.0 (which is what we actually require for the OAuth types
used here). Task 3 of 8.
* feat(mcp-oauth): manager now builds HermesMCPOAuthProvider directly
_build_provider constructs the disk-watching subclass using the
helpers from Task 3, instead of delegating to the plain
build_oauth_auth factory. Any consumer using the manager now gets
pre-flow disk-freshness checks automatically.
build_oauth_auth is preserved as the public API for backwards
compatibility. The code path is now:
MCPOAuthManager.get_or_build_provider ->
_build_provider ->
_configure_callback_port
_build_client_metadata
_maybe_preregister_client
_parse_base_url
HermesMCPOAuthProvider(...)
Task 4 of 8.
* feat(mcp): wire OAuth manager + add _reconnect_event
MCPServerTask gains _reconnect_event alongside _shutdown_event.
When set, _run_http / _run_stdio exit their async-with blocks
cleanly (no exception), and the outer run() loop re-enters the
transport to rebuild the MCP session with fresh credentials.
This is the recovery path for OAuth failures that the SDK's
in-place httpx.Auth cannot handle (e.g. cron externally consumed
the refresh_token, or server-side session invalidation).
_run_http now asks MCPOAuthManager for the OAuth provider
instead of calling build_oauth_auth directly. Config-time,
runtime, and reconnect paths all share one provider instance
with pre-flow disk-watch active.
shutdown() defensively sets both events so there is no race
between reconnect and shutdown signalling.
Task 5 of 8.
* feat(mcp): detect auth failures in tool handlers, trigger reconnect
All 5 MCP tool handlers (tool call, list_resources, read_resource,
list_prompts, get_prompt) now detect auth failures and route
through MCPOAuthManager.handle_401:
1. If the manager says recovery is viable (disk has fresh tokens,
or SDK can refresh in-place), signal MCPServerTask._reconnect_event
to tear down and rebuild the MCP session with fresh credentials,
then retry the tool call once.
2. If no recovery path exists, return a structured needs_reauth
JSON error so the model stops hallucinating manual refresh
attempts (the 'let me curl the token endpoint' loop Cthulhu
pasted from Discord).
_is_auth_error catches OAuthFlowError, OAuthTokenError,
OAuthNonInteractiveError, and httpx.HTTPStatusError(401). Non-auth
exceptions still surface via the generic error path unchanged.
Task 6 of 8.
* feat(mcp-cli): route add/remove through manager, add 'hermes mcp login'
cmd_mcp_add and cmd_mcp_remove now go through MCPOAuthManager
instead of calling build_oauth_auth / remove_oauth_tokens
directly. This means CLI config-time state and runtime MCP
session state are backed by the same provider cache — removing
a server evicts the live provider, adding a server populates
the same cache the MCP session will read from.
New 'hermes mcp login <name>' command:
- Wipes both the on-disk tokens file and the in-memory
MCPOAuthManager cache
- Triggers a fresh OAuth browser flow via the existing probe
path
- Intended target for the needs_reauth error Task 6 returns
to the model
Task 7 of 8.
* test(mcp-oauth): end-to-end integration tests
Five new tests exercising the full consolidation with real file
I/O and real imports (no transport mocks):
1. external_refresh_picked_up_without_restart — Cthulhu's cron
workflow. External process writes fresh tokens to disk;
on the next auth flow the manager's mtime-watch flips
_initialized and the SDK re-reads from storage.
2. handle_401_deduplicates_concurrent_callers — 10 concurrent
handlers for the same failed token fire exactly ONE recovery
attempt (thundering-herd protection).
3. handle_401_returns_false_when_no_provider — defensive path
for unknown servers.
4. invalidate_if_disk_changed_handles_missing_file — pre-auth
state returns False cleanly.
5. provider_is_reused_across_reconnects — cache stickiness so
reconnects preserve the disk-watch baseline mtime.
Task 8 of 8 — consolidation complete.
This commit is contained in:
parent
436a7359cd
commit
70768665a4
11 changed files with 1566 additions and 90 deletions
|
|
@ -5904,6 +5904,12 @@ Examples:
|
||||||
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
|
mcp_cfg_p = mcp_sub.add_parser("configure", aliases=["config"], help="Toggle tool selection")
|
||||||
mcp_cfg_p.add_argument("name", help="Server name to configure")
|
mcp_cfg_p.add_argument("name", help="Server name to configure")
|
||||||
|
|
||||||
|
mcp_login_p = mcp_sub.add_parser(
|
||||||
|
"login",
|
||||||
|
help="Force re-authentication for an OAuth-based MCP server",
|
||||||
|
)
|
||||||
|
mcp_login_p.add_argument("name", help="Server name to re-authenticate")
|
||||||
|
|
||||||
def cmd_mcp(args):
|
def cmd_mcp(args):
|
||||||
from hermes_cli.mcp_config import mcp_command
|
from hermes_cli.mcp_config import mcp_command
|
||||||
mcp_command(args)
|
mcp_command(args)
|
||||||
|
|
|
||||||
|
|
@ -279,8 +279,8 @@ def cmd_mcp_add(args):
|
||||||
_info(f"Starting OAuth flow for '{name}'...")
|
_info(f"Starting OAuth flow for '{name}'...")
|
||||||
oauth_ok = False
|
oauth_ok = False
|
||||||
try:
|
try:
|
||||||
from tools.mcp_oauth import build_oauth_auth
|
from tools.mcp_oauth_manager import get_manager
|
||||||
oauth_auth = build_oauth_auth(name, url)
|
oauth_auth = get_manager().get_or_build_provider(name, url, None)
|
||||||
if oauth_auth:
|
if oauth_auth:
|
||||||
server_config["auth"] = "oauth"
|
server_config["auth"] = "oauth"
|
||||||
_success("OAuth configured (tokens will be acquired on first connection)")
|
_success("OAuth configured (tokens will be acquired on first connection)")
|
||||||
|
|
@ -428,10 +428,12 @@ def cmd_mcp_remove(args):
|
||||||
_remove_mcp_server(name)
|
_remove_mcp_server(name)
|
||||||
_success(f"Removed '{name}' from config")
|
_success(f"Removed '{name}' from config")
|
||||||
|
|
||||||
# Clean up OAuth tokens if they exist
|
# Clean up OAuth tokens if they exist — route through MCPOAuthManager so
|
||||||
|
# any provider instance cached in the current process (e.g. from an
|
||||||
|
# earlier `hermes mcp test` in the same session) is evicted too.
|
||||||
try:
|
try:
|
||||||
from tools.mcp_oauth import remove_oauth_tokens
|
from tools.mcp_oauth_manager import get_manager
|
||||||
remove_oauth_tokens(name)
|
get_manager().remove(name)
|
||||||
_success("Cleaned up OAuth tokens")
|
_success("Cleaned up OAuth tokens")
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
@ -577,6 +579,63 @@ def _interpolate_value(value: str) -> str:
|
||||||
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
return re.sub(r"\$\{(\w+)\}", _replace, value)
|
||||||
|
|
||||||
|
|
||||||
|
# ─── hermes mcp login ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def cmd_mcp_login(args):
|
||||||
|
"""Force re-authentication for an OAuth-based MCP server.
|
||||||
|
|
||||||
|
Deletes cached tokens (both on disk and in the running process's
|
||||||
|
MCPOAuthManager cache) and triggers a fresh OAuth flow via the
|
||||||
|
existing probe path.
|
||||||
|
|
||||||
|
Use this when:
|
||||||
|
- Tokens are stuck in a bad state (server revoked, refresh token
|
||||||
|
consumed by an external process, etc.)
|
||||||
|
- You want to re-authenticate to change scopes or account
|
||||||
|
- A tool call returned ``needs_reauth: true``
|
||||||
|
"""
|
||||||
|
name = args.name
|
||||||
|
servers = _get_mcp_servers()
|
||||||
|
|
||||||
|
if name not in servers:
|
||||||
|
_error(f"Server '{name}' not found in config.")
|
||||||
|
if servers:
|
||||||
|
_info(f"Available servers: {', '.join(servers)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
server_config = servers[name]
|
||||||
|
url = server_config.get("url")
|
||||||
|
if not url:
|
||||||
|
_error(f"Server '{name}' has no URL — not an OAuth-capable server")
|
||||||
|
return
|
||||||
|
if server_config.get("auth") != "oauth":
|
||||||
|
_error(f"Server '{name}' is not configured for OAuth (auth={server_config.get('auth')})")
|
||||||
|
_info("Use `hermes mcp remove` + `hermes mcp add` to reconfigure auth.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Wipe both disk and in-memory cache so the next probe forces a fresh
|
||||||
|
# OAuth flow.
|
||||||
|
try:
|
||||||
|
from tools.mcp_oauth_manager import get_manager
|
||||||
|
mgr = get_manager()
|
||||||
|
mgr.remove(name)
|
||||||
|
except Exception as exc:
|
||||||
|
_warning(f"Could not clear existing OAuth state: {exc}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
_info(f"Starting OAuth flow for '{name}'...")
|
||||||
|
|
||||||
|
# Probe triggers the OAuth flow (browser redirect + callback capture).
|
||||||
|
try:
|
||||||
|
tools = _probe_single_server(name, server_config)
|
||||||
|
if tools:
|
||||||
|
_success(f"Authenticated — {len(tools)} tool(s) available")
|
||||||
|
else:
|
||||||
|
_success("Authenticated (server reported no tools)")
|
||||||
|
except Exception as exc:
|
||||||
|
_error(f"Authentication failed: {exc}")
|
||||||
|
|
||||||
|
|
||||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def cmd_mcp_configure(args):
|
def cmd_mcp_configure(args):
|
||||||
|
|
@ -696,6 +755,7 @@ def mcp_command(args):
|
||||||
"test": cmd_mcp_test,
|
"test": cmd_mcp_test,
|
||||||
"configure": cmd_mcp_configure,
|
"configure": cmd_mcp_configure,
|
||||||
"config": cmd_mcp_configure,
|
"config": cmd_mcp_configure,
|
||||||
|
"login": cmd_mcp_login,
|
||||||
}
|
}
|
||||||
|
|
||||||
handler = handlers.get(action)
|
handler = handlers.get(action)
|
||||||
|
|
@ -713,4 +773,5 @@ def mcp_command(args):
|
||||||
_info("hermes mcp list List servers")
|
_info("hermes mcp list List servers")
|
||||||
_info("hermes mcp test <name> Test connection")
|
_info("hermes mcp test <name> Test connection")
|
||||||
_info("hermes mcp configure <name> Toggle tools")
|
_info("hermes mcp configure <name> Toggle tools")
|
||||||
|
_info("hermes mcp login <name> Re-authenticate OAuth")
|
||||||
print()
|
print()
|
||||||
|
|
|
||||||
|
|
@ -539,3 +539,64 @@ class TestDispatcher:
|
||||||
mcp_command(_make_args(mcp_action=None))
|
mcp_command(_make_args(mcp_action=None))
|
||||||
out = capsys.readouterr().out
|
out = capsys.readouterr().out
|
||||||
assert "Commands:" in out or "No MCP servers" in out
|
assert "Commands:" in out or "No MCP servers" in out
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests: Task 7 consolidation — cmd_mcp_remove evicts manager cache,
|
||||||
|
# cmd_mcp_login forces re-auth
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMcpRemoveEvictsManager:
|
||||||
|
def test_remove_evicts_in_memory_provider(self, tmp_path, capsys, monkeypatch):
|
||||||
|
"""After cmd_mcp_remove, the MCPOAuthManager no longer caches the provider."""
|
||||||
|
_seed_config(tmp_path, {
|
||||||
|
"oauth-srv": {"url": "https://example.com/mcp", "auth": "oauth"},
|
||||||
|
})
|
||||||
|
monkeypatch.setattr("builtins.input", lambda _: "y")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"hermes_cli.mcp_config.get_hermes_home", lambda: tmp_path
|
||||||
|
)
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
mgr = get_manager()
|
||||||
|
mgr.get_or_build_provider(
|
||||||
|
"oauth-srv", "https://example.com/mcp", None,
|
||||||
|
)
|
||||||
|
assert "oauth-srv" in mgr._entries
|
||||||
|
|
||||||
|
from hermes_cli.mcp_config import cmd_mcp_remove
|
||||||
|
cmd_mcp_remove(_make_args(name="oauth-srv"))
|
||||||
|
|
||||||
|
assert "oauth-srv" not in mgr._entries
|
||||||
|
|
||||||
|
|
||||||
|
class TestMcpLogin:
|
||||||
|
def test_login_rejects_unknown_server(self, tmp_path, capsys):
|
||||||
|
_seed_config(tmp_path, {})
|
||||||
|
from hermes_cli.mcp_config import cmd_mcp_login
|
||||||
|
cmd_mcp_login(_make_args(name="ghost"))
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "not found" in out
|
||||||
|
|
||||||
|
def test_login_rejects_non_oauth_server(self, tmp_path, capsys):
|
||||||
|
_seed_config(tmp_path, {
|
||||||
|
"srv": {"url": "https://example.com/mcp", "auth": "header"},
|
||||||
|
})
|
||||||
|
from hermes_cli.mcp_config import cmd_mcp_login
|
||||||
|
cmd_mcp_login(_make_args(name="srv"))
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "not configured for OAuth" in out
|
||||||
|
|
||||||
|
def test_login_rejects_stdio_server(self, tmp_path, capsys):
|
||||||
|
_seed_config(tmp_path, {
|
||||||
|
"srv": {"command": "npx", "args": ["some-server"]},
|
||||||
|
})
|
||||||
|
from hermes_cli.mcp_config import cmd_mcp_login
|
||||||
|
cmd_mcp_login(_make_args(name="srv"))
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "no URL" in out or "not an OAuth" in out
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -431,3 +431,71 @@ class TestBuildOAuthAuthNonInteractive:
|
||||||
|
|
||||||
assert auth is not None
|
assert auth is not None
|
||||||
assert "no cached tokens found" not in caplog.text.lower()
|
assert "no cached tokens found" not in caplog.text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Extracted helper tests (Task 3 of MCP OAuth consolidation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_client_metadata_basic():
|
||||||
|
"""_build_client_metadata returns metadata with expected defaults."""
|
||||||
|
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
|
||||||
|
|
||||||
|
cfg = {"client_name": "Test Client"}
|
||||||
|
_configure_callback_port(cfg)
|
||||||
|
md = _build_client_metadata(cfg)
|
||||||
|
|
||||||
|
assert md.client_name == "Test Client"
|
||||||
|
assert "authorization_code" in md.grant_types
|
||||||
|
assert "refresh_token" in md.grant_types
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_client_metadata_without_secret_is_public():
|
||||||
|
"""Without client_secret, token endpoint auth is 'none' (public client)."""
|
||||||
|
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
|
||||||
|
|
||||||
|
cfg = {}
|
||||||
|
_configure_callback_port(cfg)
|
||||||
|
md = _build_client_metadata(cfg)
|
||||||
|
assert md.token_endpoint_auth_method == "none"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_client_metadata_with_secret_is_confidential():
|
||||||
|
"""With client_secret, token endpoint auth is 'client_secret_post'."""
|
||||||
|
from tools.mcp_oauth import _build_client_metadata, _configure_callback_port
|
||||||
|
|
||||||
|
cfg = {"client_secret": "shh"}
|
||||||
|
_configure_callback_port(cfg)
|
||||||
|
md = _build_client_metadata(cfg)
|
||||||
|
assert md.token_endpoint_auth_method == "client_secret_post"
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_callback_port_picks_free_port():
|
||||||
|
"""_configure_callback_port(0) picks a free port in the ephemeral range."""
|
||||||
|
from tools.mcp_oauth import _configure_callback_port
|
||||||
|
|
||||||
|
cfg = {"redirect_port": 0}
|
||||||
|
port = _configure_callback_port(cfg)
|
||||||
|
assert 1024 < port < 65536
|
||||||
|
assert cfg["_resolved_port"] == port
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_callback_port_uses_explicit_port():
|
||||||
|
"""An explicit redirect_port is preserved."""
|
||||||
|
from tools.mcp_oauth import _configure_callback_port
|
||||||
|
|
||||||
|
cfg = {"redirect_port": 54321}
|
||||||
|
port = _configure_callback_port(cfg)
|
||||||
|
assert port == 54321
|
||||||
|
assert cfg["_resolved_port"] == 54321
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_base_url_strips_path():
|
||||||
|
"""_parse_base_url drops path components for OAuth discovery."""
|
||||||
|
from tools.mcp_oauth import _parse_base_url
|
||||||
|
|
||||||
|
assert _parse_base_url("https://example.com/mcp/v1") == "https://example.com"
|
||||||
|
assert _parse_base_url("https://example.com") == "https://example.com"
|
||||||
|
assert _parse_base_url("https://host.example.com:8080/api") == "https://host.example.com:8080"
|
||||||
|
|
||||||
|
|
|
||||||
193
tests/tools/test_mcp_oauth_integration.py
Normal file
193
tests/tools/test_mcp_oauth_integration.py
Normal file
|
|
@ -0,0 +1,193 @@
|
||||||
|
"""End-to-end integration tests for the MCP OAuth consolidation.
|
||||||
|
|
||||||
|
Exercises the full chain — manager, provider subclass, disk watch, 401
|
||||||
|
dedup — with real file I/O and real imports (no transport mocks, no
|
||||||
|
subprocesses). These are the tests that would catch Cthulhu's original
|
||||||
|
BetterStack bug: an external process rewrites the tokens file on disk,
|
||||||
|
and the running Hermes session picks up the new tokens on the next auth
|
||||||
|
flow without requiring a restart.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
pytest.importorskip("mcp.client.auth.oauth2", reason="MCP SDK 1.26.0+ required")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_external_refresh_picked_up_without_restart(tmp_path, monkeypatch):
|
||||||
|
"""Simulate Cthulhu's cron workflow end-to-end.
|
||||||
|
|
||||||
|
1. A running Hermes session has OAuth tokens loaded in memory.
|
||||||
|
2. An external process (cron) writes fresh tokens to disk.
|
||||||
|
3. On the next auth flow, the manager's disk-watch invalidates the
|
||||||
|
in-memory state so the SDK re-reads from storage.
|
||||||
|
4. ``provider.context.current_tokens`` now reflects the new tokens
|
||||||
|
with no process restart required.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
token_dir = tmp_path / "mcp-tokens"
|
||||||
|
token_dir.mkdir(parents=True)
|
||||||
|
tokens_file = token_dir / "srv.json"
|
||||||
|
client_info_file = token_dir / "srv.client.json"
|
||||||
|
|
||||||
|
# Pre-seed the baseline state: valid tokens the session loaded at startup.
|
||||||
|
tokens_file.write_text(json.dumps({
|
||||||
|
"access_token": "OLD_ACCESS",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "OLD_REFRESH",
|
||||||
|
}))
|
||||||
|
client_info_file.write_text(json.dumps({
|
||||||
|
"client_id": "test-client",
|
||||||
|
"redirect_uris": ["http://127.0.0.1:12345/callback"],
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
}))
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
provider = mgr.get_or_build_provider(
|
||||||
|
"srv", "https://example.com/mcp", None,
|
||||||
|
)
|
||||||
|
assert provider is not None
|
||||||
|
|
||||||
|
# The SDK's _initialize reads tokens from storage into memory. This
|
||||||
|
# is what happens on the first http request under normal operation.
|
||||||
|
await provider._initialize()
|
||||||
|
assert provider.context.current_tokens.access_token == "OLD_ACCESS"
|
||||||
|
|
||||||
|
# Now record the baseline mtime in the manager (this happens
|
||||||
|
# automatically via the HermesMCPOAuthProvider.async_auth_flow
|
||||||
|
# pre-hook on the first real request, but we exercise it directly
|
||||||
|
# here for test determinism).
|
||||||
|
await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
|
||||||
|
# EXTERNAL PROCESS: cron rewrites the tokens file with fresh creds.
|
||||||
|
# The old refresh_token has been consumed by this external exchange.
|
||||||
|
future_mtime = time.time() + 1
|
||||||
|
tokens_file.write_text(json.dumps({
|
||||||
|
"access_token": "NEW_ACCESS",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"refresh_token": "NEW_REFRESH",
|
||||||
|
}))
|
||||||
|
os.utime(tokens_file, (future_mtime, future_mtime))
|
||||||
|
|
||||||
|
# The next auth flow should detect the mtime change and reload.
|
||||||
|
changed = await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
assert changed, "manager must detect the disk mtime change"
|
||||||
|
assert provider._initialized is False, "_initialized must flip so SDK re-reads storage"
|
||||||
|
|
||||||
|
# Simulate the next async_auth_flow: _initialize runs because _initialized=False.
|
||||||
|
await provider._initialize()
|
||||||
|
assert provider.context.current_tokens.access_token == "NEW_ACCESS"
|
||||||
|
assert provider.context.current_tokens.refresh_token == "NEW_REFRESH"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_401_deduplicates_concurrent_callers(tmp_path, monkeypatch):
|
||||||
|
"""Ten concurrent 401 handlers for the same token should fire one recovery.
|
||||||
|
|
||||||
|
Mirrors Claude Code's pending401Handlers dedup pattern — prevents N MCP
|
||||||
|
tool calls hitting 401 simultaneously from all independently clearing
|
||||||
|
caches and re-reading the keychain (which thrashes the storage and
|
||||||
|
bogs down startup per CC-1096).
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
token_dir = tmp_path / "mcp-tokens"
|
||||||
|
token_dir.mkdir(parents=True)
|
||||||
|
(token_dir / "srv.json").write_text(json.dumps({
|
||||||
|
"access_token": "TOK",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
}))
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
provider = mgr.get_or_build_provider(
|
||||||
|
"srv", "https://example.com/mcp", None,
|
||||||
|
)
|
||||||
|
assert provider is not None
|
||||||
|
|
||||||
|
# Count how many times invalidate_if_disk_changed is called — proxy for
|
||||||
|
# how many actual recovery attempts fire.
|
||||||
|
call_count = 0
|
||||||
|
real_invalidate = mgr.invalidate_if_disk_changed
|
||||||
|
|
||||||
|
async def counting(name):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
return await real_invalidate(name)
|
||||||
|
|
||||||
|
monkeypatch.setattr(mgr, "invalidate_if_disk_changed", counting)
|
||||||
|
|
||||||
|
# Fire 10 concurrent handlers with the same failed token.
|
||||||
|
results = await asyncio.gather(*(
|
||||||
|
mgr.handle_401("srv", "SAME_FAILED_TOKEN") for _ in range(10)
|
||||||
|
))
|
||||||
|
|
||||||
|
# All callers get the same result (the shared future's resolution).
|
||||||
|
assert all(r == results[0] for r in results), "dedup must return identical result"
|
||||||
|
# Exactly ONE recovery ran — the rest awaited the same pending future.
|
||||||
|
assert call_count == 1, f"expected 1 recovery attempt, got {call_count}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_401_returns_false_when_no_provider(tmp_path, monkeypatch):
|
||||||
|
"""handle_401 for an unknown server returns False cleanly."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
result = await mgr.handle_401("nonexistent", "any_token")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalidate_if_disk_changed_handles_missing_file(tmp_path, monkeypatch):
|
||||||
|
"""invalidate_if_disk_changed returns False when tokens file doesn't exist."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
|
||||||
|
# No tokens file exists yet — this is the pre-auth state
|
||||||
|
result = await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_is_reused_across_reconnects(tmp_path, monkeypatch):
|
||||||
|
"""The manager caches providers; multiple reconnects reuse the same instance.
|
||||||
|
|
||||||
|
This is what makes the disk-watch stick across reconnects: tearing down
|
||||||
|
the MCP session and rebuilding it (Task 5's _reconnect_event path) must
|
||||||
|
not create a new provider, otherwise ``last_mtime_ns`` resets and the
|
||||||
|
first post-reconnect auth flow would spuriously "detect" a change.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
|
||||||
|
# Simulate a reconnect: _run_http calls get_or_build_provider again
|
||||||
|
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
|
||||||
|
assert p1 is p2, "manager must cache the provider across reconnects"
|
||||||
141
tests/tools/test_mcp_oauth_manager.py
Normal file
141
tests/tools/test_mcp_oauth_manager.py
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""Tests for the MCP OAuth manager (tools/mcp_oauth_manager.py).
|
||||||
|
|
||||||
|
The manager consolidates the eight scattered MCP-OAuth call sites into a
|
||||||
|
single object with disk-mtime watch, dedup'd 401 handling, and a provider
|
||||||
|
cache. See `tools/mcp_oauth_manager.py` for design rationale.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip(
|
||||||
|
"mcp.client.auth.oauth2",
|
||||||
|
reason="MCP SDK 1.26.0+ required for OAuth support",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_is_singleton():
|
||||||
|
"""get_manager() returns the same instance across calls."""
|
||||||
|
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
|
||||||
|
reset_manager_for_tests()
|
||||||
|
m1 = get_manager()
|
||||||
|
m2 = get_manager()
|
||||||
|
assert m1 is m2
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_get_or_build_provider_caches(tmp_path, monkeypatch):
|
||||||
|
"""Calling get_or_build_provider twice with same name returns same provider."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
assert p1 is p2
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_get_or_build_rebuilds_on_url_change(tmp_path, monkeypatch):
|
||||||
|
"""Changing the URL discards the cached provider."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
p1 = mgr.get_or_build_provider("srv", "https://a.example.com/mcp", None)
|
||||||
|
p2 = mgr.get_or_build_provider("srv", "https://b.example.com/mcp", None)
|
||||||
|
assert p1 is not p2
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_remove_evicts_cache(tmp_path, monkeypatch):
|
||||||
|
"""remove(name) evicts the provider from cache AND deletes disk files."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager
|
||||||
|
|
||||||
|
# Pre-seed tokens on disk
|
||||||
|
token_dir = tmp_path / "mcp-tokens"
|
||||||
|
token_dir.mkdir(parents=True)
|
||||||
|
(token_dir / "srv.json").write_text(json.dumps({
|
||||||
|
"access_token": "TOK",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}))
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
p1 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
assert p1 is not None
|
||||||
|
assert (token_dir / "srv.json").exists()
|
||||||
|
|
||||||
|
mgr.remove("srv")
|
||||||
|
|
||||||
|
assert not (token_dir / "srv.json").exists()
|
||||||
|
p2 = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
assert p1 is not p2
|
||||||
|
|
||||||
|
|
||||||
|
def test_hermes_provider_subclass_exists():
|
||||||
|
"""HermesMCPOAuthProvider is defined and subclasses OAuthClientProvider."""
|
||||||
|
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS
|
||||||
|
from mcp.client.auth.oauth2 import OAuthClientProvider
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
assert issubclass(_HERMES_PROVIDER_CLS, OAuthClientProvider)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disk_watch_invalidates_on_mtime_change(tmp_path, monkeypatch):
|
||||||
|
"""When the tokens file mtime changes, provider._initialized flips False.
|
||||||
|
|
||||||
|
This is the behaviour Claude Code ships as
|
||||||
|
invalidateOAuthCacheIfDiskChanged (CC-1096 / GH#24317) and is the core
|
||||||
|
fix for Cthulhu's external-cron refresh workflow.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||||
|
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
token_dir = tmp_path / "mcp-tokens"
|
||||||
|
token_dir.mkdir(parents=True)
|
||||||
|
tokens_file = token_dir / "srv.json"
|
||||||
|
tokens_file.write_text(json.dumps({
|
||||||
|
"access_token": "OLD",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
}))
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
assert provider is not None
|
||||||
|
|
||||||
|
# First call: records mtime (zero -> real) -> returns True
|
||||||
|
changed1 = await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
assert changed1 is True
|
||||||
|
|
||||||
|
# No file change -> False
|
||||||
|
changed2 = await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
assert changed2 is False
|
||||||
|
|
||||||
|
# Touch file with a newer mtime
|
||||||
|
future_mtime = time.time() + 10
|
||||||
|
os.utime(tokens_file, (future_mtime, future_mtime))
|
||||||
|
|
||||||
|
changed3 = await mgr.invalidate_if_disk_changed("srv")
|
||||||
|
assert changed3 is True
|
||||||
|
# _initialized flipped — next async_auth_flow will re-read from disk
|
||||||
|
assert provider._initialized is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_builds_hermes_provider_subclass(tmp_path, monkeypatch):
|
||||||
|
"""get_or_build_provider returns HermesMCPOAuthProvider, not plain OAuthClientProvider."""
|
||||||
|
from tools.mcp_oauth_manager import (
|
||||||
|
MCPOAuthManager, _HERMES_PROVIDER_CLS, reset_manager_for_tests,
|
||||||
|
)
|
||||||
|
reset_manager_for_tests()
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
mgr = MCPOAuthManager()
|
||||||
|
provider = mgr.get_or_build_provider("srv", "https://example.com/mcp", None)
|
||||||
|
|
||||||
|
assert _HERMES_PROVIDER_CLS is not None
|
||||||
|
assert isinstance(provider, _HERMES_PROVIDER_CLS)
|
||||||
|
assert provider._hermes_server_name == "srv"
|
||||||
|
|
||||||
57
tests/tools/test_mcp_reconnect_signal.py
Normal file
57
tests/tools/test_mcp_reconnect_signal.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Tests for the MCPServerTask reconnect signal.
|
||||||
|
|
||||||
|
When the OAuth layer cannot recover in-place (e.g., external refresh of a
|
||||||
|
single-use refresh_token made the SDK's in-memory refresh fail), the tool
|
||||||
|
handler signals MCPServerTask to tear down the current MCP session and
|
||||||
|
reconnect with fresh credentials. This file exercises the signal plumbing
|
||||||
|
in isolation from the full stdio/http transport machinery.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconnect_event_attribute_exists():
|
||||||
|
"""MCPServerTask has a _reconnect_event alongside _shutdown_event."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
task = MCPServerTask("test")
|
||||||
|
assert hasattr(task, "_reconnect_event")
|
||||||
|
assert isinstance(task._reconnect_event, asyncio.Event)
|
||||||
|
assert not task._reconnect_event.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wait_for_lifecycle_event_returns_reconnect():
|
||||||
|
"""When _reconnect_event fires, helper returns 'reconnect' and clears it."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
task = MCPServerTask("test")
|
||||||
|
|
||||||
|
task._reconnect_event.set()
|
||||||
|
reason = await task._wait_for_lifecycle_event()
|
||||||
|
assert reason == "reconnect"
|
||||||
|
# Should have cleared so the next cycle starts fresh
|
||||||
|
assert not task._reconnect_event.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wait_for_lifecycle_event_returns_shutdown():
|
||||||
|
"""When _shutdown_event fires, helper returns 'shutdown'."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
task = MCPServerTask("test")
|
||||||
|
|
||||||
|
task._shutdown_event.set()
|
||||||
|
reason = await task._wait_for_lifecycle_event()
|
||||||
|
assert reason == "shutdown"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_wait_for_lifecycle_event_shutdown_wins_when_both_set():
|
||||||
|
"""If both events are set simultaneously, shutdown takes precedence."""
|
||||||
|
from tools.mcp_tool import MCPServerTask
|
||||||
|
task = MCPServerTask("test")
|
||||||
|
|
||||||
|
task._shutdown_event.set()
|
||||||
|
task._reconnect_event.set()
|
||||||
|
reason = await task._wait_for_lifecycle_event()
|
||||||
|
assert reason == "shutdown"
|
||||||
139
tests/tools/test_mcp_tool_401_handling.py
Normal file
139
tests/tools/test_mcp_tool_401_handling.py
Normal file
|
|
@ -0,0 +1,139 @@
|
||||||
|
"""Tests for MCP tool-handler auth-failure detection.
|
||||||
|
|
||||||
|
When a tool call raises UnauthorizedError / OAuthNonInteractiveError /
|
||||||
|
httpx.HTTPStatusError(401), the handler should:
|
||||||
|
1. Ask MCPOAuthManager.handle_401 if recovery is viable.
|
||||||
|
2. If yes, trigger MCPServerTask._reconnect_event and retry once.
|
||||||
|
3. If no, return a structured needs_reauth error so the model stops
|
||||||
|
hallucinating manual refresh attempts.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
pytest.importorskip("mcp.client.auth.oauth2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_auth_error_detects_oauth_flow_error():
|
||||||
|
from tools.mcp_tool import _is_auth_error
|
||||||
|
from mcp.client.auth import OAuthFlowError
|
||||||
|
|
||||||
|
assert _is_auth_error(OAuthFlowError("expired")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_auth_error_detects_oauth_non_interactive():
|
||||||
|
from tools.mcp_tool import _is_auth_error
|
||||||
|
from tools.mcp_oauth import OAuthNonInteractiveError
|
||||||
|
|
||||||
|
assert _is_auth_error(OAuthNonInteractiveError("no browser")) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_auth_error_detects_httpx_401():
|
||||||
|
from tools.mcp_tool import _is_auth_error
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 401
|
||||||
|
exc = httpx.HTTPStatusError("unauth", request=MagicMock(), response=response)
|
||||||
|
assert _is_auth_error(exc) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_auth_error_rejects_httpx_500():
|
||||||
|
from tools.mcp_tool import _is_auth_error
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.status_code = 500
|
||||||
|
exc = httpx.HTTPStatusError("oops", request=MagicMock(), response=response)
|
||||||
|
assert _is_auth_error(exc) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_auth_error_rejects_generic_exception():
|
||||||
|
from tools.mcp_tool import _is_auth_error
|
||||||
|
assert _is_auth_error(ValueError("not auth")) is False
|
||||||
|
assert _is_auth_error(RuntimeError("not auth")) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_tool_handler_returns_needs_reauth_on_unrecoverable_401(monkeypatch, tmp_path):
|
||||||
|
"""When session.call_tool raises 401 and handle_401 returns False,
|
||||||
|
handler returns a structured needs_reauth error (not a generic failure)."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
|
||||||
|
from tools.mcp_tool import _make_tool_handler
|
||||||
|
from tools.mcp_oauth_manager import get_manager, reset_manager_for_tests
|
||||||
|
from mcp.client.auth import OAuthFlowError
|
||||||
|
|
||||||
|
reset_manager_for_tests()
|
||||||
|
|
||||||
|
# Stub server
|
||||||
|
server = MagicMock()
|
||||||
|
server.name = "srv"
|
||||||
|
session = MagicMock()
|
||||||
|
|
||||||
|
async def _call_tool_raises(*a, **kw):
|
||||||
|
raise OAuthFlowError("token expired")
|
||||||
|
|
||||||
|
session.call_tool = _call_tool_raises
|
||||||
|
server.session = session
|
||||||
|
server._reconnect_event = MagicMock()
|
||||||
|
server._ready = MagicMock()
|
||||||
|
server._ready.is_set.return_value = True
|
||||||
|
|
||||||
|
from tools import mcp_tool
|
||||||
|
mcp_tool._servers["srv"] = server
|
||||||
|
mcp_tool._server_error_counts.pop("srv", None)
|
||||||
|
|
||||||
|
# Ensure the MCP loop exists (run_on_mcp_loop needs it)
|
||||||
|
mcp_tool._ensure_mcp_loop()
|
||||||
|
|
||||||
|
# Force handle_401 to return False (no recovery available)
|
||||||
|
mgr = get_manager()
|
||||||
|
|
||||||
|
async def _h401(name, token=None):
|
||||||
|
return False
|
||||||
|
|
||||||
|
monkeypatch.setattr(mgr, "handle_401", _h401)
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler = _make_tool_handler("srv", "tool1", 10.0)
|
||||||
|
result = handler({"arg": "v"})
|
||||||
|
parsed = json.loads(result)
|
||||||
|
assert parsed.get("needs_reauth") is True, f"expected needs_reauth, got: {parsed}"
|
||||||
|
assert parsed.get("server") == "srv"
|
||||||
|
assert "re-auth" in parsed.get("error", "").lower() or "reauth" in parsed.get("error", "").lower()
|
||||||
|
finally:
|
||||||
|
mcp_tool._servers.pop("srv", None)
|
||||||
|
mcp_tool._server_error_counts.pop("srv", None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_call_tool_handler_non_auth_error_still_generic(monkeypatch, tmp_path):
|
||||||
|
"""Non-auth exceptions still surface via the generic error path, not needs_reauth."""
|
||||||
|
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||||
|
from tools.mcp_tool import _make_tool_handler
|
||||||
|
|
||||||
|
server = MagicMock()
|
||||||
|
server.name = "srv"
|
||||||
|
session = MagicMock()
|
||||||
|
|
||||||
|
async def _raises(*a, **kw):
|
||||||
|
raise RuntimeError("unrelated")
|
||||||
|
|
||||||
|
session.call_tool = _raises
|
||||||
|
server.session = session
|
||||||
|
|
||||||
|
from tools import mcp_tool
|
||||||
|
mcp_tool._servers["srv"] = server
|
||||||
|
mcp_tool._server_error_counts.pop("srv", None)
|
||||||
|
mcp_tool._ensure_mcp_loop()
|
||||||
|
|
||||||
|
try:
|
||||||
|
handler = _make_tool_handler("srv", "tool1", 10.0)
|
||||||
|
result = handler({"arg": "v"})
|
||||||
|
parsed = json.loads(result)
|
||||||
|
assert "needs_reauth" not in parsed
|
||||||
|
assert "MCP call failed" in parsed.get("error", "")
|
||||||
|
finally:
|
||||||
|
mcp_tool._servers.pop("srv", None)
|
||||||
|
mcp_tool._server_error_counts.pop("srv", None)
|
||||||
|
|
@ -375,6 +375,103 @@ def remove_oauth_tokens(server_name: str) -> None:
|
||||||
logger.info("OAuth tokens removed for '%s'", server_name)
|
logger.info("OAuth tokens removed for '%s'", server_name)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Extracted helpers (Task 3 of MCP OAuth consolidation)
|
||||||
|
#
|
||||||
|
# These compose into ``build_oauth_auth`` below, and are also used by
|
||||||
|
# ``tools.mcp_oauth_manager.MCPOAuthManager._build_provider`` so the two
|
||||||
|
# construction paths share one implementation.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_callback_port(cfg: dict) -> int:
|
||||||
|
"""Pick or validate the OAuth callback port.
|
||||||
|
|
||||||
|
Stores the resolved port into ``cfg['_resolved_port']`` so sibling
|
||||||
|
helpers (and the manager) can read it from the same dict. Returns the
|
||||||
|
resolved port.
|
||||||
|
|
||||||
|
NOTE: also sets the legacy module-level ``_oauth_port`` so existing
|
||||||
|
calls to ``_wait_for_callback`` keep working. The legacy global is
|
||||||
|
the root cause of issue #5344 (port collision on concurrent OAuth
|
||||||
|
flows); replacing it with a ContextVar is out of scope for this
|
||||||
|
consolidation PR.
|
||||||
|
"""
|
||||||
|
global _oauth_port
|
||||||
|
requested = int(cfg.get("redirect_port", 0))
|
||||||
|
port = _find_free_port() if requested == 0 else requested
|
||||||
|
cfg["_resolved_port"] = port
|
||||||
|
_oauth_port = port # legacy consumer: _wait_for_callback reads this
|
||||||
|
return port
|
||||||
|
|
||||||
|
|
||||||
|
def _build_client_metadata(cfg: dict) -> "OAuthClientMetadata":
|
||||||
|
"""Build OAuthClientMetadata from the oauth config dict.
|
||||||
|
|
||||||
|
Requires ``cfg['_resolved_port']`` to have been populated by
|
||||||
|
:func:`_configure_callback_port` first.
|
||||||
|
"""
|
||||||
|
port = cfg.get("_resolved_port")
|
||||||
|
if port is None:
|
||||||
|
raise ValueError(
|
||||||
|
"_configure_callback_port() must be called before _build_client_metadata()"
|
||||||
|
)
|
||||||
|
client_name = cfg.get("client_name", "Hermes Agent")
|
||||||
|
scope = cfg.get("scope")
|
||||||
|
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||||
|
|
||||||
|
metadata_kwargs: dict[str, Any] = {
|
||||||
|
"client_name": client_name,
|
||||||
|
"redirect_uris": [AnyUrl(redirect_uri)],
|
||||||
|
"grant_types": ["authorization_code", "refresh_token"],
|
||||||
|
"response_types": ["code"],
|
||||||
|
"token_endpoint_auth_method": "none",
|
||||||
|
}
|
||||||
|
if scope:
|
||||||
|
metadata_kwargs["scope"] = scope
|
||||||
|
if cfg.get("client_secret"):
|
||||||
|
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
||||||
|
|
||||||
|
return OAuthClientMetadata.model_validate(metadata_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_preregister_client(
|
||||||
|
storage: "HermesTokenStorage",
|
||||||
|
cfg: dict,
|
||||||
|
client_metadata: "OAuthClientMetadata",
|
||||||
|
) -> None:
|
||||||
|
"""If cfg has a pre-registered client_id, persist it to storage."""
|
||||||
|
client_id = cfg.get("client_id")
|
||||||
|
if not client_id:
|
||||||
|
return
|
||||||
|
port = cfg["_resolved_port"]
|
||||||
|
redirect_uri = f"http://127.0.0.1:{port}/callback"
|
||||||
|
|
||||||
|
info_dict: dict[str, Any] = {
|
||||||
|
"client_id": client_id,
|
||||||
|
"redirect_uris": [redirect_uri],
|
||||||
|
"grant_types": client_metadata.grant_types,
|
||||||
|
"response_types": client_metadata.response_types,
|
||||||
|
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
||||||
|
}
|
||||||
|
if cfg.get("client_secret"):
|
||||||
|
info_dict["client_secret"] = cfg["client_secret"]
|
||||||
|
if cfg.get("client_name"):
|
||||||
|
info_dict["client_name"] = cfg["client_name"]
|
||||||
|
if cfg.get("scope"):
|
||||||
|
info_dict["scope"] = cfg["scope"]
|
||||||
|
|
||||||
|
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
||||||
|
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
||||||
|
logger.debug("Pre-registered client_id=%s for '%s'", client_id, storage._server_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_base_url(server_url: str) -> str:
|
||||||
|
"""Strip path component from server URL, returning the base origin."""
|
||||||
|
parsed = urlparse(server_url)
|
||||||
|
return f"{parsed.scheme}://{parsed.netloc}"
|
||||||
|
|
||||||
|
|
||||||
def build_oauth_auth(
|
def build_oauth_auth(
|
||||||
server_name: str,
|
server_name: str,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
|
|
@ -382,7 +479,9 @@ def build_oauth_auth(
|
||||||
) -> "OAuthClientProvider | None":
|
) -> "OAuthClientProvider | None":
|
||||||
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
|
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
|
||||||
|
|
||||||
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config.
|
Public API preserved for backwards compatibility. New code should use
|
||||||
|
:func:`tools.mcp_oauth_manager.get_manager` so OAuth state is shared
|
||||||
|
across config-time, runtime, and reconnect paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
server_name: Server key in mcp_servers config (used for storage).
|
server_name: Server key in mcp_servers config (used for storage).
|
||||||
|
|
@ -396,87 +495,32 @@ def build_oauth_auth(
|
||||||
if not _OAUTH_AVAILABLE:
|
if not _OAUTH_AVAILABLE:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP OAuth requested for '%s' but SDK auth types are not available. "
|
"MCP OAuth requested for '%s' but SDK auth types are not available. "
|
||||||
"Install with: pip install 'mcp>=1.10.0'",
|
"Install with: pip install 'mcp>=1.26.0'",
|
||||||
server_name,
|
server_name,
|
||||||
)
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
global _oauth_port
|
cfg = dict(oauth_config or {}) # copy — we mutate _resolved_port
|
||||||
|
|
||||||
cfg = oauth_config or {}
|
|
||||||
|
|
||||||
# --- Storage ---
|
|
||||||
storage = HermesTokenStorage(server_name)
|
storage = HermesTokenStorage(server_name)
|
||||||
|
|
||||||
# --- Non-interactive warning ---
|
|
||||||
if not _is_interactive() and not storage.has_cached_tokens():
|
if not _is_interactive() and not storage.has_cached_tokens():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. "
|
"MCP OAuth for '%s': non-interactive environment and no cached tokens "
|
||||||
"The OAuth flow requires browser authorization. Run interactively first "
|
"found. The OAuth flow requires browser authorization. Run "
|
||||||
"to complete the initial authorization, then cached tokens will be reused.",
|
"interactively first to complete the initial authorization, then "
|
||||||
|
"cached tokens will be reused.",
|
||||||
server_name,
|
server_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Pick callback port ---
|
_configure_callback_port(cfg)
|
||||||
redirect_port = int(cfg.get("redirect_port", 0))
|
client_metadata = _build_client_metadata(cfg)
|
||||||
if redirect_port == 0:
|
_maybe_preregister_client(storage, cfg, client_metadata)
|
||||||
redirect_port = _find_free_port()
|
|
||||||
_oauth_port = redirect_port
|
|
||||||
|
|
||||||
# --- Client metadata ---
|
return OAuthClientProvider(
|
||||||
client_name = cfg.get("client_name", "Hermes Agent")
|
server_url=_parse_base_url(server_url),
|
||||||
scope = cfg.get("scope")
|
|
||||||
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
|
||||||
|
|
||||||
metadata_kwargs: dict[str, Any] = {
|
|
||||||
"client_name": client_name,
|
|
||||||
"redirect_uris": [AnyUrl(redirect_uri)],
|
|
||||||
"grant_types": ["authorization_code", "refresh_token"],
|
|
||||||
"response_types": ["code"],
|
|
||||||
"token_endpoint_auth_method": "none",
|
|
||||||
}
|
|
||||||
if scope:
|
|
||||||
metadata_kwargs["scope"] = scope
|
|
||||||
|
|
||||||
client_secret = cfg.get("client_secret")
|
|
||||||
if client_secret:
|
|
||||||
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
|
||||||
|
|
||||||
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
|
|
||||||
|
|
||||||
# --- Pre-registered client ---
|
|
||||||
client_id = cfg.get("client_id")
|
|
||||||
if client_id:
|
|
||||||
info_dict: dict[str, Any] = {
|
|
||||||
"client_id": client_id,
|
|
||||||
"redirect_uris": [redirect_uri],
|
|
||||||
"grant_types": client_metadata.grant_types,
|
|
||||||
"response_types": client_metadata.response_types,
|
|
||||||
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
|
||||||
}
|
|
||||||
if client_secret:
|
|
||||||
info_dict["client_secret"] = client_secret
|
|
||||||
if client_name:
|
|
||||||
info_dict["client_name"] = client_name
|
|
||||||
if scope:
|
|
||||||
info_dict["scope"] = scope
|
|
||||||
|
|
||||||
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
|
||||||
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
|
||||||
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
|
|
||||||
|
|
||||||
# --- Base URL for discovery ---
|
|
||||||
parsed = urlparse(server_url)
|
|
||||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
|
||||||
|
|
||||||
# --- Build provider ---
|
|
||||||
provider = OAuthClientProvider(
|
|
||||||
server_url=base_url,
|
|
||||||
client_metadata=client_metadata,
|
client_metadata=client_metadata,
|
||||||
storage=storage,
|
storage=storage,
|
||||||
redirect_handler=_redirect_handler,
|
redirect_handler=_redirect_handler,
|
||||||
callback_handler=_wait_for_callback,
|
callback_handler=_wait_for_callback,
|
||||||
timeout=float(cfg.get("timeout", 300)),
|
timeout=float(cfg.get("timeout", 300)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return provider
|
|
||||||
|
|
|
||||||
413
tools/mcp_oauth_manager.py
Normal file
413
tools/mcp_oauth_manager.py
Normal file
|
|
@ -0,0 +1,413 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Central manager for per-server MCP OAuth state.
|
||||||
|
|
||||||
|
One instance shared across the process. Holds per-server OAuth provider
|
||||||
|
instances and coordinates:
|
||||||
|
|
||||||
|
- **Cross-process token reload** via mtime-based disk watch. When an external
|
||||||
|
process (e.g. a user cron job) refreshes tokens on disk, the next auth flow
|
||||||
|
picks them up without requiring a process restart.
|
||||||
|
- **401 deduplication** via in-flight futures. When N concurrent tool calls
|
||||||
|
all hit 401 with the same access_token, only one recovery attempt fires;
|
||||||
|
the rest await the same result.
|
||||||
|
- **Reconnect signalling** for long-lived MCP sessions. The manager itself
|
||||||
|
does not drive reconnection — the `MCPServerTask` in `mcp_tool.py` does —
|
||||||
|
but the manager is the single source of truth that decides when reconnect
|
||||||
|
is warranted.
|
||||||
|
|
||||||
|
Replaces what used to be scattered across eight call sites in `mcp_oauth.py`,
|
||||||
|
`mcp_tool.py`, and `hermes_cli/mcp_config.py`. This module is the ONLY place
|
||||||
|
that instantiates the MCP SDK's `OAuthClientProvider` — all other code paths
|
||||||
|
go through `get_manager()`.
|
||||||
|
|
||||||
|
Design reference:
|
||||||
|
|
||||||
|
- Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
||||||
|
(``claude-code/src/utils/auth.ts:1320``, CC-1096 / GH#24317). Identical
|
||||||
|
external-refresh staleness bug class.
|
||||||
|
- Codex's ``refresh_oauth_if_needed`` / ``persist_if_needed``
|
||||||
|
(``codex-rs/rmcp-client/src/rmcp_client.rs:805``). We lean on the MCP SDK's
|
||||||
|
lazy refresh rather than calling refresh before every op, because one
|
||||||
|
``stat()`` per tool call is cheaper than an ``await`` + potential refresh
|
||||||
|
round-trip, and the SDK's in-memory expiry path is already correct.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-server entry
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _ProviderEntry:
|
||||||
|
"""Per-server OAuth state tracked by the manager.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
server_url: The MCP server URL used to build the provider. Tracked
|
||||||
|
so we can discard a cached provider if the URL changes.
|
||||||
|
oauth_config: Optional dict from ``mcp_servers.<name>.oauth``.
|
||||||
|
provider: The ``httpx.Auth``-compatible provider wrapping the MCP
|
||||||
|
SDK. None until first use.
|
||||||
|
last_mtime_ns: Last-seen ``st_mtime_ns`` of the on-disk tokens file.
|
||||||
|
Zero if never read. Used by :meth:`MCPOAuthManager.invalidate_if_disk_changed`
|
||||||
|
to detect external refreshes.
|
||||||
|
lock: Serialises concurrent access to this entry's state. Bound to
|
||||||
|
whichever asyncio loop first awaits it (the MCP event loop).
|
||||||
|
pending_401: In-flight 401-handler futures keyed by the failed
|
||||||
|
access_token, for deduplicating thundering-herd 401s. Mirrors
|
||||||
|
Claude Code's ``pending401Handlers`` map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
server_url: str
|
||||||
|
oauth_config: Optional[dict]
|
||||||
|
provider: Optional[Any] = None
|
||||||
|
last_mtime_ns: int = 0
|
||||||
|
lock: asyncio.Lock = field(default_factory=asyncio.Lock)
|
||||||
|
pending_401: dict[str, "asyncio.Future[bool]"] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# HermesMCPOAuthProvider — OAuthClientProvider subclass with disk-watch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hermes_provider_class() -> Optional[type]:
|
||||||
|
"""Lazy-import the SDK base class and return our subclass.
|
||||||
|
|
||||||
|
Wrapped in a function so this module imports cleanly even when the
|
||||||
|
MCP SDK's OAuth module is unavailable (e.g. older mcp versions).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from mcp.client.auth.oauth2 import OAuthClientProvider
|
||||||
|
except ImportError: # pragma: no cover — SDK required in CI
|
||||||
|
return None
|
||||||
|
|
||||||
|
class HermesMCPOAuthProvider(OAuthClientProvider):
|
||||||
|
"""OAuthClientProvider with pre-flow disk-mtime reload.
|
||||||
|
|
||||||
|
Before every ``async_auth_flow`` invocation, asks the manager to
|
||||||
|
check whether the tokens file on disk has been modified externally.
|
||||||
|
If so, the manager resets ``_initialized`` so the next flow
|
||||||
|
re-reads from storage.
|
||||||
|
|
||||||
|
This makes external-process refreshes (cron, another CLI instance)
|
||||||
|
visible to the running MCP session without requiring a restart.
|
||||||
|
|
||||||
|
Reference: Claude Code's ``invalidateOAuthCacheIfDiskChanged``
|
||||||
|
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._hermes_server_name = server_name
|
||||||
|
|
||||||
|
async def async_auth_flow(self, request): # type: ignore[override]
|
||||||
|
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
||||||
|
# Any failure here is non-fatal — we just log and proceed with
|
||||||
|
# whatever state the SDK already has.
|
||||||
|
try:
|
||||||
|
await get_manager().invalidate_if_disk_changed(
|
||||||
|
self._hermes_server_name
|
||||||
|
)
|
||||||
|
except Exception as exc: # pragma: no cover — defensive
|
||||||
|
logger.debug(
|
||||||
|
"MCP OAuth '%s': pre-flow disk-watch failed (non-fatal): %s",
|
||||||
|
self._hermes_server_name, exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delegate to the SDK's auth flow
|
||||||
|
async for item in super().async_auth_flow(request):
|
||||||
|
yield item
|
||||||
|
|
||||||
|
return HermesMCPOAuthProvider
|
||||||
|
|
||||||
|
|
||||||
|
# Cached at import time. Tested and used by :class:`MCPOAuthManager`.
|
||||||
|
_HERMES_PROVIDER_CLS: Optional[type] = _make_hermes_provider_class()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Manager
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class MCPOAuthManager:
|
||||||
|
"""Single source of truth for per-server MCP OAuth state.
|
||||||
|
|
||||||
|
Thread-safe: the ``_entries`` dict is guarded by ``_entries_lock`` for
|
||||||
|
get-or-create semantics. Per-entry state is guarded by the entry's own
|
||||||
|
``asyncio.Lock`` (used from the MCP event loop thread).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._entries: dict[str, _ProviderEntry] = {}
|
||||||
|
self._entries_lock = threading.Lock()
|
||||||
|
|
||||||
|
# -- Provider construction / caching -------------------------------------
|
||||||
|
|
||||||
|
def get_or_build_provider(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
server_url: str,
|
||||||
|
oauth_config: Optional[dict],
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Return a cached OAuth provider for ``server_name`` or build one.
|
||||||
|
|
||||||
|
Idempotent: repeat calls with the same name return the same instance.
|
||||||
|
If ``server_url`` changes for a given name, the cached entry is
|
||||||
|
discarded and a fresh provider is built.
|
||||||
|
|
||||||
|
Returns None if the MCP SDK's OAuth support is unavailable.
|
||||||
|
"""
|
||||||
|
with self._entries_lock:
|
||||||
|
entry = self._entries.get(server_name)
|
||||||
|
if entry is not None and entry.server_url != server_url:
|
||||||
|
logger.info(
|
||||||
|
"MCP OAuth '%s': URL changed from %s to %s, discarding cache",
|
||||||
|
server_name, entry.server_url, server_url,
|
||||||
|
)
|
||||||
|
entry = None
|
||||||
|
|
||||||
|
if entry is None:
|
||||||
|
entry = _ProviderEntry(
|
||||||
|
server_url=server_url,
|
||||||
|
oauth_config=oauth_config,
|
||||||
|
)
|
||||||
|
self._entries[server_name] = entry
|
||||||
|
|
||||||
|
if entry.provider is None:
|
||||||
|
entry.provider = self._build_provider(server_name, entry)
|
||||||
|
|
||||||
|
return entry.provider
|
||||||
|
|
||||||
|
def _build_provider(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
entry: _ProviderEntry,
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""Build the underlying OAuth provider.
|
||||||
|
|
||||||
|
Constructs :class:`HermesMCPOAuthProvider` directly using the helpers
|
||||||
|
extracted from ``tools.mcp_oauth``. The subclass injects a pre-flow
|
||||||
|
disk-watch hook so external token refreshes (cron, other CLI
|
||||||
|
instances) are visible to running MCP sessions.
|
||||||
|
|
||||||
|
Returns None if the MCP SDK's OAuth support is unavailable.
|
||||||
|
"""
|
||||||
|
if _HERMES_PROVIDER_CLS is None:
|
||||||
|
logger.warning(
|
||||||
|
"MCP OAuth '%s': SDK auth module unavailable", server_name,
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Local imports avoid circular deps at module import time.
|
||||||
|
from tools.mcp_oauth import (
|
||||||
|
HermesTokenStorage,
|
||||||
|
_OAUTH_AVAILABLE,
|
||||||
|
_build_client_metadata,
|
||||||
|
_configure_callback_port,
|
||||||
|
_is_interactive,
|
||||||
|
_maybe_preregister_client,
|
||||||
|
_parse_base_url,
|
||||||
|
_redirect_handler,
|
||||||
|
_wait_for_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _OAUTH_AVAILABLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cfg = dict(entry.oauth_config or {})
|
||||||
|
storage = HermesTokenStorage(server_name)
|
||||||
|
|
||||||
|
if not _is_interactive() and not storage.has_cached_tokens():
|
||||||
|
logger.warning(
|
||||||
|
"MCP OAuth for '%s': non-interactive environment and no "
|
||||||
|
"cached tokens found. Run interactively first to complete "
|
||||||
|
"initial authorization.",
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
_configure_callback_port(cfg)
|
||||||
|
client_metadata = _build_client_metadata(cfg)
|
||||||
|
_maybe_preregister_client(storage, cfg, client_metadata)
|
||||||
|
|
||||||
|
return _HERMES_PROVIDER_CLS(
|
||||||
|
server_name=server_name,
|
||||||
|
server_url=_parse_base_url(entry.server_url),
|
||||||
|
client_metadata=client_metadata,
|
||||||
|
storage=storage,
|
||||||
|
redirect_handler=_redirect_handler,
|
||||||
|
callback_handler=_wait_for_callback,
|
||||||
|
timeout=float(cfg.get("timeout", 300)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove(self, server_name: str) -> None:
|
||||||
|
"""Evict the provider from cache AND delete tokens from disk.
|
||||||
|
|
||||||
|
Called by ``hermes mcp remove <name>`` and (indirectly) by
|
||||||
|
``hermes mcp login <name>`` during forced re-auth.
|
||||||
|
"""
|
||||||
|
with self._entries_lock:
|
||||||
|
self._entries.pop(server_name, None)
|
||||||
|
|
||||||
|
from tools.mcp_oauth import remove_oauth_tokens
|
||||||
|
remove_oauth_tokens(server_name)
|
||||||
|
logger.info(
|
||||||
|
"MCP OAuth '%s': evicted from cache and removed from disk",
|
||||||
|
server_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# -- Disk watch ----------------------------------------------------------
|
||||||
|
|
||||||
|
async def invalidate_if_disk_changed(self, server_name: str) -> bool:
|
||||||
|
"""If the tokens file on disk has a newer mtime than last-seen, force
|
||||||
|
the MCP SDK provider to reload its in-memory state.
|
||||||
|
|
||||||
|
Returns True if the cache was invalidated (mtime differed). This is
|
||||||
|
the core fix for the external-refresh workflow: a cron job writes
|
||||||
|
fresh tokens to disk, and on the next tool call the running MCP
|
||||||
|
session picks them up without a restart.
|
||||||
|
"""
|
||||||
|
from tools.mcp_oauth import _get_token_dir, _safe_filename
|
||||||
|
|
||||||
|
entry = self._entries.get(server_name)
|
||||||
|
if entry is None or entry.provider is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async with entry.lock:
|
||||||
|
tokens_path = _get_token_dir() / f"{_safe_filename(server_name)}.json"
|
||||||
|
try:
|
||||||
|
mtime_ns = tokens_path.stat().st_mtime_ns
|
||||||
|
except (FileNotFoundError, OSError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if mtime_ns != entry.last_mtime_ns:
|
||||||
|
old = entry.last_mtime_ns
|
||||||
|
entry.last_mtime_ns = mtime_ns
|
||||||
|
# Force the SDK's OAuthClientProvider to reload from storage
|
||||||
|
# on its next auth flow. `_initialized` is private API but
|
||||||
|
# stable across the MCP SDK versions we pin (>=1.26.0).
|
||||||
|
if hasattr(entry.provider, "_initialized"):
|
||||||
|
entry.provider._initialized = False # noqa: SLF001
|
||||||
|
logger.info(
|
||||||
|
"MCP OAuth '%s': tokens file changed (mtime %d -> %d), "
|
||||||
|
"forcing reload",
|
||||||
|
server_name, old, mtime_ns,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
# -- 401 handler (dedup'd) -----------------------------------------------
|
||||||
|
|
||||||
|
async def handle_401(
|
||||||
|
self,
|
||||||
|
server_name: str,
|
||||||
|
failed_access_token: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Handle a 401 from a tool call, deduplicated across concurrent callers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a (possibly new) access token is now available — caller
|
||||||
|
should trigger a reconnect and retry the operation.
|
||||||
|
False if no recovery path exists — caller should surface a
|
||||||
|
``needs_reauth`` error to the model so it stops hallucinating
|
||||||
|
manual refresh attempts.
|
||||||
|
|
||||||
|
Thundering-herd protection: if N concurrent tool calls hit 401 with
|
||||||
|
the same ``failed_access_token``, only one recovery attempt fires.
|
||||||
|
Others await the same future.
|
||||||
|
"""
|
||||||
|
entry = self._entries.get(server_name)
|
||||||
|
if entry is None or entry.provider is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
key = failed_access_token or "<unknown>"
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
async with entry.lock:
|
||||||
|
pending = entry.pending_401.get(key)
|
||||||
|
if pending is None:
|
||||||
|
pending = loop.create_future()
|
||||||
|
entry.pending_401[key] = pending
|
||||||
|
|
||||||
|
async def _do_handle() -> None:
|
||||||
|
try:
|
||||||
|
# Step 1: Did disk change? Picks up external refresh.
|
||||||
|
disk_changed = await self.invalidate_if_disk_changed(
|
||||||
|
server_name
|
||||||
|
)
|
||||||
|
if disk_changed:
|
||||||
|
if not pending.done():
|
||||||
|
pending.set_result(True)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Step 2: No disk change — if the SDK can refresh
|
||||||
|
# in-place, let the caller retry. The SDK's httpx.Auth
|
||||||
|
# flow will issue the refresh on the next request.
|
||||||
|
provider = entry.provider
|
||||||
|
ctx = getattr(provider, "context", None)
|
||||||
|
can_refresh = False
|
||||||
|
if ctx is not None:
|
||||||
|
can_refresh_fn = getattr(ctx, "can_refresh_token", None)
|
||||||
|
if callable(can_refresh_fn):
|
||||||
|
try:
|
||||||
|
can_refresh = bool(can_refresh_fn())
|
||||||
|
except Exception:
|
||||||
|
can_refresh = False
|
||||||
|
if not pending.done():
|
||||||
|
pending.set_result(can_refresh)
|
||||||
|
except Exception as exc: # pragma: no cover — defensive
|
||||||
|
logger.warning(
|
||||||
|
"MCP OAuth '%s': 401 handler failed: %s",
|
||||||
|
server_name, exc,
|
||||||
|
)
|
||||||
|
if not pending.done():
|
||||||
|
pending.set_result(False)
|
||||||
|
finally:
|
||||||
|
entry.pending_401.pop(key, None)
|
||||||
|
|
||||||
|
asyncio.create_task(_do_handle())
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await pending
|
||||||
|
except Exception as exc: # pragma: no cover — defensive
|
||||||
|
logger.warning(
|
||||||
|
"MCP OAuth '%s': awaiting 401 handler failed: %s",
|
||||||
|
server_name, exc,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Module-level singleton
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
_MANAGER: Optional[MCPOAuthManager] = None
|
||||||
|
_MANAGER_LOCK = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_manager() -> MCPOAuthManager:
|
||||||
|
"""Return the process-wide :class:`MCPOAuthManager` singleton."""
|
||||||
|
global _MANAGER
|
||||||
|
with _MANAGER_LOCK:
|
||||||
|
if _MANAGER is None:
|
||||||
|
_MANAGER = MCPOAuthManager()
|
||||||
|
return _MANAGER
|
||||||
|
|
||||||
|
|
||||||
|
def reset_manager_for_tests() -> None:
|
||||||
|
"""Test-only helper: drop the singleton so fixtures start clean."""
|
||||||
|
global _MANAGER
|
||||||
|
with _MANAGER_LOCK:
|
||||||
|
_MANAGER = None
|
||||||
|
|
@ -783,7 +783,8 @@ class MCPServerTask:
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"name", "session", "tool_timeout",
|
"name", "session", "tool_timeout",
|
||||||
"_task", "_ready", "_shutdown_event", "_tools", "_error", "_config",
|
"_task", "_ready", "_shutdown_event", "_reconnect_event",
|
||||||
|
"_tools", "_error", "_config",
|
||||||
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
"_sampling", "_registered_tool_names", "_auth_type", "_refresh_lock",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -794,6 +795,12 @@ class MCPServerTask:
|
||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
self._ready = asyncio.Event()
|
self._ready = asyncio.Event()
|
||||||
self._shutdown_event = asyncio.Event()
|
self._shutdown_event = asyncio.Event()
|
||||||
|
# Set by tool handlers on auth failure after manager.handle_401()
|
||||||
|
# confirms recovery is viable. When set, _run_http / _run_stdio
|
||||||
|
# exit their async-with blocks cleanly (no exception), and the
|
||||||
|
# outer run() loop re-enters the transport so the MCP session is
|
||||||
|
# rebuilt with fresh credentials.
|
||||||
|
self._reconnect_event = asyncio.Event()
|
||||||
self._tools: list = []
|
self._tools: list = []
|
||||||
self._error: Optional[Exception] = None
|
self._error: Optional[Exception] = None
|
||||||
self._config: dict = {}
|
self._config: dict = {}
|
||||||
|
|
@ -887,6 +894,40 @@ class MCPServerTask:
|
||||||
self.name, len(self._registered_tool_names),
|
self.name, len(self._registered_tool_names),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _wait_for_lifecycle_event(self) -> str:
|
||||||
|
"""Block until either _shutdown_event or _reconnect_event fires.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
"shutdown" if the server should exit the run loop entirely.
|
||||||
|
"reconnect" if the server should tear down the current MCP
|
||||||
|
session and re-enter the transport (fresh OAuth
|
||||||
|
tokens, new session ID, etc.). The reconnect event
|
||||||
|
is cleared before return so the next cycle starts
|
||||||
|
with a fresh signal.
|
||||||
|
|
||||||
|
Shutdown takes precedence if both events are set simultaneously.
|
||||||
|
"""
|
||||||
|
shutdown_task = asyncio.create_task(self._shutdown_event.wait())
|
||||||
|
reconnect_task = asyncio.create_task(self._reconnect_event.wait())
|
||||||
|
try:
|
||||||
|
await asyncio.wait(
|
||||||
|
{shutdown_task, reconnect_task},
|
||||||
|
return_when=asyncio.FIRST_COMPLETED,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
for t in (shutdown_task, reconnect_task):
|
||||||
|
if not t.done():
|
||||||
|
t.cancel()
|
||||||
|
try:
|
||||||
|
await t
|
||||||
|
except (asyncio.CancelledError, Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._shutdown_event.is_set():
|
||||||
|
return "shutdown"
|
||||||
|
self._reconnect_event.clear()
|
||||||
|
return "reconnect"
|
||||||
|
|
||||||
async def _run_stdio(self, config: dict):
|
async def _run_stdio(self, config: dict):
|
||||||
"""Run the server using stdio transport."""
|
"""Run the server using stdio transport."""
|
||||||
command = config.get("command")
|
command = config.get("command")
|
||||||
|
|
@ -932,7 +973,10 @@ class MCPServerTask:
|
||||||
self.session = session
|
self.session = session
|
||||||
await self._discover_tools()
|
await self._discover_tools()
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
await self._shutdown_event.wait()
|
# stdio transport does not use OAuth, but we still honor
|
||||||
|
# _reconnect_event (e.g. future manual /mcp refresh) for
|
||||||
|
# consistency with _run_http.
|
||||||
|
await self._wait_for_lifecycle_event()
|
||||||
# Context exited cleanly — subprocess was terminated by the SDK.
|
# Context exited cleanly — subprocess was terminated by the SDK.
|
||||||
if new_pids:
|
if new_pids:
|
||||||
with _lock:
|
with _lock:
|
||||||
|
|
@ -951,16 +995,18 @@ class MCPServerTask:
|
||||||
headers = dict(config.get("headers") or {})
|
headers = dict(config.get("headers") or {})
|
||||||
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
connect_timeout = config.get("connect_timeout", _DEFAULT_CONNECT_TIMEOUT)
|
||||||
|
|
||||||
# OAuth 2.1 PKCE: build httpx.Auth handler using the MCP SDK.
|
# OAuth 2.1 PKCE: route through the central MCPOAuthManager so the
|
||||||
# If OAuth setup fails (e.g. non-interactive environment without
|
# same provider instance is reused across reconnects, pre-flow
|
||||||
# cached tokens), re-raise so this server is reported as failed
|
# disk-watch is active, and config-time CLI code paths share state.
|
||||||
# without blocking other MCP servers from connecting.
|
# If OAuth setup fails (e.g. non-interactive env without cached
|
||||||
|
# tokens), re-raise so this server is reported as failed without
|
||||||
|
# blocking other MCP servers from connecting.
|
||||||
_oauth_auth = None
|
_oauth_auth = None
|
||||||
if self._auth_type == "oauth":
|
if self._auth_type == "oauth":
|
||||||
try:
|
try:
|
||||||
from tools.mcp_oauth import build_oauth_auth
|
from tools.mcp_oauth_manager import get_manager
|
||||||
_oauth_auth = build_oauth_auth(
|
_oauth_auth = get_manager().get_or_build_provider(
|
||||||
self.name, url, config.get("oauth")
|
self.name, url, config.get("oauth"),
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||||
|
|
@ -995,7 +1041,12 @@ class MCPServerTask:
|
||||||
self.session = session
|
self.session = session
|
||||||
await self._discover_tools()
|
await self._discover_tools()
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
await self._shutdown_event.wait()
|
reason = await self._wait_for_lifecycle_event()
|
||||||
|
if reason == "reconnect":
|
||||||
|
logger.info(
|
||||||
|
"MCP server '%s': reconnect requested — "
|
||||||
|
"tearing down HTTP session", self.name,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Deprecated API (mcp < 1.24.0): manages httpx client internally.
|
# Deprecated API (mcp < 1.24.0): manages httpx client internally.
|
||||||
_http_kwargs: dict = {
|
_http_kwargs: dict = {
|
||||||
|
|
@ -1012,7 +1063,12 @@ class MCPServerTask:
|
||||||
self.session = session
|
self.session = session
|
||||||
await self._discover_tools()
|
await self._discover_tools()
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
await self._shutdown_event.wait()
|
reason = await self._wait_for_lifecycle_event()
|
||||||
|
if reason == "reconnect":
|
||||||
|
logger.info(
|
||||||
|
"MCP server '%s': reconnect requested — "
|
||||||
|
"tearing down legacy HTTP session", self.name,
|
||||||
|
)
|
||||||
|
|
||||||
async def _discover_tools(self):
|
async def _discover_tools(self):
|
||||||
"""Discover tools from the connected session."""
|
"""Discover tools from the connected session."""
|
||||||
|
|
@ -1060,8 +1116,25 @@ class MCPServerTask:
|
||||||
await self._run_http(config)
|
await self._run_http(config)
|
||||||
else:
|
else:
|
||||||
await self._run_stdio(config)
|
await self._run_stdio(config)
|
||||||
# Normal exit (shutdown requested) -- break out
|
# Transport returned cleanly. Two cases:
|
||||||
break
|
# - _shutdown_event was set: exit the run loop entirely.
|
||||||
|
# - _reconnect_event was set (auth recovery): loop back and
|
||||||
|
# rebuild the MCP session with fresh credentials. Do NOT
|
||||||
|
# touch the retry counters — this is not a failure.
|
||||||
|
if self._shutdown_event.is_set():
|
||||||
|
break
|
||||||
|
logger.info(
|
||||||
|
"MCP server '%s': reconnecting (OAuth recovery or "
|
||||||
|
"manual refresh)",
|
||||||
|
self.name,
|
||||||
|
)
|
||||||
|
# Reset the session reference; _run_http/_run_stdio will
|
||||||
|
# repopulate it on successful re-entry.
|
||||||
|
self.session = None
|
||||||
|
# Keep _ready set across reconnects so tool handlers can
|
||||||
|
# still detect a transient in-flight state — it'll be
|
||||||
|
# re-set after the fresh session initializes.
|
||||||
|
continue
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
self.session = None
|
self.session = None
|
||||||
|
|
||||||
|
|
@ -1141,6 +1214,12 @@ class MCPServerTask:
|
||||||
from tools.registry import registry
|
from tools.registry import registry
|
||||||
|
|
||||||
self._shutdown_event.set()
|
self._shutdown_event.set()
|
||||||
|
# Defensive: if _wait_for_lifecycle_event is blocking, we need ANY
|
||||||
|
# event to unblock it. _shutdown_event alone is sufficient (the
|
||||||
|
# helper checks shutdown first), but setting reconnect too ensures
|
||||||
|
# there's no race where the helper misses the shutdown flag after
|
||||||
|
# returning "reconnect".
|
||||||
|
self._reconnect_event.set()
|
||||||
if self._task and not self._task.done():
|
if self._task and not self._task.done():
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._task, timeout=10)
|
await asyncio.wait_for(self._task, timeout=10)
|
||||||
|
|
@ -1174,6 +1253,175 @@ _servers: Dict[str, MCPServerTask] = {}
|
||||||
_server_error_counts: Dict[str, int] = {}
|
_server_error_counts: Dict[str, int] = {}
|
||||||
_CIRCUIT_BREAKER_THRESHOLD = 3
|
_CIRCUIT_BREAKER_THRESHOLD = 3
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth-failure detection helpers (Task 6 of MCP OAuth consolidation)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Cached tuple of auth-related exception types. Lazy so this module
|
||||||
|
# imports cleanly when the MCP SDK OAuth module is missing.
|
||||||
|
_AUTH_ERROR_TYPES: tuple = ()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_auth_error_types() -> tuple:
|
||||||
|
"""Return a tuple of exception types that indicate MCP OAuth failure.
|
||||||
|
|
||||||
|
Cached after first call. Includes:
|
||||||
|
- ``mcp.client.auth.OAuthFlowError`` / ``OAuthTokenError`` — raised by
|
||||||
|
the SDK's auth flow when discovery, refresh, or full re-auth fails.
|
||||||
|
- ``mcp.client.auth.UnauthorizedError`` (older MCP SDKs) — kept as an
|
||||||
|
optional import for forward/backward compatibility.
|
||||||
|
- ``tools.mcp_oauth.OAuthNonInteractiveError`` — raised by our callback
|
||||||
|
handler when no user is present to complete a browser flow.
|
||||||
|
- ``httpx.HTTPStatusError`` — caller must additionally check
|
||||||
|
``status_code == 401`` via :func:`_is_auth_error`.
|
||||||
|
"""
|
||||||
|
global _AUTH_ERROR_TYPES
|
||||||
|
if _AUTH_ERROR_TYPES:
|
||||||
|
return _AUTH_ERROR_TYPES
|
||||||
|
types: list = []
|
||||||
|
try:
|
||||||
|
from mcp.client.auth import OAuthFlowError, OAuthTokenError
|
||||||
|
types.extend([OAuthFlowError, OAuthTokenError])
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
# Older MCP SDK variants exported this
|
||||||
|
from mcp.client.auth import UnauthorizedError # type: ignore
|
||||||
|
types.append(UnauthorizedError)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
from tools.mcp_oauth import OAuthNonInteractiveError
|
||||||
|
types.append(OAuthNonInteractiveError)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
types.append(httpx.HTTPStatusError)
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
_AUTH_ERROR_TYPES = tuple(types)
|
||||||
|
return _AUTH_ERROR_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def _is_auth_error(exc: BaseException) -> bool:
|
||||||
|
"""Return True if ``exc`` indicates an MCP OAuth failure.
|
||||||
|
|
||||||
|
``httpx.HTTPStatusError`` is only treated as auth-related when the
|
||||||
|
response status code is 401. Other HTTP errors fall through to the
|
||||||
|
generic error path in the tool handlers.
|
||||||
|
"""
|
||||||
|
types = _get_auth_error_types()
|
||||||
|
if not types or not isinstance(exc, types):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
import httpx
|
||||||
|
if isinstance(exc, httpx.HTTPStatusError):
|
||||||
|
return getattr(exc.response, "status_code", None) == 401
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_auth_error_and_retry(
|
||||||
|
server_name: str,
|
||||||
|
exc: BaseException,
|
||||||
|
retry_call,
|
||||||
|
op_description: str,
|
||||||
|
):
|
||||||
|
"""Attempt auth recovery and one retry; return None to fall through.
|
||||||
|
|
||||||
|
Called by the 5 MCP tool handlers when ``session.<op>()`` raises an
|
||||||
|
auth-related exception. Workflow:
|
||||||
|
|
||||||
|
1. Ask :class:`tools.mcp_oauth_manager.MCPOAuthManager.handle_401` if
|
||||||
|
recovery is viable (i.e., disk has fresh tokens, or the SDK can
|
||||||
|
refresh in-place).
|
||||||
|
2. If yes, set the server's ``_reconnect_event`` so the server task
|
||||||
|
tears down the current MCP session and rebuilds it with fresh
|
||||||
|
credentials. Wait briefly for ``_ready`` to re-fire.
|
||||||
|
3. Retry the operation once. Return the retry result if it produced
|
||||||
|
a non-error JSON payload. Otherwise return the ``needs_reauth``
|
||||||
|
error dict so the model stops hallucinating manual refresh.
|
||||||
|
4. Return None if ``exc`` is not an auth error, signalling the
|
||||||
|
caller to use the generic error path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_name: Name of the MCP server that raised.
|
||||||
|
exc: The exception from the failed tool call.
|
||||||
|
retry_call: Zero-arg callable that re-runs the tool call, returning
|
||||||
|
the same JSON string format as the handler.
|
||||||
|
op_description: Human-readable name of the operation (for logs).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A JSON string if auth recovery was attempted, or None to fall
|
||||||
|
through to the caller's generic error path.
|
||||||
|
"""
|
||||||
|
if not _is_auth_error(exc):
|
||||||
|
return None
|
||||||
|
|
||||||
|
from tools.mcp_oauth_manager import get_manager
|
||||||
|
manager = get_manager()
|
||||||
|
|
||||||
|
async def _recover():
|
||||||
|
return await manager.handle_401(server_name, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
recovered = _run_on_mcp_loop(_recover(), timeout=10)
|
||||||
|
except Exception as rec_exc:
|
||||||
|
logger.warning(
|
||||||
|
"MCP OAuth '%s': recovery attempt failed: %s",
|
||||||
|
server_name, rec_exc,
|
||||||
|
)
|
||||||
|
recovered = False
|
||||||
|
|
||||||
|
if recovered:
|
||||||
|
with _lock:
|
||||||
|
srv = _servers.get(server_name)
|
||||||
|
if srv is not None and hasattr(srv, "_reconnect_event"):
|
||||||
|
loop = _mcp_loop
|
||||||
|
if loop is not None and loop.is_running():
|
||||||
|
loop.call_soon_threadsafe(srv._reconnect_event.set)
|
||||||
|
# Wait briefly for the session to come back ready. Bounded
|
||||||
|
# so that a stuck reconnect falls through to the error
|
||||||
|
# path rather than hanging the caller.
|
||||||
|
deadline = time.monotonic() + 15
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
if srv.session is not None and srv._ready.is_set():
|
||||||
|
break
|
||||||
|
time.sleep(0.25)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = retry_call()
|
||||||
|
try:
|
||||||
|
parsed = json.loads(result)
|
||||||
|
if "error" not in parsed:
|
||||||
|
_server_error_counts[server_name] = 0
|
||||||
|
return result
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
_server_error_counts[server_name] = 0
|
||||||
|
return result
|
||||||
|
except Exception as retry_exc:
|
||||||
|
logger.warning(
|
||||||
|
"MCP %s/%s retry after auth recovery failed: %s",
|
||||||
|
server_name, op_description, retry_exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# No recovery available, or retry also failed: surface a structured
|
||||||
|
# needs_reauth error. Bumps the circuit breaker so the model stops
|
||||||
|
# retrying the tool.
|
||||||
|
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||||
|
return json.dumps({
|
||||||
|
"error": (
|
||||||
|
f"MCP server '{server_name}' requires re-authentication. "
|
||||||
|
f"Run `hermes mcp login {server_name}` (or delete the tokens "
|
||||||
|
f"file under ~/.hermes/mcp-tokens/ and restart). Do NOT retry "
|
||||||
|
f"this tool — ask the user to re-authenticate."
|
||||||
|
),
|
||||||
|
"needs_reauth": True,
|
||||||
|
"server": server_name,
|
||||||
|
}, ensure_ascii=False)
|
||||||
|
|
||||||
# Dedicated event loop running in a background daemon thread.
|
# Dedicated event loop running in a background daemon thread.
|
||||||
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
_mcp_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
_mcp_thread: Optional[threading.Thread] = None
|
_mcp_thread: Optional[threading.Thread] = None
|
||||||
|
|
@ -1420,8 +1668,11 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||||
return json.dumps({"result": structured}, ensure_ascii=False)
|
return json.dumps({"result": structured}, ensure_ascii=False)
|
||||||
return json.dumps({"result": text_result}, ensure_ascii=False)
|
return json.dumps({"result": text_result}, ensure_ascii=False)
|
||||||
|
|
||||||
|
def _call_once():
|
||||||
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
result = _call_once()
|
||||||
# Check if the MCP tool itself returned an error
|
# Check if the MCP tool itself returned an error
|
||||||
try:
|
try:
|
||||||
parsed = json.loads(result)
|
parsed = json.loads(result)
|
||||||
|
|
@ -1435,6 +1686,16 @@ def _make_tool_handler(server_name: str, tool_name: str, tool_timeout: float):
|
||||||
except InterruptedError:
|
except InterruptedError:
|
||||||
return _interrupted_call_result()
|
return _interrupted_call_result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
# Auth-specific recovery path: consult the manager, signal
|
||||||
|
# reconnect if viable, retry once. Returns None to fall
|
||||||
|
# through for non-auth exceptions.
|
||||||
|
recovered = _handle_auth_error_and_retry(
|
||||||
|
server_name, exc, _call_once,
|
||||||
|
f"tools/call {tool_name}",
|
||||||
|
)
|
||||||
|
if recovered is not None:
|
||||||
|
return recovered
|
||||||
|
|
||||||
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
_server_error_counts[server_name] = _server_error_counts.get(server_name, 0) + 1
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP tool %s/%s call failed: %s",
|
"MCP tool %s/%s call failed: %s",
|
||||||
|
|
@ -1476,11 +1737,19 @@ def _make_list_resources_handler(server_name: str, tool_timeout: float):
|
||||||
resources.append(entry)
|
resources.append(entry)
|
||||||
return json.dumps({"resources": resources}, ensure_ascii=False)
|
return json.dumps({"resources": resources}, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
def _call_once():
|
||||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _call_once()
|
||||||
except InterruptedError:
|
except InterruptedError:
|
||||||
return _interrupted_call_result()
|
return _interrupted_call_result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
recovered = _handle_auth_error_and_retry(
|
||||||
|
server_name, exc, _call_once, "resources/list",
|
||||||
|
)
|
||||||
|
if recovered is not None:
|
||||||
|
return recovered
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP %s/list_resources failed: %s", server_name, exc,
|
"MCP %s/list_resources failed: %s", server_name, exc,
|
||||||
)
|
)
|
||||||
|
|
@ -1522,11 +1791,19 @@ def _make_read_resource_handler(server_name: str, tool_timeout: float):
|
||||||
parts.append(f"[binary data, {len(block.blob)} bytes]")
|
parts.append(f"[binary data, {len(block.blob)} bytes]")
|
||||||
return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False)
|
return json.dumps({"result": "\n".join(parts) if parts else ""}, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
def _call_once():
|
||||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _call_once()
|
||||||
except InterruptedError:
|
except InterruptedError:
|
||||||
return _interrupted_call_result()
|
return _interrupted_call_result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
recovered = _handle_auth_error_and_retry(
|
||||||
|
server_name, exc, _call_once, "resources/read",
|
||||||
|
)
|
||||||
|
if recovered is not None:
|
||||||
|
return recovered
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP %s/read_resource failed: %s", server_name, exc,
|
"MCP %s/read_resource failed: %s", server_name, exc,
|
||||||
)
|
)
|
||||||
|
|
@ -1571,11 +1848,19 @@ def _make_list_prompts_handler(server_name: str, tool_timeout: float):
|
||||||
prompts.append(entry)
|
prompts.append(entry)
|
||||||
return json.dumps({"prompts": prompts}, ensure_ascii=False)
|
return json.dumps({"prompts": prompts}, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
def _call_once():
|
||||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _call_once()
|
||||||
except InterruptedError:
|
except InterruptedError:
|
||||||
return _interrupted_call_result()
|
return _interrupted_call_result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
recovered = _handle_auth_error_and_retry(
|
||||||
|
server_name, exc, _call_once, "prompts/list",
|
||||||
|
)
|
||||||
|
if recovered is not None:
|
||||||
|
return recovered
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP %s/list_prompts failed: %s", server_name, exc,
|
"MCP %s/list_prompts failed: %s", server_name, exc,
|
||||||
)
|
)
|
||||||
|
|
@ -1628,11 +1913,19 @@ def _make_get_prompt_handler(server_name: str, tool_timeout: float):
|
||||||
resp["description"] = result.description
|
resp["description"] = result.description
|
||||||
return json.dumps(resp, ensure_ascii=False)
|
return json.dumps(resp, ensure_ascii=False)
|
||||||
|
|
||||||
try:
|
def _call_once():
|
||||||
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
return _run_on_mcp_loop(_call(), timeout=tool_timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _call_once()
|
||||||
except InterruptedError:
|
except InterruptedError:
|
||||||
return _interrupted_call_result()
|
return _interrupted_call_result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
recovered = _handle_auth_error_and_retry(
|
||||||
|
server_name, exc, _call_once, "prompts/get",
|
||||||
|
)
|
||||||
|
if recovered is not None:
|
||||||
|
return recovered
|
||||||
logger.error(
|
logger.error(
|
||||||
"MCP %s/get_prompt failed: %s", server_name, exc,
|
"MCP %s/get_prompt failed: %s", server_name, exc,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue