Fix nous refresh token rotation failure in case where api key mint/retrieval fails

This commit is contained in:
Robin Fernandes 2026-03-02 17:18:15 +11:00
parent 7b38afc179
commit 5e5e0efc60
2 changed files with 292 additions and 6 deletions

View file

@ -21,8 +21,10 @@ import os
import shutil import shutil
import stat import stat
import base64 import base64
import hashlib
import subprocess import subprocess
import time import time
import uuid
import webbrowser import webbrowser
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -147,6 +149,31 @@ def format_auth_error(error: Exception) -> str:
return str(error) return str(error)
def _token_fingerprint(token: Any) -> Optional[str]:
"""Return a short hash fingerprint for telemetry without leaking token bytes."""
if not isinstance(token, str):
return None
cleaned = token.strip()
if not cleaned:
return None
return hashlib.sha256(cleaned.encode("utf-8")).hexdigest()[:12]
def _oauth_trace_enabled() -> bool:
raw = os.getenv("HERMES_OAUTH_TRACE", "").strip().lower()
return raw in {"1", "true", "yes", "on"}
def _oauth_trace(event: str, *, sequence_id: Optional[str] = None, **fields: Any) -> None:
if not _oauth_trace_enabled():
return
payload: Dict[str, Any] = {"event": event}
if sequence_id:
payload["sequence_id"] = sequence_id
payload.update(fields)
logger.info("oauth_trace %s", json.dumps(payload, sort_keys=True, ensure_ascii=False))
# ============================================================================= # =============================================================================
# Auth Store — persistence layer for ~/.hermes/auth.json # Auth Store — persistence layer for ~/.hermes/auth.json
# ============================================================================= # =============================================================================
@ -216,7 +243,29 @@ def _save_auth_store(auth_store: Dict[str, Any]) -> Path:
auth_file.parent.mkdir(parents=True, exist_ok=True) auth_file.parent.mkdir(parents=True, exist_ok=True)
auth_store["version"] = AUTH_STORE_VERSION auth_store["version"] = AUTH_STORE_VERSION
auth_store["updated_at"] = datetime.now(timezone.utc).isoformat() auth_store["updated_at"] = datetime.now(timezone.utc).isoformat()
auth_file.write_text(json.dumps(auth_store, indent=2) + "\n") payload = json.dumps(auth_store, indent=2) + "\n"
tmp_path = auth_file.with_name(f"{auth_file.name}.tmp.{os.getpid()}.{uuid.uuid4().hex}")
try:
with tmp_path.open("w", encoding="utf-8") as handle:
handle.write(payload)
handle.flush()
os.fsync(handle.fileno())
os.replace(tmp_path, auth_file)
try:
dir_fd = os.open(str(auth_file.parent), os.O_RDONLY)
except OSError:
dir_fd = None
if dir_fd is not None:
try:
os.fsync(dir_fd)
finally:
os.close(dir_fd)
finally:
try:
if tmp_path.exists():
tmp_path.unlink()
except OSError:
pass
# Restrict file permissions to owner only # Restrict file permissions to owner only
try: try:
auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR) auth_file.chmod(stat.S_IRUSR | stat.S_IWUSR)
@ -906,6 +955,7 @@ def resolve_nous_runtime_credentials(
expires_in, source ("cache" or "portal"). expires_in, source ("cache" or "portal").
""" """
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds)) min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
with _auth_store_lock(): with _auth_store_lock():
auth_store = _load_auth_store() auth_store = _load_auth_store()
@ -928,8 +978,35 @@ def resolve_nous_runtime_credentials(
).rstrip("/") ).rstrip("/")
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
def _persist_state(reason: str) -> None:
try:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
except Exception as exc:
_oauth_trace(
"nous_state_persist_failed",
sequence_id=sequence_id,
reason=reason,
error_type=type(exc).__name__,
)
raise
_oauth_trace(
"nous_state_persisted",
sequence_id=sequence_id,
reason=reason,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
access_token_fp=_token_fingerprint(state.get("access_token")),
)
verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state) verify = _resolve_verify(insecure=insecure, ca_bundle=ca_bundle, auth_state=state)
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0) timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
_oauth_trace(
"nous_runtime_credentials_start",
sequence_id=sequence_id,
force_mint=bool(force_mint),
min_key_ttl_seconds=min_key_ttl_seconds,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
)
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client: with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
access_token = state.get("access_token") access_token = state.get("access_token")
@ -945,12 +1022,19 @@ def resolve_nous_runtime_credentials(
raise AuthError("Session expired and no refresh token is available.", raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True) provider="nous", relogin_required=True)
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="access_expiring",
refresh_token_fp=_token_fingerprint(refresh_token),
)
refreshed = _refresh_access_token( refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url, client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token, client_id=client_id, refresh_token=refresh_token,
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
previous_refresh_token = refresh_token
state["access_token"] = refreshed["access_token"] state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token state["refresh_token"] = refreshed.get("refresh_token") or refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
@ -964,6 +1048,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc now.timestamp() + access_ttl, tz=timezone.utc
).isoformat() ).isoformat()
access_token = state["access_token"] access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="access_expiring",
previous_refresh_token_fp=_token_fingerprint(previous_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring # Step 2: mint agent key if missing/expiring
used_cached_key = False used_cached_key = False
@ -971,23 +1065,45 @@ def resolve_nous_runtime_credentials(
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds): if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
used_cached_key = True used_cached_key = True
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
else: else:
try: try:
_oauth_trace(
"mint_start",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
mint_payload = _mint_agent_key( mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url, client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds, access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
) )
except AuthError as exc: except AuthError as exc:
_oauth_trace(
"mint_error",
sequence_id=sequence_id,
code=exc.code,
)
# Retry path: access token may be stale server-side despite local checks # Retry path: access token may be stale server-side despite local checks
if exc.code in {"invalid_token", "invalid_grant"} and isinstance(refresh_token, str) and refresh_token: latest_refresh_token = state.get("refresh_token")
if (
exc.code in {"invalid_token", "invalid_grant"}
and isinstance(latest_refresh_token, str)
and latest_refresh_token
):
_oauth_trace(
"refresh_start",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
refresh_token_fp=_token_fingerprint(latest_refresh_token),
)
refreshed = _refresh_access_token( refreshed = _refresh_access_token(
client=client, portal_base_url=portal_base_url, client=client, portal_base_url=portal_base_url,
client_id=client_id, refresh_token=refresh_token, client_id=client_id, refresh_token=latest_refresh_token,
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in")) access_ttl = _coerce_ttl_seconds(refreshed.get("expires_in"))
state["access_token"] = refreshed["access_token"] state["access_token"] = refreshed["access_token"]
state["refresh_token"] = refreshed.get("refresh_token") or refresh_token state["refresh_token"] = refreshed.get("refresh_token") or latest_refresh_token
state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer" state["token_type"] = refreshed.get("token_type") or state.get("token_type") or "Bearer"
state["scope"] = refreshed.get("scope") or state.get("scope") state["scope"] = refreshed.get("scope") or state.get("scope")
refreshed_url = _optional_base_url(refreshed.get("inference_base_url")) refreshed_url = _optional_base_url(refreshed.get("inference_base_url"))
@ -999,6 +1115,16 @@ def resolve_nous_runtime_credentials(
now.timestamp() + access_ttl, tz=timezone.utc now.timestamp() + access_ttl, tz=timezone.utc
).isoformat() ).isoformat()
access_token = state["access_token"] access_token = state["access_token"]
refresh_token = state["refresh_token"]
_oauth_trace(
"refresh_success",
sequence_id=sequence_id,
reason="mint_retry_after_invalid_token",
previous_refresh_token_fp=_token_fingerprint(latest_refresh_token),
new_refresh_token_fp=_token_fingerprint(refresh_token),
)
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key( mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url, client=client, portal_base_url=portal_base_url,
@ -1018,6 +1144,11 @@ def resolve_nous_runtime_credentials(
minted_url = _optional_base_url(mint_payload.get("inference_base_url")) minted_url = _optional_base_url(mint_payload.get("inference_base_url"))
if minted_url: if minted_url:
inference_base_url = minted_url inference_base_url = minted_url
_oauth_trace(
"mint_success",
sequence_id=sequence_id,
reused=bool(mint_payload.get("reused", False)),
)
# Persist routing and TLS metadata for non-interactive refresh/mint # Persist routing and TLS metadata for non-interactive refresh/mint
state["portal_base_url"] = portal_base_url state["portal_base_url"] = portal_base_url
@ -1028,8 +1159,7 @@ def resolve_nous_runtime_credentials(
"ca_bundle": verify if isinstance(verify, str) else None, "ca_bundle": verify if isinstance(verify, str) else None,
} }
_save_provider_state(auth_store, "nous", state) _persist_state("resolve_nous_runtime_credentials_final")
_save_auth_store(auth_store)
api_key = state.get("agent_key") api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key: if not isinstance(api_key, str) or not api_key:

View file

@ -0,0 +1,156 @@
"""Regression tests for Nous OAuth refresh + agent-key mint interactions."""
import json
from datetime import datetime, timezone
from pathlib import Path
import httpx
import pytest
from hermes_cli.auth import AuthError, get_provider_auth_state, resolve_nous_runtime_credentials
def _setup_nous_auth(
hermes_home: Path,
*,
access_token: str = "access-old",
refresh_token: str = "refresh-old",
) -> None:
hermes_home.mkdir(parents=True, exist_ok=True)
auth_store = {
"version": 1,
"active_provider": "nous",
"providers": {
"nous": {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"token_type": "Bearer",
"scope": "inference:mint_agent_key",
"access_token": access_token,
"refresh_token": refresh_token,
"obtained_at": "2026-02-01T00:00:00+00:00",
"expires_in": 0,
"expires_at": "2026-02-01T00:00:00+00:00",
"agent_key": None,
"agent_key_id": None,
"agent_key_expires_at": None,
"agent_key_expires_in": None,
"agent_key_reused": None,
"agent_key_obtained_at": None,
}
},
}
(hermes_home / "auth.json").write_text(json.dumps(auth_store, indent=2))
def _mint_payload(api_key: str = "agent-key") -> dict:
return {
"api_key": api_key,
"key_id": "key-id-1",
"expires_at": datetime.now(timezone.utc).isoformat(),
"expires_in": 1800,
"reused": False,
}
def test_refresh_token_persisted_when_mint_returns_insufficient_credits(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("credits exhausted", provider="nous", code="insufficient_credits")
return _mint_payload(api_key="agent-key-2")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(AuthError) as exc:
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert exc.value.code == "insufficient_credits"
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key-2"
assert refresh_calls == ["refresh-old", "refresh-1"]
def test_refresh_token_persisted_when_mint_times_out(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
return {
"access_token": "access-1",
"refresh_token": "refresh-1",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
raise httpx.ReadTimeout("mint timeout")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
with pytest.raises(httpx.ReadTimeout):
resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
state_after_failure = get_provider_auth_state("nous")
assert state_after_failure is not None
assert state_after_failure["refresh_token"] == "refresh-1"
assert state_after_failure["access_token"] == "access-1"
def test_mint_retry_uses_latest_rotated_refresh_token(tmp_path, monkeypatch):
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
refresh_calls = []
mint_calls = {"count": 0}
def _fake_refresh_access_token(*, client, portal_base_url, client_id, refresh_token):
refresh_calls.append(refresh_token)
idx = len(refresh_calls)
return {
"access_token": f"access-{idx}",
"refresh_token": f"refresh-{idx}",
"expires_in": 0,
"token_type": "Bearer",
}
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
mint_calls["count"] += 1
if mint_calls["count"] == 1:
raise AuthError("stale access token", provider="nous", code="invalid_token")
return _mint_payload(api_key="agent-key")
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == "agent-key"
assert refresh_calls == ["refresh-old", "refresh-1"]