mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-20 05:01:30 +00:00
Add pooled same-provider credential fallback
This commit is contained in:
parent
934fbe3c06
commit
b17e5c101d
18 changed files with 2872 additions and 195 deletions
456
agent/credential_pool.py
Normal file
456
agent/credential_pool.py
Normal file
|
|
@ -0,0 +1,456 @@
|
|||
"""Persistent multi-credential pool for same-provider failover."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
import os
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
import hermes_cli.auth as auth_mod
|
||||
from hermes_cli.auth import (
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
PROVIDER_REGISTRY,
|
||||
_agent_key_is_usable,
|
||||
_codex_access_token_is_expiring,
|
||||
_decode_jwt_claims,
|
||||
_is_expiring,
|
||||
_load_auth_store,
|
||||
_load_provider_state,
|
||||
read_credential_pool,
|
||||
write_credential_pool,
|
||||
)
|
||||
|
||||
EXHAUSTED_TTL_SECONDS = 24 * 60 * 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledCredential:
|
||||
provider: str
|
||||
id: str
|
||||
label: str
|
||||
auth_type: str
|
||||
priority: int
|
||||
source: str
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
last_status: Optional[str] = None
|
||||
last_status_at: Optional[float] = None
|
||||
last_error_code: Optional[int] = None
|
||||
base_url: Optional[str] = None
|
||||
expires_at: Optional[str] = None
|
||||
expires_at_ms: Optional[int] = None
|
||||
last_refresh: Optional[str] = None
|
||||
token_type: Optional[str] = None
|
||||
scope: Optional[str] = None
|
||||
client_id: Optional[str] = None
|
||||
portal_base_url: Optional[str] = None
|
||||
inference_base_url: Optional[str] = None
|
||||
obtained_at: Optional[str] = None
|
||||
expires_in: Optional[int] = None
|
||||
agent_key: Optional[str] = None
|
||||
agent_key_id: Optional[str] = None
|
||||
agent_key_expires_at: Optional[str] = None
|
||||
agent_key_expires_in: Optional[int] = None
|
||||
agent_key_reused: Optional[bool] = None
|
||||
agent_key_obtained_at: Optional[str] = None
|
||||
tls: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, provider: str, payload: Dict[str, Any]) -> "PooledCredential":
|
||||
allowed = {f.name for f in fields(cls) if f.name != "provider"}
|
||||
data = {k: payload.get(k) for k in allowed if k in payload}
|
||||
data.setdefault("id", uuid.uuid4().hex[:6])
|
||||
data.setdefault("label", payload.get("source", provider))
|
||||
data.setdefault("auth_type", "api_key")
|
||||
data.setdefault("priority", 0)
|
||||
data.setdefault("source", "manual")
|
||||
data.setdefault("access_token", "")
|
||||
return cls(provider=provider, **data)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
result: Dict[str, Any] = {}
|
||||
for field_def in fields(self):
|
||||
if field_def.name == "provider":
|
||||
continue
|
||||
value = getattr(self, field_def.name)
|
||||
if value is not None:
|
||||
result[field_def.name] = value
|
||||
for key in ("last_status", "last_status_at", "last_error_code"):
|
||||
result.setdefault(key, getattr(self, key))
|
||||
return result
|
||||
|
||||
@property
|
||||
def runtime_api_key(self) -> str:
|
||||
if self.provider == "nous":
|
||||
return str(self.agent_key or self.access_token or "")
|
||||
return str(self.access_token or "")
|
||||
|
||||
@property
|
||||
def runtime_base_url(self) -> Optional[str]:
|
||||
if self.provider == "nous":
|
||||
return self.inference_base_url or self.base_url
|
||||
return self.base_url
|
||||
|
||||
|
||||
def _label_from_token(token: str, fallback: str) -> str:
|
||||
claims = _decode_jwt_claims(token)
|
||||
for key in ("email", "preferred_username", "upn"):
|
||||
value = claims.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return fallback
|
||||
|
||||
|
||||
def _next_priority(entries: List[PooledCredential]) -> int:
|
||||
return max((entry.priority for entry in entries), default=-1) + 1
|
||||
|
||||
|
||||
class CredentialPool:
|
||||
def __init__(self, provider: str, entries: List[PooledCredential]):
|
||||
self.provider = provider
|
||||
self._entries = sorted(entries, key=lambda entry: entry.priority)
|
||||
self._current_id: Optional[str] = None
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
return bool(self._entries)
|
||||
|
||||
def entries(self) -> List[PooledCredential]:
|
||||
return list(sorted(self._entries, key=lambda entry: entry.priority))
|
||||
|
||||
def current(self) -> Optional[PooledCredential]:
|
||||
if not self._current_id:
|
||||
return None
|
||||
return next((entry for entry in self._entries if entry.id == self._current_id), None)
|
||||
|
||||
def _persist(self) -> None:
|
||||
write_credential_pool(
|
||||
self.provider,
|
||||
[entry.to_dict() for entry in sorted(self._entries, key=lambda item: item.priority)],
|
||||
)
|
||||
|
||||
def _mark_exhausted(self, entry: PooledCredential, status_code: Optional[int]) -> None:
|
||||
entry.last_status = "exhausted"
|
||||
entry.last_status_at = time.time()
|
||||
entry.last_error_code = status_code
|
||||
self._persist()
|
||||
|
||||
def _refresh_entry(self, entry: PooledCredential, *, force: bool) -> Optional[PooledCredential]:
|
||||
if entry.auth_type != "oauth" or not entry.refresh_token:
|
||||
if force:
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
||||
try:
|
||||
if self.provider == "anthropic":
|
||||
from agent.anthropic_adapter import refresh_anthropic_oauth_pure
|
||||
|
||||
refreshed = refresh_anthropic_oauth_pure(
|
||||
entry.refresh_token,
|
||||
use_json=entry.source.endswith("hermes_pkce"),
|
||||
)
|
||||
entry.access_token = refreshed["access_token"]
|
||||
entry.refresh_token = refreshed["refresh_token"]
|
||||
entry.expires_at_ms = refreshed["expires_at_ms"]
|
||||
elif self.provider == "openai-codex":
|
||||
refreshed = auth_mod.refresh_codex_oauth_pure(
|
||||
entry.access_token,
|
||||
entry.refresh_token,
|
||||
)
|
||||
entry.access_token = refreshed["access_token"]
|
||||
entry.refresh_token = refreshed["refresh_token"]
|
||||
entry.last_refresh = refreshed.get("last_refresh")
|
||||
elif self.provider == "nous":
|
||||
refreshed = auth_mod.refresh_nous_oauth_pure(
|
||||
entry.access_token,
|
||||
entry.refresh_token,
|
||||
entry.client_id or "hermes-cli",
|
||||
entry.portal_base_url or "https://portal.nousresearch.com",
|
||||
entry.inference_base_url or "https://inference-api.nousresearch.com/v1",
|
||||
token_type=entry.token_type or "Bearer",
|
||||
scope=entry.scope or "",
|
||||
obtained_at=entry.obtained_at,
|
||||
expires_at=entry.expires_at,
|
||||
agent_key=entry.agent_key,
|
||||
agent_key_expires_at=entry.agent_key_expires_at,
|
||||
min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
force_refresh=force,
|
||||
force_mint=force,
|
||||
)
|
||||
for key, value in refreshed.items():
|
||||
if hasattr(entry, key):
|
||||
setattr(entry, key, value)
|
||||
else:
|
||||
return entry
|
||||
except Exception:
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
||||
entry.last_status = "ok"
|
||||
entry.last_status_at = None
|
||||
entry.last_error_code = None
|
||||
self._persist()
|
||||
return entry
|
||||
|
||||
def _entry_needs_refresh(self, entry: PooledCredential) -> bool:
|
||||
if entry.auth_type != "oauth":
|
||||
return False
|
||||
if self.provider == "anthropic":
|
||||
if entry.expires_at_ms is None:
|
||||
return False
|
||||
return int(entry.expires_at_ms) <= int(time.time() * 1000) + 120_000
|
||||
if self.provider == "openai-codex":
|
||||
return _codex_access_token_is_expiring(
|
||||
entry.access_token,
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
)
|
||||
if self.provider == "nous":
|
||||
if _is_expiring(entry.expires_at, ACCESS_TOKEN_REFRESH_SKEW_SECONDS):
|
||||
return True
|
||||
return not _agent_key_is_usable(
|
||||
{
|
||||
"agent_key": entry.agent_key,
|
||||
"agent_key_expires_at": entry.agent_key_expires_at,
|
||||
},
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
)
|
||||
return False
|
||||
|
||||
def select(self) -> Optional[PooledCredential]:
|
||||
now = time.time()
|
||||
for entry in sorted(self._entries, key=lambda item: item.priority):
|
||||
if entry.last_status == "exhausted":
|
||||
if entry.last_status_at and now - entry.last_status_at < EXHAUSTED_TTL_SECONDS:
|
||||
continue
|
||||
entry.last_status = "ok"
|
||||
entry.last_status_at = None
|
||||
entry.last_error_code = None
|
||||
self._persist()
|
||||
if self._entry_needs_refresh(entry):
|
||||
refreshed = self._refresh_entry(entry, force=False)
|
||||
if refreshed is None:
|
||||
continue
|
||||
entry = refreshed
|
||||
self._current_id = entry.id
|
||||
return entry
|
||||
self._current_id = None
|
||||
return None
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code: Optional[int]) -> Optional[PooledCredential]:
|
||||
entry = self.current() or self.select()
|
||||
if entry is None:
|
||||
return None
|
||||
self._mark_exhausted(entry, status_code)
|
||||
self._current_id = None
|
||||
return self.select()
|
||||
|
||||
def try_refresh_current(self) -> Optional[PooledCredential]:
|
||||
entry = self.current()
|
||||
if entry is None:
|
||||
return None
|
||||
refreshed = self._refresh_entry(entry, force=True)
|
||||
if refreshed is not None:
|
||||
self._current_id = refreshed.id
|
||||
return refreshed
|
||||
|
||||
def reset_statuses(self) -> int:
|
||||
count = 0
|
||||
for entry in self._entries:
|
||||
if entry.last_status or entry.last_status_at or entry.last_error_code:
|
||||
entry.last_status = None
|
||||
entry.last_status_at = None
|
||||
entry.last_error_code = None
|
||||
count += 1
|
||||
if count:
|
||||
self._persist()
|
||||
return count
|
||||
|
||||
def remove_index(self, index: int) -> Optional[PooledCredential]:
|
||||
ordered = sorted(self._entries, key=lambda item: item.priority)
|
||||
if index < 1 or index > len(ordered):
|
||||
return None
|
||||
removed = ordered.pop(index - 1)
|
||||
for new_priority, entry in enumerate(ordered):
|
||||
entry.priority = new_priority
|
||||
self._entries = ordered
|
||||
self._persist()
|
||||
if self._current_id == removed.id:
|
||||
self._current_id = None
|
||||
return removed
|
||||
|
||||
def add_entry(self, entry: PooledCredential) -> PooledCredential:
|
||||
entry.priority = _next_priority(self._entries)
|
||||
self._entries.append(entry)
|
||||
self._persist()
|
||||
return entry
|
||||
|
||||
|
||||
def _upsert_entry(entries: List[PooledCredential], provider: str, source: str, payload: Dict[str, Any]) -> bool:
|
||||
existing = next((entry for entry in entries if entry.source == source), None)
|
||||
if existing is None:
|
||||
payload.setdefault("id", uuid.uuid4().hex[:6])
|
||||
payload.setdefault("priority", _next_priority(entries))
|
||||
payload.setdefault("label", payload.get("label") or source)
|
||||
entries.append(PooledCredential.from_dict(provider, payload))
|
||||
return True
|
||||
|
||||
changed = False
|
||||
for key, value in payload.items():
|
||||
if key in {"id", "priority"} or value is None:
|
||||
continue
|
||||
if key == "label" and existing.label:
|
||||
continue
|
||||
if hasattr(existing, key) and getattr(existing, key) != value:
|
||||
setattr(existing, key, value)
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
|
||||
def _seed_from_env(provider: str, entries: List[PooledCredential]) -> bool:
|
||||
changed = False
|
||||
if provider == "openrouter":
|
||||
token = os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
if token:
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"env:OPENROUTER_API_KEY",
|
||||
{
|
||||
"source": "env:OPENROUTER_API_KEY",
|
||||
"auth_type": "api_key",
|
||||
"access_token": token,
|
||||
"base_url": OPENROUTER_BASE_URL,
|
||||
"label": "OPENROUTER_API_KEY",
|
||||
},
|
||||
)
|
||||
return changed
|
||||
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if not pconfig or pconfig.auth_type != "api_key":
|
||||
return changed
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip().rstrip("/")
|
||||
|
||||
for env_var in pconfig.api_key_env_vars:
|
||||
token = os.getenv(env_var, "").strip()
|
||||
if not token:
|
||||
continue
|
||||
auth_type = "oauth" if provider == "anthropic" and not token.startswith("sk-ant-api") else "api_key"
|
||||
base_url = env_url or pconfig.inference_base_url
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
f"env:{env_var}",
|
||||
{
|
||||
"source": f"env:{env_var}",
|
||||
"auth_type": auth_type,
|
||||
"access_token": token,
|
||||
"base_url": base_url,
|
||||
"label": env_var,
|
||||
},
|
||||
)
|
||||
return changed
|
||||
|
||||
|
||||
def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> bool:
|
||||
changed = False
|
||||
auth_store = _load_auth_store()
|
||||
|
||||
if provider == "anthropic":
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
|
||||
|
||||
hermes_creds = read_hermes_oauth_credentials()
|
||||
if hermes_creds and hermes_creds.get("accessToken"):
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"hermes_pkce",
|
||||
{
|
||||
"source": "hermes_pkce",
|
||||
"auth_type": "oauth",
|
||||
"access_token": hermes_creds.get("accessToken", ""),
|
||||
"refresh_token": hermes_creds.get("refreshToken"),
|
||||
"expires_at_ms": hermes_creds.get("expiresAt"),
|
||||
"label": _label_from_token(hermes_creds.get("accessToken", ""), "hermes_pkce"),
|
||||
},
|
||||
)
|
||||
claude_creds = read_claude_code_credentials()
|
||||
if claude_creds and claude_creds.get("accessToken"):
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"claude_code",
|
||||
{
|
||||
"source": "claude_code",
|
||||
"auth_type": "oauth",
|
||||
"access_token": claude_creds.get("accessToken", ""),
|
||||
"refresh_token": claude_creds.get("refreshToken"),
|
||||
"expires_at_ms": claude_creds.get("expiresAt"),
|
||||
"label": _label_from_token(claude_creds.get("accessToken", ""), "claude_code"),
|
||||
},
|
||||
)
|
||||
|
||||
elif provider == "nous":
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
if state:
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"device_code",
|
||||
{
|
||||
"source": "device_code",
|
||||
"auth_type": "oauth",
|
||||
"access_token": state.get("access_token", ""),
|
||||
"refresh_token": state.get("refresh_token"),
|
||||
"expires_at": state.get("expires_at"),
|
||||
"token_type": state.get("token_type"),
|
||||
"scope": state.get("scope"),
|
||||
"client_id": state.get("client_id"),
|
||||
"portal_base_url": state.get("portal_base_url"),
|
||||
"inference_base_url": state.get("inference_base_url"),
|
||||
"agent_key": state.get("agent_key"),
|
||||
"agent_key_expires_at": state.get("agent_key_expires_at"),
|
||||
"label": _label_from_token(state.get("access_token", ""), "device_code"),
|
||||
},
|
||||
)
|
||||
|
||||
elif provider == "openai-codex":
|
||||
state = _load_provider_state(auth_store, "openai-codex")
|
||||
tokens = state.get("tokens") if isinstance(state, dict) else None
|
||||
if isinstance(tokens, dict) and tokens.get("access_token"):
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"device_code",
|
||||
{
|
||||
"source": "device_code",
|
||||
"auth_type": "oauth",
|
||||
"access_token": tokens.get("access_token", ""),
|
||||
"refresh_token": tokens.get("refresh_token"),
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"last_refresh": state.get("last_refresh"),
|
||||
"label": _label_from_token(tokens.get("access_token", ""), "device_code"),
|
||||
},
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
def load_pool(provider: str) -> CredentialPool:
|
||||
provider = (provider or "").strip().lower()
|
||||
raw_entries = read_credential_pool(provider)
|
||||
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
|
||||
changed = _seed_from_singletons(provider, entries)
|
||||
changed |= _seed_from_env(provider, entries)
|
||||
if changed:
|
||||
write_credential_pool(
|
||||
provider,
|
||||
[entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)],
|
||||
)
|
||||
return CredentialPool(provider, entries)
|
||||
Loading…
Add table
Add a link
Reference in a new issue