refactor(auth): mostly cleanups and style changes

This commit is contained in:
Robin Fernandes 2026-05-17 21:18:53 +10:00 committed by Teknium
parent 0bac7dd05b
commit 20bffa5b37
10 changed files with 145 additions and 125 deletions

View file

@ -78,13 +78,21 @@ NOUS_INFERENCE_INVOKE_SCOPE = "inference:invoke"
DEFAULT_NOUS_SCOPE = f"{NOUS_INFERENCE_INVOKE_SCOPE} {NOUS_LEGACY_AGENT_KEY_SCOPE}"
NOUS_LEGACY_SESSION_KEYS_ENV = "HERMES_AGENT_USE_LEGACY_SESSION_KEYS"
NOUS_DEVICE_CODE_SOURCE = "device_code"
NOUS_INFERENCE_AUTH_AUTO = "auto"
NOUS_INFERENCE_AUTH_FRESH = "fresh"
NOUS_INFERENCE_AUTH_LEGACY = "legacy"
NOUS_INFERENCE_AUTH_MODE_AUTO = "auto"
NOUS_INFERENCE_AUTH_MODE_FRESH = "fresh"
NOUS_INFERENCE_AUTH_MODE_LEGACY = "legacy"
NOUS_INFERENCE_AUTH_MODES = frozenset({
NOUS_INFERENCE_AUTH_AUTO,
NOUS_INFERENCE_AUTH_FRESH,
NOUS_INFERENCE_AUTH_LEGACY,
NOUS_INFERENCE_AUTH_MODE_AUTO,
NOUS_INFERENCE_AUTH_MODE_FRESH,
NOUS_INFERENCE_AUTH_MODE_LEGACY,
})
NOUS_AUTH_PATH_INVOKE_JWT = "invoke_jwt"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE = "legacy_session_key_cache"
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT = "legacy_session_key_mint"
NOUS_AUTH_PATHS = frozenset({
NOUS_AUTH_PATH_INVOKE_JWT,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE,
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT,
})
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
@ -1592,12 +1600,13 @@ def _nous_scope_has_invoke(raw_scope: Any) -> bool:
return NOUS_INFERENCE_INVOKE_SCOPE in _scope_values(raw_scope)
def _normalize_nous_auth_mode(auth_mode: Optional[str]) -> str:
mode = str(auth_mode or NOUS_INFERENCE_AUTH_AUTO).strip().lower()
def _normalize_nous_inference_auth_mode(inference_auth_mode: Optional[str]) -> str:
mode = str(inference_auth_mode or NOUS_INFERENCE_AUTH_MODE_AUTO).strip().lower()
if mode not in NOUS_INFERENCE_AUTH_MODES:
allowed = ", ".join(sorted(NOUS_INFERENCE_AUTH_MODES))
raise ValueError(
f"Invalid Nous inference auth mode {auth_mode!r}; expected one of: {allowed}"
"Invalid Nous inference auth mode "
f"{inference_auth_mode!r}; expected one of: {allowed}"
)
return mode
@ -1649,89 +1658,57 @@ def _nous_invoke_jwt_is_usable(
)
def _nous_invoke_jwt_unavailable_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> str:
return (
_nous_invoke_jwt_status(
token,
scope=scope,
expires_at=expires_at,
min_ttl_seconds=min_ttl_seconds,
)
or "invoke_jwt_unavailable"
)
def _nous_can_select_invoke_jwt(auth_mode: str = NOUS_INFERENCE_AUTH_AUTO) -> bool:
return (
not _nous_legacy_session_keys_forced()
and _normalize_nous_auth_mode(auth_mode) != NOUS_INFERENCE_AUTH_LEGACY
)
def _nous_legacy_session_key_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
) -> str:
if _normalize_nous_auth_mode(auth_mode) == NOUS_INFERENCE_AUTH_LEGACY:
if inference_auth_mode == NOUS_INFERENCE_AUTH_MODE_LEGACY:
return "forced_legacy_session_key"
if _nous_legacy_session_keys_forced():
return "forced_legacy_session_keys"
return _nous_invoke_jwt_unavailable_reason(
token,
scope=scope,
expires_at=expires_at,
return (
_nous_invoke_jwt_status(token, scope=scope, expires_at=expires_at)
or "invoke_jwt_unavailable"
)
def _nous_cached_agent_key_is_usable(
state: Dict[str, Any],
min_ttl_seconds: int,
) -> bool:
return _agent_key_is_usable(state, min_ttl_seconds)
def _choose_nous_inference_auth_path(
state: Dict[str, Any],
*,
access_token: Any = None,
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
) -> Tuple[str, Optional[str]]:
auth_mode = _normalize_nous_auth_mode(auth_mode)
inference_auth_mode = _normalize_nous_inference_auth_mode(inference_auth_mode)
token = state.get("access_token") if access_token is None else access_token
if (
_nous_can_select_invoke_jwt(auth_mode)
not _nous_legacy_session_keys_forced()
and inference_auth_mode != NOUS_INFERENCE_AUTH_MODE_LEGACY
and _nous_invoke_jwt_is_usable(
token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
return "invoke_jwt", None
return NOUS_AUTH_PATH_INVOKE_JWT, None
if (
auth_mode == NOUS_INFERENCE_AUTH_AUTO
and _nous_cached_agent_key_is_usable(
inference_auth_mode == NOUS_INFERENCE_AUTH_MODE_AUTO
and _agent_key_is_usable(
state,
max(60, int(min_key_ttl_seconds)),
)
):
return "legacy_session_key_cache", None
return NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE, None
return (
"legacy_session_key_mint",
NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT,
_nous_legacy_session_key_reason(
token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
auth_mode=auth_mode,
inference_auth_mode=inference_auth_mode,
),
)
@ -3660,7 +3637,7 @@ def _is_nous_invoke_scope_refusal(exc: Exception) -> bool:
)
def _nous_device_scope(
def _nous_device_scope_with_env_override(
requested_scope: Optional[str],
*,
default_scope: str = DEFAULT_NOUS_SCOPE,
@ -4131,7 +4108,7 @@ def _try_import_shared_nous_state(
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=True,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
inference_auth_mode=NOUS_INFERENCE_AUTH_MODE_FRESH,
)
_write_shared_nous_state(refreshed)
except AuthError as exc:
@ -4440,10 +4417,10 @@ def refresh_nous_oauth_pure(
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
force_refresh: bool = False,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
) -> Dict[str, Any]:
"""Refresh Nous OAuth state without mutating auth.json."""
auth_mode = _normalize_nous_auth_mode(auth_mode)
inference_auth_mode = _normalize_nous_inference_auth_mode(inference_auth_mode)
state: Dict[str, Any] = {
"access_token": access_token,
"refresh_token": refresh_token,
@ -4506,11 +4483,11 @@ def refresh_nous_oauth_pure(
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
state,
min_key_ttl_seconds=min_agent_key_ttl,
auth_mode=auth_mode,
inference_auth_mode=inference_auth_mode,
)
if selected_auth_path == "invoke_jwt":
if selected_auth_path == NOUS_AUTH_PATH_INVOKE_JWT:
_select_nous_invoke_jwt(state)
elif selected_auth_path == "legacy_session_key_mint":
elif selected_auth_path == NOUS_AUTH_PATH_LEGACY_SESSION_KEY_MINT:
_log_nous_legacy_session_key_selected(
fallback_reason or "legacy_session_key_required",
access_token=state.get("access_token"),
@ -4541,7 +4518,7 @@ def refresh_nous_oauth_from_state(
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
timeout_seconds: float = 15.0,
force_refresh: bool = False,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
) -> Dict[str, Any]:
"""Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure."""
tls = state.get("tls") or {}
@ -4562,7 +4539,7 @@ def refresh_nous_oauth_from_state(
insecure=tls.get("insecure"),
ca_bundle=tls.get("ca_bundle"),
force_refresh=force_refresh,
auth_mode=auth_mode,
inference_auth_mode=inference_auth_mode,
)
@ -4640,7 +4617,7 @@ def resolve_nous_runtime_credentials(
timeout_seconds: float = 15.0,
insecure: Optional[bool] = None,
ca_bundle: Optional[str] = None,
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
inference_auth_mode: str = NOUS_INFERENCE_AUTH_MODE_AUTO,
) -> Dict[str, Any]:
"""
Resolve Nous inference credentials for runtime use.
@ -4652,7 +4629,7 @@ def resolve_nous_runtime_credentials(
Returns dict with: provider, base_url, api_key, key_id, expires_at,
expires_in, source ("invoke_jwt", "cache", or "portal"), and auth_path.
"""
auth_mode = _normalize_nous_auth_mode(auth_mode)
inference_auth_mode = _normalize_nous_inference_auth_mode(inference_auth_mode)
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
@ -4682,6 +4659,8 @@ def resolve_nous_runtime_credentials(
def _persist_state(reason: str) -> None:
nonlocal persisted_state, state_persisted
# Skip writes where only derived TTL countdowns changed; this keeps
# the mtime-keyed Nous auth-status cache warm during read paths.
if (
_nous_effective_provider_state(state)
== _nous_effective_provider_state(persisted_state)
@ -4723,7 +4702,7 @@ def resolve_nous_runtime_credentials(
_oauth_trace(
"nous_runtime_credentials_start",
sequence_id=sequence_id,
auth_mode=auth_mode,
inference_auth_mode=inference_auth_mode,
min_key_ttl_seconds=min_key_ttl_seconds,
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
)
@ -4830,16 +4809,16 @@ def resolve_nous_runtime_credentials(
state,
access_token=access_token,
min_key_ttl_seconds=min_key_ttl_seconds,
auth_mode=auth_mode,
inference_auth_mode=inference_auth_mode,
)
if selected_auth_path == "invoke_jwt":
if selected_auth_path == NOUS_AUTH_PATH_INVOKE_JWT:
_select_nous_invoke_jwt(
state,
access_token=access_token,
sequence_id=sequence_id,
)
elif selected_auth_path == "legacy_session_key_cache":
elif selected_auth_path == NOUS_AUTH_PATH_LEGACY_SESSION_KEY_CACHE:
used_cached_key = True
logger.info("Nous inference auth: using cached agent_key")
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
@ -4929,20 +4908,20 @@ def resolve_nous_runtime_credentials(
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
retry_auth_mode = (
NOUS_INFERENCE_AUTH_LEGACY
if auth_mode == NOUS_INFERENCE_AUTH_LEGACY
else NOUS_INFERENCE_AUTH_FRESH
retry_inference_auth_mode = (
NOUS_INFERENCE_AUTH_MODE_LEGACY
if inference_auth_mode == NOUS_INFERENCE_AUTH_MODE_LEGACY
else NOUS_INFERENCE_AUTH_MODE_FRESH
)
retry_auth_path, _ = _choose_nous_inference_auth_path(
state,
access_token=access_token,
min_key_ttl_seconds=min_key_ttl_seconds,
auth_mode=retry_auth_mode,
inference_auth_mode=retry_inference_auth_mode,
)
if retry_auth_path == "invoke_jwt":
if retry_auth_path == NOUS_AUTH_PATH_INVOKE_JWT:
mint_payload = None
selected_auth_path = "invoke_jwt"
selected_auth_path = NOUS_AUTH_PATH_INVOKE_JWT
_select_nous_invoke_jwt(
state,
access_token=access_token,
@ -5008,8 +4987,8 @@ def resolve_nous_runtime_credentials(
"expires_at": expires_at,
"expires_in": expires_in,
"source": (
"invoke_jwt"
if selected_auth_path == "invoke_jwt"
NOUS_AUTH_PATH_INVOKE_JWT
if selected_auth_path == NOUS_AUTH_PATH_INVOKE_JWT
else ("cache" if used_cached_key else "portal")
),
"auth_path": selected_auth_path,
@ -6691,7 +6670,10 @@ def _nous_device_code_login(
or pconfig.inference_base_url
).rstrip("/")
client_id = client_id or pconfig.client_id
scope, explicit_scope = _nous_device_scope(scope, default_scope=pconfig.scope)
scope, explicit_scope = _nous_device_scope_with_env_override(
scope,
default_scope=pconfig.scope,
)
timeout = httpx.Timeout(timeout_seconds)
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
@ -6781,7 +6763,7 @@ def _nous_device_code_login(
min_key_ttl_seconds=min_key_ttl_seconds,
timeout_seconds=timeout_seconds,
force_refresh=False,
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
inference_auth_mode=NOUS_INFERENCE_AUTH_MODE_FRESH,
)
except AuthError as exc:
if exc.code == "subscription_required":