mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-31 06:51:29 +00:00
feat(dashboard-auth): plugins/dashboard_auth/nous — contract-compliant Nous OAuth provider
Bundled, kind=backend, auto-loads. Activates ONLY when Portal-injected
env vars are present:
HERMES_DASHBOARD_OAUTH_CLIENT_ID — agent:{instance_id}
HERMES_DASHBOARD_PORTAL_URL — Portal base URL
Loopback / --insecure operators leave both unset and never see this
plugin register anything. The fail-closed branch in start_server handles
the 'public bind + zero providers' case independently.
Implementation follows nous-account-service PR #180's published OAuth
contract verbatim:
- client_id is per-instance (agent:{instance_id}); the suffix is
cross-checked against the token's agent_instance_id claim as
defense-in-depth (contract C9).
- scope is agent_dashboard:access only (contract C3).
- aud is the bare client_id, no hermes-cli: prefix (contract C2).
- RS256 JWT verification against /.well-known/jwks.json with
5-minute cache (contract C7).
- No refresh tokens in V1: refresh_session always raises
RefreshExpiredError; revoke_session is a no-op (contract C5).
- oauth_contract_version claim: missing → warn + proceed; present
and != 1 → refuse (contract C11, OQ-C2 tolerant treatment).
- redirect_uri validated client-side as defense before bouncing to
Portal; authoritative check is server-side per agent-redirect-uri.ts.
41 new tests covering construction, plugin-entry env gating, start_login
shape, complete_login httpx-mocked happy path + error mapping,
verify_session JWT verification (RSA keypair fixture, full claim-check
matrix), refresh_session always raising, revoke_session no-op.
PyJWT + cryptography are already in the venv (jose was previously
suggested; switched to pyjwt[crypto] since the latter is already
pulled in transitively).
This commit is contained in:
parent
d64626388a
commit
171921c9b2
3 changed files with 992 additions and 0 deletions
432
plugins/dashboard_auth/nous/__init__.py
Normal file
432
plugins/dashboard_auth/nous/__init__.py
Normal file
|
|
@ -0,0 +1,432 @@
|
|||
"""NousDashboardAuthProvider — Nous Portal OAuth (authorization-code + PKCE).
|
||||
|
||||
Implements ``nous-account-service/docs/agent-dashboard-oauth-contract.md``
|
||||
(PR #180). The plugin auto-loads (bundled, kind=backend) but only registers
|
||||
its provider when the Portal-injected env vars are present, so loopback /
|
||||
``--insecure`` operators are unaffected.
|
||||
|
||||
Required env vars (Portal injects at Fly.io provisioning):
|
||||
|
||||
HERMES_DASHBOARD_OAUTH_CLIENT_ID — shape ``agent:{agent_instance_id}``
|
||||
HERMES_DASHBOARD_PORTAL_URL — e.g. ``https://portal.nousresearch.com``
|
||||
|
||||
Key contract points encoded here:
|
||||
|
||||
- client_id is per-instance (``agent:{instance_id}``); the suffix is also
|
||||
cross-checked against the token's ``agent_instance_id`` claim as
|
||||
defense-in-depth.
|
||||
- scope is ``agent_dashboard:access`` only (no OIDC scopes).
|
||||
- tokens are RS256 JWTs verified against ``/.well-known/jwks.json``;
|
||||
JWKS is cached for 5 minutes.
|
||||
- V1 has NO refresh tokens — ``refresh_session`` always raises
|
||||
``RefreshExpiredError`` so the middleware redirects to ``/auth/login``.
|
||||
- audience claim is the bare ``client_id`` (no ``hermes-cli:`` prefix).
|
||||
- tolerant ``oauth_contract_version`` check: missing → warn + proceed;
|
||||
present and ``!= 1`` → refuse.
|
||||
|
||||
The cookie payload returned by ``start_login`` stashes the PKCE
|
||||
``code_verifier`` and the OAuth ``state`` parameter for the
|
||||
``/auth/callback`` handler to retrieve. The auth-route layer is the owner
|
||||
of cookie names; this provider just hands back ``{"code_verifier": …,
|
||||
"state": …}`` and the route serializes those into the ``hermes_session_pkce``
|
||||
cookie.
|
||||
|
||||
Forward compatibility: if a future Portal contract starts issuing refresh
|
||||
tokens, ``complete_login`` already captures the value forward-compatibly
|
||||
(populates ``Session.refresh_token``). Wiring the RT cookie back into the
|
||||
middleware's near-expiry refresh path lives in the host application, not
|
||||
here.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from hermes_cli.dashboard_auth import (
|
||||
DashboardAuthProvider,
|
||||
InvalidCodeError,
|
||||
LoginStart,
|
||||
ProviderError,
|
||||
RefreshExpiredError,
|
||||
Session,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Contract constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Contract C3: scope name for the dashboard flow.
|
||||
_SCOPE = "agent_dashboard:access"
|
||||
|
||||
# Contract C11: emitted claim should equal 1; tolerant (warn) if missing.
|
||||
_EXPECTED_CONTRACT_VERSION = 1
|
||||
|
||||
# Contract C7: JWKS Cache-Control max-age=300.
|
||||
_JWKS_CACHE_SECONDS = 300
|
||||
|
||||
# httpx timeout for the token endpoint POST.
|
||||
_TOKEN_ENDPOINT_TIMEOUT_SEC = 10.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _b64url_no_pad(raw: bytes) -> str:
|
||||
"""Base64url-encode without ``=`` padding (RFC 7636 §4)."""
|
||||
return base64.urlsafe_b64encode(raw).rstrip(b"=").decode()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class NousDashboardAuthProvider(DashboardAuthProvider):
|
||||
"""Nous Portal OAuth via authorization-code + PKCE (S256)."""
|
||||
|
||||
name = "nous"
|
||||
display_name = "Nous Research"
|
||||
|
||||
def __init__(self, *, client_id: str, portal_url: str) -> None:
|
||||
if not client_id.startswith("agent:"):
|
||||
# Defense-in-depth. The plugin entry point already filters, but
|
||||
# the provider should never be constructible with a malformed id.
|
||||
raise ValueError(
|
||||
"client_id must match contract shape 'agent:{instance_id}', "
|
||||
f"got {client_id!r}"
|
||||
)
|
||||
self._client_id = client_id
|
||||
self._agent_instance_id = client_id[len("agent:") :]
|
||||
self._portal_url = portal_url.rstrip("/")
|
||||
self._jwks_url = f"{self._portal_url}/.well-known/jwks.json"
|
||||
self._authorize_url = f"{self._portal_url}/oauth/authorize"
|
||||
self._token_url = f"{self._portal_url}/api/oauth/token"
|
||||
# PyJWKClient is lazily imported so plugin discovery doesn't pay the
|
||||
# crypto-import cost when the provider isn't activated.
|
||||
self._jwks_client: Any = None
|
||||
|
||||
# ---- public API (DashboardAuthProvider) -------------------------------
|
||||
|
||||
def start_login(self, *, redirect_uri: str) -> LoginStart:
|
||||
self._validate_redirect_uri(redirect_uri)
|
||||
|
||||
code_verifier = _b64url_no_pad(secrets.token_bytes(64)) # ~86 chars
|
||||
code_challenge = _b64url_no_pad(
|
||||
hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
)
|
||||
state = _b64url_no_pad(secrets.token_bytes(32))
|
||||
|
||||
params = {
|
||||
"response_type": "code",
|
||||
"client_id": self._client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": _SCOPE,
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
redirect_url = f"{self._authorize_url}?{urllib.parse.urlencode(params)}"
|
||||
# The auth-route layer expects ``cookie_payload[\"hermes_session_pkce\"]``
|
||||
# as a single semicolon-delimited string of ``key=value`` segments,
|
||||
# matching the stub provider's shape. The route handler prepends
|
||||
# ``provider=`` so the callback knows which plugin to dispatch to.
|
||||
cookie_payload = {
|
||||
"hermes_session_pkce": f"state={state};verifier={code_verifier}",
|
||||
}
|
||||
return LoginStart(redirect_url=redirect_url, cookie_payload=cookie_payload)
|
||||
|
||||
def complete_login(
|
||||
self,
|
||||
*,
|
||||
code: str,
|
||||
state: str,
|
||||
code_verifier: str,
|
||||
redirect_uri: str,
|
||||
) -> Session:
|
||||
# ``state`` is verified by the auth-route layer before this call
|
||||
# (it checks the cookie-stashed state matches the query-param state);
|
||||
# we just receive it for symmetry with the protocol. Nous Portal
|
||||
# doesn't re-check state at the token endpoint, so we ignore it here.
|
||||
_ = state
|
||||
|
||||
try:
|
||||
response = httpx.post(
|
||||
self._token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"client_id": self._client_id,
|
||||
"code_verifier": code_verifier,
|
||||
},
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=_TOKEN_ENDPOINT_TIMEOUT_SEC,
|
||||
)
|
||||
except httpx.RequestError as exc:
|
||||
raise ProviderError(f"Portal token endpoint unreachable: {exc}") from exc
|
||||
|
||||
if response.status_code == 400:
|
||||
# Contract: invalid_code, invalid_grant, redirect_uri_mismatch all
|
||||
# surface as 400 with an OAuth-shaped JSON error envelope.
|
||||
body = self._parse_json_body(response)
|
||||
error_code = body.get("error", "invalid_request")
|
||||
raise InvalidCodeError(f"Portal rejected code: {error_code}")
|
||||
if response.status_code != 200:
|
||||
raise ProviderError(
|
||||
f"Portal token endpoint returned {response.status_code}: "
|
||||
f"{response.text[:200]!r}"
|
||||
)
|
||||
|
||||
payload = self._parse_json_body(response)
|
||||
access_token = payload.get("access_token")
|
||||
if not access_token or not isinstance(access_token, str):
|
||||
raise ProviderError("Portal token response missing access_token")
|
||||
|
||||
token_type = str(payload.get("token_type", "")).lower()
|
||||
if token_type and token_type != "bearer":
|
||||
raise ProviderError(f"unexpected token_type={token_type!r}")
|
||||
|
||||
claims = self._verify_jwt(access_token)
|
||||
# Contract V1: no refresh token expected. If a future Portal ever
|
||||
# adds one, capture it forward-compatibly.
|
||||
refresh_token = payload.get("refresh_token") or ""
|
||||
if not isinstance(refresh_token, str):
|
||||
refresh_token = ""
|
||||
return self._session_from_claims(access_token, refresh_token, claims)
|
||||
|
||||
def refresh_session(self, *, refresh_token: str) -> Session:
|
||||
# Contract V1 has no refresh tokens — always force re-auth. If a
|
||||
# future Portal contract starts issuing them, this method needs to
|
||||
# be re-implemented; until then it's an unconditional refusal.
|
||||
raise RefreshExpiredError(
|
||||
"Nous Portal does not issue refresh tokens in OAuth contract v1; "
|
||||
"user must re-authenticate via /auth/login."
|
||||
)
|
||||
|
||||
def verify_session(self, *, access_token: str) -> Optional[Session]:
|
||||
# Contract: returns None on expiry/invalidity (middleware then
|
||||
# triggers redirect-to-login since refresh_session can never succeed
|
||||
# under V1); raises ProviderError if the IDP is unreachable.
|
||||
try:
|
||||
claims = self._verify_jwt(access_token)
|
||||
except InvalidCodeError:
|
||||
# Expired/invalid token — middleware contract is None, not raise.
|
||||
return None
|
||||
except ProviderError:
|
||||
# JWKS unreachable, etc. Bubble up so middleware emits 503.
|
||||
raise
|
||||
# verify_session has no access to the original refresh_token; pass
|
||||
# "" because in contract V1 there is none anyway.
|
||||
return self._session_from_claims(access_token, "", claims)
|
||||
|
||||
def revoke_session(self, *, refresh_token: str) -> None:
|
||||
# Contract V1: no refresh tokens to revoke, and no Portal revocation
|
||||
# endpoint documented for dashboard tokens. Logout is purely
|
||||
# client-side cookie clearing; this is a best-effort no-op.
|
||||
_ = refresh_token
|
||||
return None
|
||||
|
||||
# ---- internals --------------------------------------------------------
|
||||
|
||||
def _validate_redirect_uri(self, redirect_uri: str) -> None:
|
||||
"""Surface obviously-broken redirect_uris before bouncing to Portal.
|
||||
|
||||
The Portal-side check (``agent-redirect-uri.ts``) is authoritative;
|
||||
this is a fast-fail for the common operator-error case.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(redirect_uri)
|
||||
if parsed.scheme not in ("https", "http"):
|
||||
raise ProviderError(
|
||||
f"redirect_uri must be http(s), got {redirect_uri!r}"
|
||||
)
|
||||
if parsed.scheme == "http" and parsed.hostname not in (
|
||||
"localhost",
|
||||
"127.0.0.1",
|
||||
):
|
||||
raise ProviderError(
|
||||
"redirect_uri may only use http:// for localhost/127.0.0.1, "
|
||||
f"got {redirect_uri!r}"
|
||||
)
|
||||
if not parsed.path or not parsed.path.endswith("/auth/callback"):
|
||||
raise ProviderError(
|
||||
"redirect_uri path must end with '/auth/callback', "
|
||||
f"got {redirect_uri!r}"
|
||||
)
|
||||
|
||||
def _parse_json_body(self, response: httpx.Response) -> Dict[str, Any]:
|
||||
ctype = response.headers.get("content-type", "")
|
||||
if not ctype.startswith("application/json"):
|
||||
return {}
|
||||
try:
|
||||
body = response.json()
|
||||
except ValueError:
|
||||
return {}
|
||||
return body if isinstance(body, dict) else {}
|
||||
|
||||
def _get_jwks_client(self) -> Any:
|
||||
if self._jwks_client is None:
|
||||
from jwt import PyJWKClient # lazy import
|
||||
|
||||
self._jwks_client = PyJWKClient(
|
||||
self._jwks_url,
|
||||
cache_keys=True,
|
||||
lifespan=_JWKS_CACHE_SECONDS,
|
||||
)
|
||||
return self._jwks_client
|
||||
|
||||
def _verify_jwt(self, access_token: str) -> Dict[str, Any]:
|
||||
# Lazy import — keeps startup fast for operators who never trigger
|
||||
# the gated path.
|
||||
import jwt
|
||||
|
||||
try:
|
||||
signing_key = self._get_jwks_client().get_signing_key_from_jwt(
|
||||
access_token
|
||||
)
|
||||
except jwt.PyJWKClientError as exc:
|
||||
raise ProviderError(f"JWKS lookup failed: {exc}") from exc
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
raise ProviderError(f"JWKS lookup failed: {exc!r}") from exc
|
||||
|
||||
try:
|
||||
claims = jwt.decode(
|
||||
access_token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
# Contract C2: aud is the bare client_id.
|
||||
audience=self._client_id,
|
||||
# Contract: issuer is the Portal base URL.
|
||||
issuer=self._portal_url,
|
||||
options={"require": ["exp", "iat", "aud", "iss", "sub"]},
|
||||
)
|
||||
except jwt.ExpiredSignatureError as exc:
|
||||
# verify_session() catches this and returns None per protocol.
|
||||
raise InvalidCodeError(f"access token expired: {exc}") from exc
|
||||
except jwt.InvalidTokenError as exc:
|
||||
raise ProviderError(
|
||||
f"access token verification failed: {exc}"
|
||||
) from exc
|
||||
|
||||
self._check_agent_instance_id(claims)
|
||||
self._check_contract_version(claims)
|
||||
return claims
|
||||
|
||||
def _check_agent_instance_id(self, claims: Dict[str, Any]) -> None:
|
||||
"""Contract C9: cross-check agent_instance_id against our config."""
|
||||
token_instance_id = claims.get("agent_instance_id")
|
||||
if token_instance_id is None:
|
||||
# Tolerated — the claim is documented as "should" not "must".
|
||||
# Our audience check on the bare client_id already binds the
|
||||
# token to this instance; agent_instance_id is defense-in-depth.
|
||||
return
|
||||
if token_instance_id != self._agent_instance_id:
|
||||
raise ProviderError(
|
||||
f"agent_instance_id mismatch: token={token_instance_id!r} "
|
||||
f"vs configured={self._agent_instance_id!r}"
|
||||
)
|
||||
|
||||
def _check_contract_version(self, claims: Dict[str, Any]) -> None:
|
||||
"""Contract C11 — tolerant treatment per OQ-C2."""
|
||||
contract_version = claims.get("oauth_contract_version")
|
||||
if contract_version is None:
|
||||
logger.warning(
|
||||
"Nous Portal token missing oauth_contract_version claim "
|
||||
"(contract says it should be %d); proceeding anyway.",
|
||||
_EXPECTED_CONTRACT_VERSION,
|
||||
)
|
||||
return
|
||||
if contract_version != _EXPECTED_CONTRACT_VERSION:
|
||||
raise ProviderError(
|
||||
f"unsupported oauth_contract_version={contract_version!r}, "
|
||||
f"expected {_EXPECTED_CONTRACT_VERSION}"
|
||||
)
|
||||
|
||||
def _session_from_claims(
|
||||
self,
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
claims: Dict[str, Any],
|
||||
) -> Session:
|
||||
# Contract C4: no email / display_name in tokens. AuthWidget will
|
||||
# show user_id (truncated). Session fields kept for forward-compat.
|
||||
user_id = str(claims.get("sub", ""))
|
||||
if not user_id:
|
||||
raise ProviderError("token missing 'sub' (user_id) claim")
|
||||
return Session(
|
||||
user_id=user_id,
|
||||
email="",
|
||||
display_name="",
|
||||
org_id=str(claims.get("org_id") or ""),
|
||||
provider=self.name,
|
||||
expires_at=int(claims["exp"]),
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
"""Plugin entry — called by the plugin loader at startup.
|
||||
|
||||
Registers ``NousDashboardAuthProvider`` only when the Portal-injected
|
||||
env vars are present. Operator-owned dashboards (loopback / ``--insecure``)
|
||||
leave these unset, so this plugin is a no-op for them.
|
||||
|
||||
The gate-engagement layer (``hermes_cli.web_server.should_require_auth``
|
||||
+ the fail-closed check in ``start_server``) handles the "public bind
|
||||
with zero providers" case independently, so silently returning here
|
||||
is safe — it just means no Nous provider gets registered.
|
||||
"""
|
||||
client_id = os.environ.get("HERMES_DASHBOARD_OAUTH_CLIENT_ID", "").strip()
|
||||
portal_url = os.environ.get("HERMES_DASHBOARD_PORTAL_URL", "").strip()
|
||||
|
||||
if not client_id or not portal_url:
|
||||
logger.debug(
|
||||
"dashboard-auth-nous: env vars missing "
|
||||
"(HERMES_DASHBOARD_OAUTH_CLIENT_ID set=%s, "
|
||||
"HERMES_DASHBOARD_PORTAL_URL set=%s); not registering provider.",
|
||||
bool(client_id),
|
||||
bool(portal_url),
|
||||
)
|
||||
return
|
||||
|
||||
if not client_id.startswith("agent:"):
|
||||
logger.warning(
|
||||
"dashboard-auth-nous: HERMES_DASHBOARD_OAUTH_CLIENT_ID=%r does not "
|
||||
"match contract shape 'agent:{instance_id}'; not registering "
|
||||
"provider. Set this env var to the value provisioned by Nous Portal.",
|
||||
client_id,
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
provider = NousDashboardAuthProvider(
|
||||
client_id=client_id, portal_url=portal_url
|
||||
)
|
||||
except ValueError as exc:
|
||||
logger.warning("dashboard-auth-nous: refusing to register: %s", exc)
|
||||
return
|
||||
|
||||
ctx.register_dashboard_auth_provider(provider)
|
||||
logger.info(
|
||||
"dashboard-auth-nous: registered provider (client_id=%s, portal=%s)",
|
||||
client_id,
|
||||
portal_url,
|
||||
)
|
||||
8
plugins/dashboard_auth/nous/plugin.yaml
Normal file
8
plugins/dashboard_auth/nous/plugin.yaml
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
name: nous
|
||||
version: 1.0.0
|
||||
description: "Dashboard auth provider — OAuth 2.0 (authorization-code + PKCE) against Nous Portal. Auto-activates when HERMES_DASHBOARD_OAUTH_CLIENT_ID is set (Portal injects this at Fly.io provisioning)."
|
||||
author: NousResearch
|
||||
kind: backend
|
||||
requires_env:
|
||||
- HERMES_DASHBOARD_OAUTH_CLIENT_ID
|
||||
- HERMES_DASHBOARD_PORTAL_URL
|
||||
552
tests/plugins/dashboard_auth/test_nous_provider.py
Normal file
552
tests/plugins/dashboard_auth/test_nous_provider.py
Normal file
|
|
@ -0,0 +1,552 @@
|
|||
"""Tests for the bundled Nous dashboard-auth plugin.
|
||||
|
||||
Covers four shapes from Phase 4 of ``.hermes/plans/2026-05-21-dashboard-oauth-auth.md``:
|
||||
|
||||
1. Plugin entry-point registration gating (env var checks).
|
||||
2. ``start_login`` shape (PKCE/state, authorize URL parameters).
|
||||
3. ``complete_login`` httpx-mocked happy path + error mapping.
|
||||
4. ``verify_session`` JWT verification — RSA keypair, audience/issuer pinning,
|
||||
``agent_instance_id`` cross-check, ``oauth_contract_version`` tolerance.
|
||||
|
||||
Also exercises ``revoke_session`` (no-op) and ``refresh_session``
|
||||
(unconditional ``RefreshExpiredError``).
|
||||
|
||||
All HTTP is mocked: nothing in this file talks to a real Portal.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import jwt
|
||||
import pytest
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
|
||||
import plugins.dashboard_auth.nous as nous_plugin
|
||||
from hermes_cli.dashboard_auth import (
|
||||
InvalidCodeError,
|
||||
LoginStart,
|
||||
ProviderError,
|
||||
RefreshExpiredError,
|
||||
Session,
|
||||
assert_protocol_compliance,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RSA keypair fixture (module-scope — keygen is slow)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def rsa_keypair() -> Dict[str, Any]:
|
||||
"""Generate an RS256 keypair + matching JWK for verify_session tests."""
|
||||
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
||||
private_pem = key.private_bytes(
|
||||
encoding=serialization.Encoding.PEM,
|
||||
format=serialization.PrivateFormat.PKCS8,
|
||||
encryption_algorithm=serialization.NoEncryption(),
|
||||
).decode()
|
||||
public_numbers = key.public_key().public_numbers()
|
||||
|
||||
def _b64url_uint(n: int) -> str:
|
||||
length = (n.bit_length() + 7) // 8
|
||||
return (
|
||||
base64.urlsafe_b64encode(n.to_bytes(length, "big")).rstrip(b"=").decode()
|
||||
)
|
||||
|
||||
jwk = {
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"kid": "test-key-1",
|
||||
"n": _b64url_uint(public_numbers.n),
|
||||
"e": _b64url_uint(public_numbers.e),
|
||||
}
|
||||
return {"private_pem": private_pem, "jwk": jwk, "kid": jwk["kid"]}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token-mint helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mint_token(
|
||||
rsa_keypair: Dict[str, Any],
|
||||
*,
|
||||
iss: str = "https://portal.example.com",
|
||||
aud: str = "agent:inst123",
|
||||
sub: str = "usr_abc",
|
||||
agent_instance_id: str | None = "inst123",
|
||||
oauth_contract_version: Any = 1,
|
||||
org_id: str | None = "org_xyz",
|
||||
scope: str = "agent_dashboard:access",
|
||||
ttl_seconds: int = 900,
|
||||
extra_claims: Dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
now = int(time.time())
|
||||
claims = {
|
||||
"iss": iss,
|
||||
"aud": aud,
|
||||
"sub": sub,
|
||||
"iat": now,
|
||||
"exp": now + ttl_seconds,
|
||||
"scope": scope,
|
||||
}
|
||||
if agent_instance_id is not None:
|
||||
claims["agent_instance_id"] = agent_instance_id
|
||||
if oauth_contract_version is not None:
|
||||
claims["oauth_contract_version"] = oauth_contract_version
|
||||
if org_id is not None:
|
||||
claims["org_id"] = org_id
|
||||
if extra_claims:
|
||||
claims.update(extra_claims)
|
||||
return jwt.encode(
|
||||
claims,
|
||||
rsa_keypair["private_pem"],
|
||||
algorithm="RS256",
|
||||
headers={"kid": rsa_keypair["kid"]},
|
||||
)
|
||||
|
||||
|
||||
def _patched_jwks(provider: nous_plugin.NousDashboardAuthProvider, rsa_keypair):
|
||||
"""Patch the provider's JWKS client to return our fixture key."""
|
||||
fake_key = MagicMock()
|
||||
fake_key.key = serialization.load_pem_private_key(
|
||||
rsa_keypair["private_pem"].encode(), password=None
|
||||
).public_key()
|
||||
fake_client = MagicMock()
|
||||
fake_client.get_signing_key_from_jwt.return_value = fake_key
|
||||
provider._jwks_client = fake_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_protocol_compliance(self):
|
||||
assert_protocol_compliance(nous_plugin.NousDashboardAuthProvider)
|
||||
|
||||
def test_name_and_display(self):
|
||||
p = nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:inst1", portal_url="https://portal.example.com"
|
||||
)
|
||||
assert p.name == "nous"
|
||||
assert p.display_name == "Nous Research"
|
||||
|
||||
def test_extracts_agent_instance_id(self):
|
||||
p = nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:abc-123", portal_url="https://portal.example.com"
|
||||
)
|
||||
assert p._agent_instance_id == "abc-123"
|
||||
|
||||
def test_strips_trailing_slash_from_portal_url(self):
|
||||
p = nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:x", portal_url="https://portal.example.com/"
|
||||
)
|
||||
assert p._portal_url == "https://portal.example.com"
|
||||
|
||||
def test_rejects_malformed_client_id(self):
|
||||
with pytest.raises(ValueError, match="agent:"):
|
||||
nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="hermes-dashboard", portal_url="https://x"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Plugin entry point: env-gated registration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginRegister:
|
||||
def test_skips_when_client_id_missing(self, monkeypatch):
|
||||
monkeypatch.delenv("HERMES_DASHBOARD_OAUTH_CLIENT_ID", raising=False)
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_PORTAL_URL", "https://p.example")
|
||||
ctx = MagicMock()
|
||||
nous_plugin.register(ctx)
|
||||
ctx.register_dashboard_auth_provider.assert_not_called()
|
||||
|
||||
def test_skips_when_portal_url_missing(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_OAUTH_CLIENT_ID", "agent:x")
|
||||
monkeypatch.delenv("HERMES_DASHBOARD_PORTAL_URL", raising=False)
|
||||
ctx = MagicMock()
|
||||
nous_plugin.register(ctx)
|
||||
ctx.register_dashboard_auth_provider.assert_not_called()
|
||||
|
||||
def test_skips_when_client_id_malformed(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_OAUTH_CLIENT_ID", "hermes-dashboard")
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_PORTAL_URL", "https://p.example")
|
||||
ctx = MagicMock()
|
||||
nous_plugin.register(ctx)
|
||||
ctx.register_dashboard_auth_provider.assert_not_called()
|
||||
|
||||
def test_registers_when_both_present(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_OAUTH_CLIENT_ID", "agent:inst1")
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_PORTAL_URL", "https://p.example")
|
||||
ctx = MagicMock()
|
||||
nous_plugin.register(ctx)
|
||||
ctx.register_dashboard_auth_provider.assert_called_once()
|
||||
registered = ctx.register_dashboard_auth_provider.call_args.args[0]
|
||||
assert isinstance(registered, nous_plugin.NousDashboardAuthProvider)
|
||||
assert registered._client_id == "agent:inst1"
|
||||
|
||||
def test_strips_whitespace_from_env_vars(self, monkeypatch):
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_OAUTH_CLIENT_ID", " agent:x ")
|
||||
monkeypatch.setenv("HERMES_DASHBOARD_PORTAL_URL", " https://p.example ")
|
||||
ctx = MagicMock()
|
||||
nous_plugin.register(ctx)
|
||||
ctx.register_dashboard_auth_provider.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start_login
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestStartLogin:
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:inst1", portal_url="https://portal.example.com"
|
||||
)
|
||||
|
||||
def test_returns_login_start(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
assert isinstance(result, LoginStart)
|
||||
|
||||
def test_redirect_url_targets_portal_authorize(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
assert result.redirect_url.startswith(
|
||||
"https://portal.example.com/oauth/authorize?"
|
||||
)
|
||||
|
||||
def test_authorize_url_has_required_params(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
parsed = urllib.parse.urlparse(result.redirect_url)
|
||||
params = dict(urllib.parse.parse_qsl(parsed.query))
|
||||
assert params["response_type"] == "code"
|
||||
assert params["client_id"] == "agent:inst1"
|
||||
assert params["redirect_uri"] == "https://hermes.fly.dev/auth/callback"
|
||||
assert params["scope"] == "agent_dashboard:access"
|
||||
assert params["code_challenge_method"] == "S256"
|
||||
assert "state" in params
|
||||
assert "code_challenge" in params
|
||||
|
||||
def test_code_verifier_in_cookie_payload_43_to_128_chars(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
assert "hermes_session_pkce" in result.cookie_payload
|
||||
pkce = result.cookie_payload["hermes_session_pkce"]
|
||||
# Shape: ``state=…;verifier=…`` (matches stub-provider convention so
|
||||
# the auth-route layer's parser works uniformly across providers).
|
||||
parts = dict(seg.split("=", 1) for seg in pkce.split(";") if "=" in seg)
|
||||
verifier = parts["verifier"]
|
||||
# RFC 7636 §4.1
|
||||
assert 43 <= len(verifier) <= 128
|
||||
|
||||
def test_state_in_cookie_payload_matches_url_param(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
parsed = urllib.parse.urlparse(result.redirect_url)
|
||||
params = dict(urllib.parse.parse_qsl(parsed.query))
|
||||
pkce = result.cookie_payload["hermes_session_pkce"]
|
||||
parts = dict(seg.split("=", 1) for seg in pkce.split(";") if "=" in seg)
|
||||
assert parts["state"] == params["state"]
|
||||
|
||||
def test_code_challenge_is_s256_of_verifier(self, provider):
|
||||
result = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
parsed = urllib.parse.urlparse(result.redirect_url)
|
||||
params = dict(urllib.parse.parse_qsl(parsed.query))
|
||||
pkce = result.cookie_payload["hermes_session_pkce"]
|
||||
parts = dict(seg.split("=", 1) for seg in pkce.split(";") if "=" in seg)
|
||||
verifier = parts["verifier"]
|
||||
expected_challenge = (
|
||||
base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode("ascii")).digest()
|
||||
)
|
||||
.rstrip(b"=")
|
||||
.decode()
|
||||
)
|
||||
assert params["code_challenge"] == expected_challenge
|
||||
|
||||
def test_two_calls_produce_different_state_and_verifier(self, provider):
|
||||
a = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
b = provider.start_login(
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
assert a.cookie_payload["hermes_session_pkce"] != b.cookie_payload[
|
||||
"hermes_session_pkce"
|
||||
]
|
||||
|
||||
def test_rejects_non_http_scheme(self, provider):
|
||||
with pytest.raises(ProviderError, match="http"):
|
||||
provider.start_login(redirect_uri="ftp://x/auth/callback")
|
||||
|
||||
def test_rejects_http_with_non_localhost(self, provider):
|
||||
with pytest.raises(ProviderError, match="localhost"):
|
||||
provider.start_login(
|
||||
redirect_uri="http://hermes.fly.dev/auth/callback"
|
||||
)
|
||||
|
||||
def test_allows_http_localhost(self, provider):
|
||||
# Should not raise.
|
||||
provider.start_login(redirect_uri="http://localhost:8080/auth/callback")
|
||||
provider.start_login(redirect_uri="http://127.0.0.1:8080/auth/callback")
|
||||
|
||||
def test_rejects_wrong_callback_path(self, provider):
|
||||
with pytest.raises(ProviderError, match="/auth/callback"):
|
||||
provider.start_login(redirect_uri="https://x.example/oauth/cb")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# complete_login (httpx mocked)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompleteLogin:
|
||||
@pytest.fixture
|
||||
def provider(self, rsa_keypair):
|
||||
p = nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:inst123", portal_url="https://portal.example.com"
|
||||
)
|
||||
_patched_jwks(p, rsa_keypair)
|
||||
return p
|
||||
|
||||
def _mock_post(self, status_code: int, body: Any, *, ctype: str = "application/json"):
|
||||
resp = MagicMock(spec=httpx.Response)
|
||||
resp.status_code = status_code
|
||||
if isinstance(body, dict):
|
||||
resp.text = json.dumps(body)
|
||||
resp.json = MagicMock(return_value=body)
|
||||
else:
|
||||
resp.text = body
|
||||
# _parse_json_body bails on non-application/json before .json()
|
||||
# is called, but be safe for callers that pass a non-dict body
|
||||
# with ctype=application/json.
|
||||
resp.json = MagicMock(side_effect=ValueError("not json"))
|
||||
resp.headers = {"content-type": ctype}
|
||||
return resp
|
||||
|
||||
def test_happy_path_returns_session(self, provider, rsa_keypair):
|
||||
access_token = _mint_token(rsa_keypair)
|
||||
mock_resp = self._mock_post(
|
||||
200, {"access_token": access_token, "token_type": "Bearer"}
|
||||
)
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
session = provider.complete_login(
|
||||
code="abc",
|
||||
state="state-val",
|
||||
code_verifier="vfy",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
assert isinstance(session, Session)
|
||||
assert session.user_id == "usr_abc"
|
||||
assert session.provider == "nous"
|
||||
assert session.access_token == access_token
|
||||
assert session.refresh_token == "" # contract V1
|
||||
assert session.org_id == "org_xyz"
|
||||
assert session.email == ""
|
||||
assert session.display_name == ""
|
||||
|
||||
def test_400_raises_invalid_code(self, provider):
|
||||
mock_resp = self._mock_post(400, {"error": "invalid_grant"})
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
with pytest.raises(InvalidCodeError, match="invalid_grant"):
|
||||
provider.complete_login(
|
||||
code="bad", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
|
||||
def test_500_raises_provider_error(self, provider):
|
||||
mock_resp = self._mock_post(500, "internal server error", ctype="text/plain")
|
||||
mock_resp.text = "internal server error"
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
with pytest.raises(ProviderError, match="500"):
|
||||
provider.complete_login(
|
||||
code="x", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
|
||||
def test_missing_access_token_raises(self, provider):
|
||||
mock_resp = self._mock_post(200, {"token_type": "Bearer"})
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
with pytest.raises(ProviderError, match="access_token"):
|
||||
provider.complete_login(
|
||||
code="x", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
|
||||
def test_unexpected_token_type_raises(self, provider, rsa_keypair):
|
||||
access_token = _mint_token(rsa_keypair)
|
||||
mock_resp = self._mock_post(
|
||||
200, {"access_token": access_token, "token_type": "DPoP"}
|
||||
)
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
with pytest.raises(ProviderError, match="token_type"):
|
||||
provider.complete_login(
|
||||
code="x", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
|
||||
def test_network_error_raises_provider_error(self, provider):
|
||||
with patch(
|
||||
"plugins.dashboard_auth.nous.httpx.post",
|
||||
side_effect=httpx.ConnectError("conn refused"),
|
||||
):
|
||||
with pytest.raises(ProviderError, match="unreachable"):
|
||||
provider.complete_login(
|
||||
code="x", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
|
||||
def test_captures_refresh_token_if_present_forward_compat(
|
||||
self, provider, rsa_keypair
|
||||
):
|
||||
"""Forward-compat: contract V1 doesn't issue, but if a future Portal
|
||||
does, we should preserve it in the Session for later use."""
|
||||
access_token = _mint_token(rsa_keypair)
|
||||
mock_resp = self._mock_post(
|
||||
200,
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "Bearer",
|
||||
"refresh_token": "rt-opaque",
|
||||
},
|
||||
)
|
||||
with patch("plugins.dashboard_auth.nous.httpx.post", return_value=mock_resp):
|
||||
session = provider.complete_login(
|
||||
code="x", state="s", code_verifier="v",
|
||||
redirect_uri="https://hermes.fly.dev/auth/callback",
|
||||
)
|
||||
assert session.refresh_token == "rt-opaque"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# verify_session
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVerifySession:
|
||||
@pytest.fixture
|
||||
def provider(self, rsa_keypair):
|
||||
p = nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:inst123", portal_url="https://portal.example.com"
|
||||
)
|
||||
_patched_jwks(p, rsa_keypair)
|
||||
return p
|
||||
|
||||
def test_happy_path_returns_session(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair)
|
||||
session = provider.verify_session(access_token=token)
|
||||
assert session is not None
|
||||
assert session.user_id == "usr_abc"
|
||||
assert session.org_id == "org_xyz"
|
||||
|
||||
def test_expired_token_returns_none(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, ttl_seconds=-1)
|
||||
assert provider.verify_session(access_token=token) is None
|
||||
|
||||
def test_wrong_audience_raises_provider_error(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, aud="agent:other-instance")
|
||||
with pytest.raises(ProviderError, match="verification failed"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
def test_wrong_issuer_raises_provider_error(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, iss="https://evil.example")
|
||||
with pytest.raises(ProviderError, match="verification failed"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
def test_missing_sub_raises(self, provider, rsa_keypair):
|
||||
# PyJWT's "require" set includes sub, so this surfaces as
|
||||
# InvalidTokenError → ProviderError before we ever touch _session_from_claims.
|
||||
token = _mint_token(rsa_keypair, sub="")
|
||||
# Empty sub still encodes successfully; PyJWT's require check only
|
||||
# asserts presence. Our own _session_from_claims rejects empty.
|
||||
with pytest.raises(ProviderError, match="sub"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
def test_agent_instance_id_mismatch_rejected(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, agent_instance_id="some-other-id")
|
||||
with pytest.raises(ProviderError, match="agent_instance_id mismatch"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
def test_agent_instance_id_missing_is_tolerated(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, agent_instance_id=None)
|
||||
session = provider.verify_session(access_token=token)
|
||||
assert session is not None
|
||||
|
||||
def test_contract_version_missing_warns_but_succeeds(
|
||||
self, provider, rsa_keypair, caplog
|
||||
):
|
||||
import logging
|
||||
token = _mint_token(rsa_keypair, oauth_contract_version=None)
|
||||
with caplog.at_level(logging.WARNING, logger="plugins.dashboard_auth.nous"):
|
||||
session = provider.verify_session(access_token=token)
|
||||
assert session is not None
|
||||
assert any(
|
||||
"oauth_contract_version" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_contract_version_mismatch_rejected(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair, oauth_contract_version=2)
|
||||
with pytest.raises(ProviderError, match="oauth_contract_version"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
def test_jwks_unreachable_raises_provider_error(self, provider, rsa_keypair):
|
||||
token = _mint_token(rsa_keypair)
|
||||
# Replace the patched client so it raises.
|
||||
bad_client = MagicMock()
|
||||
bad_client.get_signing_key_from_jwt.side_effect = jwt.PyJWKClientError(
|
||||
"fetch failed"
|
||||
)
|
||||
provider._jwks_client = bad_client
|
||||
with pytest.raises(ProviderError, match="JWKS"):
|
||||
provider.verify_session(access_token=token)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# refresh_session + revoke_session (V1 contract: trivial)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRefreshAndRevoke:
|
||||
@pytest.fixture
|
||||
def provider(self):
|
||||
return nous_plugin.NousDashboardAuthProvider(
|
||||
client_id="agent:inst1", portal_url="https://portal.example.com"
|
||||
)
|
||||
|
||||
def test_refresh_always_raises(self, provider):
|
||||
with pytest.raises(RefreshExpiredError):
|
||||
provider.refresh_session(refresh_token="anything")
|
||||
|
||||
def test_refresh_raises_even_with_empty_token(self, provider):
|
||||
with pytest.raises(RefreshExpiredError):
|
||||
provider.refresh_session(refresh_token="")
|
||||
|
||||
def test_revoke_is_noop(self, provider):
|
||||
# Must not raise; returns None implicitly.
|
||||
assert provider.revoke_session(refresh_token="anything") is None
|
||||
assert provider.revoke_session(refresh_token="") is None
|
||||
Loading…
Add table
Add a link
Reference in a new issue