hermes-agent/agent/credential_pool.py
2026-03-23 22:37:13 +05:30

456 lines
17 KiB
Python

"""Persistent multi-credential pool for same-provider failover."""
from __future__ import annotations
import time
import uuid
import os
from dataclasses import dataclass, fields
from typing import Any, Dict, List, Optional
from hermes_constants import OPENROUTER_BASE_URL
import hermes_cli.auth as auth_mod
from hermes_cli.auth import (
ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
PROVIDER_REGISTRY,
_agent_key_is_usable,
_codex_access_token_is_expiring,
_decode_jwt_claims,
_is_expiring,
_load_auth_store,
_load_provider_state,
read_credential_pool,
write_credential_pool,
)
EXHAUSTED_TTL_SECONDS = 24 * 60 * 60
@dataclass
class PooledCredential:
provider: str
id: str
label: str
auth_type: str
priority: int
source: str
access_token: str
refresh_token: Optional[str] = None
last_status: Optional[str] = None
last_status_at: Optional[float] = None
last_error_code: Optional[int] = None
base_url: Optional[str] = None
expires_at: Optional[str] = None
expires_at_ms: Optional[int] = None
last_refresh: Optional[str] = None
token_type: Optional[str] = None
scope: Optional[str] = None
client_id: Optional[str] = None
portal_base_url: Optional[str] = None
inference_base_url: Optional[str] = None
obtained_at: Optional[str] = None
expires_in: Optional[int] = None
agent_key: Optional[str] = None
agent_key_id: Optional[str] = None
agent_key_expires_at: Optional[str] = None
agent_key_expires_in: Optional[int] = None
agent_key_reused: Optional[bool] = None
agent_key_obtained_at: Optional[str] = None
tls: Optional[Dict[str, Any]] = None
@classmethod
def from_dict(cls, provider: str, payload: Dict[str, Any]) -> "PooledCredential":
allowed = {f.name for f in fields(cls) if f.name != "provider"}
data = {k: payload.get(k) for k in allowed if k in payload}
data.setdefault("id", uuid.uuid4().hex[:6])
data.setdefault("label", payload.get("source", provider))
data.setdefault("auth_type", "api_key")
data.setdefault("priority", 0)
data.setdefault("source", "manual")
data.setdefault("access_token", "")
return cls(provider=provider, **data)
def to_dict(self) -> Dict[str, Any]:
result: Dict[str, Any] = {}
for field_def in fields(self):
if field_def.name == "provider":
continue
value = getattr(self, field_def.name)
if value is not None:
result[field_def.name] = value
for key in ("last_status", "last_status_at", "last_error_code"):
result.setdefault(key, getattr(self, key))
return result
@property
def runtime_api_key(self) -> str:
if self.provider == "nous":
return str(self.agent_key or self.access_token or "")
return str(self.access_token or "")
@property
def runtime_base_url(self) -> Optional[str]:
if self.provider == "nous":
return self.inference_base_url or self.base_url
return self.base_url
def _label_from_token(token: str, fallback: str) -> str:
claims = _decode_jwt_claims(token)
for key in ("email", "preferred_username", "upn"):
value = claims.get(key)
if isinstance(value, str) and value.strip():
return value.strip()
return fallback
def _next_priority(entries: List[PooledCredential]) -> int:
return max((entry.priority for entry in entries), default=-1) + 1
class CredentialPool:
def __init__(self, provider: str, entries: List[PooledCredential]):
self.provider = provider
self._entries = sorted(entries, key=lambda entry: entry.priority)
self._current_id: Optional[str] = None
def has_credentials(self) -> bool:
return bool(self._entries)
def entries(self) -> List[PooledCredential]:
return list(sorted(self._entries, key=lambda entry: entry.priority))
def current(self) -> Optional[PooledCredential]:
if not self._current_id:
return None
return next((entry for entry in self._entries if entry.id == self._current_id), None)
def _persist(self) -> None:
write_credential_pool(
self.provider,
[entry.to_dict() for entry in sorted(self._entries, key=lambda item: item.priority)],
)
def _mark_exhausted(self, entry: PooledCredential, status_code: Optional[int]) -> None:
entry.last_status = "exhausted"
entry.last_status_at = time.time()
entry.last_error_code = status_code
self._persist()
def _refresh_entry(self, entry: PooledCredential, *, force: bool) -> Optional[PooledCredential]:
if entry.auth_type != "oauth" or not entry.refresh_token:
if force:
self._mark_exhausted(entry, None)
return None
try:
if self.provider == "anthropic":
from agent.anthropic_adapter import refresh_anthropic_oauth_pure
refreshed = refresh_anthropic_oauth_pure(
entry.refresh_token,
use_json=entry.source.endswith("hermes_pkce"),
)
entry.access_token = refreshed["access_token"]
entry.refresh_token = refreshed["refresh_token"]
entry.expires_at_ms = refreshed["expires_at_ms"]
elif self.provider == "openai-codex":
refreshed = auth_mod.refresh_codex_oauth_pure(
entry.access_token,
entry.refresh_token,
)
entry.access_token = refreshed["access_token"]
entry.refresh_token = refreshed["refresh_token"]
entry.last_refresh = refreshed.get("last_refresh")
elif self.provider == "nous":
refreshed = auth_mod.refresh_nous_oauth_pure(
entry.access_token,
entry.refresh_token,
entry.client_id or "hermes-cli",
entry.portal_base_url or "https://portal.nousresearch.com",
entry.inference_base_url or "https://inference-api.nousresearch.com/v1",
token_type=entry.token_type or "Bearer",
scope=entry.scope or "",
obtained_at=entry.obtained_at,
expires_at=entry.expires_at,
agent_key=entry.agent_key,
agent_key_expires_at=entry.agent_key_expires_at,
min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
force_refresh=force,
force_mint=force,
)
for key, value in refreshed.items():
if hasattr(entry, key):
setattr(entry, key, value)
else:
return entry
except Exception:
self._mark_exhausted(entry, None)
return None
entry.last_status = "ok"
entry.last_status_at = None
entry.last_error_code = None
self._persist()
return entry
def _entry_needs_refresh(self, entry: PooledCredential) -> bool:
if entry.auth_type != "oauth":
return False
if self.provider == "anthropic":
if entry.expires_at_ms is None:
return False
return int(entry.expires_at_ms) <= int(time.time() * 1000) + 120_000
if self.provider == "openai-codex":
return _codex_access_token_is_expiring(
entry.access_token,
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
)
if self.provider == "nous":
if _is_expiring(entry.expires_at, ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
return True
return not _agent_key_is_usable(
{
"agent_key": entry.agent_key,
"agent_key_expires_at": entry.agent_key_expires_at,
},
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
)
return False
def select(self) -> Optional[PooledCredential]:
now = time.time()
for entry in sorted(self._entries, key=lambda item: item.priority):
if entry.last_status == "exhausted":
if entry.last_status_at and now - entry.last_status_at < EXHAUSTED_TTL_SECONDS:
continue
entry.last_status = "ok"
entry.last_status_at = None
entry.last_error_code = None
self._persist()
if self._entry_needs_refresh(entry):
refreshed = self._refresh_entry(entry, force=False)
if refreshed is None:
continue
entry = refreshed
self._current_id = entry.id
return entry
self._current_id = None
return None
def mark_exhausted_and_rotate(self, *, status_code: Optional[int]) -> Optional[PooledCredential]:
entry = self.current() or self.select()
if entry is None:
return None
self._mark_exhausted(entry, status_code)
self._current_id = None
return self.select()
def try_refresh_current(self) -> Optional[PooledCredential]:
entry = self.current()
if entry is None:
return None
refreshed = self._refresh_entry(entry, force=True)
if refreshed is not None:
self._current_id = refreshed.id
return refreshed
def reset_statuses(self) -> int:
count = 0
for entry in self._entries:
if entry.last_status or entry.last_status_at or entry.last_error_code:
entry.last_status = None
entry.last_status_at = None
entry.last_error_code = None
count += 1
if count:
self._persist()
return count
def remove_index(self, index: int) -> Optional[PooledCredential]:
ordered = sorted(self._entries, key=lambda item: item.priority)
if index < 1 or index > len(ordered):
return None
removed = ordered.pop(index - 1)
for new_priority, entry in enumerate(ordered):
entry.priority = new_priority
self._entries = ordered
self._persist()
if self._current_id == removed.id:
self._current_id = None
return removed
def add_entry(self, entry: PooledCredential) -> PooledCredential:
entry.priority = _next_priority(self._entries)
self._entries.append(entry)
self._persist()
return entry
def _upsert_entry(entries: List[PooledCredential], provider: str, source: str, payload: Dict[str, Any]) -> bool:
existing = next((entry for entry in entries if entry.source == source), None)
if existing is None:
payload.setdefault("id", uuid.uuid4().hex[:6])
payload.setdefault("priority", _next_priority(entries))
payload.setdefault("label", payload.get("label") or source)
entries.append(PooledCredential.from_dict(provider, payload))
return True
changed = False
for key, value in payload.items():
if key in {"id", "priority"} or value is None:
continue
if key == "label" and existing.label:
continue
if hasattr(existing, key) and getattr(existing, key) != value:
setattr(existing, key, value)
changed = True
return changed
def _seed_from_env(provider: str, entries: List[PooledCredential]) -> bool:
changed = False
if provider == "openrouter":
token = os.getenv("OPENROUTER_API_KEY", "").strip()
if token:
changed |= _upsert_entry(
entries,
provider,
"env:OPENROUTER_API_KEY",
{
"source": "env:OPENROUTER_API_KEY",
"auth_type": "api_key",
"access_token": token,
"base_url": OPENROUTER_BASE_URL,
"label": "OPENROUTER_API_KEY",
},
)
return changed
pconfig = PROVIDER_REGISTRY.get(provider)
if not pconfig or pconfig.auth_type != "api_key":
return changed
env_url = ""
if pconfig.base_url_env_var:
env_url = os.getenv(pconfig.base_url_env_var, "").strip().rstrip("/")
for env_var in pconfig.api_key_env_vars:
token = os.getenv(env_var, "").strip()
if not token:
continue
auth_type = "oauth" if provider == "anthropic" and not token.startswith("sk-ant-api") else "api_key"
base_url = env_url or pconfig.inference_base_url
changed |= _upsert_entry(
entries,
provider,
f"env:{env_var}",
{
"source": f"env:{env_var}",
"auth_type": auth_type,
"access_token": token,
"base_url": base_url,
"label": env_var,
},
)
return changed
def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> bool:
changed = False
auth_store = _load_auth_store()
if provider == "anthropic":
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
hermes_creds = read_hermes_oauth_credentials()
if hermes_creds and hermes_creds.get("accessToken"):
changed |= _upsert_entry(
entries,
provider,
"hermes_pkce",
{
"source": "hermes_pkce",
"auth_type": "oauth",
"access_token": hermes_creds.get("accessToken", ""),
"refresh_token": hermes_creds.get("refreshToken"),
"expires_at_ms": hermes_creds.get("expiresAt"),
"label": _label_from_token(hermes_creds.get("accessToken", ""), "hermes_pkce"),
},
)
claude_creds = read_claude_code_credentials()
if claude_creds and claude_creds.get("accessToken"):
changed |= _upsert_entry(
entries,
provider,
"claude_code",
{
"source": "claude_code",
"auth_type": "oauth",
"access_token": claude_creds.get("accessToken", ""),
"refresh_token": claude_creds.get("refreshToken"),
"expires_at_ms": claude_creds.get("expiresAt"),
"label": _label_from_token(claude_creds.get("accessToken", ""), "claude_code"),
},
)
elif provider == "nous":
state = _load_provider_state(auth_store, "nous")
if state:
changed |= _upsert_entry(
entries,
provider,
"device_code",
{
"source": "device_code",
"auth_type": "oauth",
"access_token": state.get("access_token", ""),
"refresh_token": state.get("refresh_token"),
"expires_at": state.get("expires_at"),
"token_type": state.get("token_type"),
"scope": state.get("scope"),
"client_id": state.get("client_id"),
"portal_base_url": state.get("portal_base_url"),
"inference_base_url": state.get("inference_base_url"),
"agent_key": state.get("agent_key"),
"agent_key_expires_at": state.get("agent_key_expires_at"),
"label": _label_from_token(state.get("access_token", ""), "device_code"),
},
)
elif provider == "openai-codex":
state = _load_provider_state(auth_store, "openai-codex")
tokens = state.get("tokens") if isinstance(state, dict) else None
if isinstance(tokens, dict) and tokens.get("access_token"):
changed |= _upsert_entry(
entries,
provider,
"device_code",
{
"source": "device_code",
"auth_type": "oauth",
"access_token": tokens.get("access_token", ""),
"refresh_token": tokens.get("refresh_token"),
"base_url": "https://chatgpt.com/backend-api/codex",
"last_refresh": state.get("last_refresh"),
"label": _label_from_token(tokens.get("access_token", ""), "device_code"),
},
)
return changed
def load_pool(provider: str) -> CredentialPool:
provider = (provider or "").strip().lower()
raw_entries = read_credential_pool(provider)
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
changed = _seed_from_singletons(provider, entries)
changed |= _seed_from_env(provider, entries)
if changed:
write_credential_pool(
provider,
[entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)],
)
return CredentialPool(provider, entries)