fix(auth) fix a few cases where refresh tokens were not rotated.

This commit is contained in:
Robin Fernandes 2026-05-17 22:29:40 +10:00 committed by Teknium
parent 20bffa5b37
commit 569bc94b59
6 changed files with 166 additions and 109 deletions

View file

@ -623,18 +623,35 @@ class CredentialPool:
return entry
store_refresh = state.get("refresh_token", "")
store_access = state.get("access_token", "")
if store_refresh and store_refresh != entry.refresh_token:
comparable_updates = {
"access_token": store_access,
"refresh_token": store_refresh,
"expires_at": state.get("expires_at"),
"agent_key": state.get("agent_key"),
"agent_key_expires_at": state.get("agent_key_expires_at"),
"inference_base_url": state.get("inference_base_url"),
}
should_sync = any(
value not in (None, "") and getattr(entry, key, None) != value
for key, value in comparable_updates.items()
)
if should_sync:
logger.debug(
"Pool entry %s: syncing tokens from auth.json (Nous refresh token changed)",
"Pool entry %s: syncing Nous state from auth.json",
entry.id,
)
field_updates: Dict[str, Any] = {
"access_token": store_access,
"refresh_token": store_refresh,
"last_status": None,
"last_status_at": None,
"last_error_code": None,
"last_error_reason": None,
"last_error_message": None,
"last_error_reset_at": None,
}
if store_access:
field_updates["access_token"] = store_access
if store_refresh:
field_updates["refresh_token"] = store_refresh
if state.get("expires_at"):
field_updates["expires_at"] = state["expires_at"]
if state.get("agent_key"):
@ -813,40 +830,15 @@ class CredentialPool:
synced = self._sync_nous_entry_from_auth_store(entry)
if synced is not entry:
entry = synced
nous_state = {
"access_token": entry.access_token,
"refresh_token": entry.refresh_token,
"client_id": entry.client_id,
"portal_base_url": entry.portal_base_url,
"inference_base_url": entry.inference_base_url,
"token_type": entry.token_type,
"scope": entry.scope,
"obtained_at": entry.obtained_at,
"expires_at": entry.expires_at,
"agent_key": entry.agent_key,
"agent_key_expires_at": entry.agent_key_expires_at,
"tls": entry.tls,
}
refreshed = auth_mod.refresh_nous_oauth_from_state(
nous_state,
auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
force_refresh=force,
inference_auth_mode=(
auth_mod.NOUS_INFERENCE_AUTH_MODE_LEGACY
if force
else auth_mod.NOUS_INFERENCE_AUTH_MODE_AUTO
),
)
# Apply returned fields: dataclass fields via replace, extras via dict update
field_updates = {}
extra_updates = dict(entry.extra)
_field_names = {f.name for f in fields(entry)}
for k, v in refreshed.items():
if k in _field_names:
field_updates[k] = v
elif k in _EXTRA_KEYS:
extra_updates[k] = v
updated = replace(entry, extra=extra_updates, **field_updates)
updated = self._sync_nous_entry_from_auth_store(entry)
else:
return entry
except Exception as exc:

View file

@ -41,7 +41,7 @@ from dataclasses import dataclass, field
from datetime import datetime, timezone
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
import httpx
@ -89,11 +89,6 @@ NOUS_INFERENCE_AUTH_MODES = frozenset({
NOUS_AUTH_PATH_INVOKE_JWT = "invoke_jwt"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE = "legacy_session_key_cache"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT = "legacy_session_key_mint"
NOUS_AUTH_PATHS = frozenset({
NOUS_AUTH_PATH_INVOKE_JWT,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT,
})
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS
@ -3991,7 +3986,7 @@ def _is_terminal_nous_refresh_error(exc: Exception) -> bool:
return (
isinstance(exc, AuthError)
and exc.provider == "nous"
and exc.code in {"invalid_grant", "invalid_token"}
and exc.code in {"invalid_grant", "invalid_token", "refresh_token_reused"}
and bool(exc.relogin_required)
)
@ -4103,12 +4098,16 @@ def _try_import_shared_nous_state(
"tls": {"insecure": False, "ca_bundle": None},
}
def _persist_shared_refresh(updated_state: Dict[str, Any], _reason: str) -> None:
_write_shared_nous_state(updated_state)
refreshed = refresh_nous_oauth_from_state(
state,
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
inference_auth_mode=NOUS_INFERENCE_AUTH_MODE_FRESH,
on_state_update=_persist_shared_refresh,
)
_write_shared_nous_state(refreshed)
except AuthError as exc:
@ -4163,7 +4162,7 @@ def _refresh_access_token(
code = str(error_payload.get("error", "invalid_grant"))
description = str(error_payload.get("error_description") or "Refresh token exchange failed")
relogin = code in {"invalid_grant", "invalid_token"}
relogin = code in {"invalid_grant", "invalid_token", "refresh_token_reused"}
# Detect the OAuth 2.1 "refresh token reuse" signal from the Nous portal
# server and surface an actionable message. This fires when an external
@ -4173,7 +4172,7 @@ def _refresh_access_token(
# retires the original RT, Hermes's next refresh uses it, and the whole
# session chain gets revoked as a token-theft signal (#15099).
lowered = description.lower()
if "reuse" in lowered or "reuse detected" in lowered:
if code == "refresh_token_reused" or "reuse" in lowered or "reuse detected" in lowered:
description = (
"Nous Portal detected refresh-token reuse and revoked this session.\n"
"This usually means an external process (monitoring script, "
@ -4185,6 +4184,7 @@ def _refresh_access_token(
"instead.\n"
"Re-authenticate with: hermes auth add nous"
)
relogin = True
raise AuthError(description, provider="nous", code=code, relogin_required=relogin)
@ -4418,8 +4418,14 @@ def refresh_nous_oauth_pure(
ca_bundle: Optional[str] = None,
force_refresh: bool = False,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
on_state_update: Optional[Callable[[Dict[str, Any], str], None]] = None,
) -> Dict[str, Any]:
"""Refresh Nous OAuth state without mutating auth.json."""
"""Refresh Nous OAuth state without mutating auth.json directly.
``on_state_update`` is called after a successful access-token refresh and
before any subsequent agent-key mint. Callers that own persistent state can
use it to save the newly rotated refresh token before later work can fail.
"""
inference_auth_mode = _normalize_nous_inference_auth_mode(inference_auth_mode)
state: Dict[str, Any] = {
"access_token": access_token,
@ -4479,6 +4485,8 @@ def refresh_nous_oauth_pure(
state["expires_at"] = datetime.fromtimestamp(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
if on_state_update is not None:
on_state_update(dict(state), "post_refresh_access_token")
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
state,
@ -4519,6 +4527,7 @@ def refresh_nous_oauth_from_state(
timeout_seconds: float = 15.0,
force_refresh: bool = False,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
on_state_update: Optional[Callable[[Dict[str, Any], str], None]] = None,
) -> Dict[str, Any]:
"""Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure."""
tls = state.get("tls") or {}
@ -4540,6 +4549,7 @@ def refresh_nous_oauth_from_state(
ca_bundle=tls.get("ca_bundle"),
force_refresh=force_refresh,
inference_auth_mode=inference_auth_mode,
on_state_update=on_state_update,
)
@ -4603,6 +4613,7 @@ def persist_nous_credentials(
def _sync_nous_pool_from_auth_store() -> None:
"""Best-effort pool reseed after providers.nous changes; never fail login."""
try:
from agent.credential_pool import load_pool

View file

@ -1,13 +1,13 @@
"""Nous Portal upstream adapter.
Reads the user's Nous OAuth state from ``~/.hermes/auth.json``, refreshes
the access token and resolves the ``agent_key`` compatibility credential
when needed, then exposes the upstream base URL plus bearer for the proxy
server to forward to.
Reads the user's Nous OAuth state from ``~/.hermes/auth.json`` through the
shared runtime resolver, refreshes the access token and resolves the
``agent_key`` compatibility credential when needed, then exposes the upstream
base URL plus bearer for the proxy server to forward to.
The ``agent_key`` field may hold either a NAS invoke JWT or the legacy
opaque session key. The refresh helper handles both see
:func:`hermes_cli.auth.refresh_nous_oauth_from_state`.
:func:`hermes_cli.auth.resolve_nous_runtime_credentials`.
"""
from __future__ import annotations
@ -22,12 +22,13 @@ from hermes_cli.auth import (
NOUS_INFERENCE_AUTH_MODE_AUTO,
NOUS_INFERENCE_AUTH_MODE_LEGACY,
_load_auth_store,
_auth_store_lock,
_is_terminal_nous_refresh_error,
_quarantine_nous_oauth_state,
_quarantine_nous_pool_entries,
_save_auth_store,
_write_shared_nous_state,
refresh_nous_oauth_from_state,
resolve_nous_runtime_credentials,
)
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
@ -50,9 +51,8 @@ class NousPortalAdapter(UpstreamAdapter):
"""Proxy upstream for the Nous Portal inference API."""
def __init__(self) -> None:
# Lock guards _load → refresh → _save against parallel proxy requests
# racing to refresh expired tokens. Refresh itself is HTTP, so we
# hold the lock across the network call (brief; OAuth refresh is fast).
# Serialize proxy requests in this process; cross-process token refresh
# and persistence are handled by resolve_nous_runtime_credentials().
self._lock = threading.Lock()
@property
@ -107,8 +107,7 @@ class NousPortalAdapter(UpstreamAdapter):
)
try:
refreshed = refresh_nous_oauth_from_state(
state,
refreshed = resolve_nous_runtime_credentials(
inference_auth_mode=inference_auth_mode,
)
except AuthError as exc:
@ -131,22 +130,20 @@ class NousPortalAdapter(UpstreamAdapter):
f"Failed to refresh Nous Portal credentials: {exc}"
) from exc
self._save_state(refreshed)
agent_key = refreshed.get("agent_key")
agent_key = refreshed.get("api_key")
if not agent_key:
raise RuntimeError(
"Nous Portal refresh did not return a usable agent_key. "
"Try `hermes login nous` to re-authenticate."
)
base_url = refreshed.get("inference_base_url") or DEFAULT_NOUS_INFERENCE_URL
base_url = refreshed.get("base_url") or DEFAULT_NOUS_INFERENCE_URL
base_url = base_url.rstrip("/")
return UpstreamCredential(
bearer=agent_key,
base_url=base_url,
expires_at=refreshed.get("agent_key_expires_at"),
expires_at=refreshed.get("expires_at"),
)
# ------------------------------------------------------------------
@ -156,7 +153,8 @@ class NousPortalAdapter(UpstreamAdapter):
def _read_state(self) -> Optional[Dict[str, Any]]:
try:
store = _load_auth_store()
with _auth_store_lock():
store = _load_auth_store()
except Exception as exc:
logger.warning("proxy: failed to load auth store: %s", exc)
return None
@ -174,21 +172,20 @@ class NousPortalAdapter(UpstreamAdapter):
quarantine_reason: Optional[str] = None,
) -> None:
try:
store = _load_auth_store()
if quarantine_error is not None and quarantine_reason:
_quarantine_nous_pool_entries(
store,
quarantine_error,
reason=quarantine_reason,
)
providers = store.setdefault("providers", {})
providers["nous"] = state
_save_auth_store(store)
with _auth_store_lock():
store = _load_auth_store()
if quarantine_error is not None and quarantine_reason:
_quarantine_nous_pool_entries(
store,
quarantine_error,
reason=quarantine_reason,
)
providers = store.setdefault("providers", {})
providers["nous"] = state
_save_auth_store(store)
_write_shared_nous_state(state)
except Exception as exc:
# Best effort — we still return the fresh credential. The next
# request just won't see cached state, which means another refresh.
logger.warning("proxy: failed to persist refreshed Nous state: %s", exc)
logger.warning("proxy: failed to persist Nous quarantine state: %s", exc)
__all__ = ["NousPortalAdapter"]

View file

@ -625,7 +625,7 @@ def test_nous_pool_terminal_refresh_removes_device_code_entry(tmp_path, monkeypa
"access_token": "manual-nous-key",
}))
monkeypatch.setattr(auth_mod, "refresh_nous_oauth_from_state", _terminal_refresh_failure)
monkeypatch.setattr(auth_mod, "resolve_nous_runtime_credentials", _terminal_refresh_failure)
assert pool.try_refresh_current() is None

View file

@ -1426,6 +1426,36 @@ def test_refresh_token_reuse_detection_surfaces_actionable_message():
assert exc_info.value.relogin_required is True
def test_refresh_token_reuse_error_code_is_terminal():
"""Nous may return refresh_token_reused as the OAuth error code itself."""
from hermes_cli import auth as auth_mod
class _FakeResponse:
status_code = 400
def json(self):
return {
"error": "refresh_token_reused",
"error_description": "Refresh token reuse detected",
}
class _FakeClient:
def post(self, *args, **kwargs):
return _FakeResponse()
with pytest.raises(AuthError) as exc_info:
auth_mod._refresh_access_token(
client=_FakeClient(),
portal_base_url="https://portal.nousresearch.com",
client_id="hermes-cli",
refresh_token="rt_consumed_elsewhere",
)
assert exc_info.value.code == "refresh_token_reused"
assert exc_info.value.relogin_required is True
assert auth_mod._is_terminal_nous_refresh_error(exc_info.value) is True
def test_refresh_token_exchange_sends_refresh_token_header():
"""Nous refresh tokens must be sent in a header so sandbox proxies can
substitute placeholder credentials without parsing form bodies.
@ -1686,6 +1716,46 @@ def test_try_import_shared_returns_none_on_refresh_failure(
assert auth_mod._read_shared_nous_state() is None
def test_try_import_shared_persists_rotated_token_when_mint_fails(
shared_store_env, monkeypatch,
):
"""A forced shared import refresh rotates the single-use token before minting.
If the later agent-key mint fails, the shared store must still keep the
rotated refresh token; otherwise the next import attempt replays the
consumed token and trips refresh-token reuse.
"""
from hermes_cli import auth as auth_mod
shared_state = _full_state_fixture()
shared_state["refresh_token"] = "refresh-old"
shared_state["access_token"] = "access-old"
auth_mod._write_shared_nous_state(shared_state)
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
assert refresh_token == "refresh-old"
return {
"access_token": "access-new",
"refresh_token": "refresh-new",
"expires_in": 900,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
assert access_token == "access-new"
raise AuthError("credits exhausted", provider="nous", code="insufficient_credits")
monkeypatch.setattr(auth_mod, "_refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key)
assert auth_mod._try_import_shared_nous_state() is None
shared_after = auth_mod._read_shared_nous_state()
assert shared_after is not None
assert shared_after["refresh_token"] == "refresh-new"
assert shared_after["access_token"] == "access-new"
def test_try_import_shared_rehydrates_on_success(shared_store_env, monkeypatch):
"""Happy path: stored refresh_token is accepted, forced refresh+mint
returns a fresh access_token + agent_key, and the returned dict has

View file

@ -103,7 +103,7 @@ def test_nous_adapter_authenticated_with_refresh_token_only(tmp_path, monkeypatc
assert NousPortalAdapter().is_authenticated()
def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatch):
def test_nous_adapter_get_credential_uses_runtime_resolver(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_auth_store(tmp_path, {
"access_token": "access-tok",
@ -114,32 +114,24 @@ def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatc
})
refreshed_state = {
"access_token": "access-tok",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "minted-bearer",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"api_key": "minted-bearer",
"base_url": "https://inference-api.nousresearch.com/v1",
"expires_at": "2099-01-01T00:00:00Z",
}
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value=refreshed_state,
) as mock_refresh:
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_credential()
mock_refresh.assert_called_once()
mock_resolve.assert_called_once()
assert cred.bearer == "minted-bearer"
assert cred.base_url == "https://inference-api.nousresearch.com/v1"
assert cred.expires_at == "2099-01-01T00:00:00Z"
assert cred.token_type == "Bearer"
# Verify state was persisted back
stored = json.loads((tmp_path / "auth.json").read_text())
assert stored["providers"]["nous"]["agent_key"] == "minted-bearer"
def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
@ -153,19 +145,15 @@ def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch)
})
refreshed_state = {
"access_token": "jwt-access",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "legacy-bearer",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"api_key": "legacy-bearer",
"base_url": "https://inference-api.nousresearch.com/v1",
"expires_at": "2099-01-01T00:00:00Z",
}
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value=refreshed_state,
) as mock_refresh:
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_retry_credential(
failed_credential=UpstreamCredential(
@ -177,7 +165,7 @@ def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch)
assert cred is not None
assert cred.bearer == "legacy-bearer"
assert mock_refresh.call_args.kwargs["inference_auth_mode"] == "legacy"
assert mock_resolve.call_args.kwargs["inference_auth_mode"] == "legacy"
def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch):
@ -189,8 +177,8 @@ def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
) as mock_refresh:
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
) as mock_resolve:
adapter = NousPortalAdapter()
cred = adapter.get_retry_credential(
failed_credential=UpstreamCredential(
@ -201,7 +189,7 @@ def test_nous_adapter_retry_credential_skips_opaque_bearer(tmp_path, monkeypatch
)
assert cred is None
mock_refresh.assert_not_called()
mock_resolve.assert_not_called()
def test_nous_adapter_get_credential_raises_when_not_logged_in(tmp_path, monkeypatch):
@ -219,7 +207,7 @@ def test_nous_adapter_get_credential_raises_on_refresh_failure(tmp_path, monkeyp
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=RuntimeError("Refresh session has been revoked"),
):
adapter = NousPortalAdapter()
@ -240,7 +228,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch
assert load_pool("nous").select() is not None
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=AuthError(
"Refresh session has been revoked",
provider="nous",
@ -270,7 +258,7 @@ def test_nous_adapter_get_credential_raises_when_no_agent_key_returned(tmp_path,
})
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
return_value={"access_token": "a", "refresh_token": "r"},
):
adapter = NousPortalAdapter()
@ -291,7 +279,7 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
counter = [0]
counter_lock = threading.Lock()
def serializing_refresh(state, **kwargs):
def serializing_refresh(**kwargs):
# If another thread is already inside refresh, the lock is broken.
if in_flight.is_set():
overlap_detected.set()
@ -305,10 +293,9 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
counter[0] += 1
idx = counter[0]
return {
**state,
"agent_key": f"key-{idx}",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"api_key": f"key-{idx}",
"expires_at": "2099-01-01T00:00:00Z",
"base_url": "https://inference-api.nousresearch.com/v1",
}
finally:
in_flight.clear()
@ -324,7 +311,7 @@ def test_nous_adapter_concurrent_refresh_serialized(tmp_path, monkeypatch):
errors.append(exc)
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
"hermes_cli.proxy.adapters.nous_portal.resolve_nous_runtime_credentials",
side_effect=serializing_refresh,
):
threads = [threading.Thread(target=worker) for _ in range(3)]