mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
Merge branch 'main' into rewbs/tool-use-charge-to-subscription
This commit is contained in:
commit
a2e56d044b
175 changed files with 18848 additions and 3772 deletions
|
|
@ -162,6 +162,21 @@ def _is_oauth_token(key: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _is_third_party_anthropic_endpoint(base_url: str | None) -> bool:
|
||||
"""Return True for non-Anthropic endpoints using the Anthropic Messages API.
|
||||
|
||||
Third-party proxies (Azure AI Foundry, AWS Bedrock, self-hosted) authenticate
|
||||
with their own API keys via x-api-key, not Anthropic OAuth tokens. OAuth
|
||||
detection should be skipped for these endpoints.
|
||||
"""
|
||||
if not base_url:
|
||||
return False # No base_url = direct Anthropic API
|
||||
normalized = base_url.rstrip("/").lower()
|
||||
if "anthropic.com" in normalized:
|
||||
return False # Direct Anthropic API — OAuth applies
|
||||
return True # Any other endpoint is a third-party proxy
|
||||
|
||||
|
||||
def _requires_bearer_auth(base_url: str | None) -> bool:
|
||||
"""Return True for Anthropic-compatible providers that require Bearer auth.
|
||||
|
||||
|
|
@ -205,6 +220,14 @@ def build_anthropic_client(api_key: str, base_url: str = None):
|
|||
kwargs["auth_token"] = api_key
|
||||
if _COMMON_BETAS:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)}
|
||||
elif _is_third_party_anthropic_endpoint(base_url):
|
||||
# Third-party proxies (Azure AI Foundry, AWS Bedrock, etc.) use their
|
||||
# own API keys with x-api-key auth. Skip OAuth detection — their keys
|
||||
# don't follow Anthropic's sk-ant-* prefix convention and would be
|
||||
# misclassified as OAuth tokens.
|
||||
kwargs["api_key"] = api_key
|
||||
if _COMMON_BETAS:
|
||||
kwargs["default_headers"] = {"anthropic-beta": ",".join(_COMMON_BETAS)}
|
||||
elif _is_oauth_token(api_key):
|
||||
# OAuth access token / setup-token → Bearer auth + Claude Code identity.
|
||||
# Anthropic routes OAuth requests based on user-agent and headers;
|
||||
|
|
@ -284,71 +307,105 @@ def is_claude_code_token_valid(creds: Dict[str, Any]) -> bool:
|
|||
return now_ms < (expires_at - 60_000)
|
||||
|
||||
|
||||
def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
"""Attempt to refresh an expired Claude Code OAuth token.
|
||||
|
||||
Uses the same token endpoint and client_id as Claude Code / OpenCode.
|
||||
Only works for credentials that have a refresh token (from claude /login
|
||||
or claude setup-token with OAuth flow).
|
||||
|
||||
Tries the new platform.claude.com endpoint first (Claude Code >=2.1.81),
|
||||
then falls back to console.anthropic.com for older tokens.
|
||||
|
||||
Returns the new access token, or None if refresh fails.
|
||||
"""
|
||||
def refresh_anthropic_oauth_pure(refresh_token: str, *, use_json: bool = False) -> Dict[str, Any]:
|
||||
"""Refresh an Anthropic OAuth token without mutating local credential files."""
|
||||
import time
|
||||
import urllib.parse
|
||||
import urllib.request
|
||||
|
||||
if not refresh_token:
|
||||
raise ValueError("refresh_token is required")
|
||||
|
||||
client_id = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
if use_json:
|
||||
data = json.dumps({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}).encode()
|
||||
content_type = "application/json"
|
||||
else:
|
||||
data = urllib.parse.urlencode({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": client_id,
|
||||
}).encode()
|
||||
content_type = "application/x-www-form-urlencoded"
|
||||
|
||||
token_endpoints = [
|
||||
"https://platform.claude.com/v1/oauth/token",
|
||||
"https://console.anthropic.com/v1/oauth/token",
|
||||
]
|
||||
last_error = None
|
||||
for endpoint in token_endpoints:
|
||||
req = urllib.request.Request(
|
||||
endpoint,
|
||||
data=data,
|
||||
headers={
|
||||
"Content-Type": content_type,
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
logger.debug("Anthropic token refresh failed at %s: %s", endpoint, exc)
|
||||
continue
|
||||
|
||||
access_token = result.get("access_token", "")
|
||||
if not access_token:
|
||||
raise ValueError("Anthropic refresh response was missing access_token")
|
||||
next_refresh = result.get("refresh_token", refresh_token)
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": next_refresh,
|
||||
"expires_at_ms": int(time.time() * 1000) + (expires_in * 1000),
|
||||
}
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise ValueError("Anthropic token refresh failed")
|
||||
|
||||
|
||||
def _refresh_oauth_token(creds: Dict[str, Any]) -> Optional[str]:
|
||||
"""Attempt to refresh an expired Claude Code OAuth token."""
|
||||
refresh_token = creds.get("refreshToken", "")
|
||||
if not refresh_token:
|
||||
logger.debug("No refresh token available — cannot refresh")
|
||||
return None
|
||||
|
||||
# Client ID used by Claude Code's OAuth flow
|
||||
CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
|
||||
# Anthropic migrated OAuth from console.anthropic.com to platform.claude.com
|
||||
# (Claude Code v2.1.81+). Try new endpoint first, fall back to old.
|
||||
token_endpoints = [
|
||||
"https://platform.claude.com/v1/oauth/token",
|
||||
"https://console.anthropic.com/v1/oauth/token",
|
||||
]
|
||||
|
||||
payload = json.dumps({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token,
|
||||
"client_id": CLIENT_ID,
|
||||
}).encode()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
}
|
||||
|
||||
for endpoint in token_endpoints:
|
||||
req = urllib.request.Request(
|
||||
endpoint, data=payload, headers=headers, method="POST",
|
||||
try:
|
||||
refreshed = refresh_anthropic_oauth_pure(refresh_token, use_json=False)
|
||||
_write_claude_code_credentials(
|
||||
refreshed["access_token"],
|
||||
refreshed["refresh_token"],
|
||||
refreshed["expires_at_ms"],
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
new_access = result.get("access_token", "")
|
||||
new_refresh = result.get("refresh_token", refresh_token)
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if new_access:
|
||||
new_expires_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
_write_claude_code_credentials(new_access, new_refresh, new_expires_ms)
|
||||
logger.debug("Refreshed Claude Code OAuth token via %s", endpoint)
|
||||
return new_access
|
||||
except Exception as e:
|
||||
logger.debug("Token refresh failed at %s: %s", endpoint, e)
|
||||
|
||||
return None
|
||||
logger.debug("Successfully refreshed Claude Code OAuth token")
|
||||
return refreshed["access_token"]
|
||||
except Exception as e:
|
||||
logger.debug("Failed to refresh Claude Code token: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _write_claude_code_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Write refreshed credentials back to ~/.claude/.credentials.json."""
|
||||
def _write_claude_code_credentials(
|
||||
access_token: str,
|
||||
refresh_token: str,
|
||||
expires_at_ms: int,
|
||||
*,
|
||||
scopes: Optional[list] = None,
|
||||
) -> None:
|
||||
"""Write refreshed credentials back to ~/.claude/.credentials.json.
|
||||
|
||||
The optional *scopes* list (e.g. ``["user:inference", "user:profile", ...]``)
|
||||
is persisted so that Claude Code's own auth check recognises the credential
|
||||
as valid. Claude Code >=2.1.81 gates on the presence of ``"user:inference"``
|
||||
in the stored scopes before it will use the token.
|
||||
"""
|
||||
cred_path = Path.home() / ".claude" / ".credentials.json"
|
||||
try:
|
||||
# Read existing file to preserve other fields
|
||||
|
|
@ -356,11 +413,19 @@ def _write_claude_code_credentials(access_token: str, refresh_token: str, expire
|
|||
if cred_path.exists():
|
||||
existing = json.loads(cred_path.read_text(encoding="utf-8"))
|
||||
|
||||
existing["claudeAiOauth"] = {
|
||||
oauth_data: Dict[str, Any] = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
if scopes is not None:
|
||||
oauth_data["scopes"] = scopes
|
||||
elif "claudeAiOauth" in existing and "scopes" in existing["claudeAiOauth"]:
|
||||
# Preserve previously-stored scopes when the refresh response
|
||||
# does not include a scope field.
|
||||
oauth_data["scopes"] = existing["claudeAiOauth"]["scopes"]
|
||||
|
||||
existing["claudeAiOauth"] = oauth_data
|
||||
|
||||
cred_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
cred_path.write_text(json.dumps(existing, indent=2), encoding="utf-8")
|
||||
|
|
@ -520,10 +585,208 @@ def run_oauth_setup_token() -> Optional[str]:
|
|||
return None
|
||||
|
||||
|
||||
# ── Hermes-native PKCE OAuth flow ────────────────────────────────────────
|
||||
# Mirrors the flow used by Claude Code, pi-ai, and OpenCode.
|
||||
# Stores credentials in ~/.hermes/.anthropic_oauth.json (our own file).
|
||||
|
||||
_OAUTH_CLIENT_ID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
|
||||
_OAUTH_TOKEN_URL = "https://console.anthropic.com/v1/oauth/token"
|
||||
_OAUTH_REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback"
|
||||
_OAUTH_SCOPES = "org:create_api_key user:profile user:inference"
|
||||
_HERMES_OAUTH_FILE = get_hermes_home() / ".anthropic_oauth.json"
|
||||
|
||||
|
||||
def _generate_pkce() -> tuple:
|
||||
"""Generate PKCE code_verifier and code_challenge (S256)."""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
verifier = base64.urlsafe_b64encode(secrets.token_bytes(32)).rstrip(b"=").decode()
|
||||
challenge = base64.urlsafe_b64encode(
|
||||
hashlib.sha256(verifier.encode()).digest()
|
||||
).rstrip(b"=").decode()
|
||||
return verifier, challenge
|
||||
|
||||
|
||||
def run_hermes_oauth_login_pure() -> Optional[Dict[str, Any]]:
|
||||
"""Run Hermes-native OAuth PKCE flow and return credential state."""
|
||||
import time
|
||||
import webbrowser
|
||||
|
||||
verifier, challenge = _generate_pkce()
|
||||
|
||||
params = {
|
||||
"code": "true",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"response_type": "code",
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"scope": _OAUTH_SCOPES,
|
||||
"code_challenge": challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"state": verifier,
|
||||
}
|
||||
from urllib.parse import urlencode
|
||||
|
||||
auth_url = f"https://claude.ai/oauth/authorize?{urlencode(params)}"
|
||||
|
||||
print()
|
||||
print("Authorize Hermes with your Claude Pro/Max subscription.")
|
||||
print()
|
||||
print("╭─ Claude Pro/Max Authorization ────────────────────╮")
|
||||
print("│ │")
|
||||
print("│ Open this link in your browser: │")
|
||||
print("╰───────────────────────────────────────────────────╯")
|
||||
print()
|
||||
print(f" {auth_url}")
|
||||
print()
|
||||
|
||||
try:
|
||||
webbrowser.open(auth_url)
|
||||
print(" (Browser opened automatically)")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
print()
|
||||
print("After authorizing, you'll see a code. Paste it below.")
|
||||
print()
|
||||
try:
|
||||
auth_code = input("Authorization code: ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
return None
|
||||
|
||||
if not auth_code:
|
||||
print("No code entered.")
|
||||
return None
|
||||
|
||||
splits = auth_code.split("#")
|
||||
code = splits[0]
|
||||
state = splits[1] if len(splits) > 1 else ""
|
||||
|
||||
try:
|
||||
import urllib.request
|
||||
|
||||
exchange_data = json.dumps({
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": _OAUTH_CLIENT_ID,
|
||||
"code": code,
|
||||
"state": state,
|
||||
"redirect_uri": _OAUTH_REDIRECT_URI,
|
||||
"code_verifier": verifier,
|
||||
}).encode()
|
||||
|
||||
req = urllib.request.Request(
|
||||
_OAUTH_TOKEN_URL,
|
||||
data=exchange_data,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"claude-cli/{_get_claude_code_version()} (external, cli)",
|
||||
},
|
||||
method="POST",
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||
result = json.loads(resp.read().decode())
|
||||
except Exception as e:
|
||||
print(f"Token exchange failed: {e}")
|
||||
return None
|
||||
|
||||
access_token = result.get("access_token", "")
|
||||
refresh_token = result.get("refresh_token", "")
|
||||
expires_in = result.get("expires_in", 3600)
|
||||
|
||||
if not access_token:
|
||||
print("No access token in response.")
|
||||
return None
|
||||
|
||||
expires_at_ms = int(time.time() * 1000) + (expires_in * 1000)
|
||||
return {
|
||||
"access_token": access_token,
|
||||
"refresh_token": refresh_token,
|
||||
"expires_at_ms": expires_at_ms,
|
||||
}
|
||||
|
||||
|
||||
def run_hermes_oauth_login() -> Optional[str]:
|
||||
"""Run Hermes-native OAuth PKCE flow for Claude Pro/Max subscription.
|
||||
|
||||
Opens a browser to claude.ai for authorization, prompts for the code,
|
||||
exchanges it for tokens, and stores them in ~/.hermes/.anthropic_oauth.json.
|
||||
|
||||
Returns the access token on success, None on failure.
|
||||
"""
|
||||
result = run_hermes_oauth_login_pure()
|
||||
if not result:
|
||||
return None
|
||||
|
||||
access_token = result["access_token"]
|
||||
refresh_token = result["refresh_token"]
|
||||
expires_at_ms = result["expires_at_ms"]
|
||||
|
||||
_save_hermes_oauth_credentials(access_token, refresh_token, expires_at_ms)
|
||||
_write_claude_code_credentials(access_token, refresh_token, expires_at_ms)
|
||||
|
||||
print("Authentication successful!")
|
||||
return access_token
|
||||
|
||||
|
||||
def _save_hermes_oauth_credentials(access_token: str, refresh_token: str, expires_at_ms: int) -> None:
|
||||
"""Save OAuth credentials to ~/.hermes/.anthropic_oauth.json."""
|
||||
data = {
|
||||
"accessToken": access_token,
|
||||
"refreshToken": refresh_token,
|
||||
"expiresAt": expires_at_ms,
|
||||
}
|
||||
try:
|
||||
_HERMES_OAUTH_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
_HERMES_OAUTH_FILE.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
||||
_HERMES_OAUTH_FILE.chmod(0o600)
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Failed to save Hermes OAuth credentials: %s", e)
|
||||
|
||||
|
||||
def read_hermes_oauth_credentials() -> Optional[Dict[str, Any]]:
|
||||
"""Read Hermes-managed OAuth credentials from ~/.hermes/.anthropic_oauth.json."""
|
||||
if _HERMES_OAUTH_FILE.exists():
|
||||
try:
|
||||
data = json.loads(_HERMES_OAUTH_FILE.read_text(encoding="utf-8"))
|
||||
if data.get("accessToken"):
|
||||
return data
|
||||
except (json.JSONDecodeError, OSError, IOError) as e:
|
||||
logger.debug("Failed to read Hermes OAuth credentials: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def refresh_hermes_oauth_token() -> Optional[str]:
|
||||
"""Refresh the Hermes-managed OAuth token using the stored refresh token.
|
||||
|
||||
Returns the new access token, or None if refresh fails.
|
||||
"""
|
||||
creds = read_hermes_oauth_credentials()
|
||||
if not creds or not creds.get("refreshToken"):
|
||||
return None
|
||||
|
||||
try:
|
||||
refreshed = refresh_anthropic_oauth_pure(
|
||||
creds["refreshToken"],
|
||||
use_json=True,
|
||||
)
|
||||
_save_hermes_oauth_credentials(
|
||||
refreshed["access_token"],
|
||||
refreshed["refresh_token"],
|
||||
refreshed["expires_at_ms"],
|
||||
)
|
||||
_write_claude_code_credentials(
|
||||
refreshed["access_token"],
|
||||
refreshed["refresh_token"],
|
||||
refreshed["expires_at_ms"],
|
||||
)
|
||||
logger.debug("Successfully refreshed Hermes OAuth token")
|
||||
return refreshed["access_token"]
|
||||
except Exception as e:
|
||||
logger.debug("Failed to refresh Hermes OAuth token: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -1056,4 +1319,4 @@ def normalize_anthropic_response(
|
|||
reasoning_details=None,
|
||||
),
|
||||
finish_reason,
|
||||
)
|
||||
)
|
||||
|
|
@ -7,7 +7,7 @@ the best available backend without duplicating fallback logic.
|
|||
Resolution order for text tasks (auto mode):
|
||||
1. OpenRouter (OPENROUTER_API_KEY)
|
||||
2. Nous Portal (~/.hermes/auth.json active provider)
|
||||
3. Custom endpoint (OPENAI_BASE_URL + OPENAI_API_KEY)
|
||||
3. Custom endpoint (config.yaml model.base_url + OPENAI_API_KEY)
|
||||
4. Codex OAuth (Responses API via chatgpt.com with gpt-5.3-codex,
|
||||
wrapped to look like a chat.completions client)
|
||||
5. Native Anthropic
|
||||
|
|
@ -47,7 +47,8 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
from openai import OpenAI
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL, get_hermes_home
|
||||
from hermes_cli.config import get_hermes_home
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -95,6 +96,45 @@ _CODEX_AUX_MODEL = "gpt-5.2-codex"
|
|||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]:
|
||||
"""Return (pool_exists_for_provider, selected_entry)."""
|
||||
try:
|
||||
pool = load_pool(provider)
|
||||
except Exception as exc:
|
||||
logger.debug("Auxiliary client: could not load pool for %s: %s", provider, exc)
|
||||
return False, None
|
||||
if not pool or not pool.has_credentials():
|
||||
return False, None
|
||||
try:
|
||||
return True, pool.select()
|
||||
except Exception as exc:
|
||||
logger.debug("Auxiliary client: could not select pool entry for %s: %s", provider, exc)
|
||||
return True, None
|
||||
|
||||
|
||||
def _pool_runtime_api_key(entry: Any) -> str:
|
||||
if entry is None:
|
||||
return ""
|
||||
# Use the PooledCredential.runtime_api_key property which handles
|
||||
# provider-specific fallback (e.g. agent_key for nous).
|
||||
key = getattr(entry, "runtime_api_key", None) or getattr(entry, "access_token", "")
|
||||
return str(key or "").strip()
|
||||
|
||||
|
||||
def _pool_runtime_base_url(entry: Any, fallback: str = "") -> str:
|
||||
if entry is None:
|
||||
return str(fallback or "").strip().rstrip("/")
|
||||
# runtime_base_url handles provider-specific logic (e.g. nous prefers inference_base_url).
|
||||
# Fall back through inference_base_url and base_url for non-PooledCredential entries.
|
||||
url = (
|
||||
getattr(entry, "runtime_base_url", None)
|
||||
or getattr(entry, "inference_base_url", None)
|
||||
or getattr(entry, "base_url", None)
|
||||
or fallback
|
||||
)
|
||||
return str(url or "").strip().rstrip("/")
|
||||
|
||||
|
||||
# ── Codex Responses → chat.completions adapter ─────────────────────────────
|
||||
# All auxiliary consumers call client.chat.completions.create(**kwargs) and
|
||||
# read response.choices[0].message.content. This adapter translates those
|
||||
|
|
@ -438,6 +478,22 @@ def _read_nous_auth() -> Optional[dict]:
|
|||
Returns the provider state dict if Nous is active with tokens,
|
||||
otherwise None.
|
||||
"""
|
||||
pool_present, entry = _select_pool_entry("nous")
|
||||
if pool_present:
|
||||
if entry is None:
|
||||
return None
|
||||
return {
|
||||
"access_token": getattr(entry, "access_token", ""),
|
||||
"refresh_token": getattr(entry, "refresh_token", None),
|
||||
"agent_key": getattr(entry, "agent_key", None),
|
||||
"inference_base_url": _pool_runtime_base_url(entry, _NOUS_DEFAULT_BASE_URL),
|
||||
"portal_base_url": getattr(entry, "portal_base_url", None),
|
||||
"client_id": getattr(entry, "client_id", None),
|
||||
"scope": getattr(entry, "scope", None),
|
||||
"token_type": getattr(entry, "token_type", "Bearer"),
|
||||
"source": "pool",
|
||||
}
|
||||
|
||||
try:
|
||||
if not _AUTH_JSON_PATH.is_file():
|
||||
return None
|
||||
|
|
@ -466,6 +522,11 @@ def _nous_base_url() -> str:
|
|||
|
||||
def _read_codex_access_token() -> Optional[str]:
|
||||
"""Read a valid, non-expired Codex OAuth access token from Hermes auth store."""
|
||||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
token = _pool_runtime_api_key(entry)
|
||||
return token or None
|
||||
|
||||
try:
|
||||
from hermes_cli.auth import _read_codex_tokens
|
||||
data = _read_codex_tokens()
|
||||
|
|
@ -512,6 +573,24 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
|||
if provider_id == "anthropic":
|
||||
return _try_anthropic()
|
||||
|
||||
pool_present, entry = _select_pool_entry(provider_id)
|
||||
if pool_present:
|
||||
api_key = _pool_runtime_api_key(entry)
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = _pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
if "api.kimi.com" in base_url.lower():
|
||||
extra["default_headers"] = {"User-Agent": "KimiCLI/1.0"}
|
||||
elif "api.githubcopilot.com" in base_url.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
extra["default_headers"] = copilot_default_headers()
|
||||
return OpenAI(api_key=api_key, base_url=base_url, **extra), model
|
||||
|
||||
creds = resolve_api_key_provider_credentials(provider_id)
|
||||
api_key = str(creds.get("api_key", "")).strip()
|
||||
if not api_key:
|
||||
|
|
@ -561,6 +640,16 @@ def _get_auxiliary_env_override(task: str, suffix: str) -> Optional[str]:
|
|||
|
||||
|
||||
def _try_openrouter() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
pool_present, entry = _select_pool_entry("openrouter")
|
||||
if pool_present:
|
||||
or_key = _pool_runtime_api_key(entry)
|
||||
if not or_key:
|
||||
return None, None
|
||||
base_url = _pool_runtime_base_url(entry, OPENROUTER_BASE_URL) or OPENROUTER_BASE_URL
|
||||
logger.debug("Auxiliary client: OpenRouter via pool")
|
||||
return OpenAI(api_key=or_key, base_url=base_url,
|
||||
default_headers=_OR_HEADERS), _OPENROUTER_MODEL
|
||||
|
||||
or_key = os.getenv("OPENROUTER_API_KEY")
|
||||
if not or_key:
|
||||
return None, None
|
||||
|
|
@ -576,22 +665,22 @@ def _try_nous() -> Tuple[Optional[OpenAI], Optional[str]]:
|
|||
global auxiliary_is_nous
|
||||
auxiliary_is_nous = True
|
||||
logger.debug("Auxiliary client: Nous Portal")
|
||||
model = "gemini-3-flash" if nous.get("source") == "pool" else _NOUS_MODEL
|
||||
return (
|
||||
OpenAI(api_key=_nous_api_key(nous), base_url=_nous_base_url()),
|
||||
_NOUS_MODEL,
|
||||
OpenAI(
|
||||
api_key=_nous_api_key(nous),
|
||||
base_url=str(nous.get("inference_base_url") or _nous_base_url()).rstrip("/"),
|
||||
),
|
||||
model,
|
||||
)
|
||||
|
||||
|
||||
def _read_main_model() -> str:
|
||||
"""Read the user's configured main model from config/env.
|
||||
"""Read the user's configured main model from config.yaml.
|
||||
|
||||
Falls back through HERMES_MODEL → LLM_MODEL → config.yaml model.default
|
||||
so the auxiliary client can use the same model as the main agent when no
|
||||
dedicated auxiliary model is available.
|
||||
config.yaml model.default is the single source of truth for the active
|
||||
model. Environment variables are no longer consulted.
|
||||
"""
|
||||
from_env = os.getenv("OPENAI_MODEL") or os.getenv("HERMES_MODEL") or os.getenv("LLM_MODEL")
|
||||
if from_env:
|
||||
return from_env.strip()
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
|
|
@ -658,11 +747,19 @@ def _try_custom_endpoint() -> Tuple[Optional[OpenAI], Optional[str]]:
|
|||
|
||||
|
||||
def _try_codex() -> Tuple[Optional[Any], Optional[str]]:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
pool_present, entry = _select_pool_entry("openai-codex")
|
||||
if pool_present:
|
||||
codex_token = _pool_runtime_api_key(entry)
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _pool_runtime_base_url(entry, _CODEX_AUX_BASE_URL) or _CODEX_AUX_BASE_URL
|
||||
else:
|
||||
codex_token = _read_codex_access_token()
|
||||
if not codex_token:
|
||||
return None, None
|
||||
base_url = _CODEX_AUX_BASE_URL
|
||||
logger.debug("Auxiliary client: Codex OAuth (%s via Responses API)", _CODEX_AUX_MODEL)
|
||||
real_client = OpenAI(api_key=codex_token, base_url=_CODEX_AUX_BASE_URL)
|
||||
real_client = OpenAI(api_key=codex_token, base_url=base_url)
|
||||
return CodexAuxiliaryClient(real_client, _CODEX_AUX_MODEL), _CODEX_AUX_MODEL
|
||||
|
||||
|
||||
|
|
@ -672,14 +769,21 @@ def _try_anthropic() -> Tuple[Optional[Any], Optional[str]]:
|
|||
except ImportError:
|
||||
return None, None
|
||||
|
||||
token = resolve_anthropic_token()
|
||||
pool_present, entry = _select_pool_entry("anthropic")
|
||||
if pool_present:
|
||||
if entry is None:
|
||||
return None, None
|
||||
token = _pool_runtime_api_key(entry)
|
||||
else:
|
||||
entry = None
|
||||
token = resolve_anthropic_token()
|
||||
if not token:
|
||||
return None, None
|
||||
|
||||
# Allow base URL override from config.yaml model.base_url, but only
|
||||
# when the configured provider is anthropic — otherwise a non-Anthropic
|
||||
# base_url (e.g. Codex endpoint) would leak into Anthropic requests.
|
||||
base_url = _ANTHROPIC_DEFAULT_BASE_URL
|
||||
base_url = _pool_runtime_base_url(entry, _ANTHROPIC_DEFAULT_BASE_URL) if pool_present else _ANTHROPIC_DEFAULT_BASE_URL
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
cfg = load_config()
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ REFERENCE_PATTERN = re.compile(
|
|||
r"(?<![\w/])@(?:(?P<simple>diff|staged)\b|(?P<kind>file|folder|git|url):(?P<value>\S+))"
|
||||
)
|
||||
TRAILING_PUNCTUATION = ",.;!?"
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube")
|
||||
_SENSITIVE_HOME_DIRS = (".ssh", ".aws", ".gnupg", ".kube", ".docker", ".azure", ".config/gh")
|
||||
_SENSITIVE_HERMES_DIRS = (Path("skills") / ".hub",)
|
||||
_SENSITIVE_HOME_FILES = (
|
||||
Path(".ssh") / "authorized_keys",
|
||||
|
|
|
|||
848
agent/credential_pool.py
Normal file
848
agent/credential_pool.py
Normal file
|
|
@ -0,0 +1,848 @@
|
|||
"""Persistent multi-credential pool for same-provider failover."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import os
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
import hermes_cli.auth as auth_mod
|
||||
from hermes_cli.auth import (
|
||||
ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
PROVIDER_REGISTRY,
|
||||
_agent_key_is_usable,
|
||||
_codex_access_token_is_expiring,
|
||||
_decode_jwt_claims,
|
||||
_is_expiring,
|
||||
_load_auth_store,
|
||||
_load_provider_state,
|
||||
read_credential_pool,
|
||||
write_credential_pool,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _load_config_safe() -> Optional[dict]:
|
||||
"""Load config.yaml, returning None on any error."""
|
||||
try:
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
return load_config()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# --- Status and type constants ---
|
||||
|
||||
STATUS_OK = "ok"
|
||||
STATUS_EXHAUSTED = "exhausted"
|
||||
|
||||
AUTH_TYPE_OAUTH = "oauth"
|
||||
AUTH_TYPE_API_KEY = "api_key"
|
||||
|
||||
SOURCE_MANUAL = "manual"
|
||||
|
||||
STRATEGY_FILL_FIRST = "fill_first"
|
||||
STRATEGY_ROUND_ROBIN = "round_robin"
|
||||
STRATEGY_RANDOM = "random"
|
||||
STRATEGY_LEAST_USED = "least_used"
|
||||
SUPPORTED_POOL_STRATEGIES = {
|
||||
STRATEGY_FILL_FIRST,
|
||||
STRATEGY_ROUND_ROBIN,
|
||||
STRATEGY_RANDOM,
|
||||
STRATEGY_LEAST_USED,
|
||||
}
|
||||
|
||||
# Cooldown before retrying an exhausted credential.
|
||||
# 429 (rate-limited) cools down faster since quotas reset frequently.
|
||||
# 402 (billing/quota) and other codes use a longer default.
|
||||
EXHAUSTED_TTL_429_SECONDS = 60 * 60 # 1 hour
|
||||
EXHAUSTED_TTL_DEFAULT_SECONDS = 24 * 60 * 60 # 24 hours
|
||||
|
||||
# Pool key prefix for custom OpenAI-compatible endpoints.
|
||||
# Custom endpoints all share provider='custom' but are keyed by their
|
||||
# custom_providers name: 'custom:<normalized_name>'.
|
||||
CUSTOM_POOL_PREFIX = "custom:"
|
||||
|
||||
|
||||
# Fields that are only round-tripped through JSON — never used for logic as attributes.
|
||||
_EXTRA_KEYS = frozenset({
|
||||
"token_type", "scope", "client_id", "portal_base_url", "obtained_at",
|
||||
"expires_in", "agent_key_id", "agent_key_expires_in", "agent_key_reused",
|
||||
"agent_key_obtained_at", "tls",
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledCredential:
|
||||
provider: str
|
||||
id: str
|
||||
label: str
|
||||
auth_type: str
|
||||
priority: int
|
||||
source: str
|
||||
access_token: str
|
||||
refresh_token: Optional[str] = None
|
||||
last_status: Optional[str] = None
|
||||
last_status_at: Optional[float] = None
|
||||
last_error_code: Optional[int] = None
|
||||
base_url: Optional[str] = None
|
||||
expires_at: Optional[str] = None
|
||||
expires_at_ms: Optional[int] = None
|
||||
last_refresh: Optional[str] = None
|
||||
inference_base_url: Optional[str] = None
|
||||
agent_key: Optional[str] = None
|
||||
agent_key_expires_at: Optional[str] = None
|
||||
request_count: int = 0
|
||||
extra: Dict[str, Any] = None # type: ignore[assignment]
|
||||
|
||||
def __post_init__(self):
|
||||
if self.extra is None:
|
||||
self.extra = {}
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
if name in _EXTRA_KEYS:
|
||||
return self.extra.get(name)
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute {name!r}")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, provider: str, payload: Dict[str, Any]) -> "PooledCredential":
|
||||
field_names = {f.name for f in fields(cls) if f.name != "provider"}
|
||||
data = {k: payload.get(k) for k in field_names if k in payload}
|
||||
extra = {k: payload[k] for k in _EXTRA_KEYS if k in payload and payload[k] is not None}
|
||||
data["extra"] = extra
|
||||
data.setdefault("id", uuid.uuid4().hex[:6])
|
||||
data.setdefault("label", payload.get("source", provider))
|
||||
data.setdefault("auth_type", AUTH_TYPE_API_KEY)
|
||||
data.setdefault("priority", 0)
|
||||
data.setdefault("source", SOURCE_MANUAL)
|
||||
data.setdefault("access_token", "")
|
||||
return cls(provider=provider, **data)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
_ALWAYS_EMIT = {"last_status", "last_status_at", "last_error_code"}
|
||||
result: Dict[str, Any] = {}
|
||||
for field_def in fields(self):
|
||||
if field_def.name in ("provider", "extra"):
|
||||
continue
|
||||
value = getattr(self, field_def.name)
|
||||
if value is not None or field_def.name in _ALWAYS_EMIT:
|
||||
result[field_def.name] = value
|
||||
for k, v in self.extra.items():
|
||||
if v is not None:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
@property
|
||||
def runtime_api_key(self) -> str:
|
||||
if self.provider == "nous":
|
||||
return str(self.agent_key or self.access_token or "")
|
||||
return str(self.access_token or "")
|
||||
|
||||
@property
|
||||
def runtime_base_url(self) -> Optional[str]:
|
||||
if self.provider == "nous":
|
||||
return self.inference_base_url or self.base_url
|
||||
return self.base_url
|
||||
|
||||
|
||||
def label_from_token(token: str, fallback: str) -> str:
|
||||
claims = _decode_jwt_claims(token)
|
||||
for key in ("email", "preferred_username", "upn"):
|
||||
value = claims.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return fallback
|
||||
|
||||
|
||||
def _next_priority(entries: List[PooledCredential]) -> int:
|
||||
return max((entry.priority for entry in entries), default=-1) + 1
|
||||
|
||||
|
||||
def _is_manual_source(source: str) -> bool:
|
||||
normalized = (source or "").strip().lower()
|
||||
return normalized == SOURCE_MANUAL or normalized.startswith(f"{SOURCE_MANUAL}:")
|
||||
|
||||
|
||||
def _exhausted_ttl(error_code: Optional[int]) -> int:
|
||||
"""Return cooldown seconds based on the HTTP status that caused exhaustion."""
|
||||
if error_code == 429:
|
||||
return EXHAUSTED_TTL_429_SECONDS
|
||||
return EXHAUSTED_TTL_DEFAULT_SECONDS
|
||||
|
||||
|
||||
def _normalize_custom_pool_name(name: str) -> str:
|
||||
"""Normalize a custom provider name for use as a pool key suffix."""
|
||||
return name.strip().lower().replace(" ", "-")
|
||||
|
||||
|
||||
def _iter_custom_providers(config: Optional[dict] = None):
|
||||
"""Yield (normalized_name, entry_dict) for each valid custom_providers entry."""
|
||||
if config is None:
|
||||
config = _load_config_safe()
|
||||
if config is None:
|
||||
return
|
||||
custom_providers = config.get("custom_providers")
|
||||
if not isinstance(custom_providers, list):
|
||||
return
|
||||
for entry in custom_providers:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if not isinstance(name, str):
|
||||
continue
|
||||
yield _normalize_custom_pool_name(name), entry
|
||||
|
||||
|
||||
def get_custom_provider_pool_key(base_url: str) -> Optional[str]:
|
||||
"""Look up the custom_providers list in config.yaml and return 'custom:<name>' for a matching base_url.
|
||||
|
||||
Returns None if no match is found.
|
||||
"""
|
||||
if not base_url:
|
||||
return None
|
||||
normalized_url = base_url.strip().rstrip("/")
|
||||
for norm_name, entry in _iter_custom_providers():
|
||||
entry_url = str(entry.get("base_url") or "").strip().rstrip("/")
|
||||
if entry_url and entry_url == normalized_url:
|
||||
return f"{CUSTOM_POOL_PREFIX}{norm_name}"
|
||||
return None
|
||||
|
||||
|
||||
def list_custom_pool_providers() -> List[str]:
|
||||
"""Return all 'custom:*' pool keys that have entries in auth.json."""
|
||||
pool_data = read_credential_pool(None)
|
||||
return sorted(
|
||||
key for key in pool_data
|
||||
if key.startswith(CUSTOM_POOL_PREFIX)
|
||||
and isinstance(pool_data.get(key), list)
|
||||
and pool_data[key]
|
||||
)
|
||||
|
||||
|
||||
def _get_custom_provider_config(pool_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Return the custom_providers config entry matching a pool key like 'custom:together.ai'."""
|
||||
if not pool_key.startswith(CUSTOM_POOL_PREFIX):
|
||||
return None
|
||||
suffix = pool_key[len(CUSTOM_POOL_PREFIX):]
|
||||
for norm_name, entry in _iter_custom_providers():
|
||||
if norm_name == suffix:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
def get_pool_strategy(provider: str) -> str:
|
||||
"""Return the configured selection strategy for a provider."""
|
||||
config = _load_config_safe()
|
||||
if config is None:
|
||||
return STRATEGY_FILL_FIRST
|
||||
|
||||
strategies = config.get("credential_pool_strategies")
|
||||
if not isinstance(strategies, dict):
|
||||
return STRATEGY_FILL_FIRST
|
||||
|
||||
strategy = str(strategies.get(provider, "") or "").strip().lower()
|
||||
if strategy in SUPPORTED_POOL_STRATEGIES:
|
||||
return strategy
|
||||
return STRATEGY_FILL_FIRST
|
||||
|
||||
|
||||
class CredentialPool:
|
||||
def __init__(self, provider: str, entries: List[PooledCredential]):
|
||||
self.provider = provider
|
||||
self._entries = sorted(entries, key=lambda entry: entry.priority)
|
||||
self._current_id: Optional[str] = None
|
||||
self._strategy = get_pool_strategy(provider)
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def has_credentials(self) -> bool:
|
||||
return bool(self._entries)
|
||||
|
||||
def has_available(self) -> bool:
|
||||
"""True if at least one entry is not currently in exhaustion cooldown."""
|
||||
return bool(self._available_entries())
|
||||
|
||||
def entries(self) -> List[PooledCredential]:
|
||||
return list(self._entries)
|
||||
|
||||
def current(self) -> Optional[PooledCredential]:
|
||||
if not self._current_id:
|
||||
return None
|
||||
return next((entry for entry in self._entries if entry.id == self._current_id), None)
|
||||
|
||||
def _replace_entry(self, old: PooledCredential, new: PooledCredential) -> None:
|
||||
"""Swap an entry in-place by id, preserving sort order."""
|
||||
for idx, entry in enumerate(self._entries):
|
||||
if entry.id == old.id:
|
||||
self._entries[idx] = new
|
||||
return
|
||||
|
||||
def _persist(self) -> None:
|
||||
write_credential_pool(
|
||||
self.provider,
|
||||
[entry.to_dict() for entry in self._entries],
|
||||
)
|
||||
|
||||
def _mark_exhausted(self, entry: PooledCredential, status_code: Optional[int]) -> PooledCredential:
|
||||
updated = replace(
|
||||
entry,
|
||||
last_status=STATUS_EXHAUSTED,
|
||||
last_status_at=time.time(),
|
||||
last_error_code=status_code,
|
||||
)
|
||||
self._replace_entry(entry, updated)
|
||||
self._persist()
|
||||
return updated
|
||||
|
||||
def _refresh_entry(self, entry: PooledCredential, *, force: bool) -> Optional[PooledCredential]:
|
||||
if entry.auth_type != AUTH_TYPE_OAUTH or not entry.refresh_token:
|
||||
if force:
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
||||
try:
|
||||
if self.provider == "anthropic":
|
||||
from agent.anthropic_adapter import refresh_anthropic_oauth_pure
|
||||
|
||||
refreshed = refresh_anthropic_oauth_pure(
|
||||
entry.refresh_token,
|
||||
use_json=entry.source.endswith("hermes_pkce"),
|
||||
)
|
||||
updated = replace(
|
||||
entry,
|
||||
access_token=refreshed["access_token"],
|
||||
refresh_token=refreshed["refresh_token"],
|
||||
expires_at_ms=refreshed["expires_at_ms"],
|
||||
)
|
||||
elif self.provider == "openai-codex":
|
||||
refreshed = auth_mod.refresh_codex_oauth_pure(
|
||||
entry.access_token,
|
||||
entry.refresh_token,
|
||||
)
|
||||
updated = replace(
|
||||
entry,
|
||||
access_token=refreshed["access_token"],
|
||||
refresh_token=refreshed["refresh_token"],
|
||||
last_refresh=refreshed.get("last_refresh"),
|
||||
)
|
||||
elif self.provider == "nous":
|
||||
nous_state = {
|
||||
"access_token": entry.access_token,
|
||||
"refresh_token": entry.refresh_token,
|
||||
"client_id": entry.client_id,
|
||||
"portal_base_url": entry.portal_base_url,
|
||||
"inference_base_url": entry.inference_base_url,
|
||||
"token_type": entry.token_type,
|
||||
"scope": entry.scope,
|
||||
"obtained_at": entry.obtained_at,
|
||||
"expires_at": entry.expires_at,
|
||||
"agent_key": entry.agent_key,
|
||||
"agent_key_expires_at": entry.agent_key_expires_at,
|
||||
"tls": entry.tls,
|
||||
}
|
||||
refreshed = auth_mod.refresh_nous_oauth_from_state(
|
||||
nous_state,
|
||||
min_key_ttl_seconds=DEFAULT_AGENT_KEY_MIN_TTL_SECONDS,
|
||||
force_refresh=force,
|
||||
force_mint=force,
|
||||
)
|
||||
# Apply returned fields: dataclass fields via replace, extras via dict update
|
||||
field_updates = {}
|
||||
extra_updates = dict(entry.extra)
|
||||
_field_names = {f.name for f in fields(entry)}
|
||||
for k, v in refreshed.items():
|
||||
if k in _field_names:
|
||||
field_updates[k] = v
|
||||
elif k in _EXTRA_KEYS:
|
||||
extra_updates[k] = v
|
||||
updated = replace(entry, extra=extra_updates, **field_updates)
|
||||
else:
|
||||
return entry
|
||||
except Exception as exc:
|
||||
logger.debug("Credential refresh failed for %s/%s: %s", self.provider, entry.id, exc)
|
||||
self._mark_exhausted(entry, None)
|
||||
return None
|
||||
|
||||
updated = replace(updated, last_status=STATUS_OK, last_status_at=None, last_error_code=None)
|
||||
self._replace_entry(entry, updated)
|
||||
self._persist()
|
||||
return updated
|
||||
|
||||
def _entry_needs_refresh(self, entry: PooledCredential) -> bool:
|
||||
if entry.auth_type != AUTH_TYPE_OAUTH:
|
||||
return False
|
||||
if self.provider == "anthropic":
|
||||
if entry.expires_at_ms is None:
|
||||
return False
|
||||
return int(entry.expires_at_ms) <= int(time.time() * 1000) + 120_000
|
||||
if self.provider == "openai-codex":
|
||||
return _codex_access_token_is_expiring(
|
||||
entry.access_token,
|
||||
CODEX_ACCESS_TOKEN_REFRESH_SKEW_SECONDS,
|
||||
)
|
||||
if self.provider == "nous":
|
||||
# Nous refresh/mint can require network access and should happen when
|
||||
# runtime credentials are actually resolved, not merely when the pool
|
||||
# is enumerated for listing, migration, or selection.
|
||||
return False
|
||||
return False
|
||||
|
||||
def mark_used(self, entry_id: Optional[str] = None) -> None:
|
||||
"""Increment request_count for tracking. Used by least_used strategy."""
|
||||
target_id = entry_id or self._current_id
|
||||
if not target_id:
|
||||
return
|
||||
with self._lock:
|
||||
for idx, entry in enumerate(self._entries):
|
||||
if entry.id == target_id:
|
||||
self._entries[idx] = replace(entry, request_count=entry.request_count + 1)
|
||||
return
|
||||
|
||||
def select(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._select_unlocked()
|
||||
|
||||
def _available_entries(self, *, clear_expired: bool = False, refresh: bool = False) -> List[PooledCredential]:
|
||||
"""Return entries not currently in exhaustion cooldown.
|
||||
|
||||
When *clear_expired* is True, entries whose cooldown has elapsed are
|
||||
reset to STATUS_OK and persisted. When *refresh* is True, entries
|
||||
that need a token refresh are refreshed (skipped on failure).
|
||||
"""
|
||||
now = time.time()
|
||||
cleared_any = False
|
||||
available: List[PooledCredential] = []
|
||||
for entry in self._entries:
|
||||
if entry.last_status == STATUS_EXHAUSTED:
|
||||
ttl = _exhausted_ttl(entry.last_error_code)
|
||||
if entry.last_status_at and now - entry.last_status_at < ttl:
|
||||
continue
|
||||
if clear_expired:
|
||||
cleared = replace(entry, last_status=STATUS_OK, last_status_at=None, last_error_code=None)
|
||||
self._replace_entry(entry, cleared)
|
||||
entry = cleared
|
||||
cleared_any = True
|
||||
if refresh and self._entry_needs_refresh(entry):
|
||||
refreshed = self._refresh_entry(entry, force=False)
|
||||
if refreshed is None:
|
||||
continue
|
||||
entry = refreshed
|
||||
available.append(entry)
|
||||
if cleared_any:
|
||||
self._persist()
|
||||
return available
|
||||
|
||||
def _select_unlocked(self) -> Optional[PooledCredential]:
|
||||
available = self._available_entries(clear_expired=True, refresh=True)
|
||||
if not available:
|
||||
self._current_id = None
|
||||
return None
|
||||
|
||||
if self._strategy == STRATEGY_RANDOM:
|
||||
entry = random.choice(available)
|
||||
self._current_id = entry.id
|
||||
return entry
|
||||
|
||||
if self._strategy == STRATEGY_LEAST_USED and len(available) > 1:
|
||||
entry = min(available, key=lambda e: e.request_count)
|
||||
self._current_id = entry.id
|
||||
return entry
|
||||
|
||||
if self._strategy == STRATEGY_ROUND_ROBIN and len(available) > 1:
|
||||
entry = available[0]
|
||||
rotated = [candidate for candidate in self._entries if candidate.id != entry.id]
|
||||
rotated.append(replace(entry, priority=len(self._entries) - 1))
|
||||
self._entries = [replace(candidate, priority=idx) for idx, candidate in enumerate(rotated)]
|
||||
self._persist()
|
||||
self._current_id = entry.id
|
||||
return self.current() or entry
|
||||
|
||||
entry = available[0]
|
||||
self._current_id = entry.id
|
||||
return entry
|
||||
|
||||
def peek(self) -> Optional[PooledCredential]:
|
||||
current = self.current()
|
||||
if current is not None:
|
||||
return current
|
||||
available = self._available_entries()
|
||||
return available[0] if available else None
|
||||
|
||||
def mark_exhausted_and_rotate(self, *, status_code: Optional[int]) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
entry = self.current() or self._select_unlocked()
|
||||
if entry is None:
|
||||
return None
|
||||
self._mark_exhausted(entry, status_code)
|
||||
self._current_id = None
|
||||
return self._select_unlocked()
|
||||
|
||||
def try_refresh_current(self) -> Optional[PooledCredential]:
|
||||
with self._lock:
|
||||
return self._try_refresh_current_unlocked()
|
||||
|
||||
def _try_refresh_current_unlocked(self) -> Optional[PooledCredential]:
|
||||
entry = self.current()
|
||||
if entry is None:
|
||||
return None
|
||||
refreshed = self._refresh_entry(entry, force=True)
|
||||
if refreshed is not None:
|
||||
self._current_id = refreshed.id
|
||||
return refreshed
|
||||
|
||||
def reset_statuses(self) -> int:
|
||||
count = 0
|
||||
new_entries = []
|
||||
for entry in self._entries:
|
||||
if entry.last_status or entry.last_status_at or entry.last_error_code:
|
||||
new_entries.append(replace(entry, last_status=None, last_status_at=None, last_error_code=None))
|
||||
count += 1
|
||||
else:
|
||||
new_entries.append(entry)
|
||||
if count:
|
||||
self._entries = new_entries
|
||||
self._persist()
|
||||
return count
|
||||
|
||||
def remove_index(self, index: int) -> Optional[PooledCredential]:
|
||||
if index < 1 or index > len(self._entries):
|
||||
return None
|
||||
removed = self._entries.pop(index - 1)
|
||||
self._entries = [
|
||||
replace(entry, priority=new_priority)
|
||||
for new_priority, entry in enumerate(self._entries)
|
||||
]
|
||||
self._persist()
|
||||
if self._current_id == removed.id:
|
||||
self._current_id = None
|
||||
return removed
|
||||
|
||||
def add_entry(self, entry: PooledCredential) -> PooledCredential:
|
||||
entry = replace(entry, priority=_next_priority(self._entries))
|
||||
self._entries.append(entry)
|
||||
self._persist()
|
||||
return entry
|
||||
|
||||
|
||||
def _upsert_entry(entries: List[PooledCredential], provider: str, source: str, payload: Dict[str, Any]) -> bool:
|
||||
existing_idx = None
|
||||
for idx, entry in enumerate(entries):
|
||||
if entry.source == source:
|
||||
existing_idx = idx
|
||||
break
|
||||
|
||||
if existing_idx is None:
|
||||
payload.setdefault("id", uuid.uuid4().hex[:6])
|
||||
payload.setdefault("priority", _next_priority(entries))
|
||||
payload.setdefault("label", payload.get("label") or source)
|
||||
entries.append(PooledCredential.from_dict(provider, payload))
|
||||
return True
|
||||
|
||||
existing = entries[existing_idx]
|
||||
field_updates = {}
|
||||
extra_updates = {}
|
||||
_field_names = {f.name for f in fields(existing)}
|
||||
for key, value in payload.items():
|
||||
if key in {"id", "priority"} or value is None:
|
||||
continue
|
||||
if key == "label" and existing.label:
|
||||
continue
|
||||
if key in _field_names:
|
||||
if getattr(existing, key) != value:
|
||||
field_updates[key] = value
|
||||
elif key in _EXTRA_KEYS:
|
||||
if existing.extra.get(key) != value:
|
||||
extra_updates[key] = value
|
||||
if field_updates or extra_updates:
|
||||
if extra_updates:
|
||||
field_updates["extra"] = {**existing.extra, **extra_updates}
|
||||
entries[existing_idx] = replace(existing, **field_updates)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_pool_priorities(provider: str, entries: List[PooledCredential]) -> bool:
|
||||
if provider != "anthropic":
|
||||
return False
|
||||
|
||||
source_rank = {
|
||||
"env:ANTHROPIC_TOKEN": 0,
|
||||
"env:CLAUDE_CODE_OAUTH_TOKEN": 1,
|
||||
"hermes_pkce": 2,
|
||||
"claude_code": 3,
|
||||
"env:ANTHROPIC_API_KEY": 4,
|
||||
}
|
||||
manual_entries = sorted(
|
||||
(entry for entry in entries if _is_manual_source(entry.source)),
|
||||
key=lambda entry: entry.priority,
|
||||
)
|
||||
seeded_entries = sorted(
|
||||
(entry for entry in entries if not _is_manual_source(entry.source)),
|
||||
key=lambda entry: (
|
||||
source_rank.get(entry.source, len(source_rank)),
|
||||
entry.priority,
|
||||
entry.label,
|
||||
),
|
||||
)
|
||||
|
||||
ordered = [*manual_entries, *seeded_entries]
|
||||
id_to_idx = {entry.id: idx for idx, entry in enumerate(entries)}
|
||||
changed = False
|
||||
for new_priority, entry in enumerate(ordered):
|
||||
if entry.priority != new_priority:
|
||||
entries[id_to_idx[entry.id]] = replace(entry, priority=new_priority)
|
||||
changed = True
|
||||
return changed
|
||||
|
||||
|
||||
def _seed_from_singletons(provider: str, entries: List[PooledCredential]) -> Tuple[bool, Set[str]]:
|
||||
changed = False
|
||||
active_sources: Set[str] = set()
|
||||
auth_store = _load_auth_store()
|
||||
|
||||
if provider == "anthropic":
|
||||
from agent.anthropic_adapter import read_claude_code_credentials, read_hermes_oauth_credentials
|
||||
|
||||
for source_name, creds in (
|
||||
("hermes_pkce", read_hermes_oauth_credentials()),
|
||||
("claude_code", read_claude_code_credentials()),
|
||||
):
|
||||
if creds and creds.get("accessToken"):
|
||||
active_sources.add(source_name)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
source_name,
|
||||
{
|
||||
"source": source_name,
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"access_token": creds.get("accessToken", ""),
|
||||
"refresh_token": creds.get("refreshToken"),
|
||||
"expires_at_ms": creds.get("expiresAt"),
|
||||
"label": label_from_token(creds.get("accessToken", ""), source_name),
|
||||
},
|
||||
)
|
||||
|
||||
elif provider == "nous":
|
||||
state = _load_provider_state(auth_store, "nous")
|
||||
if state:
|
||||
active_sources.add("device_code")
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"device_code",
|
||||
{
|
||||
"source": "device_code",
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"access_token": state.get("access_token", ""),
|
||||
"refresh_token": state.get("refresh_token"),
|
||||
"expires_at": state.get("expires_at"),
|
||||
"token_type": state.get("token_type"),
|
||||
"scope": state.get("scope"),
|
||||
"client_id": state.get("client_id"),
|
||||
"portal_base_url": state.get("portal_base_url"),
|
||||
"inference_base_url": state.get("inference_base_url"),
|
||||
"agent_key": state.get("agent_key"),
|
||||
"agent_key_expires_at": state.get("agent_key_expires_at"),
|
||||
"tls": state.get("tls") if isinstance(state.get("tls"), dict) else None,
|
||||
"label": label_from_token(state.get("access_token", ""), "device_code"),
|
||||
},
|
||||
)
|
||||
|
||||
elif provider == "openai-codex":
|
||||
state = _load_provider_state(auth_store, "openai-codex")
|
||||
tokens = state.get("tokens") if isinstance(state, dict) else None
|
||||
if isinstance(tokens, dict) and tokens.get("access_token"):
|
||||
active_sources.add("device_code")
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
"device_code",
|
||||
{
|
||||
"source": "device_code",
|
||||
"auth_type": AUTH_TYPE_OAUTH,
|
||||
"access_token": tokens.get("access_token", ""),
|
||||
"refresh_token": tokens.get("refresh_token"),
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"last_refresh": state.get("last_refresh"),
|
||||
"label": label_from_token(tokens.get("access_token", ""), "device_code"),
|
||||
},
|
||||
)
|
||||
|
||||
return changed, active_sources
|
||||
|
||||
|
||||
def _seed_from_env(provider: str, entries: List[PooledCredential]) -> Tuple[bool, Set[str]]:
|
||||
changed = False
|
||||
active_sources: Set[str] = set()
|
||||
if provider == "openrouter":
|
||||
token = os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
if token:
|
||||
source = "env:OPENROUTER_API_KEY"
|
||||
active_sources.add(source)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
source,
|
||||
{
|
||||
"source": source,
|
||||
"auth_type": AUTH_TYPE_API_KEY,
|
||||
"access_token": token,
|
||||
"base_url": OPENROUTER_BASE_URL,
|
||||
"label": "OPENROUTER_API_KEY",
|
||||
},
|
||||
)
|
||||
return changed, active_sources
|
||||
|
||||
pconfig = PROVIDER_REGISTRY.get(provider)
|
||||
if not pconfig or pconfig.auth_type != AUTH_TYPE_API_KEY:
|
||||
return changed, active_sources
|
||||
|
||||
env_url = ""
|
||||
if pconfig.base_url_env_var:
|
||||
env_url = os.getenv(pconfig.base_url_env_var, "").strip().rstrip("/")
|
||||
|
||||
env_vars = list(pconfig.api_key_env_vars)
|
||||
if provider == "anthropic":
|
||||
env_vars = [
|
||||
"ANTHROPIC_TOKEN",
|
||||
"CLAUDE_CODE_OAUTH_TOKEN",
|
||||
"ANTHROPIC_API_KEY",
|
||||
]
|
||||
|
||||
for env_var in env_vars:
|
||||
token = os.getenv(env_var, "").strip()
|
||||
if not token:
|
||||
continue
|
||||
source = f"env:{env_var}"
|
||||
active_sources.add(source)
|
||||
auth_type = AUTH_TYPE_OAUTH if provider == "anthropic" and not token.startswith("sk-ant-api") else AUTH_TYPE_API_KEY
|
||||
base_url = env_url or pconfig.inference_base_url
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
provider,
|
||||
source,
|
||||
{
|
||||
"source": source,
|
||||
"auth_type": auth_type,
|
||||
"access_token": token,
|
||||
"base_url": base_url,
|
||||
"label": env_var,
|
||||
},
|
||||
)
|
||||
return changed, active_sources
|
||||
|
||||
|
||||
def _prune_stale_seeded_entries(entries: List[PooledCredential], active_sources: Set[str]) -> bool:
|
||||
retained = [
|
||||
entry
|
||||
for entry in entries
|
||||
if _is_manual_source(entry.source)
|
||||
or entry.source in active_sources
|
||||
or not (
|
||||
entry.source.startswith("env:")
|
||||
or entry.source in {"claude_code", "hermes_pkce"}
|
||||
)
|
||||
]
|
||||
if len(retained) == len(entries):
|
||||
return False
|
||||
entries[:] = retained
|
||||
return True
|
||||
|
||||
|
||||
def _seed_custom_pool(pool_key: str, entries: List[PooledCredential]) -> Tuple[bool, Set[str]]:
|
||||
"""Seed a custom endpoint pool from custom_providers config and model config."""
|
||||
changed = False
|
||||
active_sources: Set[str] = set()
|
||||
|
||||
# Seed from the custom_providers config entry's api_key field
|
||||
cp_config = _get_custom_provider_config(pool_key)
|
||||
if cp_config:
|
||||
api_key = str(cp_config.get("api_key") or "").strip()
|
||||
base_url = str(cp_config.get("base_url") or "").strip().rstrip("/")
|
||||
name = str(cp_config.get("name") or "").strip()
|
||||
if api_key:
|
||||
source = f"config:{name}"
|
||||
active_sources.add(source)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
pool_key,
|
||||
source,
|
||||
{
|
||||
"source": source,
|
||||
"auth_type": AUTH_TYPE_API_KEY,
|
||||
"access_token": api_key,
|
||||
"base_url": base_url,
|
||||
"label": name or source,
|
||||
},
|
||||
)
|
||||
|
||||
# Seed from model.api_key if model.provider=='custom' and model.base_url matches
|
||||
try:
|
||||
config = _load_config_safe()
|
||||
model_cfg = config.get("model") if config else None
|
||||
if isinstance(model_cfg, dict):
|
||||
model_provider = str(model_cfg.get("provider") or "").strip().lower()
|
||||
model_base_url = str(model_cfg.get("base_url") or "").strip().rstrip("/")
|
||||
model_api_key = ""
|
||||
for k in ("api_key", "api"):
|
||||
v = model_cfg.get(k)
|
||||
if isinstance(v, str) and v.strip():
|
||||
model_api_key = v.strip()
|
||||
break
|
||||
if model_provider == "custom" and model_base_url and model_api_key:
|
||||
# Check if this model's base_url matches our custom provider
|
||||
matched_key = get_custom_provider_pool_key(model_base_url)
|
||||
if matched_key == pool_key:
|
||||
source = "model_config"
|
||||
active_sources.add(source)
|
||||
changed |= _upsert_entry(
|
||||
entries,
|
||||
pool_key,
|
||||
source,
|
||||
{
|
||||
"source": source,
|
||||
"auth_type": AUTH_TYPE_API_KEY,
|
||||
"access_token": model_api_key,
|
||||
"base_url": model_base_url,
|
||||
"label": "model_config",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return changed, active_sources
|
||||
|
||||
|
||||
def load_pool(provider: str) -> CredentialPool:
|
||||
provider = (provider or "").strip().lower()
|
||||
raw_entries = read_credential_pool(provider)
|
||||
entries = [PooledCredential.from_dict(provider, payload) for payload in raw_entries]
|
||||
|
||||
if provider.startswith(CUSTOM_POOL_PREFIX):
|
||||
# Custom endpoint pool — seed from custom_providers config and model config
|
||||
custom_changed, custom_sources = _seed_custom_pool(provider, entries)
|
||||
changed = custom_changed
|
||||
changed |= _prune_stale_seeded_entries(entries, custom_sources)
|
||||
else:
|
||||
singleton_changed, singleton_sources = _seed_from_singletons(provider, entries)
|
||||
env_changed, env_sources = _seed_from_env(provider, entries)
|
||||
changed = singleton_changed or env_changed
|
||||
changed |= _prune_stale_seeded_entries(entries, singleton_sources | env_sources)
|
||||
changed |= _normalize_pool_priorities(provider, entries)
|
||||
|
||||
if changed:
|
||||
write_credential_pool(
|
||||
provider,
|
||||
[entry.to_dict() for entry in sorted(entries, key=lambda item: item.priority)],
|
||||
)
|
||||
return CredentialPool(provider, entries)
|
||||
313
agent/display.py
313
agent/display.py
|
|
@ -10,6 +10,9 @@ import os
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from difflib import unified_diff
|
||||
from pathlib import Path
|
||||
|
||||
# ANSI escape codes for coloring tool failure indicators
|
||||
_RED = "\033[31m"
|
||||
|
|
@ -17,6 +20,22 @@ _RESET = "\033[0m"
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ANSI_RESET = "\033[0m"
|
||||
_ANSI_DIM = "\033[38;2;150;150;150m"
|
||||
_ANSI_FILE = "\033[38;2;180;160;255m"
|
||||
_ANSI_HUNK = "\033[38;2;120;120;140m"
|
||||
_ANSI_MINUS = "\033[38;2;255;255;255;48;2;120;20;20m"
|
||||
_ANSI_PLUS = "\033[38;2;255;255;255;48;2;20;90;20m"
|
||||
_MAX_INLINE_DIFF_FILES = 6
|
||||
_MAX_INLINE_DIFF_LINES = 80
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalEditSnapshot:
|
||||
"""Pre-tool filesystem snapshot used to render diffs locally after writes."""
|
||||
paths: list[Path] = field(default_factory=list)
|
||||
before: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
# =========================================================================
|
||||
# Configurable tool preview length (0 = no limit)
|
||||
# Set once at startup by CLI or gateway from display.tool_preview_length config.
|
||||
|
|
@ -218,6 +237,300 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -
|
|||
return preview
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Inline diff previews for write actions
|
||||
# =========================================================================
|
||||
|
||||
def _resolved_path(path: str) -> Path:
|
||||
"""Resolve a possibly-relative filesystem path against the current cwd."""
|
||||
candidate = Path(os.path.expanduser(path))
|
||||
if candidate.is_absolute():
|
||||
return candidate
|
||||
return Path.cwd() / candidate
|
||||
|
||||
|
||||
def _snapshot_text(path: Path) -> str | None:
|
||||
"""Return UTF-8 file content, or None for missing/unreadable files."""
|
||||
try:
|
||||
return path.read_text(encoding="utf-8")
|
||||
except (FileNotFoundError, IsADirectoryError, UnicodeDecodeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _display_diff_path(path: Path) -> str:
|
||||
"""Prefer cwd-relative paths in diffs when available."""
|
||||
try:
|
||||
return str(path.resolve().relative_to(Path.cwd().resolve()))
|
||||
except Exception:
|
||||
return str(path)
|
||||
|
||||
|
||||
def _resolve_skill_manage_paths(args: dict) -> list[Path]:
|
||||
"""Resolve skill_manage write targets to filesystem paths."""
|
||||
action = args.get("action")
|
||||
name = args.get("name")
|
||||
if not action or not name:
|
||||
return []
|
||||
|
||||
from tools.skill_manager_tool import _find_skill, _resolve_skill_dir
|
||||
|
||||
if action == "create":
|
||||
skill_dir = _resolve_skill_dir(name, args.get("category"))
|
||||
return [skill_dir / "SKILL.md"]
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return []
|
||||
|
||||
skill_dir = Path(existing["path"])
|
||||
if action in {"edit", "patch"}:
|
||||
file_path = args.get("file_path")
|
||||
return [skill_dir / file_path] if file_path else [skill_dir / "SKILL.md"]
|
||||
if action in {"write_file", "remove_file"}:
|
||||
file_path = args.get("file_path")
|
||||
return [skill_dir / file_path] if file_path else []
|
||||
if action == "delete":
|
||||
files = [path for path in sorted(skill_dir.rglob("*")) if path.is_file()]
|
||||
return files
|
||||
return []
|
||||
|
||||
|
||||
def _resolve_local_edit_paths(tool_name: str, function_args: dict | None) -> list[Path]:
|
||||
"""Resolve local filesystem targets for write-capable tools."""
|
||||
if not isinstance(function_args, dict):
|
||||
return []
|
||||
|
||||
if tool_name == "write_file":
|
||||
path = function_args.get("path")
|
||||
return [_resolved_path(path)] if path else []
|
||||
|
||||
if tool_name == "patch":
|
||||
path = function_args.get("path")
|
||||
return [_resolved_path(path)] if path else []
|
||||
|
||||
if tool_name == "skill_manage":
|
||||
return _resolve_skill_manage_paths(function_args)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def capture_local_edit_snapshot(tool_name: str, function_args: dict | None) -> LocalEditSnapshot | None:
|
||||
"""Capture before-state for local write previews."""
|
||||
paths = _resolve_local_edit_paths(tool_name, function_args)
|
||||
if not paths:
|
||||
return None
|
||||
|
||||
snapshot = LocalEditSnapshot(paths=paths)
|
||||
for path in paths:
|
||||
snapshot.before[str(path)] = _snapshot_text(path)
|
||||
return snapshot
|
||||
|
||||
|
||||
def _result_succeeded(result: str | None) -> bool:
|
||||
"""Conservatively detect whether a tool result represents success."""
|
||||
if not result:
|
||||
return False
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return False
|
||||
if not isinstance(data, dict):
|
||||
return False
|
||||
if data.get("error"):
|
||||
return False
|
||||
if "success" in data:
|
||||
return bool(data.get("success"))
|
||||
return True
|
||||
|
||||
|
||||
def _diff_from_snapshot(snapshot: LocalEditSnapshot | None) -> str | None:
|
||||
"""Generate unified diff text from a stored before-state and current files."""
|
||||
if not snapshot:
|
||||
return None
|
||||
|
||||
chunks: list[str] = []
|
||||
for path in snapshot.paths:
|
||||
before = snapshot.before.get(str(path))
|
||||
after = _snapshot_text(path)
|
||||
if before == after:
|
||||
continue
|
||||
|
||||
display_path = _display_diff_path(path)
|
||||
diff = "".join(
|
||||
unified_diff(
|
||||
[] if before is None else before.splitlines(keepends=True),
|
||||
[] if after is None else after.splitlines(keepends=True),
|
||||
fromfile=f"a/{display_path}",
|
||||
tofile=f"b/{display_path}",
|
||||
)
|
||||
)
|
||||
if diff:
|
||||
chunks.append(diff)
|
||||
|
||||
if not chunks:
|
||||
return None
|
||||
return "".join(chunk if chunk.endswith("\n") else chunk + "\n" for chunk in chunks)
|
||||
|
||||
|
||||
def extract_edit_diff(
|
||||
tool_name: str,
|
||||
result: str | None,
|
||||
*,
|
||||
function_args: dict | None = None,
|
||||
snapshot: LocalEditSnapshot | None = None,
|
||||
) -> str | None:
|
||||
"""Extract a unified diff from a file-edit tool result."""
|
||||
if tool_name == "patch" and result:
|
||||
try:
|
||||
data = json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
data = None
|
||||
if isinstance(data, dict):
|
||||
diff = data.get("diff")
|
||||
if isinstance(diff, str) and diff.strip():
|
||||
return diff
|
||||
|
||||
if tool_name not in {"write_file", "patch", "skill_manage"}:
|
||||
return None
|
||||
if not _result_succeeded(result):
|
||||
return None
|
||||
return _diff_from_snapshot(snapshot)
|
||||
|
||||
|
||||
def _emit_inline_diff(diff_text: str, print_fn) -> bool:
|
||||
"""Emit rendered diff text through the CLI's prompt_toolkit-safe printer."""
|
||||
if print_fn is None or not diff_text:
|
||||
return False
|
||||
try:
|
||||
print_fn(" ┊ review diff")
|
||||
for line in diff_text.rstrip("\n").splitlines():
|
||||
print_fn(line)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _render_inline_unified_diff(diff: str) -> list[str]:
|
||||
"""Render unified diff lines in Hermes' inline transcript style."""
|
||||
rendered: list[str] = []
|
||||
from_file = None
|
||||
to_file = None
|
||||
|
||||
for raw_line in diff.splitlines():
|
||||
if raw_line.startswith("--- "):
|
||||
from_file = raw_line[4:].strip()
|
||||
continue
|
||||
if raw_line.startswith("+++ "):
|
||||
to_file = raw_line[4:].strip()
|
||||
if from_file or to_file:
|
||||
rendered.append(f"{_ANSI_FILE}{from_file or 'a/?'} → {to_file or 'b/?'}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("@@"):
|
||||
rendered.append(f"{_ANSI_HUNK}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("-"):
|
||||
rendered.append(f"{_ANSI_MINUS}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith("+"):
|
||||
rendered.append(f"{_ANSI_PLUS}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line.startswith(" "):
|
||||
rendered.append(f"{_ANSI_DIM}{raw_line}{_ANSI_RESET}")
|
||||
continue
|
||||
if raw_line:
|
||||
rendered.append(raw_line)
|
||||
|
||||
return rendered
|
||||
|
||||
|
||||
def _split_unified_diff_sections(diff: str) -> list[str]:
|
||||
"""Split a unified diff into per-file sections."""
|
||||
sections: list[list[str]] = []
|
||||
current: list[str] = []
|
||||
|
||||
for line in diff.splitlines():
|
||||
if line.startswith("--- ") and current:
|
||||
sections.append(current)
|
||||
current = [line]
|
||||
continue
|
||||
current.append(line)
|
||||
|
||||
if current:
|
||||
sections.append(current)
|
||||
|
||||
return ["\n".join(section) for section in sections if section]
|
||||
|
||||
|
||||
def _summarize_rendered_diff_sections(
|
||||
diff: str,
|
||||
*,
|
||||
max_files: int = _MAX_INLINE_DIFF_FILES,
|
||||
max_lines: int = _MAX_INLINE_DIFF_LINES,
|
||||
) -> list[str]:
|
||||
"""Render diff sections while capping file count and total line count."""
|
||||
sections = _split_unified_diff_sections(diff)
|
||||
rendered: list[str] = []
|
||||
omitted_files = 0
|
||||
omitted_lines = 0
|
||||
|
||||
for idx, section in enumerate(sections):
|
||||
if idx >= max_files:
|
||||
omitted_files += 1
|
||||
omitted_lines += len(_render_inline_unified_diff(section))
|
||||
continue
|
||||
|
||||
section_lines = _render_inline_unified_diff(section)
|
||||
remaining_budget = max_lines - len(rendered)
|
||||
if remaining_budget <= 0:
|
||||
omitted_lines += len(section_lines)
|
||||
omitted_files += 1
|
||||
continue
|
||||
|
||||
if len(section_lines) <= remaining_budget:
|
||||
rendered.extend(section_lines)
|
||||
continue
|
||||
|
||||
rendered.extend(section_lines[:remaining_budget])
|
||||
omitted_lines += len(section_lines) - remaining_budget
|
||||
omitted_files += 1 + max(0, len(sections) - idx - 1)
|
||||
for leftover in sections[idx + 1:]:
|
||||
omitted_lines += len(_render_inline_unified_diff(leftover))
|
||||
break
|
||||
|
||||
if omitted_files or omitted_lines:
|
||||
summary = f"… omitted {omitted_lines} diff line(s)"
|
||||
if omitted_files:
|
||||
summary += f" across {omitted_files} additional file(s)/section(s)"
|
||||
rendered.append(f"{_ANSI_HUNK}{summary}{_ANSI_RESET}")
|
||||
|
||||
return rendered
|
||||
|
||||
|
||||
def render_edit_diff_with_delta(
|
||||
tool_name: str,
|
||||
result: str | None,
|
||||
*,
|
||||
function_args: dict | None = None,
|
||||
snapshot: LocalEditSnapshot | None = None,
|
||||
print_fn=None,
|
||||
) -> bool:
|
||||
"""Render an edit diff inline without taking over the terminal UI."""
|
||||
diff = extract_edit_diff(
|
||||
tool_name,
|
||||
result,
|
||||
function_args=function_args,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
if not diff:
|
||||
return False
|
||||
try:
|
||||
rendered_lines = _summarize_rendered_diff_sections(diff)
|
||||
except Exception as exc:
|
||||
logger.debug("Could not render inline diff: %s", exc)
|
||||
return False
|
||||
return _emit_inline_diff("\n".join(rendered_lines), print_fn)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# KawaiiSpinner
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -644,6 +644,9 @@ class InsightsEngine:
|
|||
lines.append(f" Sessions: {o['total_sessions']:<12} Messages: {o['total_messages']:,}")
|
||||
lines.append(f" Tool calls: {o['total_tool_calls']:<12,} User messages: {o['user_messages']:,}")
|
||||
lines.append(f" Input tokens: {o['total_input_tokens']:<12,} Output tokens: {o['total_output_tokens']:,}")
|
||||
cache_total = o.get("total_cache_read_tokens", 0) + o.get("total_cache_write_tokens", 0)
|
||||
if cache_total > 0:
|
||||
lines.append(f" Cache read: {o['total_cache_read_tokens']:<12,} Cache write: {o['total_cache_write_tokens']:,}")
|
||||
cost_str = f"${o['estimated_cost']:.2f}"
|
||||
if o.get("models_without_pricing"):
|
||||
cost_str += " *"
|
||||
|
|
@ -746,7 +749,11 @@ class InsightsEngine:
|
|||
|
||||
# Overview
|
||||
lines.append(f"**Sessions:** {o['total_sessions']} | **Messages:** {o['total_messages']:,} | **Tool calls:** {o['total_tool_calls']:,}")
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
|
||||
cache_total = o.get("total_cache_read_tokens", 0) + o.get("total_cache_write_tokens", 0)
|
||||
if cache_total > 0:
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,} / cache: {cache_total:,})")
|
||||
else:
|
||||
lines.append(f"**Tokens:** {o['total_tokens']:,} (in: {o['total_input_tokens']:,} / out: {o['total_output_tokens']:,})")
|
||||
cost_note = ""
|
||||
if o.get("models_without_pricing"):
|
||||
cost_note = " _(excludes custom/self-hosted models)_"
|
||||
|
|
|
|||
|
|
@ -176,6 +176,7 @@ _URL_TO_PROVIDER: Dict[str, str] = {
|
|||
"api.deepseek.com": "deepseek",
|
||||
"api.githubcopilot.com": "copilot",
|
||||
"models.github.ai": "copilot",
|
||||
"api.fireworks.ai": "fireworks",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -43,6 +43,7 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
|||
"opencode-zen": "opencode",
|
||||
"opencode-go": "opencode-go",
|
||||
"kilocode": "kilo",
|
||||
"fireworks": "fireworks-ai",
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -189,6 +189,13 @@ TOOL_USE_ENFORCEMENT_GUIDANCE = (
|
|||
# Add new patterns here when a model family needs explicit steering.
|
||||
TOOL_USE_ENFORCEMENT_MODELS = ("gpt", "codex")
|
||||
|
||||
# Model name substrings that should use the 'developer' role instead of
|
||||
# 'system' for the system prompt. OpenAI's newer models (GPT-5, Codex)
|
||||
# give stronger instruction-following weight to the 'developer' role.
|
||||
# The swap happens at the API boundary in _build_api_kwargs() so internal
|
||||
# message representation stays consistent ("system" everywhere).
|
||||
DEVELOPER_ROLE_MODELS = ("gpt-5", "codex")
|
||||
|
||||
PLATFORM_HINTS = {
|
||||
"whatsapp": (
|
||||
"You are on a text messaging communication platform, WhatsApp. "
|
||||
|
|
|
|||
|
|
@ -13,11 +13,19 @@ import re
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Snapshot at import time so runtime env mutations (e.g. LLM-generated
|
||||
# `export HERMES_REDACT_SECRETS=false`) cannot disable redaction mid-session.
|
||||
_REDACT_ENABLED = os.getenv("HERMES_REDACT_SECRETS", "").lower() not in ("0", "false", "no", "off")
|
||||
|
||||
# Known API key prefixes -- match the prefix + contiguous token chars
|
||||
_PREFIX_PATTERNS = [
|
||||
r"sk-[A-Za-z0-9_-]{10,}", # OpenAI / OpenRouter / Anthropic (sk-ant-*)
|
||||
r"ghp_[A-Za-z0-9]{10,}", # GitHub PAT (classic)
|
||||
r"github_pat_[A-Za-z0-9_]{10,}", # GitHub PAT (fine-grained)
|
||||
r"gho_[A-Za-z0-9]{10,}", # GitHub OAuth access token
|
||||
r"ghu_[A-Za-z0-9]{10,}", # GitHub user-to-server token
|
||||
r"ghs_[A-Za-z0-9]{10,}", # GitHub server-to-server token
|
||||
r"ghr_[A-Za-z0-9]{10,}", # GitHub refresh token
|
||||
r"xox[baprs]-[A-Za-z0-9-]{10,}", # Slack tokens
|
||||
r"AIza[A-Za-z0-9_-]{30,}", # Google API keys
|
||||
r"pplx-[A-Za-z0-9]{10,}", # Perplexity
|
||||
|
|
@ -109,7 +117,7 @@ def redact_sensitive_text(text: str) -> str:
|
|||
text = str(text)
|
||||
if not text:
|
||||
return text
|
||||
if os.getenv("HERMES_REDACT_SECRETS", "").lower() in ("0", "false", "no", "off"):
|
||||
if not _REDACT_ENABLED:
|
||||
return text
|
||||
|
||||
# Known prefixes (sk-, ghp_, etc.)
|
||||
|
|
|
|||
|
|
@ -230,7 +230,13 @@ def get_all_skills_dirs() -> List[Path]:
|
|||
|
||||
def extract_skill_conditions(frontmatter: Dict[str, Any]) -> Dict[str, List]:
|
||||
"""Extract conditional activation fields from parsed frontmatter."""
|
||||
hermes = (frontmatter.get("metadata") or {}).get("hermes") or {}
|
||||
metadata = frontmatter.get("metadata")
|
||||
# Handle cases where metadata is not a dict (e.g., a string from malformed YAML)
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
hermes = metadata.get("hermes") or {}
|
||||
if not isinstance(hermes, dict):
|
||||
hermes = {}
|
||||
return {
|
||||
"fallback_for_toolsets": hermes.get("fallback_for_toolsets", []),
|
||||
"requires_toolsets": hermes.get("requires_toolsets", []),
|
||||
|
|
|
|||
|
|
@ -123,6 +123,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any
|
|||
"api_mode": primary.get("api_mode"),
|
||||
"command": primary.get("command"),
|
||||
"args": list(primary.get("args") or []),
|
||||
"credential_pool": primary.get("credential_pool"),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
|
|
@ -158,6 +159,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any
|
|||
"api_mode": primary.get("api_mode"),
|
||||
"command": primary.get("command"),
|
||||
"args": list(primary.get("args") or []),
|
||||
"credential_pool": primary.get("credential_pool"),
|
||||
},
|
||||
"label": None,
|
||||
"signature": (
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue