Honor provider reset windows in pooled credential failover

Persist structured exhaustion metadata from provider errors, use explicit reset timestamps when available, and expose label-based credential targeting in the auth CLI. This keeps long-lived Codex cooldowns from being misreported as one-hour waits and avoids forcing operators to manage entries by list position alone.

Constraint: Existing credential pool JSON needs to remain backward compatible with stored entries that only record status code and timestamp
Constraint: Runtime recovery must keep the existing retry-then-rotate semantics for 429s while enriching pool state with provider metadata
Rejected: Add a separate credential scheduler subsystem | too large for the Hermes pool architecture and unnecessary for this fix
Rejected: Only change CLI formatting | would leave runtime rotation blind to resets_at and preserve the serial-failure behavior
Confidence: high
Scope-risk: moderate
Reversibility: clean
Directive: Preserve structured rate-limit metadata when new providers expose reset hints; do not collapse back to status-code-only exhaustion tracking
Tested: Focused pytest slice for auth commands, credential pool recovery, and routing (272 passed); py_compile on changed Python files; hermes -w auth list/remove smoke test with temporary HERMES_HOME
Not-tested: Full repository pytest suite, broader gateway/integration flows outside the touched auth and pool paths
This commit is contained in:
kshitijk4poor 2026-04-05 12:03:20 +05:30 committed by Teknium
parent ed4a605696
commit 65952ac00c
8 changed files with 446 additions and 42 deletions

View file

@ -8,7 +8,9 @@ import threading
import time
import uuid
import os
import re
from dataclasses import dataclass, fields, replace
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set, Tuple
from hermes_constants import OPENROUTER_BASE_URL
@ -95,6 +97,9 @@ class PooledCredential:
last_status: Optional[str] = None
last_status_at: Optional[float] = None
last_error_code: Optional[int] = None
last_error_reason: Optional[str] = None
last_error_message: Optional[str] = None
last_error_reset_at: Optional[float] = None
base_url: Optional[str] = None
expires_at: Optional[str] = None
expires_at_ms: Optional[int] = None
@ -129,7 +134,14 @@ class PooledCredential:
return cls(provider=provider, **data)
def to_dict(self) -> Dict[str, Any]:
_ALWAYS_EMIT = {"last_status", "last_status_at", "last_error_code"}
_ALWAYS_EMIT = {
"last_status",
"last_status_at",
"last_error_code",
"last_error_reason",
"last_error_message",
"last_error_reset_at",
}
result: Dict[str, Any] = {}
for field_def in fields(self):
if field_def.name in ("provider", "extra"):
@ -180,6 +192,85 @@ def _exhausted_ttl(error_code: Optional[int]) -> int:
return EXHAUSTED_TTL_DEFAULT_SECONDS
def _parse_absolute_timestamp(value: Any) -> Optional[float]:
"""Best-effort parse for provider reset timestamps.
Accepts epoch seconds, epoch milliseconds, and ISO-8601 strings.
Returns seconds since epoch.
"""
if value is None or value == "":
return None
if isinstance(value, (int, float)):
numeric = float(value)
if numeric <= 0:
return None
return numeric / 1000.0 if numeric > 1_000_000_000_000 else numeric
if isinstance(value, str):
raw = value.strip()
if not raw:
return None
try:
numeric = float(raw)
except ValueError:
numeric = None
if numeric is not None:
return numeric / 1000.0 if numeric > 1_000_000_000_000 else numeric
try:
return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp()
except ValueError:
return None
return None
def _extract_retry_delay_seconds(message: str) -> Optional[float]:
if not message:
return None
delay_match = re.search(r"quotaResetDelay[:\s\"]+(\d+(?:\.\d+)?)(ms|s)", message, re.IGNORECASE)
if delay_match:
value = float(delay_match.group(1))
return value / 1000.0 if delay_match.group(2).lower() == "ms" else value
sec_match = re.search(r"retry\s+(?:after\s+)?(\d+(?:\.\d+)?)\s*(?:sec|secs|seconds|s\b)", message, re.IGNORECASE)
if sec_match:
return float(sec_match.group(1))
return None
def _normalize_error_context(error_context: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if not isinstance(error_context, dict):
return {}
normalized: Dict[str, Any] = {}
reason = error_context.get("reason")
if isinstance(reason, str) and reason.strip():
normalized["reason"] = reason.strip()
message = error_context.get("message")
if isinstance(message, str) and message.strip():
normalized["message"] = message.strip()
reset_at = (
error_context.get("reset_at")
or error_context.get("resets_at")
or error_context.get("retry_until")
)
parsed_reset_at = _parse_absolute_timestamp(reset_at)
if parsed_reset_at is None and isinstance(message, str):
retry_delay_seconds = _extract_retry_delay_seconds(message)
if retry_delay_seconds is not None:
parsed_reset_at = time.time() + retry_delay_seconds
if parsed_reset_at is not None:
normalized["reset_at"] = parsed_reset_at
return normalized
def _exhausted_until(entry: PooledCredential) -> Optional[float]:
if entry.last_status != STATUS_EXHAUSTED:
return None
reset_at = _parse_absolute_timestamp(getattr(entry, "last_error_reset_at", None))
if reset_at is not None:
return reset_at
if entry.last_status_at:
return entry.last_status_at + _exhausted_ttl(entry.last_error_code)
return None
def _normalize_custom_pool_name(name: str) -> str:
"""Normalize a custom provider name for use as a pool key suffix."""
return name.strip().lower().replace(" ", "-")
@ -292,12 +383,21 @@ class CredentialPool:
[entry.to_dict() for entry in self._entries],
)
def _mark_exhausted(self, entry: PooledCredential, status_code: Optional[int]) -> PooledCredential:
def _mark_exhausted(
self,
entry: PooledCredential,
status_code: Optional[int],
error_context: Optional[Dict[str, Any]] = None,
) -> PooledCredential:
normalized_error = _normalize_error_context(error_context)
updated = replace(
entry,
last_status=STATUS_EXHAUSTED,
last_status_at=time.time(),
last_error_code=status_code,
last_error_reason=normalized_error.get("reason"),
last_error_message=normalized_error.get("message"),
last_error_reset_at=normalized_error.get("reset_at"),
)
self._replace_entry(entry, updated)
self._persist()
@ -462,7 +562,15 @@ class CredentialPool:
self._mark_exhausted(entry, None)
return None
updated = replace(updated, last_status=STATUS_OK, last_status_at=None, last_error_code=None)
updated = replace(
updated,
last_status=STATUS_OK,
last_status_at=None,
last_error_code=None,
last_error_reason=None,
last_error_message=None,
last_error_reset_at=None,
)
self._replace_entry(entry, updated)
self._persist()
return updated
@ -522,11 +630,19 @@ class CredentialPool:
entry = synced
cleared_any = True
if entry.last_status == STATUS_EXHAUSTED:
ttl = _exhausted_ttl(entry.last_error_code)
if entry.last_status_at and now - entry.last_status_at < ttl:
exhausted_until = _exhausted_until(entry)
if exhausted_until is not None and now < exhausted_until:
continue
if clear_expired:
cleared = replace(entry, last_status=STATUS_OK, last_status_at=None, last_error_code=None)
cleared = replace(
entry,
last_status=STATUS_OK,
last_status_at=None,
last_error_code=None,
last_error_reason=None,
last_error_message=None,
last_error_reset_at=None,
)
self._replace_entry(entry, cleared)
entry = cleared
cleared_any = True
@ -576,12 +692,17 @@ class CredentialPool:
available = self._available_entries()
return available[0] if available else None
def mark_exhausted_and_rotate(self, *, status_code: Optional[int]) -> Optional[PooledCredential]:
def mark_exhausted_and_rotate(
self,
*,
status_code: Optional[int],
error_context: Optional[Dict[str, Any]] = None,
) -> Optional[PooledCredential]:
with self._lock:
entry = self.current() or self._select_unlocked()
if entry is None:
return None
self._mark_exhausted(entry, status_code)
self._mark_exhausted(entry, status_code, error_context)
self._current_id = None
return self._select_unlocked()
@ -603,7 +724,17 @@ class CredentialPool:
new_entries = []
for entry in self._entries:
if entry.last_status or entry.last_status_at or entry.last_error_code:
new_entries.append(replace(entry, last_status=None, last_status_at=None, last_error_code=None))
new_entries.append(
replace(
entry,
last_status=None,
last_status_at=None,
last_error_code=None,
last_error_reason=None,
last_error_message=None,
last_error_reset_at=None,
)
)
count += 1
else:
new_entries.append(entry)
@ -625,6 +756,31 @@ class CredentialPool:
self._current_id = None
return removed
def resolve_target(self, target: Any) -> Tuple[Optional[int], Optional[PooledCredential], Optional[str]]:
raw = str(target or "").strip()
if not raw:
return None, None, "No credential target provided."
if raw.isdigit():
index = int(raw)
if 1 <= index <= len(self._entries):
return index, self._entries[index - 1], None
return None, None, f"No credential #{index}."
for idx, entry in enumerate(self._entries, start=1):
if entry.id == raw:
return idx, entry, None
label_matches = [
(idx, entry)
for idx, entry in enumerate(self._entries, start=1)
if entry.label.strip().lower() == raw.lower()
]
if len(label_matches) == 1:
return label_matches[0][0], label_matches[0][1], None
if len(label_matches) > 1:
return None, None, f'Ambiguous credential label "{raw}". Use the numeric index or entry id instead.'
return None, None, f'No credential matching "{raw}".'
def add_entry(self, entry: PooledCredential) -> PooledCredential:
entry = replace(entry, priority=_next_priority(self._entries))
self._entries.append(entry)

View file

@ -20,12 +20,12 @@ from agent.credential_pool import (
STRATEGY_LEAST_USED,
SUPPORTED_POOL_STRATEGIES,
PooledCredential,
_exhausted_until,
_normalize_custom_pool_name,
get_pool_strategy,
label_from_token,
list_custom_pool_providers,
load_pool,
_exhausted_ttl,
)
import hermes_cli.auth as auth_mod
from hermes_cli.auth import PROVIDER_REGISTRY
@ -113,21 +113,27 @@ def _display_source(source: str) -> str:
def _format_exhausted_status(entry) -> str:
if entry.last_status != STATUS_EXHAUSTED:
return ""
reason = getattr(entry, "last_error_reason", None)
reason_text = f" {reason}" if isinstance(reason, str) and reason.strip() else ""
code = f" ({entry.last_error_code})" if entry.last_error_code else ""
if not entry.last_status_at:
return f" exhausted{code}"
remaining = max(0, int(math.ceil((entry.last_status_at + _exhausted_ttl(entry.last_error_code)) - time.time())))
exhausted_until = _exhausted_until(entry)
if exhausted_until is None:
return f" exhausted{reason_text}{code}"
remaining = max(0, int(math.ceil(exhausted_until - time.time())))
if remaining <= 0:
return f" exhausted{code} (ready to retry)"
return f" exhausted{reason_text}{code} (ready to retry)"
minutes, seconds = divmod(remaining, 60)
hours, minutes = divmod(minutes, 60)
if hours:
days, hours = divmod(hours, 24)
if days:
wait = f"{days}d {hours}h"
elif hours:
wait = f"{hours}h {minutes}m"
elif minutes:
wait = f"{minutes}m {seconds}s"
else:
wait = f"{seconds}s"
return f" exhausted{code} ({wait} left)"
return f" exhausted{reason_text}{code} ({wait} left)"
def auth_add_command(args) -> None:
@ -277,11 +283,16 @@ def auth_list_command(args) -> None:
def auth_remove_command(args) -> None:
provider = _normalize_provider(getattr(args, "provider", ""))
index = int(getattr(args, "index"))
target = getattr(args, "target", None)
if target is None:
target = getattr(args, "index", None)
pool = load_pool(provider)
index, matched, error = pool.resolve_target(target)
if matched is None or index is None:
raise SystemExit(f"{error} Provider: {provider}.")
removed = pool.remove_index(index)
if removed is None:
raise SystemExit(f"No credential #{index} for provider {provider}.")
raise SystemExit(f'No credential matching "{target}" for provider {provider}.')
print(f"Removed {provider} credential #{index} ({removed.label})")
@ -369,8 +380,16 @@ def _interactive_add() -> None:
else:
auth_type = "api_key"
label = None
try:
typed_label = input("Label / account name (optional): ").strip()
except (EOFError, KeyboardInterrupt):
return
if typed_label:
label = typed_label
auth_add_command(SimpleNamespace(
provider=provider, auth_type=auth_type, label=None, api_key=None,
provider=provider, auth_type=auth_type, label=label, api_key=None,
portal_url=None, inference_url=None, client_id=None, scope=None,
no_browser=False, timeout=None, insecure=False, ca_bundle=None,
))
@ -386,22 +405,16 @@ def _interactive_remove() -> None:
# Show entries with indices
for i, e in enumerate(pool.entries(), 1):
exhausted = _format_exhausted_status(e)
print(f" #{i} {e.label:25s} {e.auth_type:10s} {e.source}{exhausted}")
print(f" #{i} {e.label:25s} {e.auth_type:10s} {e.source}{exhausted} [id:{e.id}]")
try:
raw = input("Remove # (or blank to cancel): ").strip()
raw = input("Remove #, id, or label (blank to cancel): ").strip()
except (EOFError, KeyboardInterrupt):
return
if not raw:
return
try:
index = int(raw)
except ValueError:
print("Invalid number.")
return
auth_remove_command(SimpleNamespace(provider=provider, index=index))
auth_remove_command(SimpleNamespace(provider=provider, target=raw))
def _interactive_reset() -> None:

View file

@ -3943,7 +3943,7 @@ Examples:
hermes logout Clear stored authentication
hermes auth add <provider> Add a pooled credential
hermes auth list List pooled credentials
hermes auth remove <p> <n> Remove pooled credential by index
hermes auth remove <p> <t> Remove pooled credential by index, id, or label
hermes auth reset <provider> Clear exhaustion status for a provider
hermes model Select default model
hermes config View configuration
@ -4333,9 +4333,9 @@ For more help on a command:
auth_add.add_argument("--ca-bundle", help="Custom CA bundle for OAuth login")
auth_list = auth_subparsers.add_parser("list", help="List pooled credentials")
auth_list.add_argument("provider", nargs="?", help="Optional provider filter")
auth_remove = auth_subparsers.add_parser("remove", help="Remove a pooled credential by index")
auth_remove = auth_subparsers.add_parser("remove", help="Remove a pooled credential by index, id, or label")
auth_remove.add_argument("provider", help="Provider id")
auth_remove.add_argument("index", type=int, help="1-based credential index")
auth_remove.add_argument("target", help="Credential index, entry id, or exact label")
auth_reset = auth_subparsers.add_parser("reset", help="Clear exhaustion status for all credentials for a provider")
auth_reset.add_argument("provider", help="Provider id")
auth_parser.set_defaults(func=cmd_auth)

View file

@ -2138,6 +2138,73 @@ class AIAgent:
return cleaned
@staticmethod
def _extract_api_error_context(error: Exception) -> Dict[str, Any]:
"""Extract structured rate-limit details from provider errors."""
context: Dict[str, Any] = {}
body = getattr(error, "body", None)
payload = None
if isinstance(body, dict):
payload = body.get("error") if isinstance(body.get("error"), dict) else body
if isinstance(payload, dict):
reason = payload.get("code") or payload.get("error")
if isinstance(reason, str) and reason.strip():
context["reason"] = reason.strip()
message = payload.get("message") or payload.get("error_description")
if isinstance(message, str) and message.strip():
context["message"] = message.strip()
for key in ("resets_at", "reset_at"):
value = payload.get(key)
if value not in (None, ""):
context["reset_at"] = value
break
retry_after = payload.get("retry_after")
if retry_after not in (None, "") and "reset_at" not in context:
try:
context["reset_at"] = time.time() + float(retry_after)
except (TypeError, ValueError):
pass
response = getattr(error, "response", None)
headers = getattr(response, "headers", None)
if headers:
retry_after = headers.get("retry-after") or headers.get("Retry-After")
if retry_after and "reset_at" not in context:
try:
context["reset_at"] = time.time() + float(retry_after)
except (TypeError, ValueError):
pass
ratelimit_reset = headers.get("x-ratelimit-reset")
if ratelimit_reset and "reset_at" not in context:
context["reset_at"] = ratelimit_reset
if "message" not in context:
raw_message = str(error).strip()
if raw_message:
context["message"] = raw_message[:500]
if "reset_at" not in context:
message = context.get("message") or ""
if isinstance(message, str):
import re as _re
delay_match = _re.search(r"quotaResetDelay[:\s\"]+(\d+(?:\.\d+)?)(ms|s)", message, _re.IGNORECASE)
if delay_match:
value = float(delay_match.group(1))
seconds = value / 1000.0 if delay_match.group(2).lower() == "ms" else value
context["reset_at"] = time.time() + seconds
else:
sec_match = _re.search(
r"retry\s+(?:after\s+)?(\d+(?:\.\d+)?)\s*(?:sec|secs|seconds|s\b)",
message,
_re.IGNORECASE,
)
if sec_match:
context["reset_at"] = time.time() + float(sec_match.group(1))
return context
def _dump_api_request_debug(
self,
api_kwargs: Dict[str, Any],
@ -3846,6 +3913,7 @@ class AIAgent:
*,
status_code: Optional[int],
has_retried_429: bool,
error_context: Optional[Dict[str, Any]] = None,
) -> tuple[bool, bool]:
"""Attempt credential recovery via pool rotation.
@ -3860,7 +3928,7 @@ class AIAgent:
return False, has_retried_429
if status_code == 402:
next_entry = pool.mark_exhausted_and_rotate(status_code=402)
next_entry = pool.mark_exhausted_and_rotate(status_code=402, error_context=error_context)
if next_entry is not None:
logger.info(f"Credential 402 (billing) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
self._swap_credential(next_entry)
@ -3870,7 +3938,7 @@ class AIAgent:
if status_code == 429:
if not has_retried_429:
return False, True
next_entry = pool.mark_exhausted_and_rotate(status_code=429)
next_entry = pool.mark_exhausted_and_rotate(status_code=429, error_context=error_context)
if next_entry is not None:
logger.info(f"Credential 429 (rate limit) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
self._swap_credential(next_entry)
@ -3885,7 +3953,7 @@ class AIAgent:
return True, has_retried_429
# Refresh failed — rotate to next credential instead of giving up.
# The failed entry is already marked exhausted by try_refresh_current().
next_entry = pool.mark_exhausted_and_rotate(status_code=401)
next_entry = pool.mark_exhausted_and_rotate(status_code=401, error_context=error_context)
if next_entry is not None:
logger.info(f"Credential 401 (refresh failed) — rotated to pool entry {getattr(next_entry, 'id', '?')}")
self._swap_credential(next_entry)
@ -7377,9 +7445,11 @@ class AIAgent:
# prompt or prefill. Fall through to normal error path.
status_code = getattr(api_error, "status_code", None)
error_context = self._extract_api_error_context(api_error)
recovered_with_pool, has_retried_429 = self._recover_with_credential_pool(
status_code=status_code,
has_retried_429=has_retried_429,
error_context=error_context,
)
if recovered_with_pool:
continue

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import base64
import json
from datetime import datetime, timezone
import pytest
@ -224,7 +225,7 @@ def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch):
class _Args:
provider = "anthropic"
index = 1
target = "1"
auth_remove_command(_Args())
@ -235,6 +236,49 @@ def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch):
assert entries[0]["priority"] == 0
def test_auth_remove_accepts_label_target(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openai-codex": [
{
"id": "cred-1",
"label": "work-account",
"auth_type": "oauth",
"priority": 0,
"source": "manual:device_code",
"access_token": "tok-1",
},
{
"id": "cred-2",
"label": "personal-account",
"auth_type": "oauth",
"priority": 1,
"source": "manual:device_code",
"access_token": "tok-2",
},
]
},
},
)
from hermes_cli.auth_commands import auth_remove_command
class _Args:
provider = "openai-codex"
target = "personal-account"
auth_remove_command(_Args())
payload = json.loads((tmp_path / "hermes" / "auth.json").read_text())
entries = payload["credential_pool"]["openai-codex"]
assert len(entries) == 1
assert entries[0]["label"] == "work-account"
def test_auth_reset_clears_provider_statuses(tmp_path, monkeypatch, capsys):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
@ -389,3 +433,41 @@ def test_auth_list_shows_exhausted_cooldown(monkeypatch, capsys):
out = capsys.readouterr().out
assert "exhausted (429)" in out
assert "59m 30s left" in out
def test_auth_list_prefers_explicit_reset_time(monkeypatch, capsys):
from hermes_cli.auth_commands import auth_list_command
class _Entry:
id = "cred-1"
label = "weekly"
auth_type = "oauth"
source = "manual:device_code"
last_status = "exhausted"
last_error_code = 429
last_error_reason = "device_code_exhausted"
last_error_message = "Weekly credits exhausted."
last_error_reset_at = "2026-04-12T10:30:00Z"
last_status_at = 1000.0
class _Pool:
def entries(self):
return [_Entry()]
def peek(self):
return None
monkeypatch.setattr("hermes_cli.auth_commands.load_pool", lambda provider: _Pool())
monkeypatch.setattr(
"hermes_cli.auth_commands.time.time",
lambda: datetime(2026, 4, 5, 10, 30, tzinfo=timezone.utc).timestamp(),
)
class _Args:
provider = "openai-codex"
auth_list_command(_Args())
out = capsys.readouterr().out
assert "device_code_exhausted" in out
assert "7d 0h left" in out

View file

@ -214,6 +214,39 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch):
assert entry.last_status == "ok"
def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(
tmp_path,
{
"version": 1,
"credential_pool": {
"openai-codex": [
{
"id": "cred-1",
"label": "weekly-reset",
"auth_type": "oauth",
"priority": 0,
"source": "manual:device_code",
"access_token": "tok-1",
"last_status": "exhausted",
"last_status_at": time.time() - 7200,
"last_error_code": 429,
"last_error_reason": "device_code_exhausted",
"last_error_reset_at": time.time() + 7 * 24 * 60 * 60,
}
]
},
},
)
from agent.credential_pool import load_pool
pool = load_pool("openai-codex")
assert pool.has_available() is False
assert pool.select() is None
def test_mark_exhausted_and_rotate_persists_status(tmp_path, monkeypatch):
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes"))
_write_auth_store(

View file

@ -275,7 +275,7 @@ class TestPoolRotationCycle:
# mark_exhausted_and_rotate returns next entry until exhausted
self._rotation_index = 0
def rotate(status_code=None):
def rotate(status_code=None, error_context=None):
self._rotation_index += 1
if self._rotation_index < pool_entries:
return entries[self._rotation_index]
@ -307,7 +307,7 @@ class TestPoolRotationCycle:
)
assert recovered is True
assert has_retried is False # reset after rotation
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429)
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429, error_context=None)
agent._swap_credential.assert_called_once_with(entries[1])
def test_pool_exhaustion_returns_false(self):
@ -333,7 +333,7 @@ class TestPoolRotationCycle:
)
assert recovered is True
assert has_retried is False
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402)
pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402, error_context=None)
def test_no_pool_returns_false(self):
"""No pool should return (False, unchanged)."""

View file

@ -1956,8 +1956,9 @@ class TestCredentialPoolRecovery:
def current(self):
return current
def mark_exhausted_and_rotate(self, *, status_code):
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
assert status_code == 402
assert error_context is None
return next_entry
agent._credential_pool = _Pool()
@ -1979,8 +1980,9 @@ class TestCredentialPoolRecovery:
def current(self):
return SimpleNamespace(label="primary")
def mark_exhausted_and_rotate(self, *, status_code):
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
assert status_code == 429
assert error_context is None
return next_entry
agent._credential_pool = _Pool()
@ -2030,8 +2032,9 @@ class TestCredentialPoolRecovery:
def try_refresh_current(self):
return None # refresh failed
def mark_exhausted_and_rotate(self, *, status_code):
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
assert status_code == 401
assert error_context is None
return next_entry
agent._credential_pool = _Pool()
@ -2053,7 +2056,8 @@ class TestCredentialPoolRecovery:
def try_refresh_current(self):
return None
def mark_exhausted_and_rotate(self, *, status_code):
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
assert error_context is None
return None # no more credentials
agent._credential_pool = _Pool()
@ -2067,6 +2071,52 @@ class TestCredentialPoolRecovery:
assert recovered is False
agent._swap_credential.assert_not_called()
def test_extract_api_error_context_uses_reset_timestamp_and_reason(self, agent):
response = SimpleNamespace(headers={})
error = SimpleNamespace(
body={
"error": {
"code": "device_code_exhausted",
"message": "Weekly credits exhausted.",
"resets_at": "2026-04-12T10:30:00Z",
}
},
response=response,
)
context = agent._extract_api_error_context(error)
assert context["reason"] == "device_code_exhausted"
assert context["message"] == "Weekly credits exhausted."
assert context["reset_at"] == "2026-04-12T10:30:00Z"
def test_recover_with_pool_passes_error_context_on_rotated_429(self, agent):
next_entry = SimpleNamespace(label="secondary")
captured = {}
class _Pool:
def current(self):
return SimpleNamespace(label="primary")
def mark_exhausted_and_rotate(self, *, status_code, error_context=None):
captured["status_code"] = status_code
captured["error_context"] = error_context
return next_entry
agent._credential_pool = _Pool()
agent._swap_credential = MagicMock()
recovered, retry_same = agent._recover_with_credential_pool(
status_code=429,
has_retried_429=True,
error_context={"reason": "device_code_exhausted", "reset_at": "2026-04-12T10:30:00Z"},
)
assert recovered is True
assert retry_same is False
assert captured["status_code"] == 429
assert captured["error_context"]["reason"] == "device_code_exhausted"
class TestMaxTokensParam:
"""Verify _max_tokens_param returns the correct key for each provider."""