refactor(auth): collapse Nous inference fallback controls

This commit is contained in:
Robin Fernandes 2026-05-17 20:34:39 +10:00 committed by Teknium
parent 89a3d038cf
commit 0bac7dd05b
13 changed files with 1071 additions and 240 deletions

View file

@ -1252,12 +1252,20 @@ def _resolve_nous_runtime_api(*, force_refresh: bool = False) -> Optional[tuple[
or the credential pool.
"""
try:
from hermes_cli.auth import resolve_nous_runtime_credentials
from hermes_cli.auth import (
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_LEGACY,
resolve_nous_runtime_credentials,
)
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
force_mint=force_refresh,
auth_mode=(
NOUS_INFERENCE_AUTH_LEGACY
if force_refresh
else NOUS_INFERENCE_AUTH_AUTO
),
)
except Exception as exc:
logger.debug("Auxiliary Nous runtime credential resolution failed: %s", exc)
@ -2501,12 +2509,12 @@ def _refresh_provider_credentials(provider: str) -> bool:
_evict_cached_clients(normalized)
return True
if normalized == "nous":
from hermes_cli.auth import resolve_nous_runtime_credentials
from hermes_cli.auth import NOUS_INFERENCE_AUTH_LEGACY, resolve_nous_runtime_credentials
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
force_mint=True,
auth_mode=NOUS_INFERENCE_AUTH_LEGACY,
)
if not str(creds.get("api_key", "") or "").strip():
return False

View file

@ -831,7 +831,11 @@ class CredentialPool:
nous_state,
min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
force_refresh=force,
force_mint=force,
auth_mode=(
auth_mod.NOUS_INFERENCE_AUTH_LEGACY
if force
else auth_mod.NOUS_INFERENCE_AUTH_AUTO
),
)
# Apply returned fields: dataclass fields via replace, extras via dict update
field_updates = {}
@ -952,25 +956,27 @@ class CredentialPool:
exc,
reason="credential_pool_refresh_failure",
)
auth_mod._quarantine_nous_pool_entries(
auth_store,
exc,
reason="credential_pool_refresh_failure",
)
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
except Exception as clear_exc:
logger.debug("Failed to clear terminal Nous OAuth state: %s", clear_exc)
cleared = replace(
entry,
access_token=None,
refresh_token=None,
agent_key=None,
agent_key_expires_at=None,
)
self._replace_entry(entry, cleared)
singleton_sources = {
auth_mod.NOUS_DEVICE_CODE_SOURCE,
f"manual:{auth_mod.NOUS_DEVICE_CODE_SOURCE}",
}
self._entries = [
item for item in self._entries
if item.source not in singleton_sources
]
if self._current_id == entry.id:
self._current_id = None
self._persist()
self._mark_exhausted(
cleared,
401,
{"reason": getattr(exc, "code", None), "message": str(exc)},
)
return None
self._mark_exhausted(entry, None)
return None
@ -1408,7 +1414,22 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup
elif provider == "nous":
state = _load_provider_state(auth_store, "nous")
if state and not _is_suppressed(provider, "device_code"):
has_runtime_material = bool(
isinstance(state, dict)
and (
str(state.get("access_token") or "").strip()
or str(state.get("agent_key") or "").strip()
)
)
if state and not has_runtime_material:
retained = [
entry for entry in entries
if entry.source not in {"device_code", "manual:device_code"}
]
if len(retained) != len(entries):
entries[:] = retained
changed = True
if state and has_runtime_material and not _is_suppressed(provider, "device_code"):
active_sources.add("device_code")
# Prefer a user-supplied label embedded in the singleton state
# (set by persist_nous_credentials(label=...) when the user ran

View file

@ -11,6 +11,12 @@ Architecture:
- resolve_provider() picks the active provider via priority chain
- resolve_*_runtime_credentials() handles token refresh and key minting
- logout_command() is the CLI entry point for clearing auth
Nous authentication paths:
- Invoke JWT (preferred): use a scoped access_token directly for inference.
- Legacy session key (fallback): mint an opaque 24h key when JWT auth is
unavailable, or when HERMES_AGENT_USE_LEGACY_SESSION_KEYS is set for
debugging or rollback.
"""
from __future__ import annotations
@ -71,6 +77,15 @@ NOUS_LEGACY_AGENT_KEY_SCOPE = "inference:mint_agent_key"
NOUS_INFERENCE_INVOKE_SCOPE = "inference:invoke"
DEFAULT_NOUS_SCOPE = f"{NOUS_INFERENCE_INVOKE_SCOPE} {NOUS_LEGACY_AGENT_KEY_SCOPE}"
NOUS_LEGACY_SESSION_KEYS_ENV = "HERMES_AGENT_USE_LEGACY_SESSION_KEYS"
NOUS_DEVICE_CODE_SOURCE = "device_code"
NOUS_INFERENCE_AUTH_AUTO = "auto"
NOUS_INFERENCE_AUTH_FRESH = "fresh"
NOUS_INFERENCE_AUTH_LEGACY = "legacy"
NOUS_INFERENCE_AUTH_MODES = frozenset({
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_FRESH,
NOUS_INFERENCE_AUTH_LEGACY,
})
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS
@ -1554,6 +1569,8 @@ def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
def _scope_values(raw_scope: Any) -> set[str]:
# OAuth token responses normally return a space-separated string. Keep
# collection support for JWT ``scp`` claims and older stored test fixtures.
scopes: set[str] = set()
if isinstance(raw_scope, str):
for part in raw_scope.replace(",", " ").split():
@ -1575,37 +1592,24 @@ def _nous_scope_has_invoke(raw_scope: Any) -> bool:
return NOUS_INFERENCE_INVOKE_SCOPE in _scope_values(raw_scope)
def _nous_invoke_jwt_is_usable(
def _normalize_nous_auth_mode(auth_mode: Optional[str]) -> str:
mode = str(auth_mode or NOUS_INFERENCE_AUTH_AUTO).strip().lower()
if mode not in NOUS_INFERENCE_AUTH_MODES:
allowed = ", ".join(sorted(NOUS_INFERENCE_AUTH_MODES))
raise ValueError(
f"Invalid Nous inference auth mode {auth_mode!r}; expected one of: {allowed}"
)
return mode
def _nous_invoke_jwt_status(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> bool:
claims = _decode_jwt_claims(token)
if not claims:
return False
scopes = (
_scope_values(scope)
| _scope_values(claims.get("scope"))
| _scope_values(claims.get("scp"))
)
if NOUS_INFERENCE_INVOKE_SCOPE not in scopes:
return False
exp = claims.get("exp")
skew = max(0, int(min_ttl_seconds))
if isinstance(exp, (int, float)):
return float(exp) > (time.time() + skew)
return not _is_expiring(expires_at, skew)
def _nous_invoke_jwt_unavailable_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> str:
) -> Optional[str]:
"""Return None when the token can be used for inference, else a reason."""
claims = _decode_jwt_claims(token)
if not claims:
return "access_token_not_jwt"
@ -1618,11 +1622,149 @@ def _nous_invoke_jwt_unavailable_reason(
return "missing_inference_invoke_scope"
exp = claims.get("exp")
skew = max(0, int(min_ttl_seconds))
if isinstance(exp, (int, float)) and float(exp) <= (time.time() + skew):
return "invoke_jwt_expiring"
if not isinstance(exp, (int, float)) and _is_expiring(expires_at, skew):
if isinstance(exp, (int, float)):
if float(exp) <= (time.time() + skew):
return "invoke_jwt_expiring"
return None
if _is_expiring(expires_at, skew):
return "invoke_jwt_expiry_unknown_or_expiring"
return "invoke_jwt_unavailable"
return None
def _nous_invoke_jwt_is_usable(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> bool:
return (
_nous_invoke_jwt_status(
token,
scope=scope,
expires_at=expires_at,
min_ttl_seconds=min_ttl_seconds,
)
is None
)
def _nous_invoke_jwt_unavailable_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> str:
return (
_nous_invoke_jwt_status(
token,
scope=scope,
expires_at=expires_at,
min_ttl_seconds=min_ttl_seconds,
)
or "invoke_jwt_unavailable"
)
def _nous_can_select_invoke_jwt(auth_mode: str = NOUS_INFERENCE_AUTH_AUTO) -> bool:
return (
not _nous_legacy_session_keys_forced()
and _normalize_nous_auth_mode(auth_mode) != NOUS_INFERENCE_AUTH_LEGACY
)
def _nous_legacy_session_key_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
) -> str:
if _normalize_nous_auth_mode(auth_mode) == NOUS_INFERENCE_AUTH_LEGACY:
return "forced_legacy_session_key"
if _nous_legacy_session_keys_forced():
return "forced_legacy_session_keys"
return _nous_invoke_jwt_unavailable_reason(
token,
scope=scope,
expires_at=expires_at,
)
def _nous_cached_agent_key_is_usable(
state: Dict[str, Any],
min_ttl_seconds: int,
) -> bool:
return _agent_key_is_usable(state, min_ttl_seconds)
def _choose_nous_inference_auth_path(
state: Dict[str, Any],
*,
access_token: Any = None,
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
) -> Tuple[str, Optional[str]]:
auth_mode = _normalize_nous_auth_mode(auth_mode)
token = state.get("access_token") if access_token is None else access_token
if (
_nous_can_select_invoke_jwt(auth_mode)
and _nous_invoke_jwt_is_usable(
token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
return "invoke_jwt", None
if (
auth_mode == NOUS_INFERENCE_AUTH_AUTO
and _nous_cached_agent_key_is_usable(
state,
max(60, int(min_key_ttl_seconds)),
)
):
return "legacy_session_key_cache", None
return (
"legacy_session_key_mint",
_nous_legacy_session_key_reason(
token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
auth_mode=auth_mode,
),
)
def _log_nous_invoke_jwt_selected(
*,
access_token: Any,
sequence_id: Optional[str] = None,
) -> None:
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
def _log_nous_legacy_session_key_selected(
reason: str,
*,
access_token: Any,
sequence_id: Optional[str] = None,
) -> None:
logger.info(
"Nous inference auth: using legacy session key path (%s)",
reason,
)
_oauth_trace(
"nous_legacy_session_key_selected",
sequence_id=sequence_id,
reason=reason,
access_token_fp=_token_fingerprint(access_token),
)
def _nous_jwt_expires_at(token: Any, fallback_expires_at: Any = None) -> Optional[str]:
@ -1645,7 +1787,17 @@ def _set_nous_agent_key_from_invoke_jwt(
if not isinstance(access_token, str) or not access_token.strip():
return
now = datetime.now(timezone.utc)
effective_obtained_at = obtained_at or now.isoformat()
existing_obtained_at = state.get("agent_key_obtained_at")
if obtained_at:
effective_obtained_at = obtained_at
elif (
state.get("agent_key") == access_token
and isinstance(existing_obtained_at, str)
and existing_obtained_at.strip()
):
effective_obtained_at = existing_obtained_at
else:
effective_obtained_at = now.isoformat()
expires_at = _nous_jwt_expires_at(access_token, state.get("expires_at"))
expires_epoch = _parse_iso_timestamp(expires_at)
expires_in = (
@ -1664,6 +1816,38 @@ def _set_nous_agent_key_from_invoke_jwt(
state["agent_key_obtained_at"] = effective_obtained_at
def _select_nous_invoke_jwt(
state: Dict[str, Any],
*,
access_token: Any = None,
sequence_id: Optional[str] = None,
) -> None:
if isinstance(access_token, str) and access_token.strip():
state["access_token"] = access_token
_set_nous_agent_key_from_invoke_jwt(state)
_log_nous_invoke_jwt_selected(
access_token=state.get("access_token"),
sequence_id=sequence_id,
)
_NOUS_EFFECTIVE_STATE_IGNORED_KEYS = frozenset({
# These are derived from expires_at/JWT exp and naturally tick down between
# reads. Persisting only these changes makes auth.json noisy and defeats
# the mtime-keyed auth-status cache.
"expires_in",
"agent_key_expires_in",
})
def _nous_effective_provider_state(state: Dict[str, Any]) -> Dict[str, Any]:
return {
key: value
for key, value in state.items()
if key not in _NOUS_EFFECTIVE_STATE_IGNORED_KEYS
}
def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> bool:
claims = _decode_jwt_claims(access_token)
exp = claims.get("exp")
@ -3476,6 +3660,57 @@ def _is_nous_invoke_scope_refusal(exc: Exception) -> bool:
)
def _nous_device_scope(
requested_scope: Optional[str],
*,
default_scope: str = DEFAULT_NOUS_SCOPE,
) -> Tuple[str, bool]:
explicit_scope = requested_scope is not None
scope = requested_scope or default_scope
if _nous_legacy_session_keys_forced():
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
return scope, explicit_scope
def _request_nous_device_code_with_scope_fallback(
*,
client: httpx.Client,
portal_base_url: str,
client_id: str,
scope: str,
allow_legacy_fallback: bool,
) -> Tuple[Dict[str, Any], str]:
try:
return (
_request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
),
scope,
)
except Exception as exc:
if (
allow_legacy_fallback
and _nous_scope_has_invoke(scope)
and _is_nous_invoke_scope_refusal(exc)
):
logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope")
_oauth_trace("nous_device_code_invoke_scope_refused")
retry_scope = NOUS_LEGACY_AGENT_KEY_SCOPE
return (
_request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=retry_scope,
),
retry_scope,
)
raise
def _poll_for_token(
client: httpx.Client,
portal_base_url: str,
@ -3817,6 +4052,39 @@ def _quarantine_nous_oauth_state(
invalidate_nous_auth_status_cache()
def _quarantine_nous_pool_entries(
auth_store: Dict[str, Any],
error: AuthError,
*,
reason: str,
) -> bool:
"""Remove singleton-seeded Nous pool entries that contain dead OAuth state."""
pool = auth_store.get("credential_pool")
if not isinstance(pool, dict):
return False
entries = pool.get("nous")
if not isinstance(entries, list):
return False
retained = []
removed = False
singleton_sources = {NOUS_DEVICE_CODE_SOURCE, f"manual:{NOUS_DEVICE_CODE_SOURCE}"}
for entry in entries:
if isinstance(entry, dict) and entry.get("source") in singleton_sources:
removed = True
continue
retained.append(entry)
if removed:
pool["nous"] = retained
_oauth_trace(
"nous_pool_device_code_quarantined",
reason=reason,
error_code=error.code,
)
return removed
def _try_import_shared_nous_state(
*,
timeout_seconds: float = 15.0,
@ -3842,7 +4110,7 @@ def _try_import_shared_nous_state(
# Build a full state dict so refresh_nous_oauth_from_state has every
# field it needs. force_refresh=True gets us a fresh access_token
# for this profile; force_mint=True gets us a fresh agent_key.
# for this profile; fresh auth mode avoids stale cached legacy keys.
state: Dict[str, Any] = {
"access_token": shared.get("access_token"),
"refresh_token": shared.get("refresh_token"),
@ -3863,7 +4131,7 @@ def _try_import_shared_nous_state(
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
force_mint=True,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
)
_write_shared_nous_state(refreshed)
except AuthError as exc:
@ -4121,6 +4389,11 @@ def resolve_nous_access_token(
exc,
reason="managed_access_token_refresh_failure",
)
_quarantine_nous_pool_entries(
auth_store,
exc,
reason="managed_access_token_refresh_failure",
)
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
raise
@ -4167,9 +4440,10 @@ def refresh_nous_oauth_pure(
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
force_refresh: bool = False,
force_mint: bool = False,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
) -> Dict[str, Any]:
"""Refresh Nous OAuth state without mutating auth.json."""
auth_mode = _normalize_nous_auth_mode(auth_mode)
state: Dict[str, Any] = {
"access_token": access_token,
"refresh_token": refresh_token,
@ -4229,38 +4503,17 @@ def refresh_nous_oauth_pure(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
state.get("access_token"),
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
access_token_fp=_token_fingerprint(state.get("access_token")),
)
elif force_mint or not _agent_key_is_usable(state, min_agent_key_ttl):
fallback_reason = (
"forced_legacy_session_keys"
if legacy_session_keys
else _nous_invoke_jwt_unavailable_reason(
state.get("access_token"),
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
logger.info(
"Nous inference auth: using legacy session key path (%s)",
fallback_reason,
)
_oauth_trace(
"nous_legacy_session_key_selected",
reason=fallback_reason,
access_token_fp=_token_fingerprint(state.get("access_token")),
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
state,
min_key_ttl_seconds=min_agent_key_ttl,
auth_mode=auth_mode,
)
if selected_auth_path == "invoke_jwt":
_select_nous_invoke_jwt(state)
elif selected_auth_path == "legacy_session_key_mint":
_log_nous_legacy_session_key_selected(
fallback_reason or "legacy_session_key_required",
access_token=state.get("access_token"),
)
mint_payload = _mint_agent_key(
client=client,
@ -4288,7 +4541,7 @@ def refresh_nous_oauth_from_state(
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
timeout_seconds: float = 15.0,
force_refresh: bool = False,
force_mint: bool = False,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
) -> Dict[str, Any]:
"""Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure."""
tls = state.get("tls") or {}
@ -4309,13 +4562,10 @@ def refresh_nous_oauth_from_state(
insecure=tls.get("insecure"),
ca_bundle=tls.get("ca_bundle"),
force_refresh=force_refresh,
force_mint=force_mint,
auth_mode=auth_mode,
)
NOUS_DEVICE_CODE_SOURCE = "device_code"
def persist_nous_credentials(
creds: Dict[str, Any],
*,
@ -4390,7 +4640,7 @@ def resolve_nous_runtime_credentials(
timeout_seconds: float = 15.0,
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
force_mint: bool = False,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
) -> Dict[str, Any]:
"""
Resolve Nous inference credentials for runtime use.
@ -4402,6 +4652,7 @@ def resolve_nous_runtime_credentials(
Returns dict with: provider, base_url, api_key, key_id, expires_at,
expires_in, source ("invoke_jwt", "cache", or "portal"), and auth_path.
"""
auth_mode = _normalize_nous_auth_mode(auth_mode)
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
@ -4413,6 +4664,9 @@ def resolve_nous_runtime_credentials(
raise AuthError("Hermes is not logged into Nous Portal.",
provider="nous", relogin_required=True)
persisted_state = dict(state)
state_persisted = False
portal_base_url = (
_optional_base_url(state.get("portal_base_url"))
or os.getenv("HERMES_PORTAL_BASE_URL")
@ -4427,6 +4681,17 @@ def resolve_nous_runtime_credentials(
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
def _persist_state(reason: str) -> None:
nonlocal persisted_state, state_persisted
if (
_nous_effective_provider_state(state)
== _nous_effective_provider_state(persisted_state)
):
_oauth_trace(
"nous_state_persist_skipped",
sequence_id=sequence_id,
reason=reason,
)
return
try:
_save_provider_state(auth_store, "nous", state)
_save_auth_store(auth_store)
@ -4445,6 +4710,8 @@ def resolve_nous_runtime_credentials(
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
access_token_fp=_token_fingerprint(state.get("access_token")),
)
persisted_state = dict(state)
state_persisted = True
# Mirror post-refresh state to the shared store so sibling
# profiles don't hold stale refresh_tokens after rotation.
# Best-effort — any failure is logged and swallowed inside
@ -4456,7 +4723,7 @@ def resolve_nous_runtime_credentials(
_oauth_trace(
"nous_runtime_credentials_start",
sequence_id=sequence_id,
force_mint=bool(force_mint),
auth_mode=auth_mode,
min_key_ttl_seconds=min_key_ttl_seconds,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
)
@ -4520,6 +4787,11 @@ def resolve_nous_runtime_credentials(
exc,
reason="runtime_access_refresh_failure",
)
_quarantine_nous_pool_entries(
auth_store,
exc,
reason="runtime_access_refresh_failure",
)
_persist_state("terminal_runtime_access_refresh_failure")
raise
now = datetime.now(timezone.utc)
@ -4554,50 +4826,28 @@ def resolve_nous_runtime_credentials(
# the opaque session key.
used_cached_key = False
mint_payload: Optional[Dict[str, Any]] = None
selected_auth_path = "legacy_session_key"
legacy_session_keys = _nous_legacy_session_keys_forced()
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
state,
access_token=access_token,
min_key_ttl_seconds=min_key_ttl_seconds,
auth_mode=auth_mode,
)
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
selected_auth_path = "invoke_jwt"
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
if selected_auth_path == "invoke_jwt":
_select_nous_invoke_jwt(
state,
access_token=access_token,
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
elif not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
elif selected_auth_path == "legacy_session_key_cache":
used_cached_key = True
selected_auth_path = "legacy_session_key_cache"
logger.info("Nous inference auth: using cached legacy session key")
logger.info("Nous inference auth: using cached agent_key")
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
else:
fallback_reason = (
"forced_legacy_session_keys"
if legacy_session_keys
else _nous_invoke_jwt_unavailable_reason(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
selected_auth_path = "legacy_session_key_mint"
logger.info(
"Nous inference auth: using legacy session key path (%s)",
fallback_reason,
)
_oauth_trace(
"nous_legacy_session_key_selected",
_log_nous_legacy_session_key_selected(
fallback_reason or "legacy_session_key_required",
access_token=access_token,
sequence_id=sequence_id,
reason=fallback_reason,
access_token_fp=_token_fingerprint(access_token),
)
try:
_oauth_trace(
@ -4646,6 +4896,11 @@ def resolve_nous_runtime_credentials(
exc,
reason="runtime_mint_retry_refresh_failure",
)
_quarantine_nous_pool_entries(
auth_store,
exc,
reason="runtime_mint_retry_refresh_failure",
)
_persist_state("terminal_runtime_mint_retry_refresh_failure")
raise
now = datetime.now(timezone.utc)
@ -4674,22 +4929,24 @@ def resolve_nous_runtime_credentials(
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
retry_auth_mode = (
NOUS_INFERENCE_AUTH_LEGACY
if auth_mode == NOUS_INFERENCE_AUTH_LEGACY
else NOUS_INFERENCE_AUTH_FRESH
)
retry_auth_path, _ = _choose_nous_inference_auth_path(
state,
access_token=access_token,
min_key_ttl_seconds=min_key_ttl_seconds,
auth_mode=retry_auth_mode,
)
if retry_auth_path == "invoke_jwt":
mint_payload = None
selected_auth_path = "invoke_jwt"
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
_select_nous_invoke_jwt(
state,
access_token=access_token,
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
else:
mint_payload = _mint_agent_key(
@ -4727,7 +4984,8 @@ def resolve_nous_runtime_credentials(
_persist_state("resolve_nous_runtime_credentials_final")
_sync_nous_pool_from_auth_store()
if state_persisted:
_sync_nous_pool_from_auth_store()
api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key:
@ -6433,10 +6691,7 @@ def _nous_device_code_login(
or pconfig.inference_base_url
).rstrip("/")
client_id = client_id or pconfig.client_id
explicit_scope = scope is not None
scope = scope or pconfig.scope
if _nous_legacy_session_keys_forced():
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
scope, explicit_scope = _nous_device_scope(scope, default_scope=pconfig.scope)
timeout = httpx.Timeout(timeout_seconds)
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
@ -6451,30 +6706,13 @@ def _nous_device_code_login(
print(f"TLS verification: custom CA bundle ({ca_bundle})")
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
try:
device_data = _request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
except Exception as exc:
if (
not explicit_scope
and _nous_scope_has_invoke(scope)
and _is_nous_invoke_scope_refusal(exc)
):
logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope")
_oauth_trace("nous_device_code_invoke_scope_refused")
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
device_data = _request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
else:
raise
device_data, scope = _request_nous_device_code_with_scope_fallback(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
allow_legacy_fallback=not explicit_scope,
)
verification_url = str(device_data["verification_uri_complete"])
user_code = str(device_data["user_code"])
@ -6543,7 +6781,7 @@ def _nous_device_code_login(
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=False,
force_mint=True,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
)
except AuthError as exc:
if exc.code == "subscription_required":

View file

@ -81,6 +81,21 @@ class UpstreamAdapter(ABC):
refresh fails. The proxy will return 401 to the client.
"""
def get_retry_credential(
self,
*,
failed_credential: UpstreamCredential,
status_code: int,
) -> Optional[UpstreamCredential]:
"""Return an alternate credential after an upstream auth failure.
The default is no retry. Providers can override this for one-shot
fallback paths, such as switching from a preferred token type to a
legacy bearer after the upstream rejects the first request.
"""
del failed_credential, status_code
return None
def describe(self) -> str:
"""One-line status summary for ``proxy status``."""
try:

View file

@ -19,13 +19,16 @@ from typing import Any, Dict, FrozenSet, Optional
from hermes_cli.auth import (
AuthError,
DEFAULT_NOUS_INFERENCE_URL,
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_LEGACY,
_load_auth_store,
_is_terminal_nous_refresh_error,
_quarantine_nous_oauth_state,
_quarantine_nous_pool_entries,
_save_auth_store,
_write_shared_nous_state,
refresh_nous_oauth_from_state,
)
)
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
logger = logging.getLogger(__name__)
@ -76,6 +79,21 @@ class NousPortalAdapter(UpstreamAdapter):
)
def get_credential(self) -> UpstreamCredential:
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_AUTO)
def get_retry_credential(
self,
*,
failed_credential: UpstreamCredential,
status_code: int,
) -> Optional[UpstreamCredential]:
del failed_credential
if status_code != 401:
return None
logger.info("proxy: Nous upstream rejected bearer; retrying with legacy session key")
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_LEGACY)
def _get_credential(self, *, auth_mode: str) -> UpstreamCredential:
with self._lock:
state = self._read_state()
if state is None:
@ -84,7 +102,10 @@ class NousPortalAdapter(UpstreamAdapter):
)
try:
refreshed = refresh_nous_oauth_from_state(state)
refreshed = refresh_nous_oauth_from_state(
state,
auth_mode=auth_mode,
)
except AuthError as exc:
if _is_terminal_nous_refresh_error(exc):
_quarantine_nous_oauth_state(
@ -92,7 +113,11 @@ class NousPortalAdapter(UpstreamAdapter):
exc,
reason="proxy_refresh_failure",
)
self._save_state(state)
self._save_state(
state,
quarantine_error=exc,
quarantine_reason="proxy_refresh_failure",
)
raise RuntimeError(
f"Failed to refresh Nous Portal credentials: {exc}"
) from exc
@ -136,9 +161,21 @@ class NousPortalAdapter(UpstreamAdapter):
return None
return dict(state) # copy so the refresh helper can mutate freely
def _save_state(self, state: Dict[str, Any]) -> None:
def _save_state(
self,
state: Dict[str, Any],
*,
quarantine_error: Optional[AuthError] = None,
quarantine_reason: Optional[str] = None,
) -> None:
try:
store = _load_auth_store()
if quarantine_error is not None and quarantine_reason:
_quarantine_nous_pool_entries(
store,
quarantine_error,
reason=quarantine_reason,
)
providers = store.setdefault("providers", {})
providers["nous"] = state
_save_auth_store(store)

View file

@ -26,7 +26,7 @@ except ImportError:
web = None # type: ignore[assignment]
AIOHTTP_AVAILABLE = False
from hermes_cli.proxy.adapters.base import UpstreamAdapter
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
logger = logging.getLogger(__name__)
@ -136,50 +136,93 @@ def create_app(adapter: UpstreamAdapter) -> "web.Application":
logger.warning("proxy: credential resolution failed: %s", exc)
return _json_error(401, str(exc), code="upstream_auth_failed")
upstream_url = f"{cred.base_url.rstrip('/')}{rel_path}"
# Preserve query string verbatim.
if request.query_string:
upstream_url = f"{upstream_url}?{request.query_string}"
# Forward body verbatim. Read into memory once — request bodies for
# chat/completions/embeddings are small (<1MB typically). If we ever
# need to forward large multipart uploads we'll switch to streaming
# the request body too.
body = await request.read()
fwd_headers = _filter_request_headers(request.headers)
fwd_headers["Authorization"] = f"{cred.token_type} {cred.bearer}"
logger.debug(
"proxy: forwarding %s %s -> %s (body=%d bytes)",
request.method, rel_path, upstream_url, len(body),
)
# Use a per-request session so connection state doesn't leak between
# clients. Could be optimized to a shared session later.
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=300)
try:
session = aiohttp.ClientSession(timeout=timeout)
except Exception as exc: # pragma: no cover - aiohttp setup issue
return _json_error(500, f"proxy session init failed: {exc}")
try:
upstream_resp = await session.request(
request.method,
upstream_url,
data=body if body else None,
headers=fwd_headers,
allow_redirects=False,
async def _send_upstream(active_cred: UpstreamCredential):
upstream_url = f"{active_cred.base_url.rstrip('/')}{rel_path}"
# Preserve query string verbatim.
if request.query_string:
upstream_url = f"{upstream_url}?{request.query_string}"
fwd_headers = _filter_request_headers(request.headers)
fwd_headers["Authorization"] = f"{active_cred.token_type} {active_cred.bearer}"
logger.debug(
"proxy: forwarding %s %s -> %s (body=%d bytes)",
request.method, rel_path, upstream_url, len(body),
)
except aiohttp.ClientError as exc:
await session.close()
logger.warning("proxy: upstream connection failed: %s", exc)
return _json_error(502, f"upstream connection failed: {exc}",
code="upstream_unreachable")
except asyncio.TimeoutError:
await session.close()
return _json_error(504, "upstream request timed out",
code="upstream_timeout")
try:
session = aiohttp.ClientSession(timeout=timeout)
except Exception as exc: # pragma: no cover - aiohttp setup issue
raise RuntimeError(f"proxy session init failed: {exc}") from exc
try:
upstream_resp = await session.request(
request.method,
upstream_url,
data=body if body else None,
headers=fwd_headers,
allow_redirects=False,
)
except Exception:
await session.close()
raise
return session, upstream_resp
async def _open_upstream(active_cred: UpstreamCredential):
try:
return await _send_upstream(active_cred)
except RuntimeError as exc:
return _json_error(500, str(exc)), None
except aiohttp.ClientError as exc:
logger.warning("proxy: upstream connection failed: %s", exc)
return (
_json_error(
502,
f"upstream connection failed: {exc}",
code="upstream_unreachable",
),
None,
)
except asyncio.TimeoutError:
return (
_json_error(
504,
"upstream request timed out",
code="upstream_timeout",
),
None,
)
session_or_response, upstream_resp = await _open_upstream(cred)
if upstream_resp is None:
return session_or_response
session = session_or_response
if upstream_resp.status == 401:
try:
retry_cred = adapter.get_retry_credential(
failed_credential=cred,
status_code=upstream_resp.status,
)
except Exception as exc:
logger.warning("proxy: retry credential resolution failed: %s", exc)
retry_cred = None
if retry_cred is not None:
upstream_resp.release()
await session.close()
session_or_response, upstream_resp = await _open_upstream(retry_cred)
if upstream_resp is None:
return session_or_response
session = session_or_response
# Stream response back. Headers first, then chunked body.
resp = web.StreamResponse(

View file

@ -1815,7 +1815,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
so the UI can render the verification page link + user code.
"""
if provider_id == "nous":
from hermes_cli.auth import _request_device_code, PROVIDER_REGISTRY
from hermes_cli.auth import (
_nous_device_scope,
_request_nous_device_code_with_scope_fallback,
PROVIDER_REGISTRY,
)
import httpx
pconfig = PROVIDER_REGISTRY["nous"]
portal_base_url = (
@ -1824,22 +1828,31 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
or pconfig.portal_base_url
).rstrip("/")
client_id = pconfig.client_id
scope = pconfig.scope
scope, explicit_scope = _nous_device_scope(None, default_scope=pconfig.scope)
def _do_nous_device_request():
with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client:
return _request_device_code(
with httpx.Client(
timeout=httpx.Timeout(15.0),
headers={"Accept": "application/json"},
) as client:
return _request_nous_device_code_with_scope_fallback(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
allow_legacy_fallback=not explicit_scope,
)
device_data = await asyncio.get_running_loop().run_in_executor(None, _do_nous_device_request)
device_data, effective_scope = await asyncio.get_running_loop().run_in_executor(
None, _do_nous_device_request
)
sid, sess = _new_oauth_session("nous", "device_code")
sess["device_code"] = str(device_data["device_code"])
sess["interval"] = int(device_data["interval"])
sess["expires_at"] = time.time() + int(device_data["expires_in"])
sess["portal_base_url"] = portal_base_url
sess["client_id"] = client_id
sess["scope"] = effective_scope
threading.Thread(
target=_nous_poller, args=(sid,), daemon=True, name=f"oauth-poll-{sid[:6]}"
).start()
@ -1968,7 +1981,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
def _nous_poller(session_id: str) -> None:
"""Background poller that drives a Nous device-code flow to completion."""
from hermes_cli.auth import _poll_for_token, refresh_nous_oauth_from_state
from hermes_cli.auth import (
NOUS_INFERENCE_AUTH_FRESH,
_poll_for_token,
refresh_nous_oauth_from_state,
)
from datetime import datetime, timezone
import httpx
with _oauth_sessions_lock:
@ -1979,6 +1996,7 @@ def _nous_poller(session_id: str) -> None:
client_id = sess["client_id"]
device_code = sess["device_code"]
interval = sess["interval"]
scope = sess.get("scope")
expires_in = max(60, int(sess["expires_at"] - time.time()))
try:
with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client:
@ -1997,7 +2015,7 @@ def _nous_poller(session_id: str) -> None:
"portal_base_url": portal_base_url,
"inference_base_url": token_data.get("inference_base_url"),
"client_id": client_id,
"scope": token_data.get("scope"),
"scope": token_data.get("scope") or scope,
"token_type": token_data.get("token_type", "Bearer"),
"access_token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
@ -2009,8 +2027,11 @@ def _nous_poller(session_id: str) -> None:
"expires_in": token_ttl,
}
full_state = refresh_nous_oauth_from_state(
auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0,
force_refresh=False, force_mint=True,
auth_state,
min_key_ttl_seconds=300,
timeout_seconds=15.0,
force_refresh=False,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
)
from hermes_cli.auth import persist_nous_credentials
persist_nous_credentials(full_state)

View file

@ -2628,12 +2628,20 @@ class AIAgent:
return False
try:
from hermes_cli.auth import resolve_nous_runtime_credentials
from hermes_cli.auth import (
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_LEGACY,
resolve_nous_runtime_credentials,
)
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))),
timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")),
force_mint=force,
auth_mode=(
NOUS_INFERENCE_AUTH_LEGACY
if force
else NOUS_INFERENCE_AUTH_AUTO
),
)
except Exception as exc:
logger.debug("Nous credential refresh failed: %s", exc)

View file

@ -566,7 +566,7 @@ def test_load_pool_mirrors_nous_invoke_jwt_agent_key_runtime_api_key(tmp_path, m
assert pool_entry["agent_key_expires_at"] == expires_at
def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch):
def test_nous_pool_terminal_refresh_removes_device_code_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
monkeypatch.setenv("HERMES_SHARED_AUTH_DIR", str(tmp_path / "shared"))
_write_auth_store(
@ -591,7 +591,7 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch):
},
)
from agent.credential_pool import load_pool
from agent.credential_pool import PooledCredential, load_pool
from hermes_cli import auth as auth_mod
from hermes_cli.auth import AuthError
@ -606,18 +606,30 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch):
relogin_required=True,
)
pool = load_pool("nous")
selected = pool.select()
assert selected is not None
assert selected.source == "device_code"
pool.add_entry(PooledCredential.from_dict("nous", {
"id": "legacy-seeded",
"source": "manual:device_code",
"auth_type": "oauth",
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"agent_key": "old-agent-key",
}))
pool.add_entry(PooledCredential.from_dict("nous", {
"id": "manual-key",
"source": "manual",
"auth_type": "api_key",
"access_token": "manual-nous-key",
}))
monkeypatch.setattr(auth_mod, "refresh_nous_oauth_from_state", _terminal_refresh_failure)
pool = load_pool("nous")
assert pool.select() is not None
assert pool.try_refresh_current() is None
entry = pool.entries()[0]
assert entry.last_status == "exhausted"
assert entry.last_error_code == 401
assert entry.refresh_token is None
assert entry.access_token is None
assert entry.agent_key is None
assert [entry.id for entry in pool.entries()] == ["manual-key"]
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
nous_state = auth_payload["providers"]["nous"]
@ -625,11 +637,63 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch):
assert not nous_state.get("access_token")
assert not nous_state.get("agent_key")
assert nous_state["last_auth_error"]["code"] == "invalid_grant"
assert [entry["id"] for entry in auth_payload["credential_pool"]["nous"]] == ["manual-key"]
assert pool.try_refresh_current() is None
assert refresh_calls["count"] == 1
def test_load_pool_removes_nous_device_code_when_singleton_quarantined(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"active_provider": "nous",
"providers": {
"nous": {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"last_auth_error": {"code": "invalid_grant"},
}
},
"credential_pool": {
"nous": [
{
"id": "seeded-current",
"source": "device_code",
"auth_type": "oauth",
"access_token": "stale-access",
"refresh_token": "stale-refresh",
"agent_key": "stale-agent",
},
{
"id": "seeded-legacy",
"source": "manual:device_code",
"auth_type": "oauth",
"access_token": "older-stale-access",
},
{
"id": "manual-key",
"source": "manual",
"auth_type": "api_key",
"access_token": "manual-nous-key",
},
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("nous")
assert [entry.id for entry in pool.entries()] == ["manual-key"]
auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
assert [entry["id"] for entry in auth_payload["credential_pool"]["nous"]] == ["manual-key"]
def test_load_pool_removes_stale_file_backed_singleton_entry(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False)

View file

@ -231,6 +231,83 @@ def test_resolve_nous_runtime_credentials_prefers_invoke_jwt_and_mirrors(
assert pool_entries[0]["source"] == auth_mod.NOUS_DEVICE_CODE_SOURCE
def test_resolve_nous_runtime_credentials_invoke_jwt_is_idempotent(
tmp_path,
monkeypatch,
):
import hermes_cli.auth as auth_mod
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
exp = int(time.time() + 3600)
expires_at = datetime.fromtimestamp(exp, tz=timezone.utc).isoformat()
token = _jwt_with_claims({
"sub": "test-user",
"scope": auth_mod.DEFAULT_NOUS_SCOPE,
"exp": exp,
})
original_obtained_at = "2026-04-17T22:00:10+00:00"
auth_store = {
"version": 1,
"active_provider": "nous",
"providers": {
"nous": {
"portal_base_url": "https://portal.example.com",
"inference_base_url": "https://inference.example.com/v1",
"client_id": "hermes-cli",
"token_type": "Bearer",
"scope": auth_mod.DEFAULT_NOUS_SCOPE,
"access_token": token,
"refresh_token": "refresh-token",
"obtained_at": "2026-02-01T00:00:00+00:00",
"expires_in": 123,
"expires_at": expires_at,
"agent_key": token,
"agent_key_id": None,
"agent_key_expires_at": expires_at,
"agent_key_expires_in": 123,
"agent_key_reused": False,
"agent_key_obtained_at": original_obtained_at,
"tls": {"insecure": False, "ca_bundle": None},
},
},
}
auth_path = hermes_home / "auth.json"
auth_path.write_text(json.dumps(auth_store, indent=2))
before_content = auth_path.read_text()
before_mtime = auth_path.stat().st_mtime_ns
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
def _unexpected_mint(*args, **kwargs):
raise AssertionError("stable invoke JWT should not mint a legacy key")
def _unexpected_shared_write(*args, **kwargs):
raise AssertionError("unchanged invoke JWT resolution should not sync shared store")
sync_calls = []
monkeypatch.setattr(auth_mod, "_mint_agent_key", _unexpected_mint)
monkeypatch.setattr(auth_mod, "_write_shared_nous_state", _unexpected_shared_write)
monkeypatch.setattr(
auth_mod,
"_sync_nous_pool_from_auth_store",
lambda: sync_calls.append(True),
)
creds = auth_mod.resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
assert creds["api_key"] == token
assert creds["source"] == "invoke_jwt"
assert auth_path.read_text() == before_content
assert auth_path.stat().st_mtime_ns == before_mtime
assert sync_calls == []
payload = json.loads(auth_path.read_text())
assert (
payload["providers"]["nous"]["agent_key_obtained_at"]
== original_obtained_at
)
def test_resolve_nous_runtime_credentials_trusts_invoke_jwt_exp_over_stale_metadata(
tmp_path,
monkeypatch,
@ -301,6 +378,41 @@ def test_resolve_nous_runtime_credentials_does_not_apply_legacy_ttl_to_invoke_jw
assert payload["credential_pool"]["nous"][0]["agent_key"] == token
def test_legacy_auth_mode_bypasses_usable_invoke_jwt(tmp_path, monkeypatch):
import hermes_cli.auth as auth_mod
hermes_home = tmp_path / "hermes"
token = _invoke_jwt(seconds=3600)
_setup_nous_auth(
hermes_home,
access_token=token,
scope=auth_mod.DEFAULT_NOUS_SCOPE,
expires_at=_future_iso(3600),
expires_in=3600,
)
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
mint_calls = []
def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds):
del client, portal_base_url, min_ttl_seconds
mint_calls.append(access_token)
return _mint_payload(api_key="legacy-after-jwt-401")
monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key)
creds = auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
auth_mode=auth_mod.NOUS_INFERENCE_AUTH_LEGACY,
)
assert mint_calls == [token]
assert creds["api_key"] == "legacy-after-jwt-401"
assert creds["auth_path"] == "legacy_session_key_mint"
payload = json.loads((hermes_home / "auth.json").read_text())
assert payload["providers"]["nous"]["agent_key"] == "legacy-after-jwt-401"
def test_resolve_nous_runtime_credentials_falls_back_when_invoke_scope_missing(
tmp_path,
monkeypatch,
@ -735,6 +847,9 @@ def test_terminal_refresh_failure_quarantines_tokens(
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
from agent.credential_pool import load_pool
assert load_pool("nous").select() is not None
shared_state = _full_state_fixture()
shared_state["access_token"] = "access-old"
@ -765,6 +880,8 @@ def test_terminal_refresh_failure_quarantines_tokens(
assert not state_after_failure.get("agent_key")
assert state_after_failure["last_auth_error"]["code"] == "invalid_grant"
assert auth_mod._read_shared_nous_state() is None
payload = json.loads((hermes_home / "auth.json").read_text())
assert payload.get("credential_pool", {}).get("nous") == []
with pytest.raises(AuthError, match="No access token found"):
auth_mod.resolve_nous_runtime_credentials(min_key_ttl_seconds=300)
@ -780,6 +897,9 @@ def test_managed_access_token_refresh_failure_quarantines_tokens(
hermes_home = tmp_path / "hermes"
_setup_nous_auth(hermes_home, refresh_token="refresh-old")
monkeypatch.setenv("HERMES_HOME", str(hermes_home))
from agent.credential_pool import load_pool
assert load_pool("nous").select() is not None
refresh_calls: list[str] = []
@ -802,6 +922,8 @@ def test_managed_access_token_refresh_failure_quarantines_tokens(
assert not state_after_failure.get("refresh_token")
assert not state_after_failure.get("access_token")
assert state_after_failure["last_auth_error"]["message"] == "Invalid refresh token"
payload = json.loads((hermes_home / "auth.json").read_text())
assert payload.get("credential_pool", {}).get("nous") == []
with pytest.raises(AuthError, match="No access token found"):
auth_mod.resolve_nous_access_token()
@ -1076,7 +1198,11 @@ def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch
calls after a Nous 401 before the fix it would raise AuthError because
providers.nous was empty.
"""
from hermes_cli.auth import persist_nous_credentials, resolve_nous_runtime_credentials
from hermes_cli.auth import (
NOUS_INFERENCE_AUTH_FRESH,
persist_nous_credentials,
resolve_nous_runtime_credentials,
)
hermes_home = tmp_path / "hermes"
hermes_home.mkdir(parents=True, exist_ok=True)
@ -1104,7 +1230,10 @@ def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch
monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token)
monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key)
creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300, force_mint=True)
creds = resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
)
assert creds["api_key"] == "new-agent-key"
@ -1569,7 +1698,7 @@ def test_try_import_shared_rehydrates_on_success(shared_store_env, monkeypatch):
def _fake_refresh(state, **kwargs):
# Simulate portal returning fresh tokens + a new agent_key
assert kwargs.get("force_refresh") is True
assert kwargs.get("force_mint") is True
assert kwargs.get("auth_mode") == auth_mod.NOUS_INFERENCE_AUTH_FRESH
return {
**state,
"access_token": "fresh-access-tok",
@ -1697,7 +1826,7 @@ def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token(
creds = auth_mod.resolve_nous_runtime_credentials(
min_key_ttl_seconds=300,
force_mint=True,
auth_mode=auth_mod.NOUS_INFERENCE_AUTH_FRESH,
)
assert creds["api_key"] == "agent-key-from-shared-token"

View file

@ -141,6 +141,45 @@ def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatc
assert stored["providers"]["nous"]["agent_key"] == "minted-bearer"
def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_auth_store(tmp_path, {
"access_token": "jwt-access",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "jwt-access",
})
refreshed_state = {
"access_token": "jwt-access",
"refresh_token": "refresh-tok",
"client_id": "hermes-cli",
"portal_base_url": "https://portal.nousresearch.com",
"inference_base_url": "https://inference-api.nousresearch.com/v1",
"agent_key": "legacy-bearer",
"agent_key_expires_at": "2099-01-01T00:00:00Z",
}
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
return_value=refreshed_state,
) as mock_refresh:
adapter = NousPortalAdapter()
cred = adapter.get_retry_credential(
failed_credential=UpstreamCredential(
bearer="jwt-access",
base_url="https://inference-api.nousresearch.com/v1",
),
status_code=401,
)
assert cred is not None
assert cred.bearer == "legacy-bearer"
assert mock_refresh.call_args.kwargs["auth_mode"] == "legacy"
def test_nous_adapter_get_credential_raises_when_not_logged_in(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
adapter = NousPortalAdapter()
@ -166,6 +205,7 @@ def test_nous_adapter_get_credential_raises_on_refresh_failure(tmp_path, monkeyp
def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch):
from hermes_cli.auth import AuthError
from agent.credential_pool import load_pool
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
_write_auth_store(tmp_path, {
@ -173,6 +213,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch
"refresh_token": "refresh-tok",
"agent_key": "stale-agent-key",
})
assert load_pool("nous").select() is not None
with patch(
"hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state",
@ -193,6 +234,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch
assert not nous_state.get("access_token")
assert not nous_state.get("agent_key")
assert nous_state["last_auth_error"]["code"] == "invalid_grant"
assert stored.get("credential_pool", {}).get("nous") == []
def test_nous_adapter_get_credential_raises_when_no_agent_key_returned(tmp_path, monkeypatch):
@ -291,12 +333,15 @@ class FakeAdapter(UpstreamAdapter):
"""A test adapter that returns a fixed credential without touching disk."""
def __init__(self, base_url: str, bearer: str = "test-bearer",
allowed=None, raise_on_credential=False):
allowed=None, raise_on_credential=False,
retry_bearer: str | None = None):
self._base_url = base_url
self._bearer = bearer
self._allowed = frozenset(allowed or ["/chat/completions"])
self._raise = raise_on_credential
self._retry_bearer = retry_bearer
self.calls = 0
self.retry_calls = 0
@property
def name(self): return "fake"
@ -318,6 +363,17 @@ class FakeAdapter(UpstreamAdapter):
expires_at="2099-01-01T00:00:00Z",
)
def get_retry_credential(self, *, failed_credential, status_code):
del failed_credential
self.retry_calls += 1
if status_code != 401 or not self._retry_bearer:
return None
return UpstreamCredential(
bearer=self._retry_bearer,
base_url=self._base_url,
expires_at="2099-01-01T00:00:00Z",
)
async def _start_runner(app: "web.Application"):
"""Spin up an aiohttp app on an ephemeral localhost port. Returns (runner, base_url)."""
@ -358,6 +414,25 @@ def _build_fake_upstream(captured: Dict[str, Any]) -> "web.Application":
return app
def _build_retrying_fake_upstream(captured: Dict[str, Any]) -> "web.Application":
async def maybe_unauthorized(request):
body = await request.read()
auth = request.headers.get("Authorization")
captured["requests"].append({
"method": request.method,
"path": request.path,
"auth": auth,
"body": body.decode("utf-8") if body else "",
})
if auth == "Bearer jwt-bearer":
return web.json_response({"error": "bad token"}, status=401)
return web.json_response({"ok": True})
app = web.Application()
app.router.add_route("*", "/v1/chat/completions", maybe_unauthorized)
return app
def test_server_forwards_chat_completions():
async def run():
captured: Dict[str, Any] = {"requests": []}
@ -388,6 +463,41 @@ def test_server_forwards_chat_completions():
asyncio.run(run())
def test_server_retries_once_with_adapter_retry_credential_on_401():
async def run():
captured: Dict[str, Any] = {"requests": []}
upstream_runner, upstream_base = await _start_runner(
_build_retrying_fake_upstream(captured)
)
adapter = FakeAdapter(
f"{upstream_base}/v1",
bearer="jwt-bearer",
retry_bearer="legacy-bearer",
)
proxy_runner, proxy_base = await _start_runner(create_app(adapter))
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{proxy_base}/v1/chat/completions",
json={"model": "Hermes-4-70B"},
) as resp:
assert resp.status == 200
data = await resp.json()
assert data["ok"] is True
assert adapter.retry_calls == 1
assert [req["auth"] for req in captured["requests"]] == [
"Bearer jwt-bearer",
"Bearer legacy-bearer",
]
finally:
await proxy_runner.cleanup()
await upstream_runner.cleanup()
asyncio.run(run())
def test_server_rejects_disallowed_path():
async def run():
adapter = FakeAdapter("http://unused.example/v1", allowed=["/chat/completions"])

View file

@ -19,11 +19,12 @@ The fix:
These tests pin the corrected behavior.
"""
import asyncio
import time
from datetime import datetime, timezone
from unittest.mock import patch
import pytest
import httpx
from fastapi.testclient import TestClient
from hermes_cli.web_server import _SESSION_TOKEN, app
@ -32,6 +33,32 @@ client = TestClient(app)
HEADERS = {"X-Hermes-Session-Token": _SESSION_TOKEN}
def _fake_nous_device_data():
return {
"device_code": "device-code",
"user_code": "NOUS-1234",
"verification_uri": "https://portal.nousresearch.com/device",
"verification_uri_complete": (
"https://portal.nousresearch.com/device?user_code=NOUS-1234"
),
"expires_in": 600,
"interval": 5,
}
def _invoke_scope_refusal():
request = httpx.Request("POST", "https://portal.nousresearch.com/oauth/device/code")
response = httpx.Response(
400,
json={
"error": "invalid_scope",
"error_description": "unsupported scope inference:invoke",
},
request=request,
)
return httpx.HTTPStatusError("invalid scope", request=request, response=response)
def test_minimax_login_does_not_launch_anthropic_flow():
"""Click 'Login' on MiniMax → MUST NOT return claude.ai auth_url."""
fake_user_code_resp = {
@ -48,6 +75,9 @@ def test_minimax_login_does_not_launch_anthropic_flow():
), patch(
"hermes_cli.auth._minimax_pkce_pair",
return_value=("verifier-stub", "challenge-stub", "stub-state"),
), patch(
"hermes_cli.web_server._minimax_poller",
return_value=None,
):
resp = client.post(
"/api/providers/oauth/minimax-oauth/start",
@ -69,6 +99,113 @@ def test_minimax_login_does_not_launch_anthropic_flow():
assert body["expires_in"] == 600
def test_nous_dashboard_device_flow_honors_legacy_scope_override(monkeypatch):
from hermes_cli import auth as auth_mod
from hermes_cli import web_server as ws
requested_scopes = []
def fake_request_device_code(**kwargs):
requested_scopes.append(kwargs["scope"])
return _fake_nous_device_data()
monkeypatch.setenv(auth_mod.NOUS_LEGACY_SESSION_KEYS_ENV, "true")
monkeypatch.setattr(auth_mod, "_request_device_code", fake_request_device_code)
monkeypatch.setattr(ws, "_nous_poller", lambda sid: None)
result = asyncio.run(ws._start_device_code_flow("nous"))
try:
assert requested_scopes == [auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE]
assert result["flow"] == "device_code"
assert result["user_code"] == "NOUS-1234"
assert (
ws._oauth_sessions[result["session_id"]]["scope"]
== auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE
)
finally:
ws._oauth_sessions.pop(result["session_id"], None)
def test_nous_dashboard_device_flow_retries_legacy_scope_on_invoke_refusal(monkeypatch):
from hermes_cli import auth as auth_mod
from hermes_cli import web_server as ws
requested_scopes = []
def fake_request_device_code(**kwargs):
requested_scopes.append(kwargs["scope"])
if len(requested_scopes) == 1:
raise _invoke_scope_refusal()
return _fake_nous_device_data()
monkeypatch.delenv(auth_mod.NOUS_LEGACY_SESSION_KEYS_ENV, raising=False)
monkeypatch.setattr(auth_mod, "_request_device_code", fake_request_device_code)
monkeypatch.setattr(ws, "_nous_poller", lambda sid: None)
result = asyncio.run(ws._start_device_code_flow("nous"))
try:
assert requested_scopes == [
auth_mod.DEFAULT_NOUS_SCOPE,
auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE,
]
assert (
ws._oauth_sessions[result["session_id"]]["scope"]
== auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE
)
finally:
ws._oauth_sessions.pop(result["session_id"], None)
def test_nous_dashboard_poller_preserves_effective_scope_when_token_omits_scope(monkeypatch):
from hermes_cli import auth as auth_mod
from hermes_cli import web_server as ws
session_id = "nous-effective-scope-test"
ws._oauth_sessions[session_id] = {
"session_id": session_id,
"provider": "nous",
"flow": "device_code",
"created_at": time.time(),
"status": "pending",
"error_message": None,
"portal_base_url": "https://portal.nousresearch.com",
"client_id": "hermes-cli",
"device_code": "device-code",
"interval": 5,
"expires_at": time.time() + 600,
"scope": auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE,
}
captured_state = {}
def fake_refresh_nous_oauth_from_state(state, **kwargs):
captured_state.update(state)
return {**state, "agent_key": "legacy-agent-key"}
monkeypatch.setattr(
auth_mod,
"_poll_for_token",
lambda **kwargs: {
"access_token": "access-token",
"refresh_token": "refresh-token",
"expires_in": 3600,
"token_type": "Bearer",
},
)
monkeypatch.setattr(
auth_mod,
"refresh_nous_oauth_from_state",
fake_refresh_nous_oauth_from_state,
)
monkeypatch.setattr(auth_mod, "persist_nous_credentials", lambda state: None)
try:
ws._nous_poller(session_id)
assert captured_state["scope"] == auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE
assert ws._oauth_sessions[session_id]["status"] == "approved"
finally:
ws._oauth_sessions.pop(session_id, None)
def test_minimax_dashboard_poller_accepts_absolute_ms_expired_in():
"""Dashboard MiniMax completion must accept unix-ms token expiry values."""
from hermes_cli import web_server as ws

View file

@ -3667,7 +3667,7 @@ class TestNousCredentialRefresh:
assert ok is True
assert closed["value"] is True
assert captured["force_mint"] is True
assert captured["auth_mode"] == "legacy"
assert rebuilt["kwargs"]["api_key"] == "new-nous-key"
assert (
rebuilt["kwargs"]["base_url"] == "https://inference-api.nousresearch.com/v1"