mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-07 08:02:23 +00:00
refactor(auth): collapse Nous inference fallback controls
This commit is contained in:
parent
89a3d038cf
commit
0bac7dd05b
13 changed files with 1071 additions and 240 deletions
|
|
@ -11,6 +11,12 @@ Architecture:
|
|||
- resolve_provider() picks the active provider via priority chain
|
||||
- resolve_*_runtime_credentials() handles token refresh and key minting
|
||||
- logout_command() is the CLI entry point for clearing auth
|
||||
|
||||
Nous authentication paths:
|
||||
- Invoke JWT (preferred): use a scoped access_token directly for inference.
|
||||
- Legacy session key (fallback): mint an opaque 24h key when JWT auth is
|
||||
unavailable, or when HERMES_AGENT_USE_LEGACY_SESSION_KEYS is set for
|
||||
debugging or rollback.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -71,6 +77,15 @@ NOUS_LEGACY_AGENT_KEY_SCOPE = "inference:mint_agent_key"
|
|||
NOUS_INFERENCE_INVOKE_SCOPE = "inference:invoke"
|
||||
DEFAULT_NOUS_SCOPE = f"{NOUS_INFERENCE_INVOKE_SCOPE} {NOUS_LEGACY_AGENT_KEY_SCOPE}"
|
||||
NOUS_LEGACY_SESSION_KEYS_ENV = "HERMES_AGENT_USE_LEGACY_SESSION_KEYS"
|
||||
NOUS_DEVICE_CODE_SOURCE = "device_code"
|
||||
NOUS_INFERENCE_AUTH_AUTO = "auto"
|
||||
NOUS_INFERENCE_AUTH_FRESH = "fresh"
|
||||
NOUS_INFERENCE_AUTH_LEGACY = "legacy"
|
||||
NOUS_INFERENCE_AUTH_MODES = frozenset({
|
||||
NOUS_INFERENCE_AUTH_AUTO,
|
||||
NOUS_INFERENCE_AUTH_FRESH,
|
||||
NOUS_INFERENCE_AUTH_LEGACY,
|
||||
})
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS = 30 * 60 # 30 minutes
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS = 120 # refresh 2 min before expiry
|
||||
NOUS_INVOKE_JWT_MIN_TTL_SECONDS = ACCESS_TOKEN_REFRESH_SKEW_SECONDS
|
||||
|
|
@ -1554,6 +1569,8 @@ def _decode_jwt_claims(token: Any) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def _scope_values(raw_scope: Any) -> set[str]:
|
||||
# OAuth token responses normally return a space-separated string. Keep
|
||||
# collection support for JWT ``scp`` claims and older stored test fixtures.
|
||||
scopes: set[str] = set()
|
||||
if isinstance(raw_scope, str):
|
||||
for part in raw_scope.replace(",", " ").split():
|
||||
|
|
@ -1575,37 +1592,24 @@ def _nous_scope_has_invoke(raw_scope: Any) -> bool:
|
|||
return NOUS_INFERENCE_INVOKE_SCOPE in _scope_values(raw_scope)
|
||||
|
||||
|
||||
def _nous_invoke_jwt_is_usable(
|
||||
def _normalize_nous_auth_mode(auth_mode: Optional[str]) -> str:
|
||||
mode = str(auth_mode or NOUS_INFERENCE_AUTH_AUTO).strip().lower()
|
||||
if mode not in NOUS_INFERENCE_AUTH_MODES:
|
||||
allowed = ", ".join(sorted(NOUS_INFERENCE_AUTH_MODES))
|
||||
raise ValueError(
|
||||
f"Invalid Nous inference auth mode {auth_mode!r}; expected one of: {allowed}"
|
||||
)
|
||||
return mode
|
||||
|
||||
|
||||
def _nous_invoke_jwt_status(
|
||||
token: Any,
|
||||
*,
|
||||
scope: Any = None,
|
||||
expires_at: Any = None,
|
||||
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
|
||||
) -> bool:
|
||||
claims = _decode_jwt_claims(token)
|
||||
if not claims:
|
||||
return False
|
||||
scopes = (
|
||||
_scope_values(scope)
|
||||
| _scope_values(claims.get("scope"))
|
||||
| _scope_values(claims.get("scp"))
|
||||
)
|
||||
if NOUS_INFERENCE_INVOKE_SCOPE not in scopes:
|
||||
return False
|
||||
exp = claims.get("exp")
|
||||
skew = max(0, int(min_ttl_seconds))
|
||||
if isinstance(exp, (int, float)):
|
||||
return float(exp) > (time.time() + skew)
|
||||
return not _is_expiring(expires_at, skew)
|
||||
|
||||
|
||||
def _nous_invoke_jwt_unavailable_reason(
|
||||
token: Any,
|
||||
*,
|
||||
scope: Any = None,
|
||||
expires_at: Any = None,
|
||||
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
|
||||
) -> str:
|
||||
) -> Optional[str]:
|
||||
"""Return None when the token can be used for inference, else a reason."""
|
||||
claims = _decode_jwt_claims(token)
|
||||
if not claims:
|
||||
return "access_token_not_jwt"
|
||||
|
|
@ -1618,11 +1622,149 @@ def _nous_invoke_jwt_unavailable_reason(
|
|||
return "missing_inference_invoke_scope"
|
||||
exp = claims.get("exp")
|
||||
skew = max(0, int(min_ttl_seconds))
|
||||
if isinstance(exp, (int, float)) and float(exp) <= (time.time() + skew):
|
||||
return "invoke_jwt_expiring"
|
||||
if not isinstance(exp, (int, float)) and _is_expiring(expires_at, skew):
|
||||
if isinstance(exp, (int, float)):
|
||||
if float(exp) <= (time.time() + skew):
|
||||
return "invoke_jwt_expiring"
|
||||
return None
|
||||
if _is_expiring(expires_at, skew):
|
||||
return "invoke_jwt_expiry_unknown_or_expiring"
|
||||
return "invoke_jwt_unavailable"
|
||||
return None
|
||||
|
||||
|
||||
def _nous_invoke_jwt_is_usable(
|
||||
token: Any,
|
||||
*,
|
||||
scope: Any = None,
|
||||
expires_at: Any = None,
|
||||
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
|
||||
) -> bool:
|
||||
return (
|
||||
_nous_invoke_jwt_status(
|
||||
token,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
min_ttl_seconds=min_ttl_seconds,
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
|
||||
def _nous_invoke_jwt_unavailable_reason(
|
||||
token: Any,
|
||||
*,
|
||||
scope: Any = None,
|
||||
expires_at: Any = None,
|
||||
min_ttl_seconds: int = NOUS_INVOKE_JWT_MIN_TTL_SECONDS,
|
||||
) -> str:
|
||||
return (
|
||||
_nous_invoke_jwt_status(
|
||||
token,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
min_ttl_seconds=min_ttl_seconds,
|
||||
)
|
||||
or "invoke_jwt_unavailable"
|
||||
)
|
||||
|
||||
|
||||
def _nous_can_select_invoke_jwt(auth_mode: str = NOUS_INFERENCE_AUTH_AUTO) -> bool:
|
||||
return (
|
||||
not _nous_legacy_session_keys_forced()
|
||||
and _normalize_nous_auth_mode(auth_mode) != NOUS_INFERENCE_AUTH_LEGACY
|
||||
)
|
||||
|
||||
|
||||
def _nous_legacy_session_key_reason(
|
||||
token: Any,
|
||||
*,
|
||||
scope: Any = None,
|
||||
expires_at: Any = None,
|
||||
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
|
||||
) -> str:
|
||||
if _normalize_nous_auth_mode(auth_mode) == NOUS_INFERENCE_AUTH_LEGACY:
|
||||
return "forced_legacy_session_key"
|
||||
if _nous_legacy_session_keys_forced():
|
||||
return "forced_legacy_session_keys"
|
||||
return _nous_invoke_jwt_unavailable_reason(
|
||||
token,
|
||||
scope=scope,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
|
||||
def _nous_cached_agent_key_is_usable(
|
||||
state: Dict[str, Any],
|
||||
min_ttl_seconds: int,
|
||||
) -> bool:
|
||||
return _agent_key_is_usable(state, min_ttl_seconds)
|
||||
|
||||
|
||||
def _choose_nous_inference_auth_path(
|
||||
state: Dict[str, Any],
|
||||
*,
|
||||
access_token: Any = None,
|
||||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
auth_mode = _normalize_nous_auth_mode(auth_mode)
|
||||
token = state.get("access_token") if access_token is None else access_token
|
||||
if (
|
||||
_nous_can_select_invoke_jwt(auth_mode)
|
||||
and _nous_invoke_jwt_is_usable(
|
||||
token,
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
):
|
||||
return "invoke_jwt", None
|
||||
if (
|
||||
auth_mode == NOUS_INFERENCE_AUTH_AUTO
|
||||
and _nous_cached_agent_key_is_usable(
|
||||
state,
|
||||
max(60, int(min_key_ttl_seconds)),
|
||||
)
|
||||
):
|
||||
return "legacy_session_key_cache", None
|
||||
return (
|
||||
"legacy_session_key_mint",
|
||||
_nous_legacy_session_key_reason(
|
||||
token,
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
auth_mode=auth_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _log_nous_invoke_jwt_selected(
|
||||
*,
|
||||
access_token: Any,
|
||||
sequence_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.info("Nous inference auth: using NAS invoke JWT")
|
||||
_oauth_trace(
|
||||
"nous_invoke_jwt_selected",
|
||||
sequence_id=sequence_id,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
|
||||
|
||||
def _log_nous_legacy_session_key_selected(
|
||||
reason: str,
|
||||
*,
|
||||
access_token: Any,
|
||||
sequence_id: Optional[str] = None,
|
||||
) -> None:
|
||||
logger.info(
|
||||
"Nous inference auth: using legacy session key path (%s)",
|
||||
reason,
|
||||
)
|
||||
_oauth_trace(
|
||||
"nous_legacy_session_key_selected",
|
||||
sequence_id=sequence_id,
|
||||
reason=reason,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
|
||||
|
||||
def _nous_jwt_expires_at(token: Any, fallback_expires_at: Any = None) -> Optional[str]:
|
||||
|
|
@ -1645,7 +1787,17 @@ def _set_nous_agent_key_from_invoke_jwt(
|
|||
if not isinstance(access_token, str) or not access_token.strip():
|
||||
return
|
||||
now = datetime.now(timezone.utc)
|
||||
effective_obtained_at = obtained_at or now.isoformat()
|
||||
existing_obtained_at = state.get("agent_key_obtained_at")
|
||||
if obtained_at:
|
||||
effective_obtained_at = obtained_at
|
||||
elif (
|
||||
state.get("agent_key") == access_token
|
||||
and isinstance(existing_obtained_at, str)
|
||||
and existing_obtained_at.strip()
|
||||
):
|
||||
effective_obtained_at = existing_obtained_at
|
||||
else:
|
||||
effective_obtained_at = now.isoformat()
|
||||
expires_at = _nous_jwt_expires_at(access_token, state.get("expires_at"))
|
||||
expires_epoch = _parse_iso_timestamp(expires_at)
|
||||
expires_in = (
|
||||
|
|
@ -1664,6 +1816,38 @@ def _set_nous_agent_key_from_invoke_jwt(
|
|||
state["agent_key_obtained_at"] = effective_obtained_at
|
||||
|
||||
|
||||
def _select_nous_invoke_jwt(
|
||||
state: Dict[str, Any],
|
||||
*,
|
||||
access_token: Any = None,
|
||||
sequence_id: Optional[str] = None,
|
||||
) -> None:
|
||||
if isinstance(access_token, str) and access_token.strip():
|
||||
state["access_token"] = access_token
|
||||
_set_nous_agent_key_from_invoke_jwt(state)
|
||||
_log_nous_invoke_jwt_selected(
|
||||
access_token=state.get("access_token"),
|
||||
sequence_id=sequence_id,
|
||||
)
|
||||
|
||||
|
||||
_NOUS_EFFECTIVE_STATE_IGNORED_KEYS = frozenset({
|
||||
# These are derived from expires_at/JWT exp and naturally tick down between
|
||||
# reads. Persisting only these changes makes auth.json noisy and defeats
|
||||
# the mtime-keyed auth-status cache.
|
||||
"expires_in",
|
||||
"agent_key_expires_in",
|
||||
})
|
||||
|
||||
|
||||
def _nous_effective_provider_state(state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
for key, value in state.items()
|
||||
if key not in _NOUS_EFFECTIVE_STATE_IGNORED_KEYS
|
||||
}
|
||||
|
||||
|
||||
def _codex_access_token_is_expiring(access_token: Any, skew_seconds: int) -> bool:
|
||||
claims = _decode_jwt_claims(access_token)
|
||||
exp = claims.get("exp")
|
||||
|
|
@ -3476,6 +3660,57 @@ def _is_nous_invoke_scope_refusal(exc: Exception) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def _nous_device_scope(
|
||||
requested_scope: Optional[str],
|
||||
*,
|
||||
default_scope: str = DEFAULT_NOUS_SCOPE,
|
||||
) -> Tuple[str, bool]:
|
||||
explicit_scope = requested_scope is not None
|
||||
scope = requested_scope or default_scope
|
||||
if _nous_legacy_session_keys_forced():
|
||||
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
|
||||
return scope, explicit_scope
|
||||
|
||||
|
||||
def _request_nous_device_code_with_scope_fallback(
|
||||
*,
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
client_id: str,
|
||||
scope: str,
|
||||
allow_legacy_fallback: bool,
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
try:
|
||||
return (
|
||||
_request_device_code(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
),
|
||||
scope,
|
||||
)
|
||||
except Exception as exc:
|
||||
if (
|
||||
allow_legacy_fallback
|
||||
and _nous_scope_has_invoke(scope)
|
||||
and _is_nous_invoke_scope_refusal(exc)
|
||||
):
|
||||
logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope")
|
||||
_oauth_trace("nous_device_code_invoke_scope_refused")
|
||||
retry_scope = NOUS_LEGACY_AGENT_KEY_SCOPE
|
||||
return (
|
||||
_request_device_code(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=retry_scope,
|
||||
),
|
||||
retry_scope,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _poll_for_token(
|
||||
client: httpx.Client,
|
||||
portal_base_url: str,
|
||||
|
|
@ -3817,6 +4052,39 @@ def _quarantine_nous_oauth_state(
|
|||
invalidate_nous_auth_status_cache()
|
||||
|
||||
|
||||
def _quarantine_nous_pool_entries(
|
||||
auth_store: Dict[str, Any],
|
||||
error: AuthError,
|
||||
*,
|
||||
reason: str,
|
||||
) -> bool:
|
||||
"""Remove singleton-seeded Nous pool entries that contain dead OAuth state."""
|
||||
pool = auth_store.get("credential_pool")
|
||||
if not isinstance(pool, dict):
|
||||
return False
|
||||
entries = pool.get("nous")
|
||||
if not isinstance(entries, list):
|
||||
return False
|
||||
|
||||
retained = []
|
||||
removed = False
|
||||
singleton_sources = {NOUS_DEVICE_CODE_SOURCE, f"manual:{NOUS_DEVICE_CODE_SOURCE}"}
|
||||
for entry in entries:
|
||||
if isinstance(entry, dict) and entry.get("source") in singleton_sources:
|
||||
removed = True
|
||||
continue
|
||||
retained.append(entry)
|
||||
|
||||
if removed:
|
||||
pool["nous"] = retained
|
||||
_oauth_trace(
|
||||
"nous_pool_device_code_quarantined",
|
||||
reason=reason,
|
||||
error_code=error.code,
|
||||
)
|
||||
return removed
|
||||
|
||||
|
||||
def _try_import_shared_nous_state(
|
||||
*,
|
||||
timeout_seconds: float = 15.0,
|
||||
|
|
@ -3842,7 +4110,7 @@ def _try_import_shared_nous_state(
|
|||
|
||||
# Build a full state dict so refresh_nous_oauth_from_state has every
|
||||
# field it needs. force_refresh=True gets us a fresh access_token
|
||||
# for this profile; force_mint=True gets us a fresh agent_key.
|
||||
# for this profile; fresh auth mode avoids stale cached legacy keys.
|
||||
state: Dict[str, Any] = {
|
||||
"access_token": shared.get("access_token"),
|
||||
"refresh_token": shared.get("refresh_token"),
|
||||
|
|
@ -3863,7 +4131,7 @@ def _try_import_shared_nous_state(
|
|||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
timeout_seconds=timeout_seconds,
|
||||
force_refresh=True,
|
||||
force_mint=True,
|
||||
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
|
||||
)
|
||||
_write_shared_nous_state(refreshed)
|
||||
except AuthError as exc:
|
||||
|
|
@ -4121,6 +4389,11 @@ def resolve_nous_access_token(
|
|||
exc,
|
||||
reason="managed_access_token_refresh_failure",
|
||||
)
|
||||
_quarantine_nous_pool_entries(
|
||||
auth_store,
|
||||
exc,
|
||||
reason="managed_access_token_refresh_failure",
|
||||
)
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
raise
|
||||
|
|
@ -4167,9 +4440,10 @@ def refresh_nous_oauth_pure(
|
|||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
force_refresh: bool = False,
|
||||
force_mint: bool = False,
|
||||
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh Nous OAuth state without mutating auth.json."""
|
||||
auth_mode = _normalize_nous_auth_mode(auth_mode)
|
||||
state: Dict[str, Any] = {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
|
|
@ -4229,38 +4503,17 @@ def refresh_nous_oauth_pure(
|
|||
now.timestamp() + access_ttl, tz=timezone.utc
|
||||
).isoformat()
|
||||
|
||||
if (
|
||||
not legacy_session_keys
|
||||
and _nous_invoke_jwt_is_usable(
|
||||
state.get("access_token"),
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
):
|
||||
_set_nous_agent_key_from_invoke_jwt(state)
|
||||
logger.info("Nous inference auth: using NAS invoke JWT")
|
||||
_oauth_trace(
|
||||
"nous_invoke_jwt_selected",
|
||||
access_token_fp=_token_fingerprint(state.get("access_token")),
|
||||
)
|
||||
elif force_mint or not _agent_key_is_usable(state, min_agent_key_ttl):
|
||||
fallback_reason = (
|
||||
"forced_legacy_session_keys"
|
||||
if legacy_session_keys
|
||||
else _nous_invoke_jwt_unavailable_reason(
|
||||
state.get("access_token"),
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
)
|
||||
logger.info(
|
||||
"Nous inference auth: using legacy session key path (%s)",
|
||||
fallback_reason,
|
||||
)
|
||||
_oauth_trace(
|
||||
"nous_legacy_session_key_selected",
|
||||
reason=fallback_reason,
|
||||
access_token_fp=_token_fingerprint(state.get("access_token")),
|
||||
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
|
||||
state,
|
||||
min_key_ttl_seconds=min_agent_key_ttl,
|
||||
auth_mode=auth_mode,
|
||||
)
|
||||
if selected_auth_path == "invoke_jwt":
|
||||
_select_nous_invoke_jwt(state)
|
||||
elif selected_auth_path == "legacy_session_key_mint":
|
||||
_log_nous_legacy_session_key_selected(
|
||||
fallback_reason or "legacy_session_key_required",
|
||||
access_token=state.get("access_token"),
|
||||
)
|
||||
mint_payload = _mint_agent_key(
|
||||
client=client,
|
||||
|
|
@ -4288,7 +4541,7 @@ def refresh_nous_oauth_from_state(
|
|||
min_key_ttl_seconds: int = DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
timeout_seconds: float = 15.0,
|
||||
force_refresh: bool = False,
|
||||
force_mint: bool = False,
|
||||
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
|
||||
) -> Dict[str, Any]:
|
||||
"""Refresh Nous OAuth from a state dict. Thin wrapper around refresh_nous_oauth_pure."""
|
||||
tls = state.get("tls") or {}
|
||||
|
|
@ -4309,13 +4562,10 @@ def refresh_nous_oauth_from_state(
|
|||
insecure=tls.get("insecure"),
|
||||
ca_bundle=tls.get("ca_bundle"),
|
||||
force_refresh=force_refresh,
|
||||
force_mint=force_mint,
|
||||
auth_mode=auth_mode,
|
||||
)
|
||||
|
||||
|
||||
NOUS_DEVICE_CODE_SOURCE = "device_code"
|
||||
|
||||
|
||||
def persist_nous_credentials(
|
||||
creds: Dict[str, Any],
|
||||
*,
|
||||
|
|
@ -4390,7 +4640,7 @@ def resolve_nous_runtime_credentials(
|
|||
timeout_seconds: float = 15.0,
|
||||
insecure: Optional[bool] = None,
|
||||
ca_bundle: Optional[str] = None,
|
||||
force_mint: bool = False,
|
||||
auth_mode: str = NOUS_INFERENCE_AUTH_AUTO,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve Nous inference credentials for runtime use.
|
||||
|
|
@ -4402,6 +4652,7 @@ def resolve_nous_runtime_credentials(
|
|||
Returns dict with: provider, base_url, api_key, key_id, expires_at,
|
||||
expires_in, source ("invoke_jwt", "cache", or "portal"), and auth_path.
|
||||
"""
|
||||
auth_mode = _normalize_nous_auth_mode(auth_mode)
|
||||
min_key_ttl_seconds = max(60, int(min_key_ttl_seconds))
|
||||
sequence_id = uuid.uuid4().hex[:12]
|
||||
|
||||
|
|
@ -4413,6 +4664,9 @@ def resolve_nous_runtime_credentials(
|
|||
raise AuthError("Hermes is not logged into Nous Portal.",
|
||||
provider="nous", relogin_required=True)
|
||||
|
||||
persisted_state = dict(state)
|
||||
state_persisted = False
|
||||
|
||||
portal_base_url = (
|
||||
_optional_base_url(state.get("portal_base_url"))
|
||||
or os.getenv("HERMES_PORTAL_BASE_URL")
|
||||
|
|
@ -4427,6 +4681,17 @@ def resolve_nous_runtime_credentials(
|
|||
client_id = str(state.get("client_id") or DEFAULT_NOUS_CLIENT_ID)
|
||||
|
||||
def _persist_state(reason: str) -> None:
|
||||
nonlocal persisted_state, state_persisted
|
||||
if (
|
||||
_nous_effective_provider_state(state)
|
||||
== _nous_effective_provider_state(persisted_state)
|
||||
):
|
||||
_oauth_trace(
|
||||
"nous_state_persist_skipped",
|
||||
sequence_id=sequence_id,
|
||||
reason=reason,
|
||||
)
|
||||
return
|
||||
try:
|
||||
_save_provider_state(auth_store, "nous", state)
|
||||
_save_auth_store(auth_store)
|
||||
|
|
@ -4445,6 +4710,8 @@ def resolve_nous_runtime_credentials(
|
|||
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
|
||||
access_token_fp=_token_fingerprint(state.get("access_token")),
|
||||
)
|
||||
persisted_state = dict(state)
|
||||
state_persisted = True
|
||||
# Mirror post-refresh state to the shared store so sibling
|
||||
# profiles don't hold stale refresh_tokens after rotation.
|
||||
# Best-effort — any failure is logged and swallowed inside
|
||||
|
|
@ -4456,7 +4723,7 @@ def resolve_nous_runtime_credentials(
|
|||
_oauth_trace(
|
||||
"nous_runtime_credentials_start",
|
||||
sequence_id=sequence_id,
|
||||
force_mint=bool(force_mint),
|
||||
auth_mode=auth_mode,
|
||||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
refresh_token_fp=_token_fingerprint(state.get("refresh_token")),
|
||||
)
|
||||
|
|
@ -4520,6 +4787,11 @@ def resolve_nous_runtime_credentials(
|
|||
exc,
|
||||
reason="runtime_access_refresh_failure",
|
||||
)
|
||||
_quarantine_nous_pool_entries(
|
||||
auth_store,
|
||||
exc,
|
||||
reason="runtime_access_refresh_failure",
|
||||
)
|
||||
_persist_state("terminal_runtime_access_refresh_failure")
|
||||
raise
|
||||
now = datetime.now(timezone.utc)
|
||||
|
|
@ -4554,50 +4826,28 @@ def resolve_nous_runtime_credentials(
|
|||
# the opaque session key.
|
||||
used_cached_key = False
|
||||
mint_payload: Optional[Dict[str, Any]] = None
|
||||
selected_auth_path = "legacy_session_key"
|
||||
legacy_session_keys = _nous_legacy_session_keys_forced()
|
||||
selected_auth_path, fallback_reason = _choose_nous_inference_auth_path(
|
||||
state,
|
||||
access_token=access_token,
|
||||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
auth_mode=auth_mode,
|
||||
)
|
||||
|
||||
if (
|
||||
not legacy_session_keys
|
||||
and _nous_invoke_jwt_is_usable(
|
||||
access_token,
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
):
|
||||
_set_nous_agent_key_from_invoke_jwt(state)
|
||||
selected_auth_path = "invoke_jwt"
|
||||
logger.info("Nous inference auth: using NAS invoke JWT")
|
||||
_oauth_trace(
|
||||
"nous_invoke_jwt_selected",
|
||||
if selected_auth_path == "invoke_jwt":
|
||||
_select_nous_invoke_jwt(
|
||||
state,
|
||||
access_token=access_token,
|
||||
sequence_id=sequence_id,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
elif not force_mint and _agent_key_is_usable(state, min_key_ttl_seconds):
|
||||
elif selected_auth_path == "legacy_session_key_cache":
|
||||
used_cached_key = True
|
||||
selected_auth_path = "legacy_session_key_cache"
|
||||
logger.info("Nous inference auth: using cached legacy session key")
|
||||
logger.info("Nous inference auth: using cached agent_key")
|
||||
_oauth_trace("agent_key_reuse", sequence_id=sequence_id)
|
||||
else:
|
||||
fallback_reason = (
|
||||
"forced_legacy_session_keys"
|
||||
if legacy_session_keys
|
||||
else _nous_invoke_jwt_unavailable_reason(
|
||||
access_token,
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
)
|
||||
selected_auth_path = "legacy_session_key_mint"
|
||||
logger.info(
|
||||
"Nous inference auth: using legacy session key path (%s)",
|
||||
fallback_reason,
|
||||
)
|
||||
_oauth_trace(
|
||||
"nous_legacy_session_key_selected",
|
||||
_log_nous_legacy_session_key_selected(
|
||||
fallback_reason or "legacy_session_key_required",
|
||||
access_token=access_token,
|
||||
sequence_id=sequence_id,
|
||||
reason=fallback_reason,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
try:
|
||||
_oauth_trace(
|
||||
|
|
@ -4646,6 +4896,11 @@ def resolve_nous_runtime_credentials(
|
|||
exc,
|
||||
reason="runtime_mint_retry_refresh_failure",
|
||||
)
|
||||
_quarantine_nous_pool_entries(
|
||||
auth_store,
|
||||
exc,
|
||||
reason="runtime_mint_retry_refresh_failure",
|
||||
)
|
||||
_persist_state("terminal_runtime_mint_retry_refresh_failure")
|
||||
raise
|
||||
now = datetime.now(timezone.utc)
|
||||
|
|
@ -4674,22 +4929,24 @@ def resolve_nous_runtime_credentials(
|
|||
# Persist retry refresh immediately for crash safety and cross-process visibility.
|
||||
_persist_state("post_refresh_mint_retry")
|
||||
|
||||
if (
|
||||
not legacy_session_keys
|
||||
and _nous_invoke_jwt_is_usable(
|
||||
access_token,
|
||||
scope=state.get("scope"),
|
||||
expires_at=state.get("expires_at"),
|
||||
)
|
||||
):
|
||||
_set_nous_agent_key_from_invoke_jwt(state)
|
||||
retry_auth_mode = (
|
||||
NOUS_INFERENCE_AUTH_LEGACY
|
||||
if auth_mode == NOUS_INFERENCE_AUTH_LEGACY
|
||||
else NOUS_INFERENCE_AUTH_FRESH
|
||||
)
|
||||
retry_auth_path, _ = _choose_nous_inference_auth_path(
|
||||
state,
|
||||
access_token=access_token,
|
||||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
auth_mode=retry_auth_mode,
|
||||
)
|
||||
if retry_auth_path == "invoke_jwt":
|
||||
mint_payload = None
|
||||
selected_auth_path = "invoke_jwt"
|
||||
logger.info("Nous inference auth: using NAS invoke JWT")
|
||||
_oauth_trace(
|
||||
"nous_invoke_jwt_selected",
|
||||
_select_nous_invoke_jwt(
|
||||
state,
|
||||
access_token=access_token,
|
||||
sequence_id=sequence_id,
|
||||
access_token_fp=_token_fingerprint(access_token),
|
||||
)
|
||||
else:
|
||||
mint_payload = _mint_agent_key(
|
||||
|
|
@ -4727,7 +4984,8 @@ def resolve_nous_runtime_credentials(
|
|||
|
||||
_persist_state("resolve_nous_runtime_credentials_final")
|
||||
|
||||
_sync_nous_pool_from_auth_store()
|
||||
if state_persisted:
|
||||
_sync_nous_pool_from_auth_store()
|
||||
|
||||
api_key = state.get("agent_key")
|
||||
if not isinstance(api_key, str) or not api_key:
|
||||
|
|
@ -6433,10 +6691,7 @@ def _nous_device_code_login(
|
|||
or pconfig.inference_base_url
|
||||
).rstrip("/")
|
||||
client_id = client_id or pconfig.client_id
|
||||
explicit_scope = scope is not None
|
||||
scope = scope or pconfig.scope
|
||||
if _nous_legacy_session_keys_forced():
|
||||
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
|
||||
scope, explicit_scope = _nous_device_scope(scope, default_scope=pconfig.scope)
|
||||
timeout = httpx.Timeout(timeout_seconds)
|
||||
verify: bool | str = False if insecure else (ca_bundle if ca_bundle else True)
|
||||
|
||||
|
|
@ -6451,30 +6706,13 @@ def _nous_device_code_login(
|
|||
print(f"TLS verification: custom CA bundle ({ca_bundle})")
|
||||
|
||||
with httpx.Client(timeout=timeout, headers={"Accept": "application/json"}, verify=verify) as client:
|
||||
try:
|
||||
device_data = _request_device_code(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
)
|
||||
except Exception as exc:
|
||||
if (
|
||||
not explicit_scope
|
||||
and _nous_scope_has_invoke(scope)
|
||||
and _is_nous_invoke_scope_refusal(exc)
|
||||
):
|
||||
logger.info("Nous inference auth: NAS refused invoke scope, retrying legacy scope")
|
||||
_oauth_trace("nous_device_code_invoke_scope_refused")
|
||||
scope = NOUS_LEGACY_AGENT_KEY_SCOPE
|
||||
device_data = _request_device_code(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
)
|
||||
else:
|
||||
raise
|
||||
device_data, scope = _request_nous_device_code_with_scope_fallback(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
allow_legacy_fallback=not explicit_scope,
|
||||
)
|
||||
|
||||
verification_url = str(device_data["verification_uri_complete"])
|
||||
user_code = str(device_data["user_code"])
|
||||
|
|
@ -6543,7 +6781,7 @@ def _nous_device_code_login(
|
|||
min_key_ttl_seconds=min_key_ttl_seconds,
|
||||
timeout_seconds=timeout_seconds,
|
||||
force_refresh=False,
|
||||
force_mint=True,
|
||||
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
|
||||
)
|
||||
except AuthError as exc:
|
||||
if exc.code == "subscription_required":
|
||||
|
|
|
|||
|
|
@ -81,6 +81,21 @@ class UpstreamAdapter(ABC):
|
|||
refresh fails. The proxy will return 401 to the client.
|
||||
"""
|
||||
|
||||
def get_retry_credential(
|
||||
self,
|
||||
*,
|
||||
failed_credential: UpstreamCredential,
|
||||
status_code: int,
|
||||
) -> Optional[UpstreamCredential]:
|
||||
"""Return an alternate credential after an upstream auth failure.
|
||||
|
||||
The default is no retry. Providers can override this for one-shot
|
||||
fallback paths, such as switching from a preferred token type to a
|
||||
legacy bearer after the upstream rejects the first request.
|
||||
"""
|
||||
del failed_credential, status_code
|
||||
return None
|
||||
|
||||
def describe(self) -> str:
|
||||
"""One-line status summary for ``proxy status``."""
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -19,13 +19,16 @@ from typing import Any, Dict, FrozenSet, Optional
|
|||
from hermes_cli.auth import (
|
||||
AuthError,
|
||||
DEFAULT_NOUS_INFERENCE_URL,
|
||||
NOUS_INFERENCE_AUTH_AUTO,
|
||||
NOUS_INFERENCE_AUTH_LEGACY,
|
||||
_load_auth_store,
|
||||
_is_terminal_nous_refresh_error,
|
||||
_quarantine_nous_oauth_state,
|
||||
_quarantine_nous_pool_entries,
|
||||
_save_auth_store,
|
||||
_write_shared_nous_state,
|
||||
refresh_nous_oauth_from_state,
|
||||
)
|
||||
)
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -76,6 +79,21 @@ class NousPortalAdapter(UpstreamAdapter):
|
|||
)
|
||||
|
||||
def get_credential(self) -> UpstreamCredential:
|
||||
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_AUTO)
|
||||
|
||||
def get_retry_credential(
|
||||
self,
|
||||
*,
|
||||
failed_credential: UpstreamCredential,
|
||||
status_code: int,
|
||||
) -> Optional[UpstreamCredential]:
|
||||
del failed_credential
|
||||
if status_code != 401:
|
||||
return None
|
||||
logger.info("proxy: Nous upstream rejected bearer; retrying with legacy session key")
|
||||
return self._get_credential(auth_mode=NOUS_INFERENCE_AUTH_LEGACY)
|
||||
|
||||
def _get_credential(self, *, auth_mode: str) -> UpstreamCredential:
|
||||
with self._lock:
|
||||
state = self._read_state()
|
||||
if state is None:
|
||||
|
|
@ -84,7 +102,10 @@ class NousPortalAdapter(UpstreamAdapter):
|
|||
)
|
||||
|
||||
try:
|
||||
refreshed = refresh_nous_oauth_from_state(state)
|
||||
refreshed = refresh_nous_oauth_from_state(
|
||||
state,
|
||||
auth_mode=auth_mode,
|
||||
)
|
||||
except AuthError as exc:
|
||||
if _is_terminal_nous_refresh_error(exc):
|
||||
_quarantine_nous_oauth_state(
|
||||
|
|
@ -92,7 +113,11 @@ class NousPortalAdapter(UpstreamAdapter):
|
|||
exc,
|
||||
reason="proxy_refresh_failure",
|
||||
)
|
||||
self._save_state(state)
|
||||
self._save_state(
|
||||
state,
|
||||
quarantine_error=exc,
|
||||
quarantine_reason="proxy_refresh_failure",
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Failed to refresh Nous Portal credentials: {exc}"
|
||||
) from exc
|
||||
|
|
@ -136,9 +161,21 @@ class NousPortalAdapter(UpstreamAdapter):
|
|||
return None
|
||||
return dict(state) # copy so the refresh helper can mutate freely
|
||||
|
||||
def _save_state(self, state: Dict[str, Any]) -> None:
|
||||
def _save_state(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
*,
|
||||
quarantine_error: Optional[AuthError] = None,
|
||||
quarantine_reason: Optional[str] = None,
|
||||
) -> None:
|
||||
try:
|
||||
store = _load_auth_store()
|
||||
if quarantine_error is not None and quarantine_reason:
|
||||
_quarantine_nous_pool_entries(
|
||||
store,
|
||||
quarantine_error,
|
||||
reason=quarantine_reason,
|
||||
)
|
||||
providers = store.setdefault("providers", {})
|
||||
providers["nous"] = state
|
||||
_save_auth_store(store)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ except ImportError:
|
|||
web = None # type: ignore[assignment]
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter
|
||||
from hermes_cli.proxy.adapters.base import UpstreamAdapter, UpstreamCredential
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -136,50 +136,93 @@ def create_app(adapter: UpstreamAdapter) -> "web.Application":
|
|||
logger.warning("proxy: credential resolution failed: %s", exc)
|
||||
return _json_error(401, str(exc), code="upstream_auth_failed")
|
||||
|
||||
upstream_url = f"{cred.base_url.rstrip('/')}{rel_path}"
|
||||
# Preserve query string verbatim.
|
||||
if request.query_string:
|
||||
upstream_url = f"{upstream_url}?{request.query_string}"
|
||||
|
||||
# Forward body verbatim. Read into memory once — request bodies for
|
||||
# chat/completions/embeddings are small (<1MB typically). If we ever
|
||||
# need to forward large multipart uploads we'll switch to streaming
|
||||
# the request body too.
|
||||
body = await request.read()
|
||||
|
||||
fwd_headers = _filter_request_headers(request.headers)
|
||||
fwd_headers["Authorization"] = f"{cred.token_type} {cred.bearer}"
|
||||
|
||||
logger.debug(
|
||||
"proxy: forwarding %s %s -> %s (body=%d bytes)",
|
||||
request.method, rel_path, upstream_url, len(body),
|
||||
)
|
||||
|
||||
# Use a per-request session so connection state doesn't leak between
|
||||
# clients. Could be optimized to a shared session later.
|
||||
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=300)
|
||||
try:
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
except Exception as exc: # pragma: no cover - aiohttp setup issue
|
||||
return _json_error(500, f"proxy session init failed: {exc}")
|
||||
|
||||
try:
|
||||
upstream_resp = await session.request(
|
||||
request.method,
|
||||
upstream_url,
|
||||
data=body if body else None,
|
||||
headers=fwd_headers,
|
||||
allow_redirects=False,
|
||||
async def _send_upstream(active_cred: UpstreamCredential):
|
||||
upstream_url = f"{active_cred.base_url.rstrip('/')}{rel_path}"
|
||||
# Preserve query string verbatim.
|
||||
if request.query_string:
|
||||
upstream_url = f"{upstream_url}?{request.query_string}"
|
||||
|
||||
fwd_headers = _filter_request_headers(request.headers)
|
||||
fwd_headers["Authorization"] = f"{active_cred.token_type} {active_cred.bearer}"
|
||||
|
||||
logger.debug(
|
||||
"proxy: forwarding %s %s -> %s (body=%d bytes)",
|
||||
request.method, rel_path, upstream_url, len(body),
|
||||
)
|
||||
except aiohttp.ClientError as exc:
|
||||
await session.close()
|
||||
logger.warning("proxy: upstream connection failed: %s", exc)
|
||||
return _json_error(502, f"upstream connection failed: {exc}",
|
||||
code="upstream_unreachable")
|
||||
except asyncio.TimeoutError:
|
||||
await session.close()
|
||||
return _json_error(504, "upstream request timed out",
|
||||
code="upstream_timeout")
|
||||
|
||||
try:
|
||||
session = aiohttp.ClientSession(timeout=timeout)
|
||||
except Exception as exc: # pragma: no cover - aiohttp setup issue
|
||||
raise RuntimeError(f"proxy session init failed: {exc}") from exc
|
||||
|
||||
try:
|
||||
upstream_resp = await session.request(
|
||||
request.method,
|
||||
upstream_url,
|
||||
data=body if body else None,
|
||||
headers=fwd_headers,
|
||||
allow_redirects=False,
|
||||
)
|
||||
except Exception:
|
||||
await session.close()
|
||||
raise
|
||||
return session, upstream_resp
|
||||
|
||||
async def _open_upstream(active_cred: UpstreamCredential):
|
||||
try:
|
||||
return await _send_upstream(active_cred)
|
||||
except RuntimeError as exc:
|
||||
return _json_error(500, str(exc)), None
|
||||
except aiohttp.ClientError as exc:
|
||||
logger.warning("proxy: upstream connection failed: %s", exc)
|
||||
return (
|
||||
_json_error(
|
||||
502,
|
||||
f"upstream connection failed: {exc}",
|
||||
code="upstream_unreachable",
|
||||
),
|
||||
None,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return (
|
||||
_json_error(
|
||||
504,
|
||||
"upstream request timed out",
|
||||
code="upstream_timeout",
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
session_or_response, upstream_resp = await _open_upstream(cred)
|
||||
if upstream_resp is None:
|
||||
return session_or_response
|
||||
session = session_or_response
|
||||
|
||||
if upstream_resp.status == 401:
|
||||
try:
|
||||
retry_cred = adapter.get_retry_credential(
|
||||
failed_credential=cred,
|
||||
status_code=upstream_resp.status,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("proxy: retry credential resolution failed: %s", exc)
|
||||
retry_cred = None
|
||||
|
||||
if retry_cred is not None:
|
||||
upstream_resp.release()
|
||||
await session.close()
|
||||
session_or_response, upstream_resp = await _open_upstream(retry_cred)
|
||||
if upstream_resp is None:
|
||||
return session_or_response
|
||||
session = session_or_response
|
||||
|
||||
# Stream response back. Headers first, then chunked body.
|
||||
resp = web.StreamResponse(
|
||||
|
|
|
|||
|
|
@ -1815,7 +1815,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
|
|||
so the UI can render the verification page link + user code.
|
||||
"""
|
||||
if provider_id == "nous":
|
||||
from hermes_cli.auth import _request_device_code, PROVIDER_REGISTRY
|
||||
from hermes_cli.auth import (
|
||||
_nous_device_scope,
|
||||
_request_nous_device_code_with_scope_fallback,
|
||||
PROVIDER_REGISTRY,
|
||||
)
|
||||
import httpx
|
||||
pconfig = PROVIDER_REGISTRY["nous"]
|
||||
portal_base_url = (
|
||||
|
|
@ -1824,22 +1828,31 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
|
|||
or pconfig.portal_base_url
|
||||
).rstrip("/")
|
||||
client_id = pconfig.client_id
|
||||
scope = pconfig.scope
|
||||
scope, explicit_scope = _nous_device_scope(None, default_scope=pconfig.scope)
|
||||
|
||||
def _do_nous_device_request():
|
||||
with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client:
|
||||
return _request_device_code(
|
||||
with httpx.Client(
|
||||
timeout=httpx.Timeout(15.0),
|
||||
headers={"Accept": "application/json"},
|
||||
) as client:
|
||||
return _request_nous_device_code_with_scope_fallback(
|
||||
client=client,
|
||||
portal_base_url=portal_base_url,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
allow_legacy_fallback=not explicit_scope,
|
||||
)
|
||||
device_data = await asyncio.get_running_loop().run_in_executor(None, _do_nous_device_request)
|
||||
|
||||
device_data, effective_scope = await asyncio.get_running_loop().run_in_executor(
|
||||
None, _do_nous_device_request
|
||||
)
|
||||
sid, sess = _new_oauth_session("nous", "device_code")
|
||||
sess["device_code"] = str(device_data["device_code"])
|
||||
sess["interval"] = int(device_data["interval"])
|
||||
sess["expires_at"] = time.time() + int(device_data["expires_in"])
|
||||
sess["portal_base_url"] = portal_base_url
|
||||
sess["client_id"] = client_id
|
||||
sess["scope"] = effective_scope
|
||||
threading.Thread(
|
||||
target=_nous_poller, args=(sid,), daemon=True, name=f"oauth-poll-{sid[:6]}"
|
||||
).start()
|
||||
|
|
@ -1968,7 +1981,11 @@ async def _start_device_code_flow(provider_id: str) -> Dict[str, Any]:
|
|||
|
||||
def _nous_poller(session_id: str) -> None:
|
||||
"""Background poller that drives a Nous device-code flow to completion."""
|
||||
from hermes_cli.auth import _poll_for_token, refresh_nous_oauth_from_state
|
||||
from hermes_cli.auth import (
|
||||
NOUS_INFERENCE_AUTH_FRESH,
|
||||
_poll_for_token,
|
||||
refresh_nous_oauth_from_state,
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
import httpx
|
||||
with _oauth_sessions_lock:
|
||||
|
|
@ -1979,6 +1996,7 @@ def _nous_poller(session_id: str) -> None:
|
|||
client_id = sess["client_id"]
|
||||
device_code = sess["device_code"]
|
||||
interval = sess["interval"]
|
||||
scope = sess.get("scope")
|
||||
expires_in = max(60, int(sess["expires_at"] - time.time()))
|
||||
try:
|
||||
with httpx.Client(timeout=httpx.Timeout(15.0), headers={"Accept": "application/json"}) as client:
|
||||
|
|
@ -1997,7 +2015,7 @@ def _nous_poller(session_id: str) -> None:
|
|||
"portal_base_url": portal_base_url,
|
||||
"inference_base_url": token_data.get("inference_base_url"),
|
||||
"client_id": client_id,
|
||||
"scope": token_data.get("scope"),
|
||||
"scope": token_data.get("scope") or scope,
|
||||
"token_type": token_data.get("token_type", "Bearer"),
|
||||
"access_token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
|
|
@ -2009,8 +2027,11 @@ def _nous_poller(session_id: str) -> None:
|
|||
"expires_in": token_ttl,
|
||||
}
|
||||
full_state = refresh_nous_oauth_from_state(
|
||||
auth_state, min_key_ttl_seconds=300, timeout_seconds=15.0,
|
||||
force_refresh=False, force_mint=True,
|
||||
auth_state,
|
||||
min_key_ttl_seconds=300,
|
||||
timeout_seconds=15.0,
|
||||
force_refresh=False,
|
||||
auth_mode=NOUS_INFERENCE_AUTH_FRESH,
|
||||
)
|
||||
from hermes_cli.auth import persist_nous_credentials
|
||||
persist_nous_credentials(full_state)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue