diff --git a/agent/credential_pool.py b/agent/credential_pool.py index 2cf9efe562..cfdf9b2aca 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -8,7 +8,9 @@ import threading import time import uuid import os +import re from dataclasses import dataclass, fields, replace +from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set, Tuple from hermes_constants import OPENROUTER_BASE_URL @@ -95,6 +97,9 @@ class PooledCredential: last_status: Optional[str] = None last_status_at: Optional[float] = None last_error_code: Optional[int] = None + last_error_reason: Optional[str] = None + last_error_message: Optional[str] = None + last_error_reset_at: Optional[float] = None base_url: Optional[str] = None expires_at: Optional[str] = None expires_at_ms: Optional[int] = None @@ -129,7 +134,14 @@ class PooledCredential: return cls(provider=provider, **data) def to_dict(self) -> Dict[str, Any]: - _ALWAYS_EMIT = {"last_status", "last_status_at", "last_error_code"} + _ALWAYS_EMIT = { + "last_status", + "last_status_at", + "last_error_code", + "last_error_reason", + "last_error_message", + "last_error_reset_at", + } result: Dict[str, Any] = {} for field_def in fields(self): if field_def.name in ("provider", "extra"): @@ -180,6 +192,85 @@ def _exhausted_ttl(error_code: Optional[int]) -> int: return EXHAUSTED_TTL_DEFAULT_SECONDS +def _parse_absolute_timestamp(value: Any) -> Optional[float]: + """Best-effort parse for provider reset timestamps. + + Accepts epoch seconds, epoch milliseconds, and ISO-8601 strings. + Returns seconds since epoch. + """ + if value is None or value == "": + return None + if isinstance(value, (int, float)): + numeric = float(value) + if numeric <= 0: + return None + return numeric / 1000.0 if numeric > 1_000_000_000_000 else numeric + if isinstance(value, str): + raw = value.strip() + if not raw: + return None + try: + numeric = float(raw) + except ValueError: + numeric = None + if numeric is not None: + return numeric / 1000.0 if numeric > 1_000_000_000_000 else numeric + try: + return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() + except ValueError: + return None + return None + + +def _extract_retry_delay_seconds(message: str) -> Optional[float]: + if not message: + return None + delay_match = re.search(r"quotaResetDelay[:\s\"]+(\d+(?:\.\d+)?)(ms|s)", message, re.IGNORECASE) + if delay_match: + value = float(delay_match.group(1)) + return value / 1000.0 if delay_match.group(2).lower() == "ms" else value + sec_match = re.search(r"retry\s+(?:after\s+)?(\d+(?:\.\d+)?)\s*(?:sec|secs|seconds|s\b)", message, re.IGNORECASE) + if sec_match: + return float(sec_match.group(1)) + return None + + +def _normalize_error_context(error_context: Optional[Dict[str, Any]]) -> Dict[str, Any]: + if not isinstance(error_context, dict): + return {} + normalized: Dict[str, Any] = {} + reason = error_context.get("reason") + if isinstance(reason, str) and reason.strip(): + normalized["reason"] = reason.strip() + message = error_context.get("message") + if isinstance(message, str) and message.strip(): + normalized["message"] = message.strip() + reset_at = ( + error_context.get("reset_at") + or error_context.get("resets_at") + or error_context.get("retry_until") + ) + parsed_reset_at = _parse_absolute_timestamp(reset_at) + if parsed_reset_at is None and isinstance(message, str): + retry_delay_seconds = _extract_retry_delay_seconds(message) + if retry_delay_seconds is not None: + parsed_reset_at = time.time() + retry_delay_seconds + if parsed_reset_at is not None: + normalized["reset_at"] = parsed_reset_at + return normalized + + +def _exhausted_until(entry: PooledCredential) -> Optional[float]: + if entry.last_status != STATUS_EXHAUSTED: + return None + reset_at = _parse_absolute_timestamp(getattr(entry, "last_error_reset_at", None)) + if reset_at is not None: + return reset_at + if entry.last_status_at: + return entry.last_status_at + _exhausted_ttl(entry.last_error_code) + return None + + def _normalize_custom_pool_name(name: str) -> str: """Normalize a custom provider name for use as a pool key suffix.""" return name.strip().lower().replace(" ", "-") @@ -292,12 +383,21 @@ class CredentialPool: [entry.to_dict() for entry in self._entries], ) - def _mark_exhausted(self, entry: PooledCredential, status_code: Optional[int]) -> PooledCredential: + def _mark_exhausted( + self, + entry: PooledCredential, + status_code: Optional[int], + error_context: Optional[Dict[str, Any]] = None, + ) -> PooledCredential: + normalized_error = _normalize_error_context(error_context) updated = replace( entry, last_status=STATUS_EXHAUSTED, last_status_at=time.time(), last_error_code=status_code, + last_error_reason=normalized_error.get("reason"), + last_error_message=normalized_error.get("message"), + last_error_reset_at=normalized_error.get("reset_at"), ) self._replace_entry(entry, updated) self._persist() @@ -462,7 +562,15 @@ class CredentialPool: self._mark_exhausted(entry, None) return None - updated = replace(updated, last_status=STATUS_OK, last_status_at=None, last_error_code=None) + updated = replace( + updated, + last_status=STATUS_OK, + last_status_at=None, + last_error_code=None, + last_error_reason=None, + last_error_message=None, + last_error_reset_at=None, + ) self._replace_entry(entry, updated) self._persist() return updated @@ -522,11 +630,19 @@ class CredentialPool: entry = synced cleared_any = True if entry.last_status == STATUS_EXHAUSTED: - ttl = _exhausted_ttl(entry.last_error_code) - if entry.last_status_at and now - entry.last_status_at < ttl: + exhausted_until = _exhausted_until(entry) + if exhausted_until is not None and now < exhausted_until: continue if clear_expired: - cleared = replace(entry, last_status=STATUS_OK, last_status_at=None, last_error_code=None) + cleared = replace( + entry, + last_status=STATUS_OK, + last_status_at=None, + last_error_code=None, + last_error_reason=None, + last_error_message=None, + last_error_reset_at=None, + ) self._replace_entry(entry, cleared) entry = cleared cleared_any = True @@ -576,12 +692,17 @@ class CredentialPool: available = self._available_entries() return available[0] if available else None - def mark_exhausted_and_rotate(self, *, status_code: Optional[int]) -> Optional[PooledCredential]: + def mark_exhausted_and_rotate( + self, + *, + status_code: Optional[int], + error_context: Optional[Dict[str, Any]] = None, + ) -> Optional[PooledCredential]: with self._lock: entry = self.current() or self._select_unlocked() if entry is None: return None - self._mark_exhausted(entry, status_code) + self._mark_exhausted(entry, status_code, error_context) self._current_id = None return self._select_unlocked() @@ -603,7 +724,17 @@ class CredentialPool: new_entries = [] for entry in self._entries: if entry.last_status or entry.last_status_at or entry.last_error_code: - new_entries.append(replace(entry, last_status=None, last_status_at=None, last_error_code=None)) + new_entries.append( + replace( + entry, + last_status=None, + last_status_at=None, + last_error_code=None, + last_error_reason=None, + last_error_message=None, + last_error_reset_at=None, + ) + ) count += 1 else: new_entries.append(entry) @@ -625,6 +756,31 @@ class CredentialPool: self._current_id = None return removed + def resolve_target(self, target: Any) -> Tuple[Optional[int], Optional[PooledCredential], Optional[str]]: + raw = str(target or "").strip() + if not raw: + return None, None, "No credential target provided." + if raw.isdigit(): + index = int(raw) + if 1 <= index <= len(self._entries): + return index, self._entries[index - 1], None + return None, None, f"No credential #{index}." + + for idx, entry in enumerate(self._entries, start=1): + if entry.id == raw: + return idx, entry, None + + label_matches = [ + (idx, entry) + for idx, entry in enumerate(self._entries, start=1) + if entry.label.strip().lower() == raw.lower() + ] + if len(label_matches) == 1: + return label_matches[0][0], label_matches[0][1], None + if len(label_matches) > 1: + return None, None, f'Ambiguous credential label "{raw}". Use the numeric index or entry id instead.' + return None, None, f'No credential matching "{raw}".' + def add_entry(self, entry: PooledCredential) -> PooledCredential: entry = replace(entry, priority=_next_priority(self._entries)) self._entries.append(entry) diff --git a/hermes_cli/auth_commands.py b/hermes_cli/auth_commands.py index 0963877461..62b9562f3c 100644 --- a/hermes_cli/auth_commands.py +++ b/hermes_cli/auth_commands.py @@ -20,12 +20,12 @@ from agent.credential_pool import ( STRATEGY_LEAST_USED, SUPPORTED_POOL_STRATEGIES, PooledCredential, + _exhausted_until, _normalize_custom_pool_name, get_pool_strategy, label_from_token, list_custom_pool_providers, load_pool, - _exhausted_ttl, ) import hermes_cli.auth as auth_mod from hermes_cli.auth import PROVIDER_REGISTRY @@ -113,21 +113,27 @@ def _display_source(source: str) -> str: def _format_exhausted_status(entry) -> str: if entry.last_status != STATUS_EXHAUSTED: return "" + reason = getattr(entry, "last_error_reason", None) + reason_text = f" {reason}" if isinstance(reason, str) and reason.strip() else "" code = f" ({entry.last_error_code})" if entry.last_error_code else "" - if not entry.last_status_at: - return f" exhausted{code}" - remaining = max(0, int(math.ceil((entry.last_status_at + _exhausted_ttl(entry.last_error_code)) - time.time()))) + exhausted_until = _exhausted_until(entry) + if exhausted_until is None: + return f" exhausted{reason_text}{code}" + remaining = max(0, int(math.ceil(exhausted_until - time.time()))) if remaining <= 0: - return f" exhausted{code} (ready to retry)" + return f" exhausted{reason_text}{code} (ready to retry)" minutes, seconds = divmod(remaining, 60) hours, minutes = divmod(minutes, 60) - if hours: + days, hours = divmod(hours, 24) + if days: + wait = f"{days}d {hours}h" + elif hours: wait = f"{hours}h {minutes}m" elif minutes: wait = f"{minutes}m {seconds}s" else: wait = f"{seconds}s" - return f" exhausted{code} ({wait} left)" + return f" exhausted{reason_text}{code} ({wait} left)" def auth_add_command(args) -> None: @@ -277,11 +283,16 @@ def auth_list_command(args) -> None: def auth_remove_command(args) -> None: provider = _normalize_provider(getattr(args, "provider", "")) - index = int(getattr(args, "index")) + target = getattr(args, "target", None) + if target is None: + target = getattr(args, "index", None) pool = load_pool(provider) + index, matched, error = pool.resolve_target(target) + if matched is None or index is None: + raise SystemExit(f"{error} Provider: {provider}.") removed = pool.remove_index(index) if removed is None: - raise SystemExit(f"No credential #{index} for provider {provider}.") + raise SystemExit(f'No credential matching "{target}" for provider {provider}.') print(f"Removed {provider} credential #{index} ({removed.label})") @@ -369,8 +380,16 @@ def _interactive_add() -> None: else: auth_type = "api_key" + label = None + try: + typed_label = input("Label / account name (optional): ").strip() + except (EOFError, KeyboardInterrupt): + return + if typed_label: + label = typed_label + auth_add_command(SimpleNamespace( - provider=provider, auth_type=auth_type, label=None, api_key=None, + provider=provider, auth_type=auth_type, label=label, api_key=None, portal_url=None, inference_url=None, client_id=None, scope=None, no_browser=False, timeout=None, insecure=False, ca_bundle=None, )) @@ -386,22 +405,16 @@ def _interactive_remove() -> None: # Show entries with indices for i, e in enumerate(pool.entries(), 1): exhausted = _format_exhausted_status(e) - print(f" #{i} {e.label:25s} {e.auth_type:10s} {e.source}{exhausted}") + print(f" #{i} {e.label:25s} {e.auth_type:10s} {e.source}{exhausted} [id:{e.id}]") try: - raw = input("Remove # (or blank to cancel): ").strip() + raw = input("Remove #, id, or label (blank to cancel): ").strip() except (EOFError, KeyboardInterrupt): return if not raw: return - try: - index = int(raw) - except ValueError: - print("Invalid number.") - return - - auth_remove_command(SimpleNamespace(provider=provider, index=index)) + auth_remove_command(SimpleNamespace(provider=provider, target=raw)) def _interactive_reset() -> None: diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 3befd835cd..91f97d4505 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -3943,7 +3943,7 @@ Examples: hermes logout Clear stored authentication hermes auth add Add a pooled credential hermes auth list List pooled credentials - hermes auth remove

Remove pooled credential by index + hermes auth remove

Remove pooled credential by index, id, or label hermes auth reset Clear exhaustion status for a provider hermes model Select default model hermes config View configuration @@ -4333,9 +4333,9 @@ For more help on a command: auth_add.add_argument("--ca-bundle", help="Custom CA bundle for OAuth login") auth_list = auth_subparsers.add_parser("list", help="List pooled credentials") auth_list.add_argument("provider", nargs="?", help="Optional provider filter") - auth_remove = auth_subparsers.add_parser("remove", help="Remove a pooled credential by index") + auth_remove = auth_subparsers.add_parser("remove", help="Remove a pooled credential by index, id, or label") auth_remove.add_argument("provider", help="Provider id") - auth_remove.add_argument("index", type=int, help="1-based credential index") + auth_remove.add_argument("target", help="Credential index, entry id, or exact label") auth_reset = auth_subparsers.add_parser("reset", help="Clear exhaustion status for all credentials for a provider") auth_reset.add_argument("provider", help="Provider id") auth_parser.set_defaults(func=cmd_auth) diff --git a/run_agent.py b/run_agent.py index 80fc340a3e..34e4ca71e8 100644 --- a/run_agent.py +++ b/run_agent.py @@ -2138,6 +2138,73 @@ class AIAgent: return cleaned + @staticmethod + def _extract_api_error_context(error: Exception) -> Dict[str, Any]: + """Extract structured rate-limit details from provider errors.""" + context: Dict[str, Any] = {} + + body = getattr(error, "body", None) + payload = None + if isinstance(body, dict): + payload = body.get("error") if isinstance(body.get("error"), dict) else body + if isinstance(payload, dict): + reason = payload.get("code") or payload.get("error") + if isinstance(reason, str) and reason.strip(): + context["reason"] = reason.strip() + message = payload.get("message") or payload.get("error_description") + if isinstance(message, str) and message.strip(): + context["message"] = message.strip() + for key in ("resets_at", "reset_at"): + value = payload.get(key) + if value not in (None, ""): + context["reset_at"] = value + break + retry_after = payload.get("retry_after") + if retry_after not in (None, "") and "reset_at" not in context: + try: + context["reset_at"] = time.time() + float(retry_after) + except (TypeError, ValueError): + pass + + response = getattr(error, "response", None) + headers = getattr(response, "headers", None) + if headers: + retry_after = headers.get("retry-after") or headers.get("Retry-After") + if retry_after and "reset_at" not in context: + try: + context["reset_at"] = time.time() + float(retry_after) + except (TypeError, ValueError): + pass + ratelimit_reset = headers.get("x-ratelimit-reset") + if ratelimit_reset and "reset_at" not in context: + context["reset_at"] = ratelimit_reset + + if "message" not in context: + raw_message = str(error).strip() + if raw_message: + context["message"] = raw_message[:500] + + if "reset_at" not in context: + message = context.get("message") or "" + if isinstance(message, str): + import re as _re + + delay_match = _re.search(r"quotaResetDelay[:\s\"]+(\d+(?:\.\d+)?)(ms|s)", message, _re.IGNORECASE) + if delay_match: + value = float(delay_match.group(1)) + seconds = value / 1000.0 if delay_match.group(2).lower() == "ms" else value + context["reset_at"] = time.time() + seconds + else: + sec_match = _re.search( + r"retry\s+(?:after\s+)?(\d+(?:\.\d+)?)\s*(?:sec|secs|seconds|s\b)", + message, + _re.IGNORECASE, + ) + if sec_match: + context["reset_at"] = time.time() + float(sec_match.group(1)) + + return context + def _dump_api_request_debug( self, api_kwargs: Dict[str, Any], @@ -3846,6 +3913,7 @@ class AIAgent: *, status_code: Optional[int], has_retried_429: bool, + error_context: Optional[Dict[str, Any]] = None, ) -> tuple[bool, bool]: """Attempt credential recovery via pool rotation. @@ -3860,7 +3928,7 @@ class AIAgent: return False, has_retried_429 if status_code == 402: - next_entry = pool.mark_exhausted_and_rotate(status_code=402) + next_entry = pool.mark_exhausted_and_rotate(status_code=402, error_context=error_context) if next_entry is not None: logger.info(f"Credential 402 (billing) — rotated to pool entry {getattr(next_entry, 'id', '?')}") self._swap_credential(next_entry) @@ -3870,7 +3938,7 @@ class AIAgent: if status_code == 429: if not has_retried_429: return False, True - next_entry = pool.mark_exhausted_and_rotate(status_code=429) + next_entry = pool.mark_exhausted_and_rotate(status_code=429, error_context=error_context) if next_entry is not None: logger.info(f"Credential 429 (rate limit) — rotated to pool entry {getattr(next_entry, 'id', '?')}") self._swap_credential(next_entry) @@ -3885,7 +3953,7 @@ class AIAgent: return True, has_retried_429 # Refresh failed — rotate to next credential instead of giving up. # The failed entry is already marked exhausted by try_refresh_current(). - next_entry = pool.mark_exhausted_and_rotate(status_code=401) + next_entry = pool.mark_exhausted_and_rotate(status_code=401, error_context=error_context) if next_entry is not None: logger.info(f"Credential 401 (refresh failed) — rotated to pool entry {getattr(next_entry, 'id', '?')}") self._swap_credential(next_entry) @@ -7377,9 +7445,11 @@ class AIAgent: # prompt or prefill. Fall through to normal error path. status_code = getattr(api_error, "status_code", None) + error_context = self._extract_api_error_context(api_error) recovered_with_pool, has_retried_429 = self._recover_with_credential_pool( status_code=status_code, has_retried_429=has_retried_429, + error_context=error_context, ) if recovered_with_pool: continue diff --git a/tests/test_auth_commands.py b/tests/test_auth_commands.py index c556294046..bd40cb8855 100644 --- a/tests/test_auth_commands.py +++ b/tests/test_auth_commands.py @@ -4,6 +4,7 @@ from __future__ import annotations import base64 import json +from datetime import datetime, timezone import pytest @@ -224,7 +225,7 @@ def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch): class _Args: provider = "anthropic" - index = 1 + target = "1" auth_remove_command(_Args()) @@ -235,6 +236,49 @@ def test_auth_remove_reindexes_priorities(tmp_path, monkeypatch): assert entries[0]["priority"] == 0 +def test_auth_remove_accepts_label_target(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openai-codex": [ + { + "id": "cred-1", + "label": "work-account", + "auth_type": "oauth", + "priority": 0, + "source": "manual:device_code", + "access_token": "tok-1", + }, + { + "id": "cred-2", + "label": "personal-account", + "auth_type": "oauth", + "priority": 1, + "source": "manual:device_code", + "access_token": "tok-2", + }, + ] + }, + }, + ) + + from hermes_cli.auth_commands import auth_remove_command + + class _Args: + provider = "openai-codex" + target = "personal-account" + + auth_remove_command(_Args()) + + payload = json.loads((tmp_path / "hermes" / "auth.json").read_text()) + entries = payload["credential_pool"]["openai-codex"] + assert len(entries) == 1 + assert entries[0]["label"] == "work-account" + + def test_auth_reset_clears_provider_statuses(tmp_path, monkeypatch, capsys): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) _write_auth_store( @@ -389,3 +433,41 @@ def test_auth_list_shows_exhausted_cooldown(monkeypatch, capsys): out = capsys.readouterr().out assert "exhausted (429)" in out assert "59m 30s left" in out + + +def test_auth_list_prefers_explicit_reset_time(monkeypatch, capsys): + from hermes_cli.auth_commands import auth_list_command + + class _Entry: + id = "cred-1" + label = "weekly" + auth_type = "oauth" + source = "manual:device_code" + last_status = "exhausted" + last_error_code = 429 + last_error_reason = "device_code_exhausted" + last_error_message = "Weekly credits exhausted." + last_error_reset_at = "2026-04-12T10:30:00Z" + last_status_at = 1000.0 + + class _Pool: + def entries(self): + return [_Entry()] + + def peek(self): + return None + + monkeypatch.setattr("hermes_cli.auth_commands.load_pool", lambda provider: _Pool()) + monkeypatch.setattr( + "hermes_cli.auth_commands.time.time", + lambda: datetime(2026, 4, 5, 10, 30, tzinfo=timezone.utc).timestamp(), + ) + + class _Args: + provider = "openai-codex" + + auth_list_command(_Args()) + + out = capsys.readouterr().out + assert "device_code_exhausted" in out + assert "7d 0h left" in out diff --git a/tests/test_credential_pool.py b/tests/test_credential_pool.py index 14302ab13f..ff6e037be3 100644 --- a/tests/test_credential_pool.py +++ b/tests/test_credential_pool.py @@ -214,6 +214,39 @@ def test_exhausted_entry_resets_after_ttl(tmp_path, monkeypatch): assert entry.last_status == "ok" +def test_explicit_reset_timestamp_overrides_default_429_ttl(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) + _write_auth_store( + tmp_path, + { + "version": 1, + "credential_pool": { + "openai-codex": [ + { + "id": "cred-1", + "label": "weekly-reset", + "auth_type": "oauth", + "priority": 0, + "source": "manual:device_code", + "access_token": "tok-1", + "last_status": "exhausted", + "last_status_at": time.time() - 7200, + "last_error_code": 429, + "last_error_reason": "device_code_exhausted", + "last_error_reset_at": time.time() + 7 * 24 * 60 * 60, + } + ] + }, + }, + ) + + from agent.credential_pool import load_pool + + pool = load_pool("openai-codex") + assert pool.has_available() is False + assert pool.select() is None + + def test_mark_exhausted_and_rotate_persists_status(tmp_path, monkeypatch): monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes")) _write_auth_store( diff --git a/tests/test_credential_pool_routing.py b/tests/test_credential_pool_routing.py index f4006a236f..38f5c6dfd0 100644 --- a/tests/test_credential_pool_routing.py +++ b/tests/test_credential_pool_routing.py @@ -275,7 +275,7 @@ class TestPoolRotationCycle: # mark_exhausted_and_rotate returns next entry until exhausted self._rotation_index = 0 - def rotate(status_code=None): + def rotate(status_code=None, error_context=None): self._rotation_index += 1 if self._rotation_index < pool_entries: return entries[self._rotation_index] @@ -307,7 +307,7 @@ class TestPoolRotationCycle: ) assert recovered is True assert has_retried is False # reset after rotation - pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429) + pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429, error_context=None) agent._swap_credential.assert_called_once_with(entries[1]) def test_pool_exhaustion_returns_false(self): @@ -333,7 +333,7 @@ class TestPoolRotationCycle: ) assert recovered is True assert has_retried is False - pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402) + pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402, error_context=None) def test_no_pool_returns_false(self): """No pool should return (False, unchanged).""" diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index 9217117e25..963ee56f31 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -1956,8 +1956,9 @@ class TestCredentialPoolRecovery: def current(self): return current - def mark_exhausted_and_rotate(self, *, status_code): + def mark_exhausted_and_rotate(self, *, status_code, error_context=None): assert status_code == 402 + assert error_context is None return next_entry agent._credential_pool = _Pool() @@ -1979,8 +1980,9 @@ class TestCredentialPoolRecovery: def current(self): return SimpleNamespace(label="primary") - def mark_exhausted_and_rotate(self, *, status_code): + def mark_exhausted_and_rotate(self, *, status_code, error_context=None): assert status_code == 429 + assert error_context is None return next_entry agent._credential_pool = _Pool() @@ -2030,8 +2032,9 @@ class TestCredentialPoolRecovery: def try_refresh_current(self): return None # refresh failed - def mark_exhausted_and_rotate(self, *, status_code): + def mark_exhausted_and_rotate(self, *, status_code, error_context=None): assert status_code == 401 + assert error_context is None return next_entry agent._credential_pool = _Pool() @@ -2053,7 +2056,8 @@ class TestCredentialPoolRecovery: def try_refresh_current(self): return None - def mark_exhausted_and_rotate(self, *, status_code): + def mark_exhausted_and_rotate(self, *, status_code, error_context=None): + assert error_context is None return None # no more credentials agent._credential_pool = _Pool() @@ -2067,6 +2071,52 @@ class TestCredentialPoolRecovery: assert recovered is False agent._swap_credential.assert_not_called() + def test_extract_api_error_context_uses_reset_timestamp_and_reason(self, agent): + response = SimpleNamespace(headers={}) + error = SimpleNamespace( + body={ + "error": { + "code": "device_code_exhausted", + "message": "Weekly credits exhausted.", + "resets_at": "2026-04-12T10:30:00Z", + } + }, + response=response, + ) + + context = agent._extract_api_error_context(error) + + assert context["reason"] == "device_code_exhausted" + assert context["message"] == "Weekly credits exhausted." + assert context["reset_at"] == "2026-04-12T10:30:00Z" + + def test_recover_with_pool_passes_error_context_on_rotated_429(self, agent): + next_entry = SimpleNamespace(label="secondary") + captured = {} + + class _Pool: + def current(self): + return SimpleNamespace(label="primary") + + def mark_exhausted_and_rotate(self, *, status_code, error_context=None): + captured["status_code"] = status_code + captured["error_context"] = error_context + return next_entry + + agent._credential_pool = _Pool() + agent._swap_credential = MagicMock() + + recovered, retry_same = agent._recover_with_credential_pool( + status_code=429, + has_retried_429=True, + error_context={"reason": "device_code_exhausted", "reset_at": "2026-04-12T10:30:00Z"}, + ) + + assert recovered is True + assert retry_same is False + assert captured["status_code"] == 429 + assert captured["error_context"]["reason"] == "device_code_exhausted" + class TestMaxTokensParam: """Verify _max_tokens_param returns the correct key for each provider."""