fix(auth) fix a few cases where refresh tokens were not rotated.

This commit is contained in:
Robin Fernandes 2026-05-17 22:29:40 +10:00 committed by Teknium
parent 20bffa5b37
commit 569bc94b59
6 changed files with 166 additions and 109 deletions

View file

@ -41,7 +41,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
@ -89,11 +89,6 @@ NOUS_INFERENCE_AUTH_MODES = frozenset({
NOUS_AUTH_PATH_INVOKE_JWT = "invoke_jwt"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE = "legacy_session_key_cache"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT = "legacy_session_key_mint"
NOUS_AUTH_PATHS = frozenset({
NOUS_AUTH_PATH_INVOKE_JWT,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT,
})
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS
@ -3991,7 +3986,7 @@ def _is_terminal_nous_refresh_error(exc: Exception) -> bool:
return (
isinstance(exc, AuthError)
and exc.provider == "nous"
and exc.code in {"invalid_grant", "invalid_token"}
and exc.code in {"invalid_grant", "invalid_token", "refresh_token_reused"}
and bool(exc.relogin_required)
)
@ -4103,12 +4098,16 @@ def _try_import_shared_nous_state(
"tls": {"insecure": False, "ca_bundle": None},
}
def _persist_shared_refresh(updated_state: Dict[str, Any], _reason: str) -> None:
_write_shared_nous_state(updated_state)
refreshed = refresh_nous_oauth_from_state(
state,
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
inference_auth_mode=NOUS_INFERENCE_AUTH_MODE_FRESH,
on_state_update=_persist_shared_refresh,
)
_write_shared_nous_state(refreshed)
except AuthError as exc:
@ -4163,7 +4162,7 @@ def _refresh_access_token(
code = str(error_payload.get("error", "invalid_grant"))
description = str(error_payload.get("error_description") or "Refresh token exchange failed")
relogin = code in {"invalid_grant", "invalid_token"}
relogin = code in {"invalid_grant", "invalid_token", "refresh_token_reused"}
# Detect the OAuth 2.1 "refresh token reuse" signal from the Nous portal
# server and surface an actionable message. This fires when an external
@ -4173,7 +4172,7 @@ def _refresh_access_token(
# retires the original RT, Hermes's next refresh uses it, and the whole
# session chain gets revoked as a token-theft signal (#15099).
lowered = description.lower()
if "reuse" in lowered or "reuse detected" in lowered:
if code == "refresh_token_reused" or "reuse" in lowered or "reuse detected" in lowered:
description = (
"Nous Portal detected refresh-token reuse and revoked this session.\n"
"This usually means an external process (monitoring script, "
@ -4185,6 +4184,7 @@ def _refresh_access_token(
"instead.\n"
"Re-authenticate with: hermes auth add nous"
)
relogin = True
raise AuthError(description, provider="nous", code=code, relogin_required=relogin)
@ -4418,8 +4418,14 @@ def refresh_nous_oauth_pure(
ca_bundle: Optional[str] = None,
force_refresh: bool = False,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
on_state_update: Optional[Callable[[Dict[str, Any], str], None]] = None,
) -> Dict[str, Any]:
"""Refresh Nous OAuth state without mutating auth.json."""
"""Refresh Nous OAuth state without mutating auth.json directly.
``on_state_update`` is called after a successful access-token refresh and
before any subsequent agent-key mint. Callers that own persistent state can
use it to save the newly rotated refresh token before later work can fail.
"""
inference_auth_mode = _normalize_nous_inference_auth_mode(inference_auth_mode)
state: Dict[str, Any] = {
"access_token": access_token,
@ -4479,6 +4485,8 @@ def refresh_nous_oauth_pure(
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
if on_state_update is not None:
on_state_update(dict(state), "post_refresh_access_token")
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
state,
@ -4519,6 +4527,7 @@ def refresh_nous_oauth_from_state(
timeout_seconds: float = 15.0,
force_refresh: bool = False,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
on_state_update: Optional[Callable[[Dict[str, Any], str], None]] = None,
) -> Dict[str, Any]:
"""Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure."""
tls = state.get("tls") or {}
@ -4540,6 +4549,7 @@ def refresh_nous_oauth_from_state(
ca_bundle=tls.get("ca_bundle"),
force_refresh=force_refresh,
inference_auth_mode=inference_auth_mode,
on_state_update=on_state_update,
)
@ -4603,6 +4613,7 @@ def persist_nous_credentials(
def _sync_nous_pool_from_auth_store() -> None:
"""Best-effort pool reseed after providers.nous changes; never fail login."""
try:
from agent.credential_pool import load_pool