mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-13 03:52:00 +00:00
fix(mcp-oauth): persist OAuth server metadata across process restarts (#21226)
The MCP SDK discovers OAuth server metadata (token_endpoint, etc.) on
demand and keeps it in memory only. Without disk persistence, a restart
with valid cached refresh tokens forces the SDK to fall back to the
guessed '{server_url}/token' path — which returns 404 on most real
providers (Notion, Atlassian, GitHub remote MCP, etc.) and triggers a
full browser re-authorization even though the refresh token is fine.
Add a .meta.json file next to the existing tokens/client_info files:
HERMES_HOME/mcp-tokens/<server>.json -- tokens (existing)
HERMES_HOME/mcp-tokens/<server>.client.json -- client info (existing)
HERMES_HOME/mcp-tokens/<server>.meta.json -- oauth metadata (new)
Changes:
- HermesTokenStorage.save_oauth_metadata / load_oauth_metadata / _meta_path
— disk layer for the discovered OAuthMetadata.
- HermesTokenStorage.remove() now also clears .meta.json so
'hermes mcp remove <name>' and the manager's remove() path clean up fully.
- HermesMCPOAuthProvider._initialize cold-restores from disk before the
existing pre-flight discovery runs. If disk has metadata we skip the
discovery HTTP round-trips entirely.
- HermesMCPOAuthProvider._prefetch_oauth_metadata now persists ASM as
soon as it's discovered, so even the first pre-flight run seeds disk.
- HermesMCPOAuthProvider._persist_oauth_metadata_if_changed() is called
at the end of async_auth_flow so metadata discovered via the SDK's
lazy 401-branch (not pre-flight) is also saved for next time.
Tests cover the storage roundtrip (save/load/missing/corrupt/remove) and
the manager provider path (cold-load restore, skip-when-in-memory,
persist-on-discover, noop-when-unchanged, end-to-end async_auth_flow).
Co-authored-by: nocturnum91 <50326054+nocturnum91@users.noreply.github.com>
This commit is contained in:
parent
3c439ec681
commit
c4a7992317
4 changed files with 293 additions and 1 deletions
213
tests/tools/test_mcp_oauth_metadata.py
Normal file
213
tests/tools/test_mcp_oauth_metadata.py
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
"""Tests for OAuth server metadata persistence across process restarts.
|
||||
|
||||
Covers:
|
||||
- :class:`HermesTokenStorage` ``.meta.json`` roundtrip (save / load / remove)
|
||||
- The production manager provider
|
||||
(:class:`tools.mcp_oauth_manager.HermesMCPOAuthProvider`) restoring metadata
|
||||
on cold-load init and persisting metadata at the end of ``async_auth_flow``.
|
||||
|
||||
Context
|
||||
=======
|
||||
The MCP SDK discovers OAuth server metadata (``token_endpoint``, etc.)
|
||||
on-demand and keeps it in memory only. Without disk persistence a restart
|
||||
forces the SDK to fall back to guessing ``{server_url}/token``, which returns
|
||||
404 on most real providers and triggers a full browser re-auth even when the
|
||||
refresh token is still valid. These tests lock in the disk persistence
|
||||
layer so refresh across restarts stays quiet.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from mcp.shared.auth import OAuthMetadata
|
||||
|
||||
from tools.mcp_oauth import HermesTokenStorage
|
||||
from tools.mcp_oauth_manager import _HERMES_PROVIDER_CLS
|
||||
|
||||
|
||||
def _make_metadata(token_endpoint: str = "https://auth.example.com/oauth/token") -> OAuthMetadata:
|
||||
return OAuthMetadata.model_validate(
|
||||
{
|
||||
"issuer": "https://auth.example.com",
|
||||
"authorization_endpoint": "https://auth.example.com/oauth/authorize",
|
||||
"token_endpoint": token_endpoint,
|
||||
"response_types_supported": ["code"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HermesTokenStorage metadata roundtrip
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMetadataStorage:
|
||||
def test_save_and_load_roundtrip(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("example-server")
|
||||
|
||||
meta = _make_metadata()
|
||||
storage.save_oauth_metadata(meta)
|
||||
|
||||
meta_path = tmp_path / "mcp-tokens" / "example-server.meta.json"
|
||||
assert meta_path.exists()
|
||||
|
||||
loaded = storage.load_oauth_metadata()
|
||||
assert loaded is not None
|
||||
assert str(loaded.token_endpoint) == "https://auth.example.com/oauth/token"
|
||||
assert str(loaded.issuer).rstrip("/") == "https://auth.example.com"
|
||||
|
||||
def test_load_missing_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("nonexistent")
|
||||
assert storage.load_oauth_metadata() is None
|
||||
|
||||
def test_load_corrupt_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("corrupt-server")
|
||||
|
||||
# Write something that doesn't validate as OAuthMetadata
|
||||
meta_path = storage._meta_path()
|
||||
meta_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
meta_path.write_text(json.dumps({"issuer": "not-a-url", "wrong_field": 123}))
|
||||
|
||||
assert storage.load_oauth_metadata() is None
|
||||
|
||||
def test_remove_deletes_meta_file(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("cleanup-server")
|
||||
|
||||
storage.save_oauth_metadata(_make_metadata())
|
||||
assert storage._meta_path().exists()
|
||||
|
||||
storage.remove()
|
||||
assert not storage._meta_path().exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Manager-path provider (HermesMCPOAuthProvider) — production code path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _manager_provider_with_context(storage: HermesTokenStorage, **context_attrs):
|
||||
"""Build an uninitialized manager provider with a mocked context.
|
||||
|
||||
Bypasses the full OAuthClientProvider init so we can exercise the
|
||||
override logic in isolation.
|
||||
"""
|
||||
if _HERMES_PROVIDER_CLS is None:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
provider = _HERMES_PROVIDER_CLS.__new__(_HERMES_PROVIDER_CLS)
|
||||
provider._hermes_server_name = context_attrs.get("server_name", "srv")
|
||||
context = MagicMock()
|
||||
context.storage = storage
|
||||
context.oauth_metadata = context_attrs.get("oauth_metadata")
|
||||
context.current_tokens = context_attrs.get("current_tokens")
|
||||
context.server_url = context_attrs.get("server_url", "https://example.com")
|
||||
context.update_token_expiry = MagicMock()
|
||||
provider.context = context
|
||||
return provider
|
||||
|
||||
|
||||
class TestManagerOAuthProviderMetadata:
|
||||
def test_initialize_restores_metadata_from_disk(self, tmp_path, monkeypatch):
|
||||
"""Cold-load: if we have no in-memory metadata but disk has some, restore it."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("mgr-srv")
|
||||
storage.save_oauth_metadata(_make_metadata("https://mgr.example.com/token"))
|
||||
provider = _manager_provider_with_context(storage, oauth_metadata=None)
|
||||
|
||||
with patch.object(
|
||||
_HERMES_PROVIDER_CLS.__bases__[0], "_initialize", new=AsyncMock()
|
||||
):
|
||||
asyncio.run(provider._initialize())
|
||||
|
||||
assert provider.context.oauth_metadata is not None
|
||||
assert str(provider.context.oauth_metadata.token_endpoint) == \
|
||||
"https://mgr.example.com/token"
|
||||
|
||||
def test_initialize_skips_restore_when_in_memory_present(self, tmp_path, monkeypatch):
|
||||
"""If SDK already has metadata in memory, don't overwrite from disk."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("mgr-srv2")
|
||||
storage.save_oauth_metadata(_make_metadata("https://disk.example.com/token"))
|
||||
in_memory = _make_metadata("https://memory.example.com/token")
|
||||
|
||||
provider = _manager_provider_with_context(storage, oauth_metadata=in_memory)
|
||||
|
||||
with patch.object(
|
||||
_HERMES_PROVIDER_CLS.__bases__[0], "_initialize", new=AsyncMock()
|
||||
):
|
||||
asyncio.run(provider._initialize())
|
||||
|
||||
assert str(provider.context.oauth_metadata.token_endpoint) == \
|
||||
"https://memory.example.com/token"
|
||||
|
||||
def test_persist_metadata_if_changed_writes_on_first_discover(self, tmp_path, monkeypatch):
|
||||
"""When nothing on disk yet, persist what the SDK discovered in-memory."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("persist-srv")
|
||||
assert storage.load_oauth_metadata() is None
|
||||
|
||||
discovered = _make_metadata("https://discovered.example.com/token")
|
||||
provider = _manager_provider_with_context(storage, oauth_metadata=discovered)
|
||||
|
||||
provider._persist_oauth_metadata_if_changed()
|
||||
|
||||
loaded = storage.load_oauth_metadata()
|
||||
assert loaded is not None
|
||||
assert str(loaded.token_endpoint) == "https://discovered.example.com/token"
|
||||
|
||||
def test_persist_metadata_noop_when_unchanged(self, tmp_path, monkeypatch):
|
||||
"""No-op write when disk already matches in-memory metadata."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("noop-srv")
|
||||
meta = _make_metadata("https://same.example.com/token")
|
||||
storage.save_oauth_metadata(meta)
|
||||
|
||||
provider = _manager_provider_with_context(storage, oauth_metadata=meta)
|
||||
|
||||
with patch.object(
|
||||
HermesTokenStorage, "save_oauth_metadata"
|
||||
) as save_spy:
|
||||
provider._persist_oauth_metadata_if_changed()
|
||||
save_spy.assert_not_called()
|
||||
|
||||
def test_async_auth_flow_persists_on_completion(self, tmp_path, monkeypatch):
|
||||
"""End-to-end: running the wrapped auth_flow persists discovered metadata."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("flow-srv")
|
||||
provider = _manager_provider_with_context(
|
||||
storage,
|
||||
oauth_metadata=_make_metadata("https://flow.example.com/token"),
|
||||
server_name="flow-srv",
|
||||
)
|
||||
|
||||
async def fake_parent_flow(self, request):
|
||||
if False:
|
||||
yield # pragma: no cover -- make this an async generator
|
||||
return
|
||||
|
||||
manager = MagicMock()
|
||||
manager.invalidate_if_disk_changed = AsyncMock(return_value=False)
|
||||
|
||||
with patch.object(
|
||||
_HERMES_PROVIDER_CLS.__bases__[0],
|
||||
"async_auth_flow",
|
||||
new=fake_parent_flow,
|
||||
), patch("tools.mcp_oauth_manager.get_manager", return_value=manager):
|
||||
async def drive():
|
||||
gen = provider.async_auth_flow(MagicMock())
|
||||
async for _ in gen:
|
||||
pass
|
||||
|
||||
asyncio.run(drive())
|
||||
|
||||
loaded = storage.load_oauth_metadata()
|
||||
assert loaded is not None
|
||||
assert str(loaded.token_endpoint) == "https://flow.example.com/token"
|
||||
Loading…
Add table
Add a link
Reference in a new issue