diff --git a/scripts/release.py b/scripts/release.py index 634f0171bf..8b7023741d 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -53,6 +53,7 @@ AUTHOR_MAP = { "cleo@edaphic.xyz": "curiouscleo", "127238744+teknium1@users.noreply.github.com": "teknium1", "128259593+Gutslabs@users.noreply.github.com": "Gutslabs", + "50326054+nocturnum91@users.noreply.github.com": "nocturnum91", "159539633+MottledShadow@users.noreply.github.com": "MottledShadow", "aludwin+gh@gmail.com": "adamludwin", "ngusev@astralinux.ru": "NikolayGusev-astra", diff --git a/tests/tools/test_mcp_oauth_metadata.py b/tests/tools/test_mcp_oauth_metadata.py new file mode 100644 index 0000000000..5d161075e6 --- /dev/null +++ b/tests/tools/test_mcp_oauth_metadata.py @@ -0,0 +1,213 @@ +"""Tests for OAuth server metadata persistence across process restarts. + +Covers: +- :class:`HermesTokenStorage` ``.meta.json`` roundtrip (save / load / remove) +- The production manager provider + (:class:`tools.mcp_oauth_manager.HermesMCPOAuthProvider`) restoring metadata + on cold-load init and persisting metadata at the end of ``async_auth_flow``. + +Context +======= +The MCP SDK discovers OAuth server metadata (``token_endpoint``, etc.) +on-demand and keeps it in memory only. Without disk persistence a restart +forces the SDK to fall back to guessing ``{server_url}/token``, which returns +404 on most real providers and triggers a full browser re-auth even when the +refresh token is still valid. These tests lock in the disk persistence +layer so refresh across restarts stays quiet. +""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from mcp.shared.auth import OAuthMetadata + +from tools.mcp_oauth import HermesTokenStorage +from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS + + +def _make_metadata(token_endpoint: str = "https://auth.example.com/oauth/token") -> OAuthMetadata: + return OAuthMetadata.model_validate( + { + "issuer": "https://auth.example.com", + "authorization_endpoint": "https://auth.example.com/oauth/authorize", + "token_endpoint": token_endpoint, + "response_types_supported": ["code"], + } + ) + + +# --------------------------------------------------------------------------- +# HermesTokenStorage metadata roundtrip +# --------------------------------------------------------------------------- + + +class TestMetadataStorage: + def test_save_and_load_roundtrip(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("example-server") + + meta = _make_metadata() + storage.save_oauth_metadata(meta) + + meta_path = tmp_path / "mcp-tokens" / "example-server.meta.json" + assert meta_path.exists() + + loaded = storage.load_oauth_metadata() + assert loaded is not None + assert str(loaded.token_endpoint) == "https://auth.example.com/oauth/token" + assert str(loaded.issuer).rstrip("/") == "https://auth.example.com" + + def test_load_missing_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("nonexistent") + assert storage.load_oauth_metadata() is None + + def test_load_corrupt_returns_none(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("corrupt-server") + + # Write something that doesn't validate as OAuthMetadata + meta_path = storage._meta_path() + meta_path.parent.mkdir(parents=True, exist_ok=True) + meta_path.write_text(json.dumps({"issuer": "not-a-url", "wrong_field": 123})) + + assert storage.load_oauth_metadata() is None + + def test_remove_deletes_meta_file(self, tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("cleanup-server") + + storage.save_oauth_metadata(_make_metadata()) + assert storage._meta_path().exists() + + storage.remove() + assert not storage._meta_path().exists() + + +# --------------------------------------------------------------------------- +# Manager-path provider (HermesMCPOAuthProvider) — production code path +# --------------------------------------------------------------------------- + + +def _manager_provider_with_context(storage: HermesTokenStorage, **context_attrs): + """Build an uninitialized manager provider with a mocked context. + + Bypasses the full OAuthClientProvider init so we can exercise the + override logic in isolation. + """ + if _HERMES_PROVIDER_CLS is None: + pytest.skip("MCP SDK auth not available") + provider = _HERMES_PROVIDER_CLS.__new__(_HERMES_PROVIDER_CLS) + provider._hermes_server_name = context_attrs.get("server_name", "srv") + context = MagicMock() + context.storage = storage + context.oauth_metadata = context_attrs.get("oauth_metadata") + context.current_tokens = context_attrs.get("current_tokens") + context.server_url = context_attrs.get("server_url", "https://example.com") + context.update_token_expiry = MagicMock() + provider.context = context + return provider + + +class TestManagerOAuthProviderMetadata: + def test_initialize_restores_metadata_from_disk(self, tmp_path, monkeypatch): + """Cold-load: if we have no in-memory metadata but disk has some, restore it.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("mgr-srv") + storage.save_oauth_metadata(_make_metadata("https://mgr.example.com/token")) + provider = _manager_provider_with_context(storage, oauth_metadata=None) + + with patch.object( + _HERMES_PROVIDER_CLS.__bases__[0], "_initialize", new=AsyncMock() + ): + asyncio.run(provider._initialize()) + + assert provider.context.oauth_metadata is not None + assert str(provider.context.oauth_metadata.token_endpoint) == \ + "https://mgr.example.com/token" + + def test_initialize_skips_restore_when_in_memory_present(self, tmp_path, monkeypatch): + """If SDK already has metadata in memory, don't overwrite from disk.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("mgr-srv2") + storage.save_oauth_metadata(_make_metadata("https://disk.example.com/token")) + in_memory = _make_metadata("https://memory.example.com/token") + + provider = _manager_provider_with_context(storage, oauth_metadata=in_memory) + + with patch.object( + _HERMES_PROVIDER_CLS.__bases__[0], "_initialize", new=AsyncMock() + ): + asyncio.run(provider._initialize()) + + assert str(provider.context.oauth_metadata.token_endpoint) == \ + "https://memory.example.com/token" + + def test_persist_metadata_if_changed_writes_on_first_discover(self, tmp_path, monkeypatch): + """When nothing on disk yet, persist what the SDK discovered in-memory.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("persist-srv") + assert storage.load_oauth_metadata() is None + + discovered = _make_metadata("https://discovered.example.com/token") + provider = _manager_provider_with_context(storage, oauth_metadata=discovered) + + provider._persist_oauth_metadata_if_changed() + + loaded = storage.load_oauth_metadata() + assert loaded is not None + assert str(loaded.token_endpoint) == "https://discovered.example.com/token" + + def test_persist_metadata_noop_when_unchanged(self, tmp_path, monkeypatch): + """No-op write when disk already matches in-memory metadata.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("noop-srv") + meta = _make_metadata("https://same.example.com/token") + storage.save_oauth_metadata(meta) + + provider = _manager_provider_with_context(storage, oauth_metadata=meta) + + with patch.object( + HermesTokenStorage, "save_oauth_metadata" + ) as save_spy: + provider._persist_oauth_metadata_if_changed() + save_spy.assert_not_called() + + def test_async_auth_flow_persists_on_completion(self, tmp_path, monkeypatch): + """End-to-end: running the wrapped auth_flow persists discovered metadata.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + storage = HermesTokenStorage("flow-srv") + provider = _manager_provider_with_context( + storage, + oauth_metadata=_make_metadata("https://flow.example.com/token"), + server_name="flow-srv", + ) + + async def fake_parent_flow(self, request): + if False: + yield # pragma: no cover -- make this an async generator + return + + manager = MagicMock() + manager.invalidate_if_disk_changed = AsyncMock(return_value=False) + + with patch.object( + _HERMES_PROVIDER_CLS.__bases__[0], + "async_auth_flow", + new=fake_parent_flow, + ), patch("tools.mcp_oauth_manager.get_manager", return_value=manager): + async def drive(): + gen = provider.async_auth_flow(MagicMock()) + async for _ in gen: + pass + + asyncio.run(drive()) + + loaded = storage.load_oauth_metadata() + assert loaded is not None + assert str(loaded.token_endpoint) == "https://flow.example.com/token" diff --git a/tools/mcp_oauth.py b/tools/mcp_oauth.py index f40f98f32a..d7bf135da4 100644 --- a/tools/mcp_oauth.py +++ b/tools/mcp_oauth.py @@ -61,6 +61,7 @@ try: from mcp.shared.auth import ( OAuthClientInformationFull, OAuthClientMetadata, + OAuthMetadata, OAuthToken, ) @@ -212,6 +213,7 @@ class HermesTokenStorage: HERMES_HOME/mcp-tokens/.json -- tokens HERMES_HOME/mcp-tokens/.client.json -- client info + HERMES_HOME/mcp-tokens/.meta.json -- oauth server metadata """ def __init__(self, server_name: str): @@ -223,6 +225,9 @@ class HermesTokenStorage: def _client_info_path(self) -> Path: return _get_token_dir() / f"{self._server_name}.client.json" + def _meta_path(self) -> Path: + return _get_token_dir() / f"{self._server_name}.meta.json" + # -- tokens ------------------------------------------------------------ async def get_tokens(self) -> "OAuthToken | None": @@ -300,11 +305,33 @@ class HermesTokenStorage: _write_json(self._client_info_path(), client_info.model_dump(mode="json", exclude_none=True)) logger.debug("OAuth client info saved for %s", self._server_name) + # -- oauth server metadata -------------------------------------------- + # The MCP SDK keeps discovered ``OAuthMetadata`` (token endpoint URL, + # etc.) in memory only. Persisting it here lets a restarted process + # refresh tokens without re-running metadata discovery. Without this, + # cold-start refresh requests fall back to the SDK's guessed + # ``{server_url}/token`` which returns 404 on most real providers and + # forces a full browser re-authorization. + + def save_oauth_metadata(self, metadata: "OAuthMetadata") -> None: + _write_json(self._meta_path(), metadata.model_dump(exclude_none=True, mode="json")) + logger.debug("OAuth metadata saved for %s", self._server_name) + + def load_oauth_metadata(self) -> "OAuthMetadata | None": + data = _read_json(self._meta_path()) + if data is None: + return None + try: + return OAuthMetadata.model_validate(data) + except (ValueError, TypeError, KeyError) as exc: + logger.warning("Corrupt OAuth metadata at %s -- ignoring: %s", self._meta_path(), exc) + return None + # -- cleanup ----------------------------------------------------------- def remove(self) -> None: """Delete all stored OAuth state for this server.""" - for p in (self._tokens_path(), self._client_info_path()): + for p in (self._tokens_path(), self._client_info_path(), self._meta_path()): p.unlink(missing_ok=True) def has_cached_tokens(self) -> bool: diff --git a/tools/mcp_oauth_manager.py b/tools/mcp_oauth_manager.py index dbe2fc3e06..6a4573a867 100644 --- a/tools/mcp_oauth_manager.py +++ b/tools/mcp_oauth_manager.py @@ -148,6 +148,27 @@ def _make_hermes_provider_class() -> Optional[type]: if tokens is not None and tokens.expires_in is not None: self.context.update_token_expiry(tokens) + # Cold-load: restore OAuth server metadata from disk before any + # refresh attempt. Without this, a restarted process with cached + # tokens but no in-memory metadata would fall back to the SDK's + # guessed ``{server_url}/token`` path (returns 404 on most real + # providers) and require a full browser re-authorization. + storage = self.context.storage + from tools.mcp_oauth import HermesTokenStorage + if ( + isinstance(storage, HermesTokenStorage) + and self.context.oauth_metadata is None + ): + meta = storage.load_oauth_metadata() + if meta is not None: + self.context.oauth_metadata = meta + logger.debug( + "MCP OAuth '%s': restored metadata from disk " + "(token_endpoint=%s)", + self._hermes_server_name, + meta.token_endpoint, + ) + # Pre-flight OAuth AS discovery so ``_refresh_token`` has a # correct ``token_endpoint`` before the first refresh attempt. # Only runs when we have tokens on cold-load but no cached @@ -229,6 +250,12 @@ def _make_hermes_provider_class() -> Optional[type]: break if asm: self.context.oauth_metadata = asm + # Persist immediately so a subsequent cold-load can + # skip discovery entirely. + storage = self.context.storage + from tools.mcp_oauth import HermesTokenStorage + if isinstance(storage, HermesTokenStorage): + storage.save_oauth_metadata(asm) logger.debug( "MCP OAuth '%s': pre-flight ASM discovered " "token_endpoint=%s", @@ -236,6 +263,27 @@ def _make_hermes_provider_class() -> Optional[type]: ) break + def _persist_oauth_metadata_if_changed(self) -> None: + """Persist discovered OAuth metadata for future process restarts. + + Called after the SDK's normal 401-branch auth flow completes so + metadata discovered via the lazy path (not pre-flight) is also + saved. No-op when nothing to persist or metadata hasn't changed. + """ + meta = self.context.oauth_metadata + if meta is None: + return + storage = self.context.storage + from tools.mcp_oauth import HermesTokenStorage + if not isinstance(storage, HermesTokenStorage): + return + existing = storage.load_oauth_metadata() + if ( + existing is None + or str(existing.token_endpoint) != str(meta.token_endpoint) + ): + storage.save_oauth_metadata(meta) + async def async_auth_flow(self, request): # type: ignore[override] # Pre-flow hook: ask the manager to refresh from disk if needed. # Any failure here is non-fatal — we just log and proceed with @@ -271,6 +319,9 @@ def _make_hermes_provider_class() -> Optional[type]: incoming = yield outgoing outgoing = await inner.asend(incoming) except StopAsyncIteration: + # Persist any metadata the SDK discovered lazily during the + # 401 branch so a subsequent cold-load skips discovery. + self._persist_oauth_metadata_if_changed() return return HermesMCPOAuthProvider