From 0bac7dd05bd56fd615ef4b5c499a60a42a8b32b6 Mon Sep 17 00:00:00 2001 From: Robin Fernandes Date: Sun, 17 May 2026 20:34:39 +1000 Subject: [PATCH] refactor(auth): collapse Nous inference fallback controls --- agent/auxiliary_client.py | 16 +- agent/credential_pool.py | 51 +- hermes_cli/auth.py | 544 ++++++++++++++------ hermes_cli/proxy/adapters/base.py | 15 + hermes_cli/proxy/adapters/nous_portal.py | 45 +- hermes_cli/proxy/server.py | 115 +++-- hermes_cli/web_server.py | 39 +- run_agent.py | 12 +- tests/agent/test_credential_pool.py | 84 ++- tests/hermes_cli/test_auth_nous_provider.py | 137 ++++- tests/hermes_cli/test_proxy.py | 112 +++- tests/hermes_cli/test_web_oauth_dispatch.py | 139 ++++- tests/run_agent/test_run_agent.py | 2 +- 13 files changed, 1071 insertions(+), 240 deletions(-) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index b2733fd8a1b..e67b37b00da 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -1252,12 +1252,20 @@ def _resolve_nous_runtime_api(*, force_refresh: bool = False) -> Optional[tuple[ or the credential pool. """ try: - from hermes_cli.auth import resolve_nous_runtime_credentials + from hermes_cli.auth import ( + NOUS_INFERENCE_AUTH_AUTO, + NOUS_INFERENCE_AUTH_LEGACY, + resolve_nous_runtime_credentials, + ) creds = resolve_nous_runtime_credentials( min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))), timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")), - force_mint=force_refresh, + auth_mode=( + NOUS_INFERENCE_AUTH_LEGACY + if force_refresh + else NOUS_INFERENCE_AUTH_AUTO + ), ) except Exception as exc: logger.debug("Auxiliary Nous runtime credential resolution failed: %s", exc) @@ -2501,12 +2509,12 @@ def _refresh_provider_credentials(provider: str) -> bool: _evict_cached_clients(normalized) return True if normalized == "nous": - from hermes_cli.auth import resolve_nous_runtime_credentials + from hermes_cli.auth import NOUS_INFERENCE_AUTH_LEGACY, resolve_nous_runtime_credentials creds = resolve_nous_runtime_credentials( min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))), timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")), - force_mint=True, + auth_mode=NOUS_INFERENCE_AUTH_LEGACY, ) if not str(creds.get("api_key", "") or "").strip(): return False diff --git a/agent/credential_pool.py b/agent/credential_pool.py index b1c41977d51..7c91a08d2aa 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -831,7 +831,11 @@ class CredentialPool: nous_state, min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS, force_refresh=force, - force_mint=force, + auth_mode=( + auth_mod.NOUS_INFERENCE_AUTH_LEGACY + if force + else auth_mod.NOUS_INFERENCE_AUTH_AUTO + ), ) # Apply returned fields: dataclass fields via replace, extras via dict update field_updates = {} @@ -952,25 +956,27 @@ class CredentialPool: exc, reason="credential_pool_refresh_failure", ) + auth_mod._quarantine_nous_pool_entries( + auth_store, + exc, + reason="credential_pool_refresh_failure", + ) _save_provider_state(auth_store, "nous", state) _save_auth_store(auth_store) except Exception as clear_exc: logger.debug("Failed to clear terminal Nous OAuth state: %s", clear_exc) - cleared = replace( - entry, - access_token=None, - refresh_token=None, - agent_key=None, - agent_key_expires_at=None, - ) - self._replace_entry(entry, cleared) + singleton_sources = { + auth_mod.NOUS_DEVICE_CODE_SOURCE, + f"manual:{auth_mod.NOUS_DEVICE_CODE_SOURCE}", + } + self._entries = [ + item for item in self._entries + if item.source not in singleton_sources + ] + if self._current_id == entry.id: + self._current_id = None self._persist() - self._mark_exhausted( - cleared, - 401, - {"reason": getattr(exc, "code", None), "message": str(exc)}, - ) return None self._mark_exhausted(entry, None) return None @@ -1408,7 +1414,22 @@ def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tup elif provider == "nous": state = _load_provider_state(auth_store, "nous") - if state and not _is_suppressed(provider, "device_code"): + has_runtime_material = bool( + isinstance(state, dict) + and ( + str(state.get("access_token") or "").strip() + or str(state.get("agent_key") or "").strip() + ) + ) + if state and not has_runtime_material: + retained = [ + entry for entry in entries + if entry.source not in {"device_code", "manual:device_code"} + ] + if len(retained) != len(entries): + entries[:] = retained + changed = True + if state and has_runtime_material and not _is_suppressed(provider, "device_code"): active_sources.add("device_code") # Prefer a user-supplied label embedded in the singleton state # (set by persist_nous_credentials(label=...) when the user ran diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index 2a670589d48..783f2c0c655 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -11,6 +11,12 @@ Architecture: - resolve_provider() picks the active provider via priority chain - resolve_*_runtime_credentials() handles token refresh and key minting - logout_command() is the CLI entry point for clearing auth + +Nous authentication paths: +- Invoke JWT (preferred): use a scoped access_token directly for inference. +- Legacy session key (fallback): mint an opaque 24h key when JWT auth is + unavailable, or when HERMES_AGENT_USE_LEGACY_SESSION_KEYS is set for + debugging or rollback. """ from __future__ import annotations @@ -71,6 +77,15 @@ NOUS_LEGACY_AGENT_KEY_SCOPE = "inference:mint_agent_key" NOUS_INFERENCE_INVOKE_SCOPE = "inference:invoke" DEFAULT_NOUS_SCOPE = f"{NOUS_INFERENCE_INVOKE_SCOPE} {NOUS_LEGACY_AGENT_KEY_SCOPE}" NOUS_LEGACY_SESSION_KEYS_ENV = "HERMES_AGENT_USE_LEGACY_SESSION_KEYS" +NOUS_DEVICE_CODE_SOURCE = "device_code" +NOUS_INFERENCE_AUTH_AUTO = "auto" +NOUS_INFERENCE_AUTH_FRESH = "fresh" +NOUS_INFERENCE_AUTH_LEGACY = "legacy" +NOUS_INFERENCE_AUTH_MODES = frozenset({ + NOUS_INFERENCE_AUTH_AUTO, + NOUS_INFERENCE_AUTH_FRESH, + NOUS_INFERENCE_AUTH_LEGACY, +}) DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS @@ -1554,6 +1569,8 @@ def _decode_jwt_claims(token: Any) -> Dict[str, Any]: def _scope_values(raw_scope: Any) -> set[str]: + # OAuth token responses normally return a space-separated string. Keep + # collection support for JWT ``scp`` claims and older stored test fixtures. scopes: set[str] = set() if isinstance(raw_scope, str): for part in raw_scope.replace(",", " ").split(): @@ -1575,37 +1592,24 @@ def _nous_scope_has_invoke(raw_scope: Any) -> bool: return NOUS_INFERENCE_INVOKE_SCOPE in _scope_values(raw_scope) -def _nous_invoke_jwt_is_usable( +def _normalize_nous_auth_mode(auth_mode: Optional[str]) -> str: + mode = str(auth_mode or NOUS_INFERENCE_AUTH_AUTO).strip().lower() + if mode not in NOUS_INFERENCE_AUTH_MODES: + allowed = ", ".join(sorted(NOUS_INFERENCE_AUTH_MODES)) + raise ValueError( + f"Invalid Nous inference auth mode {auth_mode!r}; expected one of: {allowed}" + ) + return mode + + +def _nous_invoke_jwt_status( token: Any, *, scope: Any = None, expires_at: Any = None, min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS, -) -> bool: - claims = _decode_jwt_claims(token) - if not claims: - return False - scopes = ( - _scope_values(scope) - | _scope_values(claims.get("scope")) - | _scope_values(claims.get("scp")) - ) - if NOUS_INFERENCE_INVOKE_SCOPE not in scopes: - return False - exp = claims.get("exp") - skew = max(0, int(min_ttl_seconds)) - if isinstance(exp, (int, float)): - return float(exp) > (time.time() + skew) - return not _is_expiring(expires_at, skew) - - -def _nous_invoke_jwt_unavailable_reason( - token: Any, - *, - scope: Any = None, - expires_at: Any = None, - min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS, -) -> str: +) -> Optional[str]: + """Return None when the token can be used for inference, else a reason.""" claims = _decode_jwt_claims(token) if not claims: return "access_token_not_jwt" @@ -1618,11 +1622,149 @@ def _nous_invoke_jwt_unavailable_reason( return "missing_inference_invoke_scope" exp = claims.get("exp") skew = max(0, int(min_ttl_seconds)) - if isinstance(exp, (int, float)) and float(exp) <= (time.time() + skew): - return "invoke_jwt_expiring" - if not isinstance(exp, (int, float)) and _is_expiring(expires_at, skew): + if isinstance(exp, (int, float)): + if float(exp) <= (time.time() + skew): + return "invoke_jwt_expiring" + return None + if _is_expiring(expires_at, skew): return "invoke_jwt_expiry_unknown_or_expiring" - return "invoke_jwt_unavailable" + return None + + +def _nous_invoke_jwt_is_usable( + token: Any, + *, + scope: Any = None, + expires_at: Any = None, + min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS, +) -> bool: + return ( + _nous_invoke_jwt_status( + token, + scope=scope, + expires_at=expires_at, + min_ttl_seconds=min_ttl_seconds, + ) + is None + ) + + +def _nous_invoke_jwt_unavailable_reason( + token: Any, + *, + scope: Any = None, + expires_at: Any = None, + min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS, +) -> str: + return ( + _nous_invoke_jwt_status( + token, + scope=scope, + expires_at=expires_at, + min_ttl_seconds=min_ttl_seconds, + ) + or "invoke_jwt_unavailable" + ) + + +def _nous_can_select_invoke_jwt(auth_mode: str = NOUS_INFERENCE_AUTH_AUTO) -> bool: + return ( + not _nous_legacy_session_keys_forced() + and _normalize_nous_auth_mode(auth_mode) != NOUS_INFERENCE_AUTH_LEGACY + ) + + +def _nous_legacy_session_key_reason( + token: Any, + *, + scope: Any = None, + expires_at: Any = None, + auth_mode: str = NOUS_INFERENCE_AUTH_AUTO, +) -> str: + if _normalize_nous_auth_mode(auth_mode) == NOUS_INFERENCE_AUTH_LEGACY: + return "forced_legacy_session_key" + if _nous_legacy_session_keys_forced(): + return "forced_legacy_session_keys" + return _nous_invoke_jwt_unavailable_reason( + token, + scope=scope, + expires_at=expires_at, + ) + + +def _nous_cached_agent_key_is_usable( + state: Dict[str, Any], + min_ttl_seconds: int, +) -> bool: + return _agent_key_is_usable(state, min_ttl_seconds) + + +def _choose_nous_inference_auth_path( + state: Dict[str, Any], + *, + access_token: Any = None, + min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS, + auth_mode: str = NOUS_INFERENCE_AUTH_AUTO, +) -> Tuple[str, Optional[str]]: + auth_mode = _normalize_nous_auth_mode(auth_mode) + token = state.get("access_token") if access_token is None else access_token + if ( + _nous_can_select_invoke_jwt(auth_mode) + and _nous_invoke_jwt_is_usable( + token, + scope=state.get("scope"), + expires_at=state.get("expires_at"), + ) + ): + return "invoke_jwt", None + if ( + auth_mode == NOUS_INFERENCE_AUTH_AUTO + and _nous_cached_agent_key_is_usable( + state, + max(60, int(min_key_ttl_seconds)), + ) + ): + return "legacy_session_key_cache", None + return ( + "legacy_session_key_mint", + _nous_legacy_session_key_reason( + token, + scope=state.get("scope"), + expires_at=state.get("expires_at"), + auth_mode=auth_mode, + ), + ) + + +def _log_nous_invoke_jwt_selected( + *, + access_token: Any, + sequence_id: Optional[str] = None, +) -> None: + logger.info("Nous inference auth: using NAS invoke JWT") + _oauth_trace( + "nous_invoke_jwt_selected", + sequence_id=sequence_id, + access_token_fp=_token_fingerprint(access_token), + ) + + +def _log_nous_legacy_session_key_selected( + reason: str, + *, + access_token: Any, + sequence_id: Optional[str] = None, +) -> None: + logger.info( + "Nous inference auth: using legacy session key path (%s)", + reason, + ) + _oauth_trace( + "nous_legacy_session_key_selected", + sequence_id=sequence_id, + reason=reason, + access_token_fp=_token_fingerprint(access_token), + ) def _nous_jwt_expires_at(token: Any, fallback_expires_at: Any = None) -> Optional[str]: @@ -1645,7 +1787,17 @@ def _set_nous_agent_key_from_invoke_jwt( if not isinstance(access_token, str) or not access_token.strip(): return now = datetime.now(timezone.utc) - effective_obtained_at = obtained_at or now.isoformat() + existing_obtained_at = state.get("agent_key_obtained_at") + if obtained_at: + effective_obtained_at = obtained_at + elif ( + state.get("agent_key") == access_token + and isinstance(existing_obtained_at, str) + and existing_obtained_at.strip() + ): + effective_obtained_at = existing_obtained_at + else: + effective_obtained_at = now.isoformat() expires_at = _nous_jwt_expires_at(access_token, state.get("expires_at")) expires_epoch = _parse_iso_timestamp(expires_at) expires_in = ( @@ -1664,6 +1816,38 @@ def _set_nous_agent_key_from_invoke_jwt( state["agent_key_obtained_at"] = effective_obtained_at +def _select_nous_invoke_jwt( + state: Dict[str, Any], + *, + access_token: Any = None, + sequence_id: Optional[str] = None, +) -> None: + if isinstance(access_token, str) and access_token.strip(): + state["access_token"] = access_token + _set_nous_agent_key_from_invoke_jwt(state) + _log_nous_invoke_jwt_selected( + access_token=state.get("access_token"), + sequence_id=sequence_id, + ) + + +_NOUS_EFFECTIVE_STATE_IGNORED_KEYS = frozenset({ + # These are derived from expires_at/JWT exp and naturally tick down between + # reads. Persisting only these changes makes auth.json noisy and defeats + # the mtime-keyed auth-status cache. + "expires_in", + "agent_key_expires_in", +}) + + +def _nous_effective_provider_state(state: Dict[str, Any]) -> Dict[str, Any]: + return { + key: value + for key, value in state.items() + if key not in _NOUS_EFFECTIVE_STATE_IGNORED_KEYS + } + + def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> bool: claims = _decode_jwt_claims(access_token) exp = claims.get("exp") @@ -3476,6 +3660,57 @@ def _is_nous_invoke_scope_refusal(exc: Exception) -> bool: ) +def _nous_device_scope( + requested_scope: Optional[str], + *, + default_scope: str = DEFAULT_NOUS_SCOPE, +) -> Tuple[str, bool]: + explicit_scope = requested_scope is not None + scope = requested_scope or default_scope + if _nous_legacy_session_keys_forced(): + scope = NOUS_LEGACY_AGENT_KEY_SCOPE + return scope, explicit_scope + + +def _request_nous_device_code_with_scope_fallback( + *, + client: httpx.Client, + portal_base_url: str, + client_id: str, + scope: str, + allow_legacy_fallback: bool, +) -> Tuple[Dict[str, Any], str]: + try: + return ( + _request_device_code( + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + scope=scope, + ), + scope, + ) + except Exception as exc: + if ( + allow_legacy_fallback + and _nous_scope_has_invoke(scope) + and _is_nous_invoke_scope_refusal(exc) + ): + logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope") + _oauth_trace("nous_device_code_invoke_scope_refused") + retry_scope = NOUS_LEGACY_AGENT_KEY_SCOPE + return ( + _request_device_code( + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + scope=retry_scope, + ), + retry_scope, + ) + raise + + def _poll_for_token( client: httpx.Client, portal_base_url: str, @@ -3817,6 +4052,39 @@ def _quarantine_nous_oauth_state( invalidate_nous_auth_status_cache() +def _quarantine_nous_pool_entries( + auth_store: Dict[str, Any], + error: AuthError, + *, + reason: str, +) -> bool: + """Remove singleton-seeded Nous pool entries that contain dead OAuth state.""" + pool = auth_store.get("credential_pool") + if not isinstance(pool, dict): + return False + entries = pool.get("nous") + if not isinstance(entries, list): + return False + + retained = [] + removed = False + singleton_sources = {NOUS_DEVICE_CODE_SOURCE, f"manual:{NOUS_DEVICE_CODE_SOURCE}"} + for entry in entries: + if isinstance(entry, dict) and entry.get("source") in singleton_sources: + removed = True + continue + retained.append(entry) + + if removed: + pool["nous"] = retained + _oauth_trace( + "nous_pool_device_code_quarantined", + reason=reason, + error_code=error.code, + ) + return removed + + def _try_import_shared_nous_state( *, timeout_seconds: float = 15.0, @@ -3842,7 +4110,7 @@ def _try_import_shared_nous_state( # Build a full state dict so refresh_nous_oauth_from_state has every # field it needs. force_refresh=True gets us a fresh access_token - # for this profile; force_mint=True gets us a fresh agent_key. + # for this profile; fresh auth mode avoids stale cached legacy keys. state: Dict[str, Any] = { "access_token": shared.get("access_token"), "refresh_token": shared.get("refresh_token"), @@ -3863,7 +4131,7 @@ def _try_import_shared_nous_state( min_key_ttl_seconds=min_key_ttl_seconds, timeout_seconds=timeout_seconds, force_refresh=True, - force_mint=True, + auth_mode=NOUS_INFERENCE_AUTH_FRESH, ) _write_shared_nous_state(refreshed) except AuthError as exc: @@ -4121,6 +4389,11 @@ def resolve_nous_access_token( exc, reason="managed_access_token_refresh_failure", ) + _quarantine_nous_pool_entries( + auth_store, + exc, + reason="managed_access_token_refresh_failure", + ) _save_provider_state(auth_store, "nous", state) _save_auth_store(auth_store) raise @@ -4167,9 +4440,10 @@ def refresh_nous_oauth_pure( insecure: Optional[bool] = None, ca_bundle: Optional[str] = None, force_refresh: bool = False, - force_mint: bool = False, + auth_mode: str = NOUS_INFERENCE_AUTH_AUTO, ) -> Dict[str, Any]: """Refresh Nous OAuth state without mutating auth.json.""" + auth_mode = _normalize_nous_auth_mode(auth_mode) state: Dict[str, Any] = { "access_token": access_token, "refresh_token": refresh_token, @@ -4229,38 +4503,17 @@ def refresh_nous_oauth_pure( now.timestamp() + access_ttl, tz=timezone.utc ).isoformat() - if ( - not legacy_session_keys - and _nous_invoke_jwt_is_usable( - state.get("access_token"), - scope=state.get("scope"), - expires_at=state.get("expires_at"), - ) - ): - _set_nous_agent_key_from_invoke_jwt(state) - logger.info("Nous inference auth: using NAS invoke JWT") - _oauth_trace( - "nous_invoke_jwt_selected", - access_token_fp=_token_fingerprint(state.get("access_token")), - ) - elif force_mint or not _agent_key_is_usable(state, min_agent_key_ttl): - fallback_reason = ( - "forced_legacy_session_keys" - if legacy_session_keys - else _nous_invoke_jwt_unavailable_reason( - state.get("access_token"), - scope=state.get("scope"), - expires_at=state.get("expires_at"), - ) - ) - logger.info( - "Nous inference auth: using legacy session key path (%s)", - fallback_reason, - ) - _oauth_trace( - "nous_legacy_session_key_selected", - reason=fallback_reason, - access_token_fp=_token_fingerprint(state.get("access_token")), + selected_auth_path, fallback_reason = _choose_nous_inference_auth_path( + state, + min_key_ttl_seconds=min_agent_key_ttl, + auth_mode=auth_mode, + ) + if selected_auth_path == "invoke_jwt": + _select_nous_invoke_jwt(state) + elif selected_auth_path == "legacy_session_key_mint": + _log_nous_legacy_session_key_selected( + fallback_reason or "legacy_session_key_required", + access_token=state.get("access_token"), ) mint_payload = _mint_agent_key( client=client, @@ -4288,7 +4541,7 @@ def refresh_nous_oauth_from_state( min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS, timeout_seconds: float = 15.0, force_refresh: bool = False, - force_mint: bool = False, + auth_mode: str = NOUS_INFERENCE_AUTH_AUTO, ) -> Dict[str, Any]: """Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure.""" tls = state.get("tls") or {} @@ -4309,13 +4562,10 @@ def refresh_nous_oauth_from_state( insecure=tls.get("insecure"), ca_bundle=tls.get("ca_bundle"), force_refresh=force_refresh, - force_mint=force_mint, + auth_mode=auth_mode, ) -NOUS_DEVICE_CODE_SOURCE = "device_code" - - def persist_nous_credentials( creds: Dict[str, Any], *, @@ -4390,7 +4640,7 @@ def resolve_nous_runtime_credentials( timeout_seconds: float = 15.0, insecure: Optional[bool] = None, ca_bundle: Optional[str] = None, - force_mint: bool = False, + auth_mode: str = NOUS_INFERENCE_AUTH_AUTO, ) -> Dict[str, Any]: """ Resolve Nous inference credentials for runtime use. @@ -4402,6 +4652,7 @@ def resolve_nous_runtime_credentials( Returns dict with: provider, base_url, api_key, key_id, expires_at, expires_in, source ("invoke_jwt", "cache", or "portal"), and auth_path. """ + auth_mode = _normalize_nous_auth_mode(auth_mode) min_key_ttl_seconds = max(60, int(min_key_ttl_seconds)) sequence_id = uuid.uuid4().hex[:12] @@ -4413,6 +4664,9 @@ def resolve_nous_runtime_credentials( raise AuthError("Hermes is not logged into Nous Portal.", provider="nous", relogin_required=True) + persisted_state = dict(state) + state_persisted = False + portal_base_url = ( _optional_base_url(state.get("portal_base_url")) or os.getenv("HERMES_PORTAL_BASE_URL") @@ -4427,6 +4681,17 @@ def resolve_nous_runtime_credentials( client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID) def _persist_state(reason: str) -> None: + nonlocal persisted_state, state_persisted + if ( + _nous_effective_provider_state(state) + == _nous_effective_provider_state(persisted_state) + ): + _oauth_trace( + "nous_state_persist_skipped", + sequence_id=sequence_id, + reason=reason, + ) + return try: _save_provider_state(auth_store, "nous", state) _save_auth_store(auth_store) @@ -4445,6 +4710,8 @@ def resolve_nous_runtime_credentials( refresh_token_fp=_token_fingerprint(state.get("refresh_token")), access_token_fp=_token_fingerprint(state.get("access_token")), ) + persisted_state = dict(state) + state_persisted = True # Mirror post-refresh state to the shared store so sibling # profiles don't hold stale refresh_tokens after rotation. # Best-effort — any failure is logged and swallowed inside @@ -4456,7 +4723,7 @@ def resolve_nous_runtime_credentials( _oauth_trace( "nous_runtime_credentials_start", sequence_id=sequence_id, - force_mint=bool(force_mint), + auth_mode=auth_mode, min_key_ttl_seconds=min_key_ttl_seconds, refresh_token_fp=_token_fingerprint(state.get("refresh_token")), ) @@ -4520,6 +4787,11 @@ def resolve_nous_runtime_credentials( exc, reason="runtime_access_refresh_failure", ) + _quarantine_nous_pool_entries( + auth_store, + exc, + reason="runtime_access_refresh_failure", + ) _persist_state("terminal_runtime_access_refresh_failure") raise now = datetime.now(timezone.utc) @@ -4554,50 +4826,28 @@ def resolve_nous_runtime_credentials( # the opaque session key. used_cached_key = False mint_payload: Optional[Dict[str, Any]] = None - selected_auth_path = "legacy_session_key" - legacy_session_keys = _nous_legacy_session_keys_forced() + selected_auth_path, fallback_reason = _choose_nous_inference_auth_path( + state, + access_token=access_token, + min_key_ttl_seconds=min_key_ttl_seconds, + auth_mode=auth_mode, + ) - if ( - not legacy_session_keys - and _nous_invoke_jwt_is_usable( - access_token, - scope=state.get("scope"), - expires_at=state.get("expires_at"), - ) - ): - _set_nous_agent_key_from_invoke_jwt(state) - selected_auth_path = "invoke_jwt" - logger.info("Nous inference auth: using NAS invoke JWT") - _oauth_trace( - "nous_invoke_jwt_selected", + if selected_auth_path == "invoke_jwt": + _select_nous_invoke_jwt( + state, + access_token=access_token, sequence_id=sequence_id, - access_token_fp=_token_fingerprint(access_token), ) - elif not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds): + elif selected_auth_path == "legacy_session_key_cache": used_cached_key = True - selected_auth_path = "legacy_session_key_cache" - logger.info("Nous inference auth: using cached legacy session key") + logger.info("Nous inference auth: using cached agent_key") _oauth_trace("agent_key_reuse", sequence_id=sequence_id) else: - fallback_reason = ( - "forced_legacy_session_keys" - if legacy_session_keys - else _nous_invoke_jwt_unavailable_reason( - access_token, - scope=state.get("scope"), - expires_at=state.get("expires_at"), - ) - ) - selected_auth_path = "legacy_session_key_mint" - logger.info( - "Nous inference auth: using legacy session key path (%s)", - fallback_reason, - ) - _oauth_trace( - "nous_legacy_session_key_selected", + _log_nous_legacy_session_key_selected( + fallback_reason or "legacy_session_key_required", + access_token=access_token, sequence_id=sequence_id, - reason=fallback_reason, - access_token_fp=_token_fingerprint(access_token), ) try: _oauth_trace( @@ -4646,6 +4896,11 @@ def resolve_nous_runtime_credentials( exc, reason="runtime_mint_retry_refresh_failure", ) + _quarantine_nous_pool_entries( + auth_store, + exc, + reason="runtime_mint_retry_refresh_failure", + ) _persist_state("terminal_runtime_mint_retry_refresh_failure") raise now = datetime.now(timezone.utc) @@ -4674,22 +4929,24 @@ def resolve_nous_runtime_credentials( # Persist retry refresh immediately for crash safety and cross-process visibility. _persist_state("post_refresh_mint_retry") - if ( - not legacy_session_keys - and _nous_invoke_jwt_is_usable( - access_token, - scope=state.get("scope"), - expires_at=state.get("expires_at"), - ) - ): - _set_nous_agent_key_from_invoke_jwt(state) + retry_auth_mode = ( + NOUS_INFERENCE_AUTH_LEGACY + if auth_mode == NOUS_INFERENCE_AUTH_LEGACY + else NOUS_INFERENCE_AUTH_FRESH + ) + retry_auth_path, _ = _choose_nous_inference_auth_path( + state, + access_token=access_token, + min_key_ttl_seconds=min_key_ttl_seconds, + auth_mode=retry_auth_mode, + ) + if retry_auth_path == "invoke_jwt": mint_payload = None selected_auth_path = "invoke_jwt" - logger.info("Nous inference auth: using NAS invoke JWT") - _oauth_trace( - "nous_invoke_jwt_selected", + _select_nous_invoke_jwt( + state, + access_token=access_token, sequence_id=sequence_id, - access_token_fp=_token_fingerprint(access_token), ) else: mint_payload = _mint_agent_key( @@ -4727,7 +4984,8 @@ def resolve_nous_runtime_credentials( _persist_state("resolve_nous_runtime_credentials_final") - _sync_nous_pool_from_auth_store() + if state_persisted: + _sync_nous_pool_from_auth_store() api_key = state.get("agent_key") if not isinstance(api_key, str) or not api_key: @@ -6433,10 +6691,7 @@ def _nous_device_code_login( or pconfig.inference_base_url ).rstrip("/") client_id = client_id or pconfig.client_id - explicit_scope = scope is not None - scope = scope or pconfig.scope - if _nous_legacy_session_keys_forced(): - scope = NOUS_LEGACY_AGENT_KEY_SCOPE + scope, explicit_scope = _nous_device_scope(scope, default_scope=pconfig.scope) timeout = httpx.Timeout(timeout_seconds) verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True) @@ -6451,30 +6706,13 @@ def _nous_device_code_login( print(f"TLS verification: custom CA bundle ({ca_bundle})") with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client: - try: - device_data = _request_device_code( - client=client, - portal_base_url=portal_base_url, - client_id=client_id, - scope=scope, - ) - except Exception as exc: - if ( - not explicit_scope - and _nous_scope_has_invoke(scope) - and _is_nous_invoke_scope_refusal(exc) - ): - logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope") - _oauth_trace("nous_device_code_invoke_scope_refused") - scope = NOUS_LEGACY_AGENT_KEY_SCOPE - device_data = _request_device_code( - client=client, - portal_base_url=portal_base_url, - client_id=client_id, - scope=scope, - ) - else: - raise + device_data, scope = _request_nous_device_code_with_scope_fallback( + client=client, + portal_base_url=portal_base_url, + client_id=client_id, + scope=scope, + allow_legacy_fallback=not explicit_scope, + ) verification_url = str(device_data["verification_uri_complete"]) user_code = str(device_data["user_code"]) @@ -6543,7 +6781,7 @@ def _nous_device_code_login( min_key_ttl_seconds=min_key_ttl_seconds, timeout_seconds=timeout_seconds, force_refresh=False, - force_mint=True, + auth_mode=NOUS_INFERENCE_AUTH_FRESH, ) except AuthError as exc: if exc.code == "subscription_required": diff --git a/hermes_cli/proxy/adapters/base.py b/hermes_cli/proxy/adapters/base.py index 5ac8a5dcedd..c7f36e25a2b 100644 --- a/hermes_cli/proxy/adapters/base.py +++ b/hermes_cli/proxy/adapters/base.py @@ -81,6 +81,21 @@ class UpstreamAdapter(ABC): refresh fails. The proxy will return 401 to the client. """ + def get_retry_credential( + self, + *, + failed_credential: UpstreamCredential, + status_code: int, + ) -> Optional[UpstreamCredential]: + """Return an alternate credential after an upstream auth failure. + + The default is no retry. Providers can override this for one-shot + fallback paths, such as switching from a preferred token type to a + legacy bearer after the upstream rejects the first request. + """ + del failed_credential, status_code + return None + def describe(self) -> str: """One-line status summary for ``proxy status``.""" try: diff --git a/hermes_cli/proxy/adapters/nous_portal.py b/hermes_cli/proxy/adapters/nous_portal.py index b69f9d52644..a8cfd4cbada 100644 --- a/hermes_cli/proxy/adapters/nous_portal.py +++ b/hermes_cli/proxy/adapters/nous_portal.py @@ -19,13 +19,16 @@ from typing import Any, Dict, FrozenSet, Optional from hermes_cli.auth import ( AuthError, DEFAULT_NOUS_INFERENCE_URL, + NOUS_INFERENCE_AUTH_AUTO, + NOUS_INFERENCE_AUTH_LEGACY, _load_auth_store, _is_terminal_nous_refresh_error, _quarantine_nous_oauth_state, + _quarantine_nous_pool_entries, _save_auth_store, _write_shared_nous_state, refresh_nous_oauth_from_state, -) + ) from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential logger = logging.getLogger(__name__) @@ -76,6 +79,21 @@ class NousPortalAdapter(UpstreamAdapter): ) def get_credential(self) -> UpstreamCredential: + return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_AUTO) + + def get_retry_credential( + self, + *, + failed_credential: UpstreamCredential, + status_code: int, + ) -> Optional[UpstreamCredential]: + del failed_credential + if status_code != 401: + return None + logger.info("proxy: Nous upstream rejected bearer; retrying with legacy session key") + return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_LEGACY) + + def _get_credential(self, *, auth_mode: str) -> UpstreamCredential: with self._lock: state = self._read_state() if state is None: @@ -84,7 +102,10 @@ class NousPortalAdapter(UpstreamAdapter): ) try: - refreshed = refresh_nous_oauth_from_state(state) + refreshed = refresh_nous_oauth_from_state( + state, + auth_mode=auth_mode, + ) except AuthError as exc: if _is_terminal_nous_refresh_error(exc): _quarantine_nous_oauth_state( @@ -92,7 +113,11 @@ class NousPortalAdapter(UpstreamAdapter): exc, reason="proxy_refresh_failure", ) - self._save_state(state) + self._save_state( + state, + quarantine_error=exc, + quarantine_reason="proxy_refresh_failure", + ) raise RuntimeError( f"Failed to refresh Nous Portal credentials: {exc}" ) from exc @@ -136,9 +161,21 @@ class NousPortalAdapter(UpstreamAdapter): return None return dict(state) # copy so the refresh helper can mutate freely - def _save_state(self, state: Dict[str, Any]) -> None: + def _save_state( + self, + state: Dict[str, Any], + *, + quarantine_error: Optional[AuthError] = None, + quarantine_reason: Optional[str] = None, + ) -> None: try: store = _load_auth_store() + if quarantine_error is not None and quarantine_reason: + _quarantine_nous_pool_entries( + store, + quarantine_error, + reason=quarantine_reason, + ) providers = store.setdefault("providers", {}) providers["nous"] = state _save_auth_store(store) diff --git a/hermes_cli/proxy/server.py b/hermes_cli/proxy/server.py index fa497f13291..a72f75d67ee 100644 --- a/hermes_cli/proxy/server.py +++ b/hermes_cli/proxy/server.py @@ -26,7 +26,7 @@ except ImportError: web = None # type: ignore[assignment] AIOHTTP_AVAILABLE = False -from hermes_cli.proxy.adapters.base import UpstreamAdapter +from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential logger = logging.getLogger(__name__) @@ -136,50 +136,93 @@ def create_app(adapter: UpstreamAdapter) -> "web.Application": logger.warning("proxy: credential resolution failed: %s", exc) return _json_error(401, str(exc), code="upstream_auth_failed") - upstream_url = f"{cred.base_url.rstrip('/')}{rel_path}" - # Preserve query string verbatim. - if request.query_string: - upstream_url = f"{upstream_url}?{request.query_string}" - # Forward body verbatim. Read into memory once — request bodies for # chat/completions/embeddings are small (<1MB typically). If we ever # need to forward large multipart uploads we'll switch to streaming # the request body too. body = await request.read() - fwd_headers = _filter_request_headers(request.headers) - fwd_headers["Authorization"] = f"{cred.token_type} {cred.bearer}" - - logger.debug( - "proxy: forwarding %s %s -> %s (body=%d bytes)", - request.method, rel_path, upstream_url, len(body), - ) - - # Use a per-request session so connection state doesn't leak between - # clients. Could be optimized to a shared session later. timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=300) - try: - session = aiohttp.ClientSession(timeout=timeout) - except Exception as exc: # pragma: no cover - aiohttp setup issue - return _json_error(500, f"proxy session init failed: {exc}") - try: - upstream_resp = await session.request( - request.method, - upstream_url, - data=body if body else None, - headers=fwd_headers, - allow_redirects=False, + async def _send_upstream(active_cred: UpstreamCredential): + upstream_url = f"{active_cred.base_url.rstrip('/')}{rel_path}" + # Preserve query string verbatim. + if request.query_string: + upstream_url = f"{upstream_url}?{request.query_string}" + + fwd_headers = _filter_request_headers(request.headers) + fwd_headers["Authorization"] = f"{active_cred.token_type} {active_cred.bearer}" + + logger.debug( + "proxy: forwarding %s %s -> %s (body=%d bytes)", + request.method, rel_path, upstream_url, len(body), ) - except aiohttp.ClientError as exc: - await session.close() - logger.warning("proxy: upstream connection failed: %s", exc) - return _json_error(502, f"upstream connection failed: {exc}", - code="upstream_unreachable") - except asyncio.TimeoutError: - await session.close() - return _json_error(504, "upstream request timed out", - code="upstream_timeout") + + try: + session = aiohttp.ClientSession(timeout=timeout) + except Exception as exc: # pragma: no cover - aiohttp setup issue + raise RuntimeError(f"proxy session init failed: {exc}") from exc + + try: + upstream_resp = await session.request( + request.method, + upstream_url, + data=body if body else None, + headers=fwd_headers, + allow_redirects=False, + ) + except Exception: + await session.close() + raise + return session, upstream_resp + + async def _open_upstream(active_cred: UpstreamCredential): + try: + return await _send_upstream(active_cred) + except RuntimeError as exc: + return _json_error(500, str(exc)), None + except aiohttp.ClientError as exc: + logger.warning("proxy: upstream connection failed: %s", exc) + return ( + _json_error( + 502, + f"upstream connection failed: {exc}", + code="upstream_unreachable", + ), + None, + ) + except asyncio.TimeoutError: + return ( + _json_error( + 504, + "upstream request timed out", + code="upstream_timeout", + ), + None, + ) + + session_or_response, upstream_resp = await _open_upstream(cred) + if upstream_resp is None: + return session_or_response + session = session_or_response + + if upstream_resp.status == 401: + try: + retry_cred = adapter.get_retry_credential( + failed_credential=cred, + status_code=upstream_resp.status, + ) + except Exception as exc: + logger.warning("proxy: retry credential resolution failed: %s", exc) + retry_cred = None + + if retry_cred is not None: + upstream_resp.release() + await session.close() + session_or_response, upstream_resp = await _open_upstream(retry_cred) + if upstream_resp is None: + return session_or_response + session = session_or_response # Stream response back. Headers first, then chunked body. resp = web.StreamResponse( diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 8a1e4aca2e1..bfd47e9cc24 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -1815,7 +1815,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]: so the UI can render the verification page link + user code. """ if provider_id == "nous": - from hermes_cli.auth import _request_device_code, PROVIDER_REGISTRY + from hermes_cli.auth import ( + _nous_device_scope, + _request_nous_device_code_with_scope_fallback, + PROVIDER_REGISTRY, + ) import httpx pconfig = PROVIDER_REGISTRY["nous"] portal_base_url = ( @@ -1824,22 +1828,31 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]: or pconfig.portal_base_url ).rstrip("/") client_id = pconfig.client_id - scope = pconfig.scope + scope, explicit_scope = _nous_device_scope(None, default_scope=pconfig.scope) + def _do_nous_device_request(): - with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client: - return _request_device_code( + with httpx.Client( + timeout=httpx.Timeout(15.0), + headers={"Accept": "application/json"}, + ) as client: + return _request_nous_device_code_with_scope_fallback( client=client, portal_base_url=portal_base_url, client_id=client_id, scope=scope, + allow_legacy_fallback=not explicit_scope, ) - device_data = await asyncio.get_running_loop().run_in_executor(None, _do_nous_device_request) + + device_data, effective_scope = await asyncio.get_running_loop().run_in_executor( + None, _do_nous_device_request + ) sid, sess = _new_oauth_session("nous", "device_code") sess["device_code"] = str(device_data["device_code"]) sess["interval"] = int(device_data["interval"]) sess["expires_at"] = time.time() + int(device_data["expires_in"]) sess["portal_base_url"] = portal_base_url sess["client_id"] = client_id + sess["scope"] = effective_scope threading.Thread( target=_nous_poller, args=(sid,), daemon=True, name=f"oauth-poll-{sid[:6]}" ).start() @@ -1968,7 +1981,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]: def _nous_poller(session_id: str) -> None: """Background poller that drives a Nous device-code flow to completion.""" - from hermes_cli.auth import _poll_for_token, refresh_nous_oauth_from_state + from hermes_cli.auth import ( + NOUS_INFERENCE_AUTH_FRESH, + _poll_for_token, + refresh_nous_oauth_from_state, + ) from datetime import datetime, timezone import httpx with _oauth_sessions_lock: @@ -1979,6 +1996,7 @@ def _nous_poller(session_id: str) -> None: client_id = sess["client_id"] device_code = sess["device_code"] interval = sess["interval"] + scope = sess.get("scope") expires_in = max(60, int(sess["expires_at"] - time.time())) try: with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client: @@ -1997,7 +2015,7 @@ def _nous_poller(session_id: str) -> None: "portal_base_url": portal_base_url, "inference_base_url": token_data.get("inference_base_url"), "client_id": client_id, - "scope": token_data.get("scope"), + "scope": token_data.get("scope") or scope, "token_type": token_data.get("token_type", "Bearer"), "access_token": token_data["access_token"], "refresh_token": token_data.get("refresh_token"), @@ -2009,8 +2027,11 @@ def _nous_poller(session_id: str) -> None: "expires_in": token_ttl, } full_state = refresh_nous_oauth_from_state( - auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0, - force_refresh=False, force_mint=True, + auth_state, + min_key_ttl_seconds=300, + timeout_seconds=15.0, + force_refresh=False, + auth_mode=NOUS_INFERENCE_AUTH_FRESH, ) from hermes_cli.auth import persist_nous_credentials persist_nous_credentials(full_state) diff --git a/run_agent.py b/run_agent.py index 6e9877a1182..1244d372fdf 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2628,12 +2628,20 @@ class AIAgent: return False try: - from hermes_cli.auth import resolve_nous_runtime_credentials + from hermes_cli.auth import ( + NOUS_INFERENCE_AUTH_AUTO, + NOUS_INFERENCE_AUTH_LEGACY, + resolve_nous_runtime_credentials, + ) creds = resolve_nous_runtime_credentials( min_key_ttl_seconds=max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800"))), timeout_seconds=float(os.getenv("HERMES_NOUS_TIMEOUT_SECONDS", "15")), - force_mint=force, + auth_mode=( + NOUS_INFERENCE_AUTH_LEGACY + if force + else NOUS_INFERENCE_AUTH_AUTO + ), ) except Exception as exc: logger.debug("Nous credential refresh failed: %s", exc) diff --git a/tests/agent/test_credential_pool.py b/tests/agent/test_credential_pool.py index f7eaf9fa273..875b08d91f0 100644 --- a/tests/agent/test_credential_pool.py +++ b/tests/agent/test_credential_pool.py @@ -566,7 +566,7 @@ def test_load_pool_mirrors_nous_invoke_jwt_agent_key_runtime_api_key(tmp_path, m assert pool_entry["agent_key_expires_at"] == expires_at -def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch): +def test_nous_pool_terminal_refresh_removes_device_code_entry(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) monkeypatch.setenv("HERMES_SHARED_AUTH_DIR", str(tmp_path / "shared")) _write_auth_store( @@ -591,7 +591,7 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch): }, ) - from agent.credential_pool import load_pool + from agent.credential_pool import PooledCredential, load_pool from hermes_cli import auth as auth_mod from hermes_cli.auth import AuthError @@ -606,18 +606,30 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch): relogin_required=True, ) + pool = load_pool("nous") + selected = pool.select() + assert selected is not None + assert selected.source == "device_code" + pool.add_entry(PooledCredential.from_dict("nous", { + "id": "legacy-seeded", + "source": "manual:device_code", + "auth_type": "oauth", + "access_token": "old-access-token", + "refresh_token": "old-refresh-token", + "agent_key": "old-agent-key", + })) + pool.add_entry(PooledCredential.from_dict("nous", { + "id": "manual-key", + "source": "manual", + "auth_type": "api_key", + "access_token": "manual-nous-key", + })) + monkeypatch.setattr(auth_mod, "refresh_nous_oauth_from_state", _terminal_refresh_failure) - pool = load_pool("nous") - assert pool.select() is not None assert pool.try_refresh_current() is None - entry = pool.entries()[0] - assert entry.last_status == "exhausted" - assert entry.last_error_code == 401 - assert entry.refresh_token is None - assert entry.access_token is None - assert entry.agent_key is None + assert [entry.id for entry in pool.entries()] == ["manual-key"] auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text()) nous_state = auth_payload["providers"]["nous"] @@ -625,11 +637,63 @@ def test_nous_pool_terminal_refresh_clears_tokens(tmp_path, monkeypatch): assert not nous_state.get("access_token") assert not nous_state.get("agent_key") assert nous_state["last_auth_error"]["code"] == "invalid_grant" + assert [entry["id"] for entry in auth_payload["credential_pool"]["nous"]] == ["manual-key"] assert pool.try_refresh_current() is None assert refresh_calls["count"] == 1 +def test_load_pool_removes_nous_device_code_when_singleton_quarantined(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "active_provider": "nous", + "providers": { + "nous": { + "portal_base_url": "https://portal.example.com", + "inference_base_url": "https://inference.example.com/v1", + "client_id": "hermes-cli", + "last_auth_error": {"code": "invalid_grant"}, + } + }, + "credential_pool": { + "nous": [ + { + "id": "seeded-current", + "source": "device_code", + "auth_type": "oauth", + "access_token": "stale-access", + "refresh_token": "stale-refresh", + "agent_key": "stale-agent", + }, + { + "id": "seeded-legacy", + "source": "manual:device_code", + "auth_type": "oauth", + "access_token": "older-stale-access", + }, + { + "id": "manual-key", + "source": "manual", + "auth_type": "api_key", + "access_token": "manual-nous-key", + }, + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("nous") + + assert [entry.id for entry in pool.entries()] == ["manual-key"] + auth_payload = json.loads((tmp_path / "hermes" / "auth.json").read_text()) + assert [entry["id"] for entry in auth_payload["credential_pool"]["nous"]] == ["manual-key"] + + def test_load_pool_removes_stale_file_backed_singleton_entry(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) monkeypatch.delenv("ANTHROPIC_API_KEY", raising=False) diff --git a/tests/hermes_cli/test_auth_nous_provider.py b/tests/hermes_cli/test_auth_nous_provider.py index 1d07737a857..0bdb1330a29 100644 --- a/tests/hermes_cli/test_auth_nous_provider.py +++ b/tests/hermes_cli/test_auth_nous_provider.py @@ -231,6 +231,83 @@ def test_resolve_nous_runtime_credentials_prefers_invoke_jwt_and_mirrors( assert pool_entries[0]["source"] == auth_mod.NOUS_DEVICE_CODE_SOURCE +def test_resolve_nous_runtime_credentials_invoke_jwt_is_idempotent( + tmp_path, + monkeypatch, +): + import hermes_cli.auth as auth_mod + + hermes_home = tmp_path / "hermes" + hermes_home.mkdir(parents=True, exist_ok=True) + exp = int(time.time() + 3600) + expires_at = datetime.fromtimestamp(exp, tz=timezone.utc).isoformat() + token = _jwt_with_claims({ + "sub": "test-user", + "scope": auth_mod.DEFAULT_NOUS_SCOPE, + "exp": exp, + }) + original_obtained_at = "2026-04-17T22:00:10+00:00" + auth_store = { + "version": 1, + "active_provider": "nous", + "providers": { + "nous": { + "portal_base_url": "https://portal.example.com", + "inference_base_url": "https://inference.example.com/v1", + "client_id": "hermes-cli", + "token_type": "Bearer", + "scope": auth_mod.DEFAULT_NOUS_SCOPE, + "access_token": token, + "refresh_token": "refresh-token", + "obtained_at": "2026-02-01T00:00:00+00:00", + "expires_in": 123, + "expires_at": expires_at, + "agent_key": token, + "agent_key_id": None, + "agent_key_expires_at": expires_at, + "agent_key_expires_in": 123, + "agent_key_reused": False, + "agent_key_obtained_at": original_obtained_at, + "tls": {"insecure": False, "ca_bundle": None}, + }, + }, + } + auth_path = hermes_home / "auth.json" + auth_path.write_text(json.dumps(auth_store, indent=2)) + before_content = auth_path.read_text() + before_mtime = auth_path.stat().st_mtime_ns + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + def _unexpected_mint(*args, **kwargs): + raise AssertionError("stable invoke JWT should not mint a legacy key") + + def _unexpected_shared_write(*args, **kwargs): + raise AssertionError("unchanged invoke JWT resolution should not sync shared store") + + sync_calls = [] + + monkeypatch.setattr(auth_mod, "_mint_agent_key", _unexpected_mint) + monkeypatch.setattr(auth_mod, "_write_shared_nous_state", _unexpected_shared_write) + monkeypatch.setattr( + auth_mod, + "_sync_nous_pool_from_auth_store", + lambda: sync_calls.append(True), + ) + + creds = auth_mod.resolve_nous_runtime_credentials(min_key_ttl_seconds=300) + + assert creds["api_key"] == token + assert creds["source"] == "invoke_jwt" + assert auth_path.read_text() == before_content + assert auth_path.stat().st_mtime_ns == before_mtime + assert sync_calls == [] + payload = json.loads(auth_path.read_text()) + assert ( + payload["providers"]["nous"]["agent_key_obtained_at"] + == original_obtained_at + ) + + def test_resolve_nous_runtime_credentials_trusts_invoke_jwt_exp_over_stale_metadata( tmp_path, monkeypatch, @@ -301,6 +378,41 @@ def test_resolve_nous_runtime_credentials_does_not_apply_legacy_ttl_to_invoke_jw assert payload["credential_pool"]["nous"][0]["agent_key"] == token +def test_legacy_auth_mode_bypasses_usable_invoke_jwt(tmp_path, monkeypatch): + import hermes_cli.auth as auth_mod + + hermes_home = tmp_path / "hermes" + token = _invoke_jwt(seconds=3600) + _setup_nous_auth( + hermes_home, + access_token=token, + scope=auth_mod.DEFAULT_NOUS_SCOPE, + expires_at=_future_iso(3600), + expires_in=3600, + ) + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + + mint_calls = [] + + def _fake_mint_agent_key(*, client, portal_base_url, access_token, min_ttl_seconds): + del client, portal_base_url, min_ttl_seconds + mint_calls.append(access_token) + return _mint_payload(api_key="legacy-after-jwt-401") + + monkeypatch.setattr(auth_mod, "_mint_agent_key", _fake_mint_agent_key) + + creds = auth_mod.resolve_nous_runtime_credentials( + min_key_ttl_seconds=300, + auth_mode=auth_mod.NOUS_INFERENCE_AUTH_LEGACY, + ) + + assert mint_calls == [token] + assert creds["api_key"] == "legacy-after-jwt-401" + assert creds["auth_path"] == "legacy_session_key_mint" + payload = json.loads((hermes_home / "auth.json").read_text()) + assert payload["providers"]["nous"]["agent_key"] == "legacy-after-jwt-401" + + def test_resolve_nous_runtime_credentials_falls_back_when_invoke_scope_missing( tmp_path, monkeypatch, @@ -735,6 +847,9 @@ def test_terminal_refresh_failure_quarantines_tokens( hermes_home = tmp_path / "hermes" _setup_nous_auth(hermes_home, refresh_token="refresh-old") monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + from agent.credential_pool import load_pool + + assert load_pool("nous").select() is not None shared_state = _full_state_fixture() shared_state["access_token"] = "access-old" @@ -765,6 +880,8 @@ def test_terminal_refresh_failure_quarantines_tokens( assert not state_after_failure.get("agent_key") assert state_after_failure["last_auth_error"]["code"] == "invalid_grant" assert auth_mod._read_shared_nous_state() is None + payload = json.loads((hermes_home / "auth.json").read_text()) + assert payload.get("credential_pool", {}).get("nous") == [] with pytest.raises(AuthError, match="No access token found"): auth_mod.resolve_nous_runtime_credentials(min_key_ttl_seconds=300) @@ -780,6 +897,9 @@ def test_managed_access_token_refresh_failure_quarantines_tokens( hermes_home = tmp_path / "hermes" _setup_nous_auth(hermes_home, refresh_token="refresh-old") monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + from agent.credential_pool import load_pool + + assert load_pool("nous").select() is not None refresh_calls: list[str] = [] @@ -802,6 +922,8 @@ def test_managed_access_token_refresh_failure_quarantines_tokens( assert not state_after_failure.get("refresh_token") assert not state_after_failure.get("access_token") assert state_after_failure["last_auth_error"]["message"] == "Invalid refresh token" + payload = json.loads((hermes_home / "auth.json").read_text()) + assert payload.get("credential_pool", {}).get("nous") == [] with pytest.raises(AuthError, match="No access token found"): auth_mod.resolve_nous_access_token() @@ -1076,7 +1198,11 @@ def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch calls after a Nous 401 — before the fix it would raise AuthError because providers.nous was empty. """ - from hermes_cli.auth import persist_nous_credentials, resolve_nous_runtime_credentials + from hermes_cli.auth import ( + NOUS_INFERENCE_AUTH_FRESH, + persist_nous_credentials, + resolve_nous_runtime_credentials, + ) hermes_home = tmp_path / "hermes" hermes_home.mkdir(parents=True, exist_ok=True) @@ -1104,7 +1230,10 @@ def test_persist_nous_credentials_allows_recovery_from_401(tmp_path, monkeypatch monkeypatch.setattr("hermes_cli.auth._refresh_access_token", _fake_refresh_access_token) monkeypatch.setattr("hermes_cli.auth._mint_agent_key", _fake_mint_agent_key) - creds = resolve_nous_runtime_credentials(min_key_ttl_seconds=300, force_mint=True) + creds = resolve_nous_runtime_credentials( + min_key_ttl_seconds=300, + auth_mode=NOUS_INFERENCE_AUTH_FRESH, + ) assert creds["api_key"] == "new-agent-key" @@ -1569,7 +1698,7 @@ def test_try_import_shared_rehydrates_on_success(shared_store_env, monkeypatch): def _fake_refresh(state, **kwargs): # Simulate portal returning fresh tokens + a new agent_key assert kwargs.get("force_refresh") is True - assert kwargs.get("force_mint") is True + assert kwargs.get("auth_mode") == auth_mod.NOUS_INFERENCE_AUTH_FRESH return { **state, "access_token": "fresh-access-tok", @@ -1697,7 +1826,7 @@ def test_runtime_refresh_uses_newer_shared_token_before_local_stale_token( creds = auth_mod.resolve_nous_runtime_credentials( min_key_ttl_seconds=300, - force_mint=True, + auth_mode=auth_mod.NOUS_INFERENCE_AUTH_FRESH, ) assert creds["api_key"] == "agent-key-from-shared-token" diff --git a/tests/hermes_cli/test_proxy.py b/tests/hermes_cli/test_proxy.py index 3ab06eeb92f..9303fb1c702 100644 --- a/tests/hermes_cli/test_proxy.py +++ b/tests/hermes_cli/test_proxy.py @@ -141,6 +141,45 @@ def test_nous_adapter_get_credential_refreshes_and_persists(tmp_path, monkeypatc assert stored["providers"]["nous"]["agent_key"] == "minted-bearer" +def test_nous_adapter_retry_credential_forces_legacy_mint(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _write_auth_store(tmp_path, { + "access_token": "jwt-access", + "refresh_token": "refresh-tok", + "client_id": "hermes-cli", + "portal_base_url": "https://portal.nousresearch.com", + "inference_base_url": "https://inference-api.nousresearch.com/v1", + "agent_key": "jwt-access", + }) + + refreshed_state = { + "access_token": "jwt-access", + "refresh_token": "refresh-tok", + "client_id": "hermes-cli", + "portal_base_url": "https://portal.nousresearch.com", + "inference_base_url": "https://inference-api.nousresearch.com/v1", + "agent_key": "legacy-bearer", + "agent_key_expires_at": "2099-01-01T00:00:00Z", + } + + with patch( + "hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state", + return_value=refreshed_state, + ) as mock_refresh: + adapter = NousPortalAdapter() + cred = adapter.get_retry_credential( + failed_credential=UpstreamCredential( + bearer="jwt-access", + base_url="https://inference-api.nousresearch.com/v1", + ), + status_code=401, + ) + + assert cred is not None + assert cred.bearer == "legacy-bearer" + assert mock_refresh.call_args.kwargs["auth_mode"] == "legacy" + + def test_nous_adapter_get_credential_raises_when_not_logged_in(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path)) adapter = NousPortalAdapter() @@ -166,6 +205,7 @@ def test_nous_adapter_get_credential_raises_on_refresh_failure(tmp_path, monkeyp def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch): from hermes_cli.auth import AuthError + from agent.credential_pool import load_pool monkeypatch.setenv("HERMES_HOME", str(tmp_path)) _write_auth_store(tmp_path, { @@ -173,6 +213,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch "refresh_token": "refresh-tok", "agent_key": "stale-agent-key", }) + assert load_pool("nous").select() is not None with patch( "hermes_cli.proxy.adapters.nous_portal.refresh_nous_oauth_from_state", @@ -193,6 +234,7 @@ def test_nous_adapter_quarantines_terminal_refresh_failure(tmp_path, monkeypatch assert not nous_state.get("access_token") assert not nous_state.get("agent_key") assert nous_state["last_auth_error"]["code"] == "invalid_grant" + assert stored.get("credential_pool", {}).get("nous") == [] def test_nous_adapter_get_credential_raises_when_no_agent_key_returned(tmp_path, monkeypatch): @@ -291,12 +333,15 @@ class FakeAdapter(UpstreamAdapter): """A test adapter that returns a fixed credential without touching disk.""" def __init__(self, base_url: str, bearer: str = "test-bearer", - allowed=None, raise_on_credential=False): + allowed=None, raise_on_credential=False, + retry_bearer: str | None = None): self._base_url = base_url self._bearer = bearer self._allowed = frozenset(allowed or ["/chat/completions"]) self._raise = raise_on_credential + self._retry_bearer = retry_bearer self.calls = 0 + self.retry_calls = 0 @property def name(self): return "fake" @@ -318,6 +363,17 @@ class FakeAdapter(UpstreamAdapter): expires_at="2099-01-01T00:00:00Z", ) + def get_retry_credential(self, *, failed_credential, status_code): + del failed_credential + self.retry_calls += 1 + if status_code != 401 or not self._retry_bearer: + return None + return UpstreamCredential( + bearer=self._retry_bearer, + base_url=self._base_url, + expires_at="2099-01-01T00:00:00Z", + ) + async def _start_runner(app: "web.Application"): """Spin up an aiohttp app on an ephemeral localhost port. Returns (runner, base_url).""" @@ -358,6 +414,25 @@ def _build_fake_upstream(captured: Dict[str, Any]) -> "web.Application": return app +def _build_retrying_fake_upstream(captured: Dict[str, Any]) -> "web.Application": + async def maybe_unauthorized(request): + body = await request.read() + auth = request.headers.get("Authorization") + captured["requests"].append({ + "method": request.method, + "path": request.path, + "auth": auth, + "body": body.decode("utf-8") if body else "", + }) + if auth == "Bearer jwt-bearer": + return web.json_response({"error": "bad token"}, status=401) + return web.json_response({"ok": True}) + + app = web.Application() + app.router.add_route("*", "/v1/chat/completions", maybe_unauthorized) + return app + + def test_server_forwards_chat_completions(): async def run(): captured: Dict[str, Any] = {"requests": []} @@ -388,6 +463,41 @@ def test_server_forwards_chat_completions(): asyncio.run(run()) +def test_server_retries_once_with_adapter_retry_credential_on_401(): + async def run(): + captured: Dict[str, Any] = {"requests": []} + upstream_runner, upstream_base = await _start_runner( + _build_retrying_fake_upstream(captured) + ) + adapter = FakeAdapter( + f"{upstream_base}/v1", + bearer="jwt-bearer", + retry_bearer="legacy-bearer", + ) + proxy_runner, proxy_base = await _start_runner(create_app(adapter)) + + try: + async with aiohttp.ClientSession() as session: + async with session.post( + f"{proxy_base}/v1/chat/completions", + json={"model": "Hermes-4-70B"}, + ) as resp: + assert resp.status == 200 + data = await resp.json() + assert data["ok"] is True + + assert adapter.retry_calls == 1 + assert [req["auth"] for req in captured["requests"]] == [ + "Bearer jwt-bearer", + "Bearer legacy-bearer", + ] + finally: + await proxy_runner.cleanup() + await upstream_runner.cleanup() + + asyncio.run(run()) + + def test_server_rejects_disallowed_path(): async def run(): adapter = FakeAdapter("http://unused.example/v1", allowed=["/chat/completions"]) diff --git a/tests/hermes_cli/test_web_oauth_dispatch.py b/tests/hermes_cli/test_web_oauth_dispatch.py index 23b72a303cf..b9ee20ccae8 100644 --- a/tests/hermes_cli/test_web_oauth_dispatch.py +++ b/tests/hermes_cli/test_web_oauth_dispatch.py @@ -19,11 +19,12 @@ The fix: These tests pin the corrected behavior. """ +import asyncio import time from datetime import datetime, timezone from unittest.mock import patch -import pytest +import httpx from fastapi.testclient import TestClient from hermes_cli.web_server import _SESSION_TOKEN, app @@ -32,6 +33,32 @@ client = TestClient(app) HEADERS = {"X-Hermes-Session-Token": _SESSION_TOKEN} +def _fake_nous_device_data(): + return { + "device_code": "device-code", + "user_code": "NOUS-1234", + "verification_uri": "https://portal.nousresearch.com/device", + "verification_uri_complete": ( + "https://portal.nousresearch.com/device?user_code=NOUS-1234" + ), + "expires_in": 600, + "interval": 5, + } + + +def _invoke_scope_refusal(): + request = httpx.Request("POST", "https://portal.nousresearch.com/oauth/device/code") + response = httpx.Response( + 400, + json={ + "error": "invalid_scope", + "error_description": "unsupported scope inference:invoke", + }, + request=request, + ) + return httpx.HTTPStatusError("invalid scope", request=request, response=response) + + def test_minimax_login_does_not_launch_anthropic_flow(): """Click 'Login' on MiniMax → MUST NOT return claude.ai auth_url.""" fake_user_code_resp = { @@ -48,6 +75,9 @@ def test_minimax_login_does_not_launch_anthropic_flow(): ), patch( "hermes_cli.auth._minimax_pkce_pair", return_value=("verifier-stub", "challenge-stub", "stub-state"), + ), patch( + "hermes_cli.web_server._minimax_poller", + return_value=None, ): resp = client.post( "/api/providers/oauth/minimax-oauth/start", @@ -69,6 +99,113 @@ def test_minimax_login_does_not_launch_anthropic_flow(): assert body["expires_in"] == 600 +def test_nous_dashboard_device_flow_honors_legacy_scope_override(monkeypatch): + from hermes_cli import auth as auth_mod + from hermes_cli import web_server as ws + + requested_scopes = [] + + def fake_request_device_code(**kwargs): + requested_scopes.append(kwargs["scope"]) + return _fake_nous_device_data() + + monkeypatch.setenv(auth_mod.NOUS_LEGACY_SESSION_KEYS_ENV, "true") + monkeypatch.setattr(auth_mod, "_request_device_code", fake_request_device_code) + monkeypatch.setattr(ws, "_nous_poller", lambda sid: None) + + result = asyncio.run(ws._start_device_code_flow("nous")) + try: + assert requested_scopes == [auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE] + assert result["flow"] == "device_code" + assert result["user_code"] == "NOUS-1234" + assert ( + ws._oauth_sessions[result["session_id"]]["scope"] + == auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE + ) + finally: + ws._oauth_sessions.pop(result["session_id"], None) + + +def test_nous_dashboard_device_flow_retries_legacy_scope_on_invoke_refusal(monkeypatch): + from hermes_cli import auth as auth_mod + from hermes_cli import web_server as ws + + requested_scopes = [] + + def fake_request_device_code(**kwargs): + requested_scopes.append(kwargs["scope"]) + if len(requested_scopes) == 1: + raise _invoke_scope_refusal() + return _fake_nous_device_data() + + monkeypatch.delenv(auth_mod.NOUS_LEGACY_SESSION_KEYS_ENV, raising=False) + monkeypatch.setattr(auth_mod, "_request_device_code", fake_request_device_code) + monkeypatch.setattr(ws, "_nous_poller", lambda sid: None) + + result = asyncio.run(ws._start_device_code_flow("nous")) + try: + assert requested_scopes == [ + auth_mod.DEFAULT_NOUS_SCOPE, + auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE, + ] + assert ( + ws._oauth_sessions[result["session_id"]]["scope"] + == auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE + ) + finally: + ws._oauth_sessions.pop(result["session_id"], None) + + +def test_nous_dashboard_poller_preserves_effective_scope_when_token_omits_scope(monkeypatch): + from hermes_cli import auth as auth_mod + from hermes_cli import web_server as ws + + session_id = "nous-effective-scope-test" + ws._oauth_sessions[session_id] = { + "session_id": session_id, + "provider": "nous", + "flow": "device_code", + "created_at": time.time(), + "status": "pending", + "error_message": None, + "portal_base_url": "https://portal.nousresearch.com", + "client_id": "hermes-cli", + "device_code": "device-code", + "interval": 5, + "expires_at": time.time() + 600, + "scope": auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE, + } + captured_state = {} + + def fake_refresh_nous_oauth_from_state(state, **kwargs): + captured_state.update(state) + return {**state, "agent_key": "legacy-agent-key"} + + monkeypatch.setattr( + auth_mod, + "_poll_for_token", + lambda **kwargs: { + "access_token": "access-token", + "refresh_token": "refresh-token", + "expires_in": 3600, + "token_type": "Bearer", + }, + ) + monkeypatch.setattr( + auth_mod, + "refresh_nous_oauth_from_state", + fake_refresh_nous_oauth_from_state, + ) + monkeypatch.setattr(auth_mod, "persist_nous_credentials", lambda state: None) + + try: + ws._nous_poller(session_id) + assert captured_state["scope"] == auth_mod.NOUS_LEGACY_AGENT_KEY_SCOPE + assert ws._oauth_sessions[session_id]["status"] == "approved" + finally: + ws._oauth_sessions.pop(session_id, None) + + def test_minimax_dashboard_poller_accepts_absolute_ms_expired_in(): """Dashboard MiniMax completion must accept unix-ms token expiry values.""" from hermes_cli import web_server as ws diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index a72359227a6..e569da31666 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -3667,7 +3667,7 @@ class TestNousCredentialRefresh: assert ok is True assert closed["value"] is True - assert captured["force_mint"] is True + assert captured["auth_mode"] == "legacy" assert rebuilt["kwargs"]["api_key"] == "new-nous-key" assert ( rebuilt["kwargs"]["base_url"] == "https://inference-api.nousresearch.com/v1"