mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat: implement MCP OAuth 2.1 PKCE client support (#5420)
Implement tools/mcp_oauth.py — the OAuth adapter that mcp_tool.py's
existing auth: oauth hook has been waiting for.
Components:
- HermesTokenStorage: persists tokens + client registration to
HERMES_HOME/mcp-tokens/<server>.json with 0o600 permissions
- Callback handler factory: per-flow isolated HTTP handlers (safe for
concurrent OAuth flows across multiple MCP servers)
- OAuthClientProvider integration: wraps the MCP SDK's httpx.Auth
subclass which handles discovery, DCR, PKCE, token exchange,
refresh, and step-up auth (403 insufficient_scope) automatically
- Non-interactive detection: warns when gateway/cron environments
try to OAuth without cached tokens
- Pre-registered client support: injects client_id/secret from config
for servers that don't support Dynamic Client Registration (e.g. Slack)
- Path traversal protection on server names
- remove_oauth_tokens() for cleanup
Config format:
mcp_servers:
sentry:
url: 'https://mcp.sentry.dev/mcp'
auth: oauth
oauth: # all optional
client_id: '...' # skip DCR
client_secret: '...' # confidential client
scope: 'read write' # server-provided by default
Also passes oauth config dict through from mcp_tool.py (was passing
only server_name and url before).
E2E verified: full OAuth flow (401 → discovery → DCR → authorize →
token exchange → authenticated request → tokens persisted) against
local test servers. 23 unit tests + 186 MCP suite tests pass.
This commit is contained in:
parent
3962bc84b7
commit
38d8446011
3 changed files with 547 additions and 293 deletions
|
|
@ -1,7 +1,8 @@
|
|||
"""Tests for tools/mcp_oauth.py — thin OAuth adapter over MCP SDK."""
|
||||
"""Tests for tools/mcp_oauth.py — OAuth 2.1 PKCE support for MCP servers."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ from tools.mcp_oauth import (
|
|||
_can_open_browser,
|
||||
_is_interactive,
|
||||
_wait_for_callback,
|
||||
_make_callback_handler,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -79,34 +81,93 @@ class TestHermesTokenStorage:
|
|||
assert not (d / "test-server.json").exists()
|
||||
assert not (d / "test-server.client.json").exists()
|
||||
|
||||
def test_has_cached_tokens(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("my-server")
|
||||
|
||||
assert not storage.has_cached_tokens()
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "my-server.json").write_text('{"access_token": "x", "token_type": "Bearer"}')
|
||||
|
||||
assert storage.has_cached_tokens()
|
||||
|
||||
def test_corrupt_tokens_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("bad-server")
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "bad-server.json").write_text("NOT VALID JSON{{{")
|
||||
|
||||
import asyncio
|
||||
assert asyncio.run(storage.get_tokens()) is None
|
||||
|
||||
def test_corrupt_client_info_returns_none(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
storage = HermesTokenStorage("bad-server")
|
||||
|
||||
d = tmp_path / "mcp-tokens"
|
||||
d.mkdir(parents=True)
|
||||
(d / "bad-server.client.json").write_text("GARBAGE")
|
||||
|
||||
import asyncio
|
||||
assert asyncio.run(storage.get_client_info()) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_oauth_auth
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildOAuthAuth:
|
||||
def test_returns_oauth_provider(self):
|
||||
def test_returns_oauth_provider(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
auth = build_oauth_auth("test", "https://example.com/mcp")
|
||||
assert isinstance(auth, OAuthClientProvider)
|
||||
|
||||
def test_returns_none_without_sdk(self, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
orig_import = __builtins__.__import__ if hasattr(__builtins__, '__import__') else __import__
|
||||
monkeypatch.setattr(mod, "_OAUTH_AVAILABLE", False)
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
assert result is None
|
||||
|
||||
def _block_import(name, *args, **kwargs):
|
||||
if "mcp.client.auth" in name:
|
||||
raise ImportError("blocked")
|
||||
return orig_import(name, *args, **kwargs)
|
||||
def test_pre_registered_client_id_stored(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
with patch("builtins.__import__", side_effect=_block_import):
|
||||
result = build_oauth_auth("test", "https://example.com")
|
||||
# May or may not be None depending on import caching, but shouldn't crash
|
||||
assert result is None or result is not None
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
build_oauth_auth("slack", "https://slack.example.com/mcp", {
|
||||
"client_id": "my-app-id",
|
||||
"client_secret": "my-secret",
|
||||
"scope": "channels:read",
|
||||
})
|
||||
|
||||
client_path = tmp_path / "mcp-tokens" / "slack.client.json"
|
||||
assert client_path.exists()
|
||||
data = json.loads(client_path.read_text())
|
||||
assert data["client_id"] == "my-app-id"
|
||||
assert data["client_secret"] == "my-secret"
|
||||
|
||||
def test_scope_passed_through(self, tmp_path, monkeypatch):
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
provider = build_oauth_auth("scoped", "https://example.com/mcp", {
|
||||
"scope": "read write admin",
|
||||
})
|
||||
assert provider is not None
|
||||
assert provider.context.client_metadata.scope == "read write admin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -119,6 +180,12 @@ class TestUtilities:
|
|||
assert isinstance(port, int)
|
||||
assert 1024 <= port <= 65535
|
||||
|
||||
def test_find_free_port_unique(self):
|
||||
"""Two consecutive calls should return different ports (usually)."""
|
||||
ports = {_find_free_port() for _ in range(5)}
|
||||
# At least 2 different ports out of 5 attempts
|
||||
assert len(ports) >= 2
|
||||
|
||||
def test_can_open_browser_false_in_ssh(self, monkeypatch):
|
||||
monkeypatch.setenv("SSH_CLIENT", "1.2.3.4 1234 22")
|
||||
assert _can_open_browser() is False
|
||||
|
|
@ -127,14 +194,22 @@ class TestUtilities:
|
|||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.delenv("DISPLAY", raising=False)
|
||||
monkeypatch.delenv("WAYLAND_DISPLAY", raising=False)
|
||||
# Mock os.name and uname for non-macOS, non-Windows
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
monkeypatch.setattr(os, "uname", lambda: type("", (), {"sysname": "Linux"})())
|
||||
assert _can_open_browser() is False
|
||||
|
||||
def test_can_open_browser_true_with_display(self, monkeypatch):
|
||||
monkeypatch.delenv("SSH_CLIENT", raising=False)
|
||||
monkeypatch.delenv("SSH_TTY", raising=False)
|
||||
monkeypatch.setenv("DISPLAY", ":0")
|
||||
monkeypatch.setattr(os, "name", "posix")
|
||||
assert _can_open_browser() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# Path traversal protection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPathTraversal:
|
||||
|
|
@ -169,11 +244,14 @@ class TestPathTraversal:
|
|||
assert "/" not in path.stem
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Callback handler isolation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCallbackHandlerIsolation:
|
||||
"""Verify concurrent OAuth flows don't share state."""
|
||||
|
||||
def test_independent_result_dicts(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
_, result_a = _make_callback_handler()
|
||||
_, result_b = _make_callback_handler()
|
||||
|
||||
|
|
@ -184,10 +262,6 @@ class TestCallbackHandlerIsolation:
|
|||
assert result_b["auth_code"] == "code_B"
|
||||
|
||||
def test_handler_writes_to_own_result(self):
|
||||
from tools.mcp_oauth import _make_callback_handler
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
assert result["auth_code"] is None
|
||||
|
||||
|
|
@ -203,13 +277,30 @@ class TestCallbackHandlerIsolation:
|
|||
assert result["auth_code"] == "test123"
|
||||
assert result["state"] == "mystate"
|
||||
|
||||
def test_handler_captures_error(self):
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
|
||||
handler = HandlerClass.__new__(HandlerClass)
|
||||
handler.path = "/callback?error=access_denied"
|
||||
handler.wfile = BytesIO()
|
||||
handler.send_response = MagicMock()
|
||||
handler.send_header = MagicMock()
|
||||
handler.end_headers = MagicMock()
|
||||
handler.do_GET()
|
||||
|
||||
assert result["auth_code"] is None
|
||||
assert result["error"] == "access_denied"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Port sharing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestOAuthPortSharing:
|
||||
"""Verify build_oauth_auth and _wait_for_callback use the same port."""
|
||||
|
||||
def test_port_stored_globally(self):
|
||||
def test_port_stored_globally(self, tmp_path, monkeypatch):
|
||||
import tools.mcp_oauth as mod
|
||||
# Reset
|
||||
mod._oauth_port = None
|
||||
|
||||
try:
|
||||
|
|
@ -217,12 +308,17 @@ class TestOAuthPortSharing:
|
|||
except ImportError:
|
||||
pytest.skip("MCP SDK auth not available")
|
||||
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
build_oauth_auth("test-port", "https://example.com/mcp")
|
||||
assert mod._oauth_port is not None
|
||||
assert isinstance(mod._oauth_port, int)
|
||||
assert 1024 <= mod._oauth_port <= 65535
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# remove_oauth_tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRemoveOAuthTokens:
|
||||
def test_removes_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
|
@ -242,7 +338,7 @@ class TestRemoveOAuthTokens:
|
|||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-interactive / startup-safety tests (issue #4462)
|
||||
# Non-interactive / startup-safety tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestIsInteractive:
|
||||
|
|
|
|||
|
|
@ -1,326 +1,482 @@
|
|||
"""Thin OAuth adapter for MCP HTTP servers.
|
||||
|
||||
Wraps the MCP SDK's built-in ``OAuthClientProvider`` (which implements
|
||||
``httpx.Auth``) with Hermes-specific token storage and browser-based
|
||||
authorization. The SDK handles all of the heavy lifting: PKCE generation,
|
||||
metadata discovery, dynamic client registration, token exchange, and refresh.
|
||||
|
||||
Startup safety:
|
||||
The callback handler never calls blocking ``input()`` on the event loop.
|
||||
In non-interactive environments (no TTY, SSH, headless), the OAuth flow
|
||||
raises ``OAuthNonInteractiveError`` instead of blocking, so that the
|
||||
server degrades gracefully and other MCP servers are not affected.
|
||||
|
||||
Usage in mcp_tool.py::
|
||||
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
auth = build_oauth_auth(server_name, server_url)
|
||||
# pass ``auth`` as the httpx auth parameter
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MCP OAuth 2.1 Client Support
|
||||
|
||||
from __future__ import annotations
|
||||
Implements the browser-based OAuth 2.1 authorization code flow with PKCE
|
||||
for MCP servers that require OAuth authentication instead of static bearer
|
||||
tokens.
|
||||
|
||||
Uses the MCP Python SDK's ``OAuthClientProvider`` (an ``httpx.Auth`` subclass)
|
||||
which handles discovery, dynamic client registration, PKCE, token exchange,
|
||||
refresh, and step-up authorization automatically.
|
||||
|
||||
This module provides the glue:
|
||||
- ``HermesTokenStorage``: persists tokens/client-info to disk so they
|
||||
survive across process restarts.
|
||||
- Callback server: ephemeral localhost HTTP server to capture the OAuth
|
||||
redirect with the authorization code.
|
||||
- ``build_oauth_auth()``: entry point called by ``mcp_tool.py`` that wires
|
||||
everything together and returns the ``httpx.Auth`` object.
|
||||
|
||||
Configuration in config.yaml::
|
||||
|
||||
mcp_servers:
|
||||
my_server:
|
||||
url: "https://mcp.example.com/mcp"
|
||||
auth: oauth
|
||||
oauth: # all fields optional
|
||||
client_id: "pre-registered-id" # skip dynamic registration
|
||||
client_secret: "secret" # confidential clients only
|
||||
scope: "read write" # default: server-provided
|
||||
redirect_port: 0 # 0 = auto-pick free port
|
||||
client_name: "My Custom Client" # default: "Hermes Agent"
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
import threading
|
||||
import webbrowser
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lazy imports -- MCP SDK with OAuth support is optional
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_OAUTH_AVAILABLE = False
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import (
|
||||
OAuthClientInformationFull,
|
||||
OAuthClientMetadata,
|
||||
OAuthToken,
|
||||
)
|
||||
from pydantic import AnyUrl
|
||||
|
||||
_OAUTH_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.debug("MCP OAuth types not available -- OAuth MCP auth disabled")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class OAuthNonInteractiveError(RuntimeError):
|
||||
"""Raised when OAuth requires user interaction but the environment is non-interactive."""
|
||||
pass
|
||||
|
||||
_TOKEN_DIR_NAME = "mcp-tokens"
|
||||
"""Raised when OAuth requires browser interaction in a non-interactive env."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token storage — persists tokens + client info to ~/.hermes/mcp-tokens/
|
||||
# Module-level state
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _sanitize_server_name(name: str) -> str:
|
||||
"""Sanitize server name for safe use as a filename."""
|
||||
import re
|
||||
clean = re.sub(r"[^\w\-]", "-", name.strip().lower())
|
||||
clean = re.sub(r"-+", "-", clean).strip("-")
|
||||
return clean[:60] or "unnamed"
|
||||
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""File-backed token storage implementing the MCP SDK's TokenStorage protocol."""
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = _sanitize_server_name(server_name)
|
||||
|
||||
def _base_dir(self) -> Path:
|
||||
home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
||||
d = home / _TOKEN_DIR_NAME
|
||||
d.mkdir(parents=True, exist_ok=True)
|
||||
return d
|
||||
|
||||
def _tokens_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.json"
|
||||
|
||||
def _client_path(self) -> Path:
|
||||
return self._base_dir() / f"{self._server_name}.client.json"
|
||||
|
||||
# -- TokenStorage protocol (async) --
|
||||
|
||||
async def get_tokens(self):
|
||||
data = self._read_json(self._tokens_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthToken
|
||||
return OAuthToken(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens) -> None:
|
||||
self._write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
||||
|
||||
async def get_client_info(self):
|
||||
data = self._read_json(self._client_path())
|
||||
if not data:
|
||||
return None
|
||||
try:
|
||||
from mcp.shared.auth import OAuthClientInformationFull
|
||||
return OAuthClientInformationFull(**data)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info) -> None:
|
||||
self._write_json(self._client_path(), client_info.model_dump(exclude_none=True))
|
||||
|
||||
# -- helpers --
|
||||
|
||||
@staticmethod
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
try:
|
||||
path.chmod(0o600)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete stored tokens and client info for this server."""
|
||||
for p in (self._tokens_path(), self._client_path()):
|
||||
try:
|
||||
p.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
pass
|
||||
# Port used by the most recent build_oauth_auth() call. Exposed so that
|
||||
# tests can verify the callback server and the redirect_uri share a port.
|
||||
_oauth_port: int | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Browser-based callback handler
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_token_dir() -> Path:
|
||||
"""Return the directory for MCP OAuth token files.
|
||||
|
||||
Uses HERMES_HOME so each profile gets its own OAuth tokens.
|
||||
Layout: ``HERMES_HOME/mcp-tokens/``
|
||||
"""
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
base = Path(get_hermes_home())
|
||||
except ImportError:
|
||||
base = Path(os.environ.get("HERMES_HOME", str(Path.home() / ".hermes")))
|
||||
return base / "mcp-tokens"
|
||||
|
||||
|
||||
def _safe_filename(name: str) -> str:
|
||||
"""Sanitize a server name for use as a filename (no path separators)."""
|
||||
return re.sub(r"[^\w\-]", "_", name).strip("_")[:128] or "default"
|
||||
|
||||
|
||||
def _find_free_port() -> int:
|
||||
"""Find an available TCP port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def _make_callback_handler():
|
||||
"""Create a callback handler class with instance-scoped result storage."""
|
||||
result = {"auth_code": None, "state": None}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
qs = parse_qs(urlparse(self.path).query)
|
||||
result["auth_code"] = (qs.get("code") or [None])[0]
|
||||
result["state"] = (qs.get("state") or [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(b"<html><body><h3>Authorization complete. You can close this tab.</h3></body></html>")
|
||||
|
||||
def log_message(self, *_args: Any) -> None:
|
||||
pass
|
||||
|
||||
return Handler, result
|
||||
|
||||
|
||||
# Port chosen at build time and shared with the callback handler via closure.
|
||||
_oauth_port: int | None = None
|
||||
|
||||
|
||||
async def _redirect_to_browser(auth_url: str) -> None:
|
||||
"""Open the authorization URL in the user's browser."""
|
||||
def _is_interactive() -> bool:
|
||||
"""Return True if we can reasonably expect to interact with a user."""
|
||||
try:
|
||||
if _can_open_browser():
|
||||
webbrowser.open(auth_url)
|
||||
print(" Opened browser for authorization...")
|
||||
else:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
except Exception:
|
||||
print(f"\n Open this URL to authorize:\n {auth_url}\n")
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Start a local HTTP server on the pre-registered port and wait for the OAuth redirect.
|
||||
|
||||
If the callback times out, raises ``OAuthNonInteractiveError`` instead of
|
||||
calling blocking ``input()`` — the old ``input()`` call would block the
|
||||
entire MCP asyncio event loop, preventing all other MCP servers from
|
||||
connecting and potentially hanging Hermes startup indefinitely.
|
||||
"""
|
||||
global _oauth_port
|
||||
port = _oauth_port or _find_free_port()
|
||||
HandlerClass, result = _make_callback_handler()
|
||||
server = HTTPServer(("127.0.0.1", port), HandlerClass)
|
||||
|
||||
def _serve():
|
||||
server.timeout = 120
|
||||
server.handle_request()
|
||||
|
||||
thread = threading.Thread(target=_serve, daemon=True)
|
||||
thread.start()
|
||||
|
||||
for _ in range(1200): # 120 seconds
|
||||
await asyncio.sleep(0.1)
|
||||
if result["auth_code"] is not None:
|
||||
break
|
||||
|
||||
server.server_close()
|
||||
code = result["auth_code"] or ""
|
||||
state = result["state"]
|
||||
if not code:
|
||||
raise OAuthNonInteractiveError(
|
||||
"OAuth browser callback timed out after 120 seconds. "
|
||||
"Run 'hermes mcp auth <server-name>' to authorize interactively."
|
||||
)
|
||||
return code, state
|
||||
return sys.stdin.isatty()
|
||||
except (AttributeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def _can_open_browser() -> bool:
|
||||
"""Return True if opening a browser is likely to work."""
|
||||
# Explicit SSH session → no local display
|
||||
if os.environ.get("SSH_CLIENT") or os.environ.get("SSH_TTY"):
|
||||
return False
|
||||
if not os.environ.get("DISPLAY") and os.name != "nt" and "darwin" not in os.uname().sysname.lower():
|
||||
return False
|
||||
return True
|
||||
# macOS and Windows usually have a display
|
||||
if os.name == "nt":
|
||||
return True
|
||||
try:
|
||||
if os.uname().sysname == "Darwin":
|
||||
return True
|
||||
except AttributeError:
|
||||
pass
|
||||
# Linux/other posix: need DISPLAY or WAYLAND_DISPLAY
|
||||
if os.environ.get("DISPLAY") or os.environ.get("WAYLAND_DISPLAY"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_interactive() -> bool:
|
||||
"""Check if the current environment can support interactive OAuth flows.
|
||||
def _read_json(path: Path) -> dict | None:
|
||||
"""Read a JSON file, returning None if it doesn't exist or is invalid."""
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (json.JSONDecodeError, OSError) as exc:
|
||||
logger.warning("Failed to read %s: %s", path, exc)
|
||||
return None
|
||||
|
||||
Returns False in headless/daemon/container environments where no user
|
||||
can interact with a browser or paste an auth code.
|
||||
|
||||
def _write_json(path: Path, data: dict) -> None:
|
||||
"""Write a dict as JSON with restricted permissions (0o600)."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
tmp = path.with_suffix(".tmp")
|
||||
try:
|
||||
tmp.write_text(json.dumps(data, indent=2, default=str), encoding="utf-8")
|
||||
os.chmod(tmp, 0o600)
|
||||
tmp.rename(path)
|
||||
except OSError:
|
||||
tmp.unlink(missing_ok=True)
|
||||
raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HermesTokenStorage -- persistent token/client-info on disk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class HermesTokenStorage:
|
||||
"""Persist OAuth tokens and client registration to JSON files.
|
||||
|
||||
File layout::
|
||||
|
||||
HERMES_HOME/mcp-tokens/<server_name>.json -- tokens
|
||||
HERMES_HOME/mcp-tokens/<server_name>.client.json -- client info
|
||||
"""
|
||||
if not hasattr(sys.stdin, "isatty") or not sys.stdin.isatty():
|
||||
return False
|
||||
return True
|
||||
|
||||
def __init__(self, server_name: str):
|
||||
self._server_name = _safe_filename(server_name)
|
||||
|
||||
def _tokens_path(self) -> Path:
|
||||
return _get_token_dir() / f"{self._server_name}.json"
|
||||
|
||||
def _client_info_path(self) -> Path:
|
||||
return _get_token_dir() / f"{self._server_name}.client.json"
|
||||
|
||||
# -- tokens ------------------------------------------------------------
|
||||
|
||||
async def get_tokens(self) -> "OAuthToken | None":
|
||||
data = _read_json(self._tokens_path())
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return OAuthToken.model_validate(data)
|
||||
except Exception:
|
||||
logger.warning("Corrupt tokens at %s -- ignoring", self._tokens_path())
|
||||
return None
|
||||
|
||||
async def set_tokens(self, tokens: "OAuthToken") -> None:
|
||||
_write_json(self._tokens_path(), tokens.model_dump(exclude_none=True))
|
||||
logger.debug("OAuth tokens saved for %s", self._server_name)
|
||||
|
||||
# -- client info -------------------------------------------------------
|
||||
|
||||
async def get_client_info(self) -> "OAuthClientInformationFull | None":
|
||||
data = _read_json(self._client_info_path())
|
||||
if data is None:
|
||||
return None
|
||||
try:
|
||||
return OAuthClientInformationFull.model_validate(data)
|
||||
except Exception:
|
||||
logger.warning("Corrupt client info at %s -- ignoring", self._client_info_path())
|
||||
return None
|
||||
|
||||
async def set_client_info(self, client_info: "OAuthClientInformationFull") -> None:
|
||||
_write_json(self._client_info_path(), client_info.model_dump(exclude_none=True))
|
||||
logger.debug("OAuth client info saved for %s", self._server_name)
|
||||
|
||||
# -- cleanup -----------------------------------------------------------
|
||||
|
||||
def remove(self) -> None:
|
||||
"""Delete all stored OAuth state for this server."""
|
||||
for p in (self._tokens_path(), self._client_info_path()):
|
||||
p.unlink(missing_ok=True)
|
||||
|
||||
def has_cached_tokens(self) -> bool:
|
||||
"""Return True if we have tokens on disk (may be expired)."""
|
||||
return self._tokens_path().exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Callback handler factory -- each invocation gets its own result dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_callback_handler() -> tuple[type, dict]:
|
||||
"""Create a per-flow callback HTTP handler class with its own result dict.
|
||||
|
||||
Returns ``(HandlerClass, result_dict)`` where *result_dict* is a mutable
|
||||
dict that the handler writes ``auth_code`` and ``state`` into when the
|
||||
OAuth redirect arrives. Each call returns a fresh pair so concurrent
|
||||
flows don't stomp on each other.
|
||||
"""
|
||||
result: dict[str, Any] = {"auth_code": None, "state": None, "error": None}
|
||||
|
||||
class _Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None: # noqa: N802
|
||||
params = parse_qs(urlparse(self.path).query)
|
||||
code = params.get("code", [None])[0]
|
||||
state = params.get("state", [None])[0]
|
||||
error = params.get("error", [None])[0]
|
||||
|
||||
result["auth_code"] = code
|
||||
result["state"] = state
|
||||
result["error"] = error
|
||||
|
||||
body = (
|
||||
"<html><body><h2>Authorization Successful</h2>"
|
||||
"<p>You can close this tab and return to Hermes.</p></body></html>"
|
||||
) if code else (
|
||||
"<html><body><h2>Authorization Failed</h2>"
|
||||
f"<p>Error: {error or 'unknown'}</p></body></html>"
|
||||
)
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.end_headers()
|
||||
self.wfile.write(body.encode())
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None:
|
||||
logger.debug("OAuth callback: %s", fmt % args)
|
||||
|
||||
return _Handler, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Async redirect + callback handlers for OAuthClientProvider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _redirect_handler(authorization_url: str) -> None:
|
||||
"""Show the authorization URL to the user.
|
||||
|
||||
Opens the browser automatically when possible; always prints the URL
|
||||
as a fallback for headless/SSH/gateway environments.
|
||||
"""
|
||||
msg = (
|
||||
f"\n MCP OAuth: authorization required.\n"
|
||||
f" Open this URL in your browser:\n\n"
|
||||
f" {authorization_url}\n"
|
||||
)
|
||||
print(msg, file=sys.stderr)
|
||||
|
||||
if _can_open_browser():
|
||||
try:
|
||||
opened = webbrowser.open(authorization_url)
|
||||
if opened:
|
||||
print(" (Browser opened automatically.)\n", file=sys.stderr)
|
||||
else:
|
||||
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
|
||||
except Exception:
|
||||
print(" (Could not open browser — please open the URL manually.)\n", file=sys.stderr)
|
||||
else:
|
||||
print(" (Headless environment detected — open the URL manually.)\n", file=sys.stderr)
|
||||
|
||||
|
||||
async def _wait_for_callback() -> tuple[str, str | None]:
|
||||
"""Wait for the OAuth callback to arrive on the local callback server.
|
||||
|
||||
Uses the module-level ``_oauth_port`` which is set by ``build_oauth_auth``
|
||||
before this is ever called. Polls for the result without blocking the
|
||||
event loop.
|
||||
|
||||
Raises:
|
||||
OAuthNonInteractiveError: If the callback times out (no user present
|
||||
to complete the browser auth).
|
||||
"""
|
||||
global _oauth_port
|
||||
assert _oauth_port is not None, "OAuth callback port not set"
|
||||
|
||||
# The callback server is already running (started in build_oauth_auth).
|
||||
# We just need to poll for the result.
|
||||
handler_cls, result = _make_callback_handler()
|
||||
|
||||
# Start a temporary server on the known port
|
||||
try:
|
||||
server = HTTPServer(("127.0.0.1", _oauth_port), handler_cls)
|
||||
except OSError:
|
||||
# Port already in use — the server from build_oauth_auth is running.
|
||||
# Fall back to polling the server started by build_oauth_auth.
|
||||
raise OAuthNonInteractiveError(
|
||||
"OAuth callback timed out — could not bind callback port. "
|
||||
"Complete the authorization in a browser first, then retry."
|
||||
)
|
||||
|
||||
server_thread = threading.Thread(target=server.handle_request, daemon=True)
|
||||
server_thread.start()
|
||||
|
||||
timeout = 300.0
|
||||
poll_interval = 0.5
|
||||
elapsed = 0.0
|
||||
while elapsed < timeout:
|
||||
if result["auth_code"] is not None or result["error"] is not None:
|
||||
break
|
||||
await asyncio.sleep(poll_interval)
|
||||
elapsed += poll_interval
|
||||
|
||||
server.server_close()
|
||||
|
||||
if result["error"]:
|
||||
raise RuntimeError(f"OAuth authorization failed: {result['error']}")
|
||||
if result["auth_code"] is None:
|
||||
raise OAuthNonInteractiveError(
|
||||
"OAuth callback timed out — no authorization code received. "
|
||||
"Ensure you completed the browser authorization flow."
|
||||
)
|
||||
|
||||
return result["auth_code"], result["state"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_oauth_auth(server_name: str, server_url: str):
|
||||
"""Build an ``httpx.Auth`` handler for the given MCP server using OAuth 2.1 PKCE.
|
||||
|
||||
Uses the MCP SDK's ``OAuthClientProvider`` which handles discovery,
|
||||
registration, PKCE, token exchange, and refresh automatically.
|
||||
|
||||
In non-interactive environments (no TTY), this still returns a provider
|
||||
so that **cached tokens and refresh flows work**. Only the interactive
|
||||
authorization-code grant will fail fast with a clear error instead of
|
||||
blocking the event loop.
|
||||
|
||||
Returns an ``OAuthClientProvider`` instance (implements ``httpx.Auth``),
|
||||
or ``None`` if the MCP SDK auth module is not available.
|
||||
"""
|
||||
try:
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.shared.auth import OAuthClientMetadata
|
||||
except ImportError:
|
||||
logger.warning("MCP SDK auth module not available — OAuth disabled")
|
||||
return None
|
||||
|
||||
storage = HermesTokenStorage(server_name)
|
||||
interactive = _is_interactive()
|
||||
|
||||
if not interactive:
|
||||
# Check whether cached tokens exist. If they do, the SDK can still
|
||||
# use them (and refresh them) without any user interaction. If not,
|
||||
# we still build the provider — the callback_handler will raise
|
||||
# OAuthNonInteractiveError if a fresh authorization is actually
|
||||
# needed, which surfaces as a clean connection failure for this
|
||||
# server only (other MCP servers are unaffected).
|
||||
has_cached = storage._read_json(storage._tokens_path()) is not None
|
||||
if not has_cached:
|
||||
logger.warning(
|
||||
"MCP server '%s' requires OAuth but no cached tokens found "
|
||||
"and environment is non-interactive. The server will fail to "
|
||||
"connect. Run 'hermes mcp auth %s' to authorize interactively.",
|
||||
server_name, server_name,
|
||||
)
|
||||
|
||||
global _oauth_port
|
||||
_oauth_port = _find_free_port()
|
||||
redirect_uri = f"http://127.0.0.1:{_oauth_port}/callback"
|
||||
|
||||
client_metadata = OAuthClientMetadata(
|
||||
client_name="Hermes Agent",
|
||||
redirect_uris=[redirect_uri],
|
||||
grant_types=["authorization_code", "refresh_token"],
|
||||
response_types=["code"],
|
||||
scope="openid profile email offline_access",
|
||||
token_endpoint_auth_method="none",
|
||||
)
|
||||
|
||||
# In non-interactive mode, the redirect handler logs the URL and the
|
||||
# callback handler raises immediately — no blocking, no input().
|
||||
redirect_handler = _redirect_to_browser
|
||||
callback_handler = _wait_for_callback
|
||||
|
||||
if not interactive:
|
||||
async def _noninteractive_redirect(auth_url: str) -> None:
|
||||
logger.warning(
|
||||
"MCP server '%s' needs OAuth authorization (non-interactive, "
|
||||
"cannot open browser). URL: %s",
|
||||
server_name, auth_url,
|
||||
)
|
||||
|
||||
async def _noninteractive_callback() -> tuple[str, str | None]:
|
||||
raise OAuthNonInteractiveError(
|
||||
f"MCP server '{server_name}' requires interactive OAuth "
|
||||
f"authorization but the environment is non-interactive "
|
||||
f"(no TTY). Run 'hermes mcp auth {server_name}' to "
|
||||
f"authorize, then restart."
|
||||
)
|
||||
|
||||
redirect_handler = _noninteractive_redirect
|
||||
callback_handler = _noninteractive_callback
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=server_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
timeout=120.0,
|
||||
)
|
||||
|
||||
|
||||
def remove_oauth_tokens(server_name: str) -> None:
|
||||
"""Delete stored OAuth tokens and client info for a server."""
|
||||
HermesTokenStorage(server_name).remove()
|
||||
storage = HermesTokenStorage(server_name)
|
||||
storage.remove()
|
||||
logger.info("OAuth tokens removed for '%s'", server_name)
|
||||
|
||||
|
||||
def build_oauth_auth(
|
||||
server_name: str,
|
||||
server_url: str,
|
||||
oauth_config: dict | None = None,
|
||||
) -> "OAuthClientProvider | None":
|
||||
"""Build an ``httpx.Auth``-compatible OAuth handler for an MCP server.
|
||||
|
||||
Called from ``mcp_tool.py`` when a server has ``auth: oauth`` in config.
|
||||
|
||||
Args:
|
||||
server_name: Server key in mcp_servers config (used for storage).
|
||||
server_url: MCP server endpoint URL.
|
||||
oauth_config: Optional dict from the ``oauth:`` block in config.yaml.
|
||||
|
||||
Returns:
|
||||
An ``OAuthClientProvider`` instance, or None if the MCP SDK lacks
|
||||
OAuth support.
|
||||
"""
|
||||
if not _OAUTH_AVAILABLE:
|
||||
logger.warning(
|
||||
"MCP OAuth requested for '%s' but SDK auth types are not available. "
|
||||
"Install with: pip install 'mcp>=1.10.0'",
|
||||
server_name,
|
||||
)
|
||||
return None
|
||||
|
||||
global _oauth_port
|
||||
|
||||
cfg = oauth_config or {}
|
||||
|
||||
# --- Storage ---
|
||||
storage = HermesTokenStorage(server_name)
|
||||
|
||||
# --- Non-interactive warning ---
|
||||
if not _is_interactive() and not storage.has_cached_tokens():
|
||||
logger.warning(
|
||||
"MCP OAuth for '%s': non-interactive environment and no cached tokens found. "
|
||||
"The OAuth flow requires browser authorization. Run interactively first "
|
||||
"to complete the initial authorization, then cached tokens will be reused.",
|
||||
server_name,
|
||||
)
|
||||
|
||||
# --- Pick callback port ---
|
||||
redirect_port = int(cfg.get("redirect_port", 0))
|
||||
if redirect_port == 0:
|
||||
redirect_port = _find_free_port()
|
||||
_oauth_port = redirect_port
|
||||
|
||||
# --- Client metadata ---
|
||||
client_name = cfg.get("client_name", "Hermes Agent")
|
||||
scope = cfg.get("scope")
|
||||
redirect_uri = f"http://127.0.0.1:{redirect_port}/callback"
|
||||
|
||||
metadata_kwargs: dict[str, Any] = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [AnyUrl(redirect_uri)],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "none",
|
||||
}
|
||||
if scope:
|
||||
metadata_kwargs["scope"] = scope
|
||||
|
||||
client_secret = cfg.get("client_secret")
|
||||
if client_secret:
|
||||
metadata_kwargs["token_endpoint_auth_method"] = "client_secret_post"
|
||||
|
||||
client_metadata = OAuthClientMetadata.model_validate(metadata_kwargs)
|
||||
|
||||
# --- Pre-registered client ---
|
||||
client_id = cfg.get("client_id")
|
||||
if client_id:
|
||||
info_dict: dict[str, Any] = {
|
||||
"client_id": client_id,
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": client_metadata.grant_types,
|
||||
"response_types": client_metadata.response_types,
|
||||
"token_endpoint_auth_method": client_metadata.token_endpoint_auth_method,
|
||||
}
|
||||
if client_secret:
|
||||
info_dict["client_secret"] = client_secret
|
||||
if client_name:
|
||||
info_dict["client_name"] = client_name
|
||||
if scope:
|
||||
info_dict["scope"] = scope
|
||||
|
||||
client_info = OAuthClientInformationFull.model_validate(info_dict)
|
||||
_write_json(storage._client_info_path(), client_info.model_dump(exclude_none=True))
|
||||
logger.debug("Pre-registered client_id=%s for '%s'", client_id, server_name)
|
||||
|
||||
# --- Base URL for discovery ---
|
||||
parsed = urlparse(server_url)
|
||||
base_url = f"{parsed.scheme}://{parsed.netloc}"
|
||||
|
||||
# --- Build provider ---
|
||||
provider = OAuthClientProvider(
|
||||
server_url=base_url,
|
||||
client_metadata=client_metadata,
|
||||
storage=storage,
|
||||
redirect_handler=_redirect_handler,
|
||||
callback_handler=_wait_for_callback,
|
||||
timeout=float(cfg.get("timeout", 300)),
|
||||
)
|
||||
|
||||
return provider
|
||||
|
|
|
|||
|
|
@ -892,7 +892,9 @@ class MCPServerTask:
|
|||
if self._auth_type == "oauth":
|
||||
try:
|
||||
from tools.mcp_oauth import build_oauth_auth
|
||||
_oauth_auth = build_oauth_auth(self.name, url)
|
||||
_oauth_auth = build_oauth_auth(
|
||||
self.name, url, config.get("oauth")
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("MCP OAuth setup failed for '%s': %s", self.name, exc)
|
||||
raise
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue