fix(auxiliary): rotate pooled auth after quota failures

This commit is contained in:
Maxim Esipov 2026-05-09 22:44:12 +03:00 committed by Teknium
parent 775c0e22cf
commit 17d8914850
2 changed files with 472 additions and 63 deletions

View file

@ -490,6 +490,29 @@ def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]:
return True, None
def _peek_pool_entry(provider: str) -> Optional[Any]:
"""Best-effort current/next pool entry without mutating selection order."""
try:
pool = load_pool(provider)
except Exception as exc:
logger.debug("Auxiliary client: could not load pool for %s (peek): %s", provider, exc)
return None
if not pool or not pool.has_credentials():
return None
try:
current_fn = getattr(pool, "current", None)
if callable(current_fn):
current = current_fn()
if current is not None:
return current
peek_fn = getattr(pool, "peek", None)
if callable(peek_fn):
return peek_fn()
except Exception as exc:
logger.debug("Auxiliary client: could not peek pool entry for %s: %s", provider, exc)
return None
def _pool_runtime_api_key(entry: Any) -> str:
if entry is None:
return ""
@ -1908,6 +1931,211 @@ def _evict_cached_clients(provider: str) -> None:
_client_cache.pop(key, None)
def _pool_cache_hint(
provider: str,
*,
main_runtime: Optional[Dict[str, Any]] = None,
) -> str:
"""Return a stable cache discriminator for pooled providers."""
normalized = _normalize_aux_provider(provider)
if normalized == "auto":
runtime = _normalize_main_runtime(main_runtime)
normalized = _normalize_aux_provider(runtime.get("provider") or _read_main_provider())
if normalized in ("", "auto", "custom"):
return ""
entry = _peek_pool_entry(normalized)
if entry is None:
return ""
entry_id = str(getattr(entry, "id", "") or "").strip()
if not entry_id:
return ""
return f"{normalized}:{entry_id}"
def _pool_error_context(exc: Exception) -> Dict[str, Any]:
status = getattr(exc, "status_code", None)
payload: Dict[str, Any] = {"message": str(exc)}
if status is not None:
payload["status_code"] = status
return payload
def _recoverable_pool_provider(resolved_provider: str, client: Any) -> Optional[str]:
"""Infer which provider pool can recover the current auxiliary client."""
normalized = _normalize_aux_provider(resolved_provider)
if normalized not in ("", "auto", "custom"):
return normalized
base = str(getattr(client, "base_url", "") or "")
if base_url_host_matches(base, "chatgpt.com"):
return "openai-codex"
if base_url_host_matches(base, "openrouter.ai"):
return "openrouter"
if base_url_host_matches(base, "inference-api.nousresearch.com"):
return "nous"
if base_url_host_matches(base, "api.anthropic.com"):
return "anthropic"
if base_url_host_matches(base, "api.githubcopilot.com"):
return "copilot"
if base_url_host_matches(base, "api.kimi.com"):
return "kimi-coding"
return None
def _recover_provider_pool(provider: str, exc: Exception) -> bool:
"""Try same-provider credential-pool recovery for auxiliary calls."""
normalized = _normalize_aux_provider(provider)
try:
pool = load_pool(normalized)
except Exception as load_exc:
logger.debug("Auxiliary client: could not load pool for %s recovery: %s", normalized, load_exc)
return False
if not pool or not pool.has_credentials():
return False
status_code = getattr(exc, "status_code", None)
error_context = _pool_error_context(exc)
if _is_auth_error(exc):
refreshed = pool.try_refresh_current()
if refreshed is not None:
_evict_cached_clients(normalized)
return True
next_entry = pool.mark_exhausted_and_rotate(
status_code=status_code if status_code is not None else 401,
error_context=error_context,
)
if next_entry is not None:
_evict_cached_clients(normalized)
return True
return False
if _is_payment_error(exc) or _is_rate_limit_error(exc):
fallback_status = 402 if _is_payment_error(exc) else 429
next_entry = pool.mark_exhausted_and_rotate(
status_code=status_code if status_code is not None else fallback_status,
error_context=error_context,
)
if next_entry is not None:
_evict_cached_clients(normalized)
return True
return False
def _retry_same_provider_sync(
*,
task: Optional[str],
resolved_provider: str,
resolved_model: Optional[str],
resolved_base_url: Optional[str],
resolved_api_key: Optional[str],
resolved_api_mode: Optional[str],
main_runtime: Optional[Dict[str, Any]],
final_model: Optional[str],
messages: list,
temperature: Optional[float],
max_tokens: Optional[int],
tools: Optional[list],
effective_timeout: float,
effective_extra_body: dict,
) -> Any:
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
async_mode=False,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
)
if retry_client is None:
raise RuntimeError(
f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery"
)
retry_base = str(getattr(retry_client, "base_url", "") or "")
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=retry_base or resolved_base_url,
)
if _is_anthropic_compat_endpoint(resolved_provider, retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return _validate_llm_response(
retry_client.chat.completions.create(**retry_kwargs), task,
)
async def _retry_same_provider_async(
*,
task: Optional[str],
resolved_provider: str,
resolved_model: Optional[str],
resolved_base_url: Optional[str],
resolved_api_key: Optional[str],
resolved_api_mode: Optional[str],
final_model: Optional[str],
messages: list,
temperature: Optional[float],
max_tokens: Optional[int],
tools: Optional[list],
effective_timeout: float,
effective_extra_body: dict,
) -> Any:
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
async_mode=True,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
async_mode=True,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
)
if retry_client is None:
raise RuntimeError(
f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery"
)
retry_base = str(getattr(retry_client, "base_url", "") or "")
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=retry_base or resolved_base_url,
)
if _is_anthropic_compat_endpoint(resolved_provider, retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return _validate_llm_response(
await retry_client.chat.completions.create(**retry_kwargs), task,
)
def _refresh_provider_credentials(provider: str) -> bool:
"""Refresh short-lived credentials for OAuth-backed auxiliary providers."""
normalized = _normalize_aux_provider(provider)
@ -3033,7 +3261,8 @@ def _client_cache_key(
) -> tuple:
runtime = _normalize_main_runtime(main_runtime)
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision)
pool_hint = _pool_cache_hint(provider, main_runtime=main_runtime)
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision, pool_hint)
def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None:
@ -3821,39 +4050,56 @@ def call_llm(
"Auxiliary %s: refreshed %s credentials after auth error, retrying",
task or "call", resolved_provider,
)
retry_client, retry_model = (
resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
async_mode=False,
)[1:]
if task == "vision"
else _get_cached_client(
resolved_provider,
resolved_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
)
return _retry_same_provider_sync(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
main_runtime=main_runtime,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
if retry_client is not None:
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=resolved_base_url,
)
_retry_base = str(getattr(retry_client, "base_url", "") or "")
if _is_anthropic_compat_endpoint(resolved_provider, _retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
# ── Same-provider credential-pool recovery ─────────────────────
pool_provider = _recoverable_pool_provider(resolved_provider, client)
if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)):
recovery_err = first_err
if _is_rate_limit_error(first_err):
try:
return _validate_llm_response(
retry_client.chat.completions.create(**retry_kwargs), task)
client.chat.completions.create(**kwargs), task)
except Exception as retry_err:
if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)):
raise
recovery_err = retry_err
if _recover_provider_pool(pool_provider, recovery_err):
logger.info(
"Auxiliary %s: recovered %s via credential-pool rotation after %s",
task or "call", pool_provider, type(recovery_err).__name__,
)
return _retry_same_provider_sync(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
main_runtime=main_runtime,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Payment / credit exhaustion fallback ──────────────────────
# When the resolved provider returns 402 or a credit-related error,
@ -4136,38 +4382,54 @@ async def async_call_llm(
"Auxiliary %s (async): refreshed %s credentials after auth error, retrying",
task or "call", resolved_provider,
)
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
async_mode=True,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
async_mode=True,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
)
if retry_client is not None:
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=resolved_base_url,
)
_retry_base = str(getattr(retry_client, "base_url", "") or "")
if _is_anthropic_compat_endpoint(resolved_provider, _retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return await _retry_same_provider_async(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Same-provider credential-pool recovery (mirrors sync) ─────
pool_provider = _recoverable_pool_provider(resolved_provider, client)
if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)):
recovery_err = first_err
if _is_rate_limit_error(first_err):
try:
return _validate_llm_response(
await retry_client.chat.completions.create(**retry_kwargs), task)
await client.chat.completions.create(**kwargs), task)
except Exception as retry_err:
if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)):
raise
recovery_err = retry_err
if _recover_provider_pool(pool_provider, recovery_err):
logger.info(
"Auxiliary %s (async): recovered %s via credential-pool rotation after %s",
task or "call", pool_provider, type(recovery_err).__name__,
)
return await _retry_same_provider_async(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Payment / connection / rate-limit fallback (mirrors sync call_llm) ──
should_fallback = (

View file

@ -301,6 +301,52 @@ class TestBuildCodexClient:
assert client is None
assert model is None
def test_cached_codex_client_rebuilds_when_pool_entry_changes(self):
import agent.auxiliary_client as aux
class _Entry:
def __init__(self, entry_id, token):
self.id = entry_id
self.runtime_api_key = token
self.runtime_base_url = "https://chatgpt.com/backend-api/codex"
class _Pool:
def __init__(self):
self.entry = _Entry("cred-a", "tok-a")
def has_credentials(self):
return True
def current(self):
return self.entry
def peek(self):
return self.entry
def select(self):
return self.entry
pool = _Pool()
client_a = MagicMock(name="codex-client-a")
client_b = MagicMock(name="codex-client-b")
with (
patch("agent.auxiliary_client.load_pool", return_value=pool),
patch("agent.auxiliary_client.OpenAI", side_effect=[client_a, client_b]) as mock_openai,
):
aux.shutdown_cached_clients()
try:
first_client, first_model = aux._get_cached_client("openai-codex", "gpt-5.4")
pool.entry = _Entry("cred-b", "tok-b")
second_client, second_model = aux._get_cached_client("openai-codex", "gpt-5.4")
finally:
aux.shutdown_cached_clients()
assert first_client is not second_client
assert first_model == "gpt-5.4"
assert second_model == "gpt-5.4"
assert mock_openai.call_count == 2
class TestExpiredCodexFallback:
"""Test that expired Codex tokens don't block the auto chain."""
@ -1632,6 +1678,107 @@ class TestAuxiliaryAuthRefreshRetry:
assert fresh_client.chat.completions.create.await_count == 1
class TestAuxiliaryPoolRotationRetry:
def test_call_llm_rotates_explicit_codex_pool_on_429(self):
rate_err = Exception("usage limit reached")
rate_err.status_code = 429
stale_client = MagicMock()
stale_client.base_url = "https://chatgpt.com/backend-api/codex"
stale_client.chat.completions.create.side_effect = [rate_err, rate_err]
fresh_client = MagicMock()
fresh_client.base_url = "https://chatgpt.com/backend-api/codex"
fresh_client.chat.completions.create.return_value = _DummyResponse("rotated-sync")
class _Pool:
def __init__(self):
self.rotate_calls = []
def has_credentials(self):
return True
def try_refresh_current(self):
return None
def mark_exhausted_and_rotate(self, **kwargs):
self.rotate_calls.append(kwargs)
return SimpleNamespace(id="cred-b")
pool = _Pool()
with (
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)),
patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]),
patch("agent.auxiliary_client._refresh_provider_credentials", return_value=False),
patch("agent.auxiliary_client.load_pool", return_value=pool),
patch("agent.auxiliary_client._try_payment_fallback") as mock_fallback,
):
resp = call_llm(
task="compression",
provider="openai-codex",
model="gpt-5.4",
messages=[{"role": "user", "content": "hi"}],
)
assert resp.choices[0].message.content == "rotated-sync"
assert stale_client.chat.completions.create.call_count == 2
assert fresh_client.chat.completions.create.call_count == 1
assert len(pool.rotate_calls) == 1
assert pool.rotate_calls[0]["status_code"] == 429
mock_fallback.assert_not_called()
@pytest.mark.asyncio
async def test_async_call_llm_rotates_explicit_codex_pool_on_429(self):
rate_err = Exception("usage limit reached")
rate_err.status_code = 429
stale_client = MagicMock()
stale_client.base_url = "https://chatgpt.com/backend-api/codex"
stale_client.chat.completions.create = AsyncMock(side_effect=[rate_err, rate_err])
fresh_client = MagicMock()
fresh_client.base_url = "https://chatgpt.com/backend-api/codex"
fresh_client.chat.completions.create = AsyncMock(return_value=_DummyResponse("rotated-async"))
class _Pool:
def __init__(self):
self.rotate_calls = []
def has_credentials(self):
return True
def try_refresh_current(self):
return None
def mark_exhausted_and_rotate(self, **kwargs):
self.rotate_calls.append(kwargs)
return SimpleNamespace(id="cred-b")
pool = _Pool()
with (
patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)),
patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]),
patch("agent.auxiliary_client._refresh_provider_credentials", return_value=False),
patch("agent.auxiliary_client.load_pool", return_value=pool),
patch("agent.auxiliary_client._try_payment_fallback") as mock_fallback,
):
resp = await async_call_llm(
task="compression",
provider="openai-codex",
model="gpt-5.4",
messages=[{"role": "user", "content": "hi"}],
)
assert resp.choices[0].message.content == "rotated-async"
assert stale_client.chat.completions.create.await_count == 2
assert fresh_client.chat.completions.create.await_count == 1
assert len(pool.rotate_calls) == 1
assert pool.rotate_calls[0]["status_code"] == 429
mock_fallback.assert_not_called()
class TestCodexAdapterReasoningTranslation:
"""Verify _CodexCompletionsAdapter translates extra_body.reasoning
into the Responses API's top-level reasoning + include fields, matching