mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-27 11:22:03 +00:00
fix(mcp): auto-recover from invalid_client on stale OAuth client registration
Fixes #36767. Two complementary recoveries for the recurring "delete three cache files and re-auth by hand" ritual when an MCP server's dynamically-registered OAuth client goes dead server-side (IdP redeploy / DB wipe / rebrand): - Auto-heal (token-endpoint subset): HermesMCPOAuthProvider now sniffs auth-flow responses and, on a 400/401 `invalid_client` from the discovered token endpoint, backs up + deletes `<server>.client.json` and `.meta.json` and clears the in-memory client so the SDK re-runs RFC 7591 dynamic client registration on the next flow. Conservative by construction: only dynamically-registered (non config-supplied) clients, only the token endpoint, only on a word-boundary `invalid_client` match (so RFC 7591's `invalid_client_metadata` does not trip it); best-effort so a miss never breaks the live flow. Covers both code-exchange and refresh when the token endpoint was discovered. Tokens are preserved. - `hermes mcp reauth [<name>|--all]`: the reporter's primary symptom — the IdP's in-browser "Redirect URI Mismatch" — produces no HTTP signal (the SDK only sees a callback timeout), so it cannot be auto-detected. The new command re-auths one or ALL `auth: oauth` servers, serially: one browser flow at a time, which also fixes the startup popup storm when several servers are stale at once. Single-server reauth is factored out of `mcp login` and shared. Tests: +14 (poison helper x2; token-endpoint detection x5 incl. wrong-endpoint, success-response, pre-registered, and invalid_client_metadata negative guards; a bridge integration test driving the real async_auth_flow generator to prove the detection hook preserves the bidirectional asend() forwarding contract; reauth CLI x6). Verified against the pinned mcp==1.26.0: scripts/run_tests.sh 122/122 green for the touched suites; check-windows-footguns.py and ruff clean. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6e4e5967f7
commit
075f93ad78
7 changed files with 558 additions and 27 deletions
|
|
@ -665,44 +665,28 @@ def cmd_mcp_test(args):
|
|||
|
||||
# ─── hermes mcp login ────────────────────────────────────────────────────────
|
||||
|
||||
def cmd_mcp_login(args):
|
||||
"""Force re-authentication for an OAuth-based MCP server.
|
||||
def _reauth_oauth_server(name: str, server_config: dict) -> bool:
|
||||
"""Force a fresh OAuth flow for one server. Returns True on success.
|
||||
|
||||
Deletes cached tokens (both on disk and in the running process's
|
||||
MCPOAuthManager cache) and triggers a fresh OAuth flow via the
|
||||
existing probe path.
|
||||
|
||||
Use this when:
|
||||
- Tokens are stuck in a bad state (server revoked, refresh token
|
||||
consumed by an external process, etc.)
|
||||
- You want to re-authenticate to change scopes or account
|
||||
- A tool call returned ``needs_reauth: true``
|
||||
Wipes cached OAuth state (disk + in-process MCPOAuthManager cache),
|
||||
re-probes to trigger the browser flow, and verifies a token actually
|
||||
landed before reporting success. Shared by ``hermes mcp login`` and
|
||||
``hermes mcp reauth`` so both behave identically for a single server.
|
||||
"""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
server_config = servers[name]
|
||||
url = server_config.get("url")
|
||||
if not url:
|
||||
_error(f"Server '{name}' has no URL — not an OAuth-capable server")
|
||||
return
|
||||
return False
|
||||
if server_config.get("auth") != "oauth":
|
||||
_error(f"Server '{name}' is not configured for OAuth (auth={server_config.get('auth')})")
|
||||
_info("Use `hermes mcp remove` + `hermes mcp add` to reconfigure auth.")
|
||||
return
|
||||
return False
|
||||
|
||||
# Wipe both disk and in-memory cache so the next probe forces a fresh
|
||||
# OAuth flow.
|
||||
try:
|
||||
from tools.mcp_oauth_manager import get_manager
|
||||
mgr = get_manager()
|
||||
mgr.remove(name)
|
||||
get_manager().remove(name)
|
||||
except Exception as exc:
|
||||
_warning(f"Could not clear existing OAuth state: {exc}")
|
||||
|
||||
|
|
@ -741,13 +725,90 @@ def cmd_mcp_login(args):
|
|||
print(color(f" client_secret: \"<your-oauth-client-secret>\"", Colors.DIM))
|
||||
print()
|
||||
_info("Then re-run `hermes mcp login " + name + "`.")
|
||||
return
|
||||
return False
|
||||
if tools:
|
||||
_success(f"Authenticated — {len(tools)} tool(s) available")
|
||||
else:
|
||||
_success("Authenticated (server reported no tools)")
|
||||
return True
|
||||
except Exception as exc:
|
||||
_error(f"Authentication failed: {exc}")
|
||||
return False
|
||||
|
||||
|
||||
def cmd_mcp_login(args):
|
||||
"""Force re-authentication for an OAuth-based MCP server.
|
||||
|
||||
Deletes cached tokens (both on disk and in the running process's
|
||||
MCPOAuthManager cache) and triggers a fresh OAuth flow via the
|
||||
existing probe path.
|
||||
|
||||
Use this when:
|
||||
- Tokens are stuck in a bad state (server revoked, refresh token
|
||||
consumed by an external process, etc.)
|
||||
- You want to re-authenticate to change scopes or account
|
||||
- A tool call returned ``needs_reauth: true``
|
||||
"""
|
||||
name = args.name
|
||||
servers = _get_mcp_servers()
|
||||
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
_reauth_oauth_server(name, servers[name])
|
||||
|
||||
|
||||
def cmd_mcp_reauth(args):
|
||||
"""Re-authenticate one OAuth MCP server, or all of them sequentially.
|
||||
|
||||
``hermes mcp reauth <name>`` re-auths a single server (same as ``login``).
|
||||
``hermes mcp reauth --all`` discovers every ``auth: oauth`` server in
|
||||
config and re-auths them ONE AT A TIME.
|
||||
|
||||
Serial-by-design: a human can only complete one browser OAuth flow at a
|
||||
time, so re-authing all servers concurrently would open N tabs at once
|
||||
and N-1 would time out. This is the self-service fix for the recurring
|
||||
stale-client ritual in GH#36767 (and avoids the startup popup storm when
|
||||
several servers go stale at once).
|
||||
"""
|
||||
servers = _get_mcp_servers()
|
||||
do_all = getattr(args, "all", False)
|
||||
name = getattr(args, "name", None)
|
||||
|
||||
if do_all:
|
||||
oauth_servers = [
|
||||
(n, c) for n, c in servers.items()
|
||||
if c.get("auth") == "oauth" and c.get("url")
|
||||
]
|
||||
if not oauth_servers:
|
||||
_info("No OAuth-based MCP servers found in config.")
|
||||
return
|
||||
print()
|
||||
_info(f"Re-authenticating {len(oauth_servers)} OAuth server(s) one at a time...")
|
||||
succeeded = 0
|
||||
for n, c in oauth_servers:
|
||||
print()
|
||||
print(color(f" ── {n} ──", Colors.CYAN + Colors.BOLD))
|
||||
if _reauth_oauth_server(n, c):
|
||||
succeeded += 1
|
||||
print()
|
||||
_success(f"Re-authenticated {succeeded}/{len(oauth_servers)} server(s)")
|
||||
return
|
||||
|
||||
if not name:
|
||||
_error("Specify a server name, or use --all to re-auth every OAuth server.")
|
||||
_info("Usage: hermes mcp reauth <name> | hermes mcp reauth --all")
|
||||
return
|
||||
if name not in servers:
|
||||
_error(f"Server '{name}' not found in config.")
|
||||
if servers:
|
||||
_info(f"Available servers: {', '.join(servers)}")
|
||||
return
|
||||
|
||||
_reauth_oauth_server(name, servers[name])
|
||||
|
||||
|
||||
# ─── hermes mcp configure ────────────────────────────────────────────────────
|
||||
|
|
@ -888,6 +949,7 @@ def mcp_command(args):
|
|||
"configure": cmd_mcp_configure,
|
||||
"config": cmd_mcp_configure,
|
||||
"login": cmd_mcp_login,
|
||||
"reauth": cmd_mcp_reauth,
|
||||
}
|
||||
|
||||
handler = handlers.get(action)
|
||||
|
|
@ -911,4 +973,5 @@ def mcp_command(args):
|
|||
_info("hermes mcp test <name> Test connection")
|
||||
_info("hermes mcp configure <name> Toggle tools")
|
||||
_info("hermes mcp login <name> Re-authenticate OAuth")
|
||||
_info("hermes mcp reauth <name> | --all Re-auth one or all OAuth servers")
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -86,6 +86,19 @@ def build_mcp_parser(subparsers, *, cmd_mcp: Callable) -> None:
|
|||
)
|
||||
mcp_login_p.add_argument("name", help="Server name to re-authenticate")
|
||||
|
||||
mcp_reauth_p = mcp_sub.add_parser(
|
||||
"reauth",
|
||||
help="Re-authenticate one OAuth MCP server, or all of them (--all)",
|
||||
)
|
||||
mcp_reauth_p.add_argument(
|
||||
"name", nargs="?", help="Server name to re-authenticate (omit with --all)"
|
||||
)
|
||||
mcp_reauth_p.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
help="Re-authenticate every OAuth server in config, one at a time",
|
||||
)
|
||||
|
||||
# ── Catalog (Nous-approved MCPs shipped with the repo) ─────────────────
|
||||
mcp_sub.add_parser(
|
||||
"picker",
|
||||
|
|
|
|||
|
|
@ -754,3 +754,98 @@ class TestMcpLogin:
|
|||
|
||||
assert "Authenticated — 3 tool(s) available" in out
|
||||
assert "no OAuth token" not in out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests: cmd_mcp_reauth (GH#36767)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMcpReauth:
|
||||
def test_reauth_all_visits_only_oauth_servers_in_order(
|
||||
self, tmp_path, capsys, monkeypatch
|
||||
):
|
||||
"""--all re-auths every oauth server (skipping non-oauth), serially."""
|
||||
_seed_config(tmp_path, {
|
||||
"gh": {"url": "https://gh.example.com/mcp", "auth": "oauth"},
|
||||
"jira": {"url": "https://jira.example.com/mcp", "auth": "oauth"},
|
||||
"localstdio": {"command": "foo"}, # no url / no oauth → skipped
|
||||
"apikey": {"url": "https://k.example.com/mcp", "headers": {"x": "y"}},
|
||||
})
|
||||
visited = []
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._reauth_oauth_server",
|
||||
lambda name, cfg: visited.append(name) or True,
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name=None, all=True))
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert visited == ["gh", "jira"]
|
||||
assert "Re-authenticated 2/2 server(s)" in out
|
||||
|
||||
def test_reauth_all_reports_partial_failures(self, tmp_path, capsys, monkeypatch):
|
||||
"""A server that fails to re-auth is counted but doesn't abort the rest."""
|
||||
_seed_config(tmp_path, {
|
||||
"a": {"url": "https://a.example.com/mcp", "auth": "oauth"},
|
||||
"b": {"url": "https://b.example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._reauth_oauth_server",
|
||||
lambda name, cfg: name == "a", # only 'a' succeeds
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name=None, all=True))
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert "Re-authenticated 1/2 server(s)" in out
|
||||
|
||||
def test_reauth_all_no_oauth_servers(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {"localstdio": {"command": "foo"}})
|
||||
called = []
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._reauth_oauth_server",
|
||||
lambda name, cfg: called.append(name) or True,
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name=None, all=True))
|
||||
out = capsys.readouterr().out
|
||||
|
||||
assert "No OAuth-based MCP servers found" in out
|
||||
assert called == []
|
||||
|
||||
def test_reauth_single_server(self, tmp_path, capsys, monkeypatch):
|
||||
_seed_config(tmp_path, {
|
||||
"gh": {"url": "https://gh.example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
visited = []
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.mcp_config._reauth_oauth_server",
|
||||
lambda name, cfg: visited.append(name) or True,
|
||||
)
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name="gh", all=False))
|
||||
assert visited == ["gh"]
|
||||
|
||||
def test_reauth_requires_name_or_all(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {
|
||||
"gh": {"url": "https://gh.example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name=None, all=False))
|
||||
out = capsys.readouterr().out
|
||||
assert "Specify a server name" in out
|
||||
|
||||
def test_reauth_unknown_server(self, tmp_path, capsys):
|
||||
_seed_config(tmp_path, {
|
||||
"gh": {"url": "https://gh.example.com/mcp", "auth": "oauth"},
|
||||
})
|
||||
from hermes_cli.mcp_config import cmd_mcp_reauth
|
||||
|
||||
cmd_mcp_reauth(_make_args(name="ghost", all=False))
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
|
|
|||
|
|
@ -827,3 +827,34 @@ class TestWaitForCallbackSkipIntegration:
|
|||
asyncio.run(_wait_for_callback())
|
||||
err = capsys.readouterr().err
|
||||
assert "skip" in err.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# poison_client_registration (GH#36767)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPoisonClientRegistration:
|
||||
def test_poison_backs_up_and_removes_client_and_meta(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("srv")
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.json").write_text('{"access_token": "keep-me"}')
|
||||
(d / "srv.client.json").write_text('{"client_id": "dead"}')
|
||||
(d / "srv.meta.json").write_text('{"token_endpoint": "https://idp/token"}')
|
||||
|
||||
removed = storage.poison_client_registration()
|
||||
|
||||
assert removed is True
|
||||
# Client + metadata gone, forcing re-registration on the next flow.
|
||||
assert not (d / "srv.client.json").exists()
|
||||
assert not (d / "srv.meta.json").exists()
|
||||
# Backup of the client file kept for recovery.
|
||||
assert (d / "srv.client.json.bak").read_text() == '{"client_id": "dead"}'
|
||||
# Tokens are intentionally preserved.
|
||||
assert (d / "srv.json").read_text() == '{"access_token": "keep-me"}'
|
||||
|
||||
def test_poison_noop_when_no_client_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("srv")
|
||||
assert storage.poison_client_registration() is False
|
||||
|
|
|
|||
|
|
@ -164,3 +164,192 @@ def test_manager_fails_fast_noninteractive_without_cached_tokens(tmp_path, monke
|
|||
mgr.get_or_build_provider("linear", "https://mcp.linear.app/mcp", None)
|
||||
|
||||
assert mgr._entries["linear"].provider is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# invalid_client auto-heal (GH#36767) — _maybe_flag_poisoned_client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def _fake_response(status, url, body):
|
||||
"""A minimal stand-in for the httpx.Response the SDK feeds our bridge."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status
|
||||
resp.request = SimpleNamespace(url=url)
|
||||
|
||||
async def _aread():
|
||||
return body
|
||||
|
||||
resp.aread = _aread
|
||||
return resp
|
||||
|
||||
|
||||
def _provider_with_token_endpoint(tmp_path, oauth_config, token_endpoint):
|
||||
from tools.mcp_oauth_manager import MCPOAuthManager, reset_manager_for_tests
|
||||
reset_manager_for_tests()
|
||||
mgr = MCPOAuthManager()
|
||||
provider = mgr.get_or_build_provider("srv", "https://mcp.example.com", oauth_config)
|
||||
provider.context.oauth_metadata = SimpleNamespace(token_endpoint=token_endpoint)
|
||||
provider._initialized = True
|
||||
return provider
|
||||
|
||||
|
||||
def test_invalid_client_at_token_endpoint_poisons(tmp_path, monkeypatch):
|
||||
"""400 invalid_client on the token endpoint deletes the dead client.json."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.client.json").write_text('{"client_id": "dead"}')
|
||||
(d / "srv.meta.json").write_text("{}")
|
||||
provider = _provider_with_token_endpoint(
|
||||
tmp_path, {}, "https://idp.example.com/oauth/token"
|
||||
)
|
||||
resp = _fake_response(
|
||||
400, "https://idp.example.com/oauth/token", b'{"error":"invalid_client"}'
|
||||
)
|
||||
|
||||
asyncio.run(provider._maybe_flag_poisoned_client(resp))
|
||||
|
||||
assert not (d / "srv.client.json").exists()
|
||||
assert (d / "srv.client.json.bak").exists()
|
||||
assert provider._initialized is False
|
||||
assert provider.context.client_info is None
|
||||
|
||||
|
||||
def test_invalid_client_at_other_endpoint_is_ignored(tmp_path, monkeypatch):
|
||||
"""An invalid_client body from a non-token endpoint must not poison."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.client.json").write_text('{"client_id": "live"}')
|
||||
provider = _provider_with_token_endpoint(
|
||||
tmp_path, {}, "https://idp.example.com/oauth/token"
|
||||
)
|
||||
resp = _fake_response(
|
||||
400, "https://mcp.example.com/messages", b'{"error":"invalid_client"}'
|
||||
)
|
||||
|
||||
asyncio.run(provider._maybe_flag_poisoned_client(resp))
|
||||
|
||||
assert (d / "srv.client.json").exists()
|
||||
assert provider._initialized is True
|
||||
|
||||
|
||||
def test_success_response_is_ignored(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.client.json").write_text('{"client_id": "live"}')
|
||||
provider = _provider_with_token_endpoint(
|
||||
tmp_path, {}, "https://idp.example.com/oauth/token"
|
||||
)
|
||||
resp = _fake_response(
|
||||
200, "https://idp.example.com/oauth/token", b'{"access_token":"x"}'
|
||||
)
|
||||
|
||||
asyncio.run(provider._maybe_flag_poisoned_client(resp))
|
||||
|
||||
assert (d / "srv.client.json").exists()
|
||||
assert provider._initialized is True
|
||||
|
||||
|
||||
def test_preregistered_client_is_never_poisoned(tmp_path, monkeypatch):
|
||||
"""A config-supplied client_id is never auto-deleted (re-reg can't help)."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
provider = _provider_with_token_endpoint(
|
||||
tmp_path, {"client_id": "from-config"}, "https://idp.example.com/oauth/token"
|
||||
)
|
||||
d = tmp_path / "mcp-tokens"
|
||||
# _maybe_preregister_client wrote client.json from config during build.
|
||||
assert (d / "srv.client.json").exists()
|
||||
resp = _fake_response(
|
||||
400, "https://idp.example.com/oauth/token", b'{"error":"invalid_client"}'
|
||||
)
|
||||
|
||||
asyncio.run(provider._maybe_flag_poisoned_client(resp))
|
||||
|
||||
assert (d / "srv.client.json").exists()
|
||||
assert provider._initialized is True
|
||||
|
||||
|
||||
def test_invalid_client_metadata_does_not_trip(tmp_path, monkeypatch):
|
||||
"""RFC 7591 `invalid_client_metadata` must NOT be mistaken for invalid_client."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.client.json").write_text('{"client_id": "live"}')
|
||||
provider = _provider_with_token_endpoint(
|
||||
tmp_path, {}, "https://idp.example.com/oauth/token"
|
||||
)
|
||||
resp = _fake_response(
|
||||
400, "https://idp.example.com/oauth/token", b'{"error":"invalid_client_metadata"}'
|
||||
)
|
||||
|
||||
asyncio.run(provider._maybe_flag_poisoned_client(resp))
|
||||
|
||||
assert (d / "srv.client.json").exists()
|
||||
assert provider._initialized is True
|
||||
|
||||
|
||||
class _FakeMeta:
|
||||
"""Metadata stub usable by both detection and the post-flow persist hook."""
|
||||
|
||||
def __init__(self, token_endpoint):
|
||||
self.token_endpoint = token_endpoint
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
return {"token_endpoint": self.token_endpoint}
|
||||
|
||||
|
||||
def test_bridge_forwards_requests_and_poisons_on_token_endpoint_400(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
"""Drive the REAL async_auth_flow bridge to prove the inserted detection
|
||||
hook does not break the bidirectional asend() forwarding contract — the
|
||||
genuinely fragile part. A patched SDK base generator stands in for the
|
||||
real OAuth flow so we control exactly which response the bridge sees.
|
||||
"""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
token_ep = "https://idp.example.com/oauth/token"
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "srv.client.json").write_text('{"client_id": "dead"}')
|
||||
|
||||
forwarded = []
|
||||
|
||||
async def fake_base_flow(self, request):
|
||||
# Mimic the SDK: yield the request, receive the response, then finish.
|
||||
forwarded.append(("out", request))
|
||||
response = yield request
|
||||
forwarded.append(("in", response))
|
||||
|
||||
from mcp.client.auth.oauth2 import OAuthClientProvider
|
||||
monkeypatch.setattr(OAuthClientProvider, "async_auth_flow", fake_base_flow)
|
||||
|
||||
provider = _provider_with_token_endpoint(tmp_path, {}, token_ep)
|
||||
provider.context.oauth_metadata = _FakeMeta(token_ep)
|
||||
|
||||
sentinel_request = object()
|
||||
poison_resp = _fake_response(400, token_ep, b'{"error":"invalid_client"}')
|
||||
|
||||
async def drive():
|
||||
gen = provider.async_auth_flow(sentinel_request)
|
||||
out0 = await gen.__anext__()
|
||||
assert out0 is sentinel_request # request forwarded unchanged
|
||||
try:
|
||||
await gen.asend(poison_resp)
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
# The poison response reached the inner generator (forwarding intact)...
|
||||
assert ("in", poison_resp) in forwarded
|
||||
# ...and the detection hook fired.
|
||||
assert not (d / "srv.client.json").exists()
|
||||
assert provider._initialized is False
|
||||
assert provider.context.client_info is None
|
||||
|
|
|
|||
|
|
@ -343,6 +343,41 @@ class HermesTokenStorage:
|
|||
for p in (self._tokens_path(), self._client_info_path(), self._meta_path()):
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
def poison_client_registration(self) -> bool:
|
||||
"""Discard a dead dynamically-registered client so it gets re-created.
|
||||
|
||||
Called when the IdP rejects our cached ``client_id`` with
|
||||
``invalid_client`` on the token endpoint — proof the server-side
|
||||
registration is gone (IdP redeploy / DB wipe / rebrand). Deleting
|
||||
``client.json`` makes the MCP SDK's ``async_auth_flow`` take the
|
||||
``if not client_info`` branch and re-run RFC 7591 dynamic client
|
||||
registration on the next flow. The stale ``meta.json`` is dropped
|
||||
too so discovery re-runs against a freshly fetched document.
|
||||
|
||||
Tokens are intentionally left in place — the subsequent
|
||||
re-authorization overwrites them, and keeping them avoids losing a
|
||||
still-valid refresh token if the re-registration never completes.
|
||||
|
||||
A single ``.bak`` copy of the client file is kept for recovery.
|
||||
Returns True if a client file was present and removed.
|
||||
"""
|
||||
client_path = self._client_info_path()
|
||||
if not client_path.exists():
|
||||
return False
|
||||
backup = client_path.with_name(client_path.name + ".bak")
|
||||
try:
|
||||
backup.write_bytes(client_path.read_bytes())
|
||||
except OSError as exc: # non-fatal — proceed with the removal anyway
|
||||
logger.warning("Could not back up client info at %s: %s", client_path, exc)
|
||||
client_path.unlink(missing_ok=True)
|
||||
self._meta_path().unlink(missing_ok=True)
|
||||
logger.warning(
|
||||
"MCP OAuth '%s': cached client registration rejected as invalid_client; "
|
||||
"removed client.json + meta.json (backup at %s) to force re-registration",
|
||||
self._server_name, backup.name,
|
||||
)
|
||||
return True
|
||||
|
||||
def has_cached_tokens(self) -> bool:
|
||||
"""Return True if we have tokens on disk (may be expired)."""
|
||||
return self._tokens_path().exists()
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional
|
||||
|
|
@ -43,6 +44,26 @@ from typing import Any, Optional
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _same_endpoint(a: str, b: str) -> bool:
|
||||
"""Return True if two URLs target the same endpoint (ignoring query/fragment).
|
||||
|
||||
Compares scheme, host (case-insensitive), and path. Used to confirm a
|
||||
rejected response actually came from the OAuth token endpoint before we
|
||||
act on an ``invalid_client`` body.
|
||||
"""
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
try:
|
||||
pa, pb = urlsplit(a), urlsplit(b)
|
||||
except ValueError: # pragma: no cover — malformed URL
|
||||
return False
|
||||
return (
|
||||
pa.scheme == pb.scheme
|
||||
and pa.netloc.lower() == pb.netloc.lower()
|
||||
and pa.path.rstrip("/") == pb.path.rstrip("/")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-server entry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -107,9 +128,21 @@ def _make_hermes_provider_class() -> Optional[type]:
|
|||
(``src/utils/auth.ts:1320``, CC-1096 / GH#24317).
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, server_name: str = "", **kwargs: Any):
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
server_name: str = "",
|
||||
preregistered: bool = False,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hermes_server_name = server_name
|
||||
# When the client_id comes from config.yaml (pre-registered), an
|
||||
# invalid_client rejection means the *config* is wrong — deleting
|
||||
# client.json would just be re-seeded from config and re-running
|
||||
# registration can't help. Only auto-heal dynamically-registered
|
||||
# clients. See _maybe_flag_poisoned_client.
|
||||
self._hermes_preregistered = preregistered
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens + client info AND seed token_expiry_time.
|
||||
|
|
@ -284,6 +317,74 @@ def _make_hermes_provider_class() -> Optional[type]:
|
|||
):
|
||||
storage.save_oauth_metadata(meta)
|
||||
|
||||
async def _maybe_flag_poisoned_client(self, response: Any) -> None:
|
||||
"""Detect a dead client registration and force re-registration.
|
||||
|
||||
When the IdP rejects our ``client_id`` with ``invalid_client`` on
|
||||
the token endpoint (token exchange or refresh), the cached client
|
||||
registration is provably dead server-side. We delete ``client.json``
|
||||
(+ stale metadata) so the SDK's next ``async_auth_flow`` takes the
|
||||
``if not client_info`` branch and re-runs RFC 7591 dynamic client
|
||||
registration. This addresses the recurring manual-reset ritual in
|
||||
GH#36767 for the auto-detectable subset (token-endpoint rejection);
|
||||
the browser-side "Redirect URI Mismatch" case has no HTTP signal
|
||||
and is handled by ``hermes mcp reauth``.
|
||||
|
||||
Conservative by construction — acts ONLY when all hold:
|
||||
* status is 400/401,
|
||||
* the request hit the discovered ``token_endpoint`` (the only
|
||||
request carrying our ``client_id``), and
|
||||
* the body carries the ``invalid_client`` error code
|
||||
(word-boundary match, so RFC 7591's ``invalid_client_metadata``
|
||||
registration error does not trip it).
|
||||
Pre-registered (config-supplied) clients are never poisoned.
|
||||
Fully best-effort: any failure here is swallowed so a detection
|
||||
miss never breaks the live auth flow.
|
||||
|
||||
Covers both the authorization-code token exchange and the
|
||||
preemptive refresh — but only when ``token_endpoint`` was
|
||||
discovered (``_initialize`` prefetches it on cold-load). If that
|
||||
discovery was skipped, the guard returns early and the user falls
|
||||
back to ``hermes mcp reauth``.
|
||||
"""
|
||||
try:
|
||||
if self._hermes_preregistered:
|
||||
return
|
||||
status = getattr(response, "status_code", None)
|
||||
if status not in (400, 401):
|
||||
return
|
||||
meta = getattr(self.context, "oauth_metadata", None)
|
||||
token_endpoint = (
|
||||
str(meta.token_endpoint)
|
||||
if meta is not None and getattr(meta, "token_endpoint", None)
|
||||
else None
|
||||
)
|
||||
req = getattr(response, "request", None)
|
||||
req_url = str(req.url) if req is not None else None
|
||||
if not token_endpoint or not req_url:
|
||||
return
|
||||
if not _same_endpoint(req_url, token_endpoint):
|
||||
return
|
||||
body = await response.aread()
|
||||
# Word-boundary match: matches `"error":"invalid_client"` but
|
||||
# not the RFC 7591 registration error `invalid_client_metadata`
|
||||
# (the trailing `_metadata` removes the right-hand boundary).
|
||||
if not re.search(rb"\binvalid_client\b", body.lower()):
|
||||
return
|
||||
|
||||
storage = self.context.storage
|
||||
from tools.mcp_oauth import HermesTokenStorage
|
||||
if isinstance(storage, HermesTokenStorage):
|
||||
storage.poison_client_registration()
|
||||
# Drop the in-memory client so the SDK re-registers next flow.
|
||||
self.context.client_info = None
|
||||
self._initialized = False
|
||||
except Exception as exc: # pragma: no cover — defensive, must not throw
|
||||
logger.debug(
|
||||
"MCP OAuth '%s': invalid_client detection failed (non-fatal): %s",
|
||||
self._hermes_server_name, exc,
|
||||
)
|
||||
|
||||
async def async_auth_flow(self, request): # type: ignore[override]
|
||||
# Pre-flow hook: ask the manager to refresh from disk if needed.
|
||||
# Any failure here is non-fatal — we just log and proceed with
|
||||
|
|
@ -317,6 +418,9 @@ def _make_hermes_provider_class() -> Optional[type]:
|
|||
outgoing = await inner.__anext__()
|
||||
while True:
|
||||
incoming = yield outgoing
|
||||
# Sniff the response for a dead-client-registration signal
|
||||
# before handing it back to the SDK (best-effort, GH#36767).
|
||||
await self._maybe_flag_poisoned_client(incoming)
|
||||
outgoing = await inner.asend(incoming)
|
||||
except StopAsyncIteration:
|
||||
# Persist any metadata the SDK discovered lazily during the
|
||||
|
|
@ -439,6 +543,7 @@ class MCPOAuthManager:
|
|||
|
||||
return _HERMES_PROVIDER_CLS(
|
||||
server_name=server_name,
|
||||
preregistered=bool(cfg.get("client_id")),
|
||||
server_url=entry.server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue