mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-05 02:31:47 +00:00
feat: implement MCP OAuth 2.1 PKCE client support (#5420)
Implement tools/mcp_oauth.py — the OAuth adapter that mcp_tool.py's
existing auth: oauth hook has been waiting for.
Components:
- HermesTokenStorage: persists tokens + client registration to
HERMES_HOME/mcp-tokens/<server>.json with 0o600 permissions
- Callback handler factory: per-flow isolated HTTP handlers (safe for
concurrent OAuth flows across multiple MCP servers)
- OAuthClientProvider integration: wraps the MCP SDK's httpx.Auth
subclass which handles discovery, DCR, PKCE, token exchange,
refresh, and step-up auth (403 insufficient_scope) automatically
- Non-interactive detection: warns when gateway/cron environments
try to OAuth without cached tokens
- Pre-registered client support: injects client_id/secret from config
for servers that don't support Dynamic Client Registration (e.g. Slack)
- Path traversal protection on server names
- remove_oauth_tokens() for cleanup
Config format:
mcp_servers:
sentry:
url: 'https://mcp.sentry.dev/mcp'
auth: oauth
oauth: # all optional
client_id: '...' # skip DCR
client_secret: '...' # confidential client
scope: 'read write' # server-provided by default
Also passes oauth config dict through from mcp_tool.py (was passing
only server_name and url before).
E2E verified: full OAuth flow (401 → discovery → DCR → authorize →
token exchange → authenticated request → tokens persisted) against
local test servers. 23 unit tests + 186 MCP suite tests pass.
This commit is contained in:
parent
3962bc84b7
commit
38d8446011
3 changed files with 547 additions and 293 deletions
|
|
@ -1,7 +1,8 @@
|
|||
"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
|
||||
"""Tests for tools/mcp_oauth.py — OAuth 2.1 PKCE support for MCP servers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ from tools.mcp_oauth import (
|
|||
_can_open_browser,
|
||||
_is_interactive,
|
||||
_wait_for_callback,
|
||||
_make_callback_handler,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -79,34 +81,93 @@ class TestHermesTokenStorage:
|
|||
assert not (d / "test-server.json").exists()
|
||||
assert not (d / "test-server.client.json").exists()
|
||||
|
||||
def test_has_cached_tokens(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("my-server")
|
||||
|
||||
assert not storage.has_cached_tokens()
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "my-server.json").write_text('{"access_token": "x", "token_type": "Bearer"}')
|
||||
|
||||
assert storage.has_cached_tokens()
|
||||
|
||||
def test_corrupt_tokens_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("bad-server")
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "bad-server.json").write_text("NOT VALID JSON{{{")
|
||||
|
||||
import asyncio
|
||||
assert asyncio.run(storage.get_tokens()) is None
|
||||
|
||||
def test_corrupt_client_info_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("bad-server")
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "bad-server.client.json").write_text("GARBAGE")
|
||||
|
||||
import asyncio
|
||||
assert asyncio.run(storage.get_client_info()) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_oauth_auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildOAuthAuth:
|
||||
def test_returns_oauth_provider(self):
|
||||
def test_returns_oauth_provider(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
auth = build_oauth_auth("test", "https://example.com/mcp")
|
||||
assert isinstance(auth, OAuthClientProvider)
|
||||
|
||||
def test_returns_none_without_sdk(self, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
monkeypatch.setattr(mod, "_OAUTH_AVAILABLE", False)
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
assert result is None
|
||||
|
||||
def _block_import(name, *args, **kwargs):
|
||||
if "mcp.client.auth" in name:
|
||||
raise ImportError("blocked")
|
||||
return orig_import(name, *args, **kwargs)
|
||||
def test_pre_registered_client_id_stored(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
with patch("builtins.__import__", side_effect=_block_import):
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
# May or may not be None depending on import caching, but shouldn't crash
|
||||
assert result is None or result is not None
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
build_oauth_auth("slack", "https://slack.example.com/mcp", {
|
||||
"client_id": "my-app-id",
|
||||
"client_secret": "my-secret",
|
||||
"scope": "channels:read",
|
||||
})
|
||||
|
||||
client_path = tmp_path / "mcp-tokens" / "slack.client.json"
|
||||
assert client_path.exists()
|
||||
data = json.loads(client_path.read_text())
|
||||
assert data["client_id"] == "my-app-id"
|
||||
assert data["client_secret"] == "my-secret"
|
||||
|
||||
def test_scope_passed_through(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
provider = build_oauth_auth("scoped", "https://example.com/mcp", {
|
||||
"scope": "read write admin",
|
||||
})
|
||||
assert provider is not None
|
||||
assert provider.context.client_metadata.scope == "read write admin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -119,6 +180,12 @@ class TestUtilities:
|
|||
assert isinstance(port, int)
|
||||
assert 1024 <= port <= 65535
|
||||
|
||||
def test_find_free_port_unique(self):
|
||||
"""Two consecutive calls should return different ports (usually)."""
|
||||
ports = {_find_free_port() for _ in range(5)}
|
||||
# At least 2 different ports out of 5 attempts
|
||||
assert len(ports) >= 2
|
||||
|
||||
def test_can_open_browser_false_in_ssh(self, monkeypatch):
|
||||
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
|
||||
assert _can_open_browser() is False
|
||||
|
|
@ -127,14 +194,22 @@ class TestUtilities:
|
|||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("DISPLAY", raising=False)
|
||||
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
||||
# Mock os.name and uname for non-macOS, non-Windows
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
|
||||
assert _can_open_browser() is False
|
||||
|
||||
def test_can_open_browser_true_with_display(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.setenv("DISPLAY", ":0")
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
assert _can_open_browser() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# Path traversal protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPathTraversal:
|
||||
|
|
@ -169,11 +244,14 @@ class TestPathTraversal:
|
|||
assert "/" not in path.stem
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Callback handler isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCallbackHandlerIsolation:
|
||||
"""Verify concurrent OAuth flows don't share state."""
|
||||
|
||||
def test_independent_result_dicts(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
_, result_a = _make_callback_handler()
|
||||
_, result_b = _make_callback_handler()
|
||||
|
||||
|
|
@ -184,10 +262,6 @@ class TestCallbackHandlerIsolation:
|
|||
assert result_b["auth_code"] == "code_B"
|
||||
|
||||
def test_handler_writes_to_own_result(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
assert result["auth_code"] is None
|
||||
|
||||
|
|
@ -203,13 +277,30 @@ class TestCallbackHandlerIsolation:
|
|||
assert result["auth_code"] == "test123"
|
||||
assert result["state"] == "mystate"
|
||||
|
||||
def test_handler_captures_error(self):
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
|
||||
handler = HandlerClass.__new__(HandlerClass)
|
||||
handler.path = "/callback?error=access_denied"
|
||||
handler.wfile = BytesIO()
|
||||
handler.send_response = MagicMock()
|
||||
handler.send_header = MagicMock()
|
||||
handler.end_headers = MagicMock()
|
||||
handler.do_GET()
|
||||
|
||||
assert result["auth_code"] is None
|
||||
assert result["error"] == "access_denied"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Port sharing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOAuthPortSharing:
|
||||
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
|
||||
|
||||
def test_port_stored_globally(self):
|
||||
def test_port_stored_globally(self, tmp_path, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
# Reset
|
||||
mod._oauth_port = None
|
||||
|
||||
try:
|
||||
|
|
@ -217,12 +308,17 @@ class TestOAuthPortSharing:
|
|||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
build_oauth_auth("test-port", "https://example.com/mcp")
|
||||
assert mod._oauth_port is not None
|
||||
assert isinstance(mod._oauth_port, int)
|
||||
assert 1024 <= mod._oauth_port <= 65535
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRemoveOAuthTokens:
|
||||
def test_removes_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
@ -242,7 +338,7 @@ class TestRemoveOAuthTokens:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-interactive / startup-safety tests (issue #4462)
|
||||
# Non-interactive / startup-safety tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsInteractive:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue