Switch to JWT token for inference against Nous, falling back to old opaque token on failure.

This commit is contained in:
Robin Fernandes 2026-05-17 19:34:44 +10:00 committed by Teknium
parent c905562623
commit 89a3d038cf
10 changed files with 780 additions and 45 deletions

View file

@ -67,9 +67,13 @@ AUTH_LOCK_TIMEOUT_SECONDS = 15.0
DEFAULT_NOUS_PORTAL_URL = "https://portal.nousresearch.com"
DEFAULT_NOUS_INFERENCE_URL = "https://inference-api.nousresearch.com/v1"
DEFAULT_NOUS_CLIENT_ID = "hermes-cli"
DEFAULT_NOUS_SCOPE = "inference:mint_agent_key"
NOUS_LEGACY_AGENT_KEY_SCOPE = "inference:mint_agent_key"
NOUS_INFERENCE_INVOKE_SCOPE = "inference:invoke"
DEFAULT_NOUS_SCOPE = f"{NOUS_INFERENCE_INVOKE_SCOPE} {NOUS_LEGACY_AGENT_KEY_SCOPE}"
NOUS_LEGACY_SESSION_KEYS_ENV = "HERMES_AGENT_USE_LEGACY_SESSION_KEYS"
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS
DEVICE_AUTH_POLL_INTERVAL_CAP_SECONDS = 1 # poll at most every 1s
DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api/codex"
DEFAULT_XAI_OAUTH_BASE_URL = "https://api.x.ai/v1"
@ -1549,6 +1553,117 @@ def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
return claims if isinstance(claims, dict) else {}
def _scope_values(raw_scope: Any) -> set[str]:
scopes: set[str] = set()
if isinstance(raw_scope, str):
for part in raw_scope.replace(",", " ").split():
cleaned = part.strip()
if cleaned:
scopes.add(cleaned)
elif isinstance(raw_scope, (list, tuple, set, frozenset)):
for item in raw_scope:
if isinstance(item, str):
scopes.update(_scope_values(item))
return scopes
def _nous_legacy_session_keys_forced() -> bool:
return is_truthy_value(os.getenv(NOUS_LEGACY_SESSION_KEYS_ENV), default=False)
def _nous_scope_has_invoke(raw_scope: Any) -> bool:
return NOUS_INFERENCE_INVOKE_SCOPE in _scope_values(raw_scope)
def _nous_invoke_jwt_is_usable(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> bool:
claims = _decode_jwt_claims(token)
if not claims:
return False
scopes = (
_scope_values(scope)
| _scope_values(claims.get("scope"))
| _scope_values(claims.get("scp"))
)
if NOUS_INFERENCE_INVOKE_SCOPE not in scopes:
return False
exp = claims.get("exp")
skew = max(0, int(min_ttl_seconds))
if isinstance(exp, (int, float)):
return float(exp) > (time.time() + skew)
return not _is_expiring(expires_at, skew)
def _nous_invoke_jwt_unavailable_reason(
token: Any,
*,
scope: Any = None,
expires_at: Any = None,
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
) -> str:
claims = _decode_jwt_claims(token)
if not claims:
return "access_token_not_jwt"
scopes = (
_scope_values(scope)
| _scope_values(claims.get("scope"))
| _scope_values(claims.get("scp"))
)
if NOUS_INFERENCE_INVOKE_SCOPE not in scopes:
return "missing_inference_invoke_scope"
exp = claims.get("exp")
skew = max(0, int(min_ttl_seconds))
if isinstance(exp, (int, float)) and float(exp) <= (time.time() + skew):
return "invoke_jwt_expiring"
if not isinstance(exp, (int, float)) and _is_expiring(expires_at, skew):
return "invoke_jwt_expiry_unknown_or_expiring"
return "invoke_jwt_unavailable"
def _nous_jwt_expires_at(token: Any, fallback_expires_at: Any = None) -> Optional[str]:
claims = _decode_jwt_claims(token)
exp = claims.get("exp")
if isinstance(exp, (int, float)):
try:
return datetime.fromtimestamp(float(exp), tz=timezone.utc).isoformat()
except Exception:
pass
return fallback_expires_at if isinstance(fallback_expires_at, str) else None
def _set_nous_agent_key_from_invoke_jwt(
state: Dict[str, Any],
*,
obtained_at: Optional[str] = None,
) -> None:
access_token = state.get("access_token")
if not isinstance(access_token, str) or not access_token.strip():
return
now = datetime.now(timezone.utc)
effective_obtained_at = obtained_at or now.isoformat()
expires_at = _nous_jwt_expires_at(access_token, state.get("expires_at"))
expires_epoch = _parse_iso_timestamp(expires_at)
expires_in = (
max(0, int(expires_epoch - time.time()))
if expires_epoch is not None
else _coerce_ttl_seconds(state.get("expires_in"))
)
if expires_at:
state["expires_at"] = expires_at
state["expires_in"] = expires_in
state["agent_key"] = access_token
state["agent_key_id"] = None
state["agent_key_expires_at"] = expires_at
state["agent_key_expires_in"] = expires_in
state["agent_key_reused"] = False
state["agent_key_obtained_at"] = effective_obtained_at
def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> bool:
claims = _decode_jwt_claims(access_token)
exp = claims.get("exp")
@ -3333,6 +3448,34 @@ def _request_device_code(
return data
def _is_nous_invoke_scope_refusal(exc: Exception) -> bool:
if not isinstance(exc, httpx.HTTPStatusError):
return False
response = exc.response
if response.status_code not in {400, 401, 403}:
return False
try:
payload = response.json()
except Exception:
payload = {}
text = " ".join(
str(value)
for value in (
payload.get("error") if isinstance(payload, dict) else None,
payload.get("error_description") if isinstance(payload, dict) else None,
response.text,
)
if value
).lower()
if not text:
return False
return (
"invalid_scope" in text
or "unsupported_scope" in text
or "scope" in text and NOUS_INFERENCE_INVOKE_SCOPE in text
)
def _poll_for_token(
client: httpx.Client,
portal_base_url: str,
@ -3524,8 +3667,9 @@ def _write_shared_nous_state(state: Dict[str, Any]) -> None:
is a convenience layer; the per-profile auth.json remains the source
of truth.
We deliberately omit the short-lived ``agent_key`` (24h TTL, profile-
specific) only the long-lived OAuth tokens are cross-profile useful.
We deliberately omit the runtime ``agent_key`` compatibility field
(either an invoke JWT or legacy opaque session key) only OAuth tokens
are cross-profile useful.
"""
refresh_token = state.get("refresh_token")
access_token = state.get("access_token")
@ -3894,6 +4038,14 @@ def _agent_key_is_usable(state: Dict[str, Any], min_ttl_seconds: int) -> bool:
key = state.get("agent_key")
if not isinstance(key, str) or not key.strip():
return False
if _decode_jwt_claims(key):
if _nous_legacy_session_keys_forced():
return False
return _nous_invoke_jwt_is_usable(
key,
scope=state.get("scope"),
expires_at=state.get("agent_key_expires_at"),
)
return not _is_expiring(state.get("agent_key_expires_at"), min_ttl_seconds)
@ -4039,7 +4191,23 @@ def refresh_nous_oauth_pure(
timeout = httpx.Timeout(timeout_seconds if timeout_seconds else 15.0)
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
if force_refresh or _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
min_agent_key_ttl = max(60, int(min_key_ttl_seconds))
legacy_session_keys = _nous_legacy_session_keys_forced()
current_invoke_jwt_usable = (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
state.get("access_token"),
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
if (
force_refresh
or (
_is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS)
and not current_invoke_jwt_usable
)
):
refreshed = _refresh_access_token(
client=client,
portal_base_url=state["portal_base_url"],
@ -4061,7 +4229,39 @@ def refresh_nous_oauth_pure(
now.timestamp() + access_ttl, tz=timezone.utc
).isoformat()
if force_mint or not _agent_key_is_usable(state, max(60, int(min_key_ttl_seconds))):
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
state.get("access_token"),
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
access_token_fp=_token_fingerprint(state.get("access_token")),
)
elif force_mint or not _agent_key_is_usable(state, min_agent_key_ttl):
fallback_reason = (
"forced_legacy_session_keys"
if legacy_session_keys
else _nous_invoke_jwt_unavailable_reason(
state.get("access_token"),
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
logger.info(
"Nous inference auth: using legacy session key path (%s)",
fallback_reason,
)
_oauth_trace(
"nous_legacy_session_key_selected",
reason=fallback_reason,
access_token_fp=_token_fingerprint(state.get("access_token")),
)
mint_payload = _mint_agent_key(
client=client,
portal_base_url=state["portal_base_url"],
@ -4175,6 +4375,15 @@ def persist_nous_credentials(
)
def _sync_nous_pool_from_auth_store() -> None:
try:
from agent.credential_pool import load_pool
load_pool("nous")
except Exception as exc:
logger.debug("Failed to sync Nous credential pool from auth store: %s", exc)
def resolve_nous_runtime_credentials(
*,
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
@ -4191,7 +4400,7 @@ def resolve_nous_runtime_credentials(
Concurrent processes coordinate through the auth store file lock.
Returns dict with: provider, base_url, api_key, key_id, expires_at,
expires_in, source ("cache" or "portal").
expires_in, source ("invoke_jwt", "cache", or "portal"), and auth_path.
"""
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
sequence_id = uuid.uuid4().hex[:12]
@ -4260,15 +4469,35 @@ def resolve_nous_runtime_credentials(
raise AuthError("No access token found for Nous Portal login.",
provider="nous", relogin_required=True)
# Step 1: refresh access token if expiring
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
# Step 1: refresh access token if expiring. If the access token
# is already a valid invoke JWT, trust its own exp claim even when
# older auth.json metadata has a stale/missing expires_at.
current_invoke_jwt_usable = (
not _nous_legacy_session_keys_forced()
and _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
if (
_is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS)
and not current_invoke_jwt_usable
):
with _nous_shared_store_lock(timeout_seconds=max(timeout_seconds + 5.0, AUTH_LOCK_TIMEOUT_SECONDS)):
if _merge_shared_nous_oauth_state(state):
access_token = state.get("access_token")
refresh_token = state.get("refresh_token")
_persist_state("post_shared_merge_access_expiring")
if _is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
if (
_is_expiring(state.get("expires_at"), ACCESS_TOKEN_REFRESH_SKEW_SECONDS)
and not _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
if not isinstance(refresh_token, str) or not refresh_token:
raise AuthError("Session expired and no refresh token is available.",
provider="nous", relogin_required=True)
@ -4320,14 +4549,56 @@ def resolve_nous_runtime_credentials(
# Persist immediately so downstream mint failures cannot drop rotated refresh tokens.
_persist_state("post_refresh_access_expiring")
# Step 2: mint agent key if missing/expiring
# Step 2: resolve the compatibility ``agent_key`` field. Preferred
# path stores the NAS invoke JWT there; legacy path mints/reuses
# the opaque session key.
used_cached_key = False
mint_payload: Optional[Dict[str, Any]] = None
selected_auth_path = "legacy_session_key"
legacy_session_keys = _nous_legacy_session_keys_forced()
if not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
selected_auth_path = "invoke_jwt"
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
elif not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
used_cached_key = True
selected_auth_path = "legacy_session_key_cache"
logger.info("Nous inference auth: using cached legacy session key")
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
else:
fallback_reason = (
"forced_legacy_session_keys"
if legacy_session_keys
else _nous_invoke_jwt_unavailable_reason(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
)
selected_auth_path = "legacy_session_key_mint"
logger.info(
"Nous inference auth: using legacy session key path (%s)",
fallback_reason,
)
_oauth_trace(
"nous_legacy_session_key_selected",
sequence_id=sequence_id,
reason=fallback_reason,
access_token_fp=_token_fingerprint(access_token),
)
try:
_oauth_trace(
"mint_start",
@ -4403,10 +4674,28 @@ def resolve_nous_runtime_credentials(
# Persist retry refresh immediately for crash safety and cross-process visibility.
_persist_state("post_refresh_mint_retry")
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
)
if (
not legacy_session_keys
and _nous_invoke_jwt_is_usable(
access_token,
scope=state.get("scope"),
expires_at=state.get("expires_at"),
)
):
_set_nous_agent_key_from_invoke_jwt(state)
mint_payload = None
selected_auth_path = "invoke_jwt"
logger.info("Nous inference auth: using NAS invoke JWT")
_oauth_trace(
"nous_invoke_jwt_selected",
sequence_id=sequence_id,
access_token_fp=_token_fingerprint(access_token),
)
else:
mint_payload = _mint_agent_key(
client=client, portal_base_url=portal_base_url,
access_token=access_token, min_ttl_seconds=min_key_ttl_seconds,
)
else:
raise
@ -4438,6 +4727,8 @@ def resolve_nous_runtime_credentials(
_persist_state("resolve_nous_runtime_credentials_final")
_sync_nous_pool_from_auth_store()
api_key = state.get("agent_key")
if not isinstance(api_key, str) or not api_key:
raise AuthError("Failed to resolve a Nous inference API key",
@ -4458,7 +4749,12 @@ def resolve_nous_runtime_credentials(
"key_id": state.get("agent_key_id"),
"expires_at": expires_at,
"expires_in": expires_in,
"source": "cache" if used_cached_key else "portal",
"source": (
"invoke_jwt"
if selected_auth_path == "invoke_jwt"
else ("cache" if used_cached_key else "portal")
),
"auth_path": selected_auth_path,
}
@ -6137,7 +6433,10 @@ def _nous_device_code_login(
or pconfig.inference_base_url
).rstrip("/")
client_id = client_id or pconfig.client_id
explicit_scope = scope is not None
scope = scope or pconfig.scope
if _nous_legacy_session_keys_forced():
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
timeout = httpx.Timeout(timeout_seconds)
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
@ -6152,12 +6451,30 @@ def _nous_device_code_login(
print(f"TLS verification: custom CA bundle ({ca_bundle})")
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
device_data = _request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
try:
device_data = _request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
except Exception as exc:
if (
not explicit_scope
and _nous_scope_has_invoke(scope)
and _is_nous_invoke_scope_refusal(exc)
):
logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope")
_oauth_trace("nous_device_code_invoke_scope_refused")
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
device_data = _request_device_code(
client=client,
portal_base_url=portal_base_url,
client_id=client_id,
scope=scope,
)
else:
raise
verification_url = str(device_data["verification_uri_complete"])
user_code = str(device_data["user_code"])
@ -6287,7 +6604,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
portal_base_url=getattr(args, "portal_url", None),
inference_base_url=getattr(args, "inference_url", None),
client_id=getattr(args, "client_id", None) or pconfig.client_id,
scope=getattr(args, "scope", None) or pconfig.scope,
scope=getattr(args, "scope", None),
open_browser=not getattr(args, "no_browser", False),
timeout_seconds=timeout_seconds,
insecure=insecure,
@ -6314,6 +6631,7 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
# these credentials. Best-effort: any I/O failure is logged and
# swallowed inside the helper.
_write_shared_nous_state(auth_state)
_sync_nous_pool_from_auth_store()
print()
print("Login successful!")

View file

@ -1,12 +1,13 @@
"""Nous Portal upstream adapter.
Reads the user's Nous OAuth state from ``~/.hermes/auth.json``, refreshes
the access token and mints a fresh agent key when needed, and exposes the
upstream base URL plus minted bearer for the proxy server to forward to.
the access token and resolves the ``agent_key`` compatibility credential
when needed, then exposes the upstream base URL plus bearer for the proxy
server to forward to.
The minted ``agent_key`` (not the OAuth ``access_token``) is what
``inference-api.nousresearch.com`` accepts as a bearer. The refresh helper
already handles both see :func:`hermes_cli.auth.refresh_nous_oauth_from_state`.
The ``agent_key`` field may hold either a NAS invoke JWT or the legacy
opaque session key. The refresh helper handles both see
:func:`hermes_cli.auth.refresh_nous_oauth_from_state`.
"""
from __future__ import annotations

View file

@ -875,10 +875,9 @@ def _resolve_explicit_runtime(
explicit_base_url
or str(state.get("inference_base_url") or auth_mod.DEFAULT_NOUS_INFERENCE_URL).strip().rstrip("/")
)
# Only use agent_key for inference — access_token is an OAuth token for the
# portal API (minting keys, refreshing tokens), not for the inference API.
# Falling back to access_token sends an OAuth bearer token to the inference
# endpoint, which returns 404 because it is not a valid inference credential.
# Only use the agent_key compatibility field for inference. It may be
# either a NAS invoke JWT or a legacy opaque session key; raw OAuth
# access_token fallback is handled by resolve_nous_runtime_credentials().
api_key = explicit_api_key or str(state.get("agent_key") or "").strip()
expires_at = state.get("agent_key_expires_at") or state.get("expires_at")
if not api_key:
@ -1069,17 +1068,19 @@ def resolve_runtime_provider(
getattr(entry, "runtime_api_key", None)
or getattr(entry, "access_token", "")
)
# For Nous, the pool entry's runtime_api_key is the agent_key — a
# short-lived inference credential (~30 min TTL). The pool doesn't
# For Nous, the pool entry's runtime_api_key is the agent_key
# compatibility field: either an invoke JWT or legacy opaque key.
# The pool doesn't
# refresh it during selection (that would trigger network calls in
# non-runtime contexts like `hermes auth list`). If the key is
# expired, clear pool_api_key so we fall through to
# resolve_nous_runtime_credentials() which handles refresh + mint.
# resolve_nous_runtime_credentials() which handles refresh + fallback.
if provider == "nous" and entry is not None and pool_api_key:
min_ttl = max(60, int(os.getenv("HERMES_NOUS_MIN_KEY_TTL_SECONDS", "1800")))
nous_state = {
"agent_key": getattr(entry, "agent_key", None),
"agent_key_expires_at": getattr(entry, "agent_key_expires_at", None),
"scope": getattr(entry, "scope", None),
}
if not _agent_key_is_usable(nous_state, min_ttl):
logger.debug("Nous pool entry agent_key expired/missing, falling through to runtime resolution")