From 17d89148505f3b2d3252652d552133847e546472 Mon Sep 17 00:00:00 2001 From: Maxim Esipov Date: Sat, 9 May 2026 22:44:12 +0300 Subject: [PATCH] fix(auxiliary): rotate pooled auth after quota failures --- agent/auxiliary_client.py | 388 ++++++++++++++++++++++----- tests/agent/test_auxiliary_client.py | 147 ++++++++++ 2 files changed, 472 insertions(+), 63 deletions(-) diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index 00f461e77ef..3eefab1632b 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -490,6 +490,29 @@ def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]: return True, None +def _peek_pool_entry(provider: str) -> Optional[Any]: + """Best-effort current/next pool entry without mutating selection order.""" + try: + pool = load_pool(provider) + except Exception as exc: + logger.debug("Auxiliary client: could not load pool for %s (peek): %s", provider, exc) + return None + if not pool or not pool.has_credentials(): + return None + try: + current_fn = getattr(pool, "current", None) + if callable(current_fn): + current = current_fn() + if current is not None: + return current + peek_fn = getattr(pool, "peek", None) + if callable(peek_fn): + return peek_fn() + except Exception as exc: + logger.debug("Auxiliary client: could not peek pool entry for %s: %s", provider, exc) + return None + + def _pool_runtime_api_key(entry: Any) -> str: if entry is None: return "" @@ -1908,6 +1931,211 @@ def _evict_cached_clients(provider: str) -> None: _client_cache.pop(key, None) +def _pool_cache_hint( + provider: str, + *, + main_runtime: Optional[Dict[str, Any]] = None, +) -> str: + """Return a stable cache discriminator for pooled providers.""" + normalized = _normalize_aux_provider(provider) + if normalized == "auto": + runtime = _normalize_main_runtime(main_runtime) + normalized = _normalize_aux_provider(runtime.get("provider") or _read_main_provider()) + if normalized in ("", "auto", "custom"): + return "" + entry = _peek_pool_entry(normalized) + if entry is None: + return "" + entry_id = str(getattr(entry, "id", "") or "").strip() + if not entry_id: + return "" + return f"{normalized}:{entry_id}" + + +def _pool_error_context(exc: Exception) -> Dict[str, Any]: + status = getattr(exc, "status_code", None) + payload: Dict[str, Any] = {"message": str(exc)} + if status is not None: + payload["status_code"] = status + return payload + + +def _recoverable_pool_provider(resolved_provider: str, client: Any) -> Optional[str]: + """Infer which provider pool can recover the current auxiliary client.""" + normalized = _normalize_aux_provider(resolved_provider) + if normalized not in ("", "auto", "custom"): + return normalized + base = str(getattr(client, "base_url", "") or "") + if base_url_host_matches(base, "chatgpt.com"): + return "openai-codex" + if base_url_host_matches(base, "openrouter.ai"): + return "openrouter" + if base_url_host_matches(base, "inference-api.nousresearch.com"): + return "nous" + if base_url_host_matches(base, "api.anthropic.com"): + return "anthropic" + if base_url_host_matches(base, "api.githubcopilot.com"): + return "copilot" + if base_url_host_matches(base, "api.kimi.com"): + return "kimi-coding" + return None + + +def _recover_provider_pool(provider: str, exc: Exception) -> bool: + """Try same-provider credential-pool recovery for auxiliary calls.""" + normalized = _normalize_aux_provider(provider) + try: + pool = load_pool(normalized) + except Exception as load_exc: + logger.debug("Auxiliary client: could not load pool for %s recovery: %s", normalized, load_exc) + return False + if not pool or not pool.has_credentials(): + return False + + status_code = getattr(exc, "status_code", None) + error_context = _pool_error_context(exc) + + if _is_auth_error(exc): + refreshed = pool.try_refresh_current() + if refreshed is not None: + _evict_cached_clients(normalized) + return True + next_entry = pool.mark_exhausted_and_rotate( + status_code=status_code if status_code is not None else 401, + error_context=error_context, + ) + if next_entry is not None: + _evict_cached_clients(normalized) + return True + return False + + if _is_payment_error(exc) or _is_rate_limit_error(exc): + fallback_status = 402 if _is_payment_error(exc) else 429 + next_entry = pool.mark_exhausted_and_rotate( + status_code=status_code if status_code is not None else fallback_status, + error_context=error_context, + ) + if next_entry is not None: + _evict_cached_clients(normalized) + return True + return False + + +def _retry_same_provider_sync( + *, + task: Optional[str], + resolved_provider: str, + resolved_model: Optional[str], + resolved_base_url: Optional[str], + resolved_api_key: Optional[str], + resolved_api_mode: Optional[str], + main_runtime: Optional[Dict[str, Any]], + final_model: Optional[str], + messages: list, + temperature: Optional[float], + max_tokens: Optional[int], + tools: Optional[list], + effective_timeout: float, + effective_extra_body: dict, +) -> Any: + if task == "vision": + _, retry_client, retry_model = resolve_vision_provider_client( + provider=resolved_provider, + model=final_model, + base_url=resolved_base_url, + api_key=resolved_api_key, + async_mode=False, + ) + else: + retry_client, retry_model = _get_cached_client( + resolved_provider, + resolved_model, + base_url=resolved_base_url, + api_key=resolved_api_key, + api_mode=resolved_api_mode, + main_runtime=main_runtime, + ) + if retry_client is None: + raise RuntimeError( + f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery" + ) + + retry_base = str(getattr(retry_client, "base_url", "") or "") + retry_kwargs = _build_call_kwargs( + resolved_provider, + retry_model or final_model, + messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + timeout=effective_timeout, + extra_body=effective_extra_body, + base_url=retry_base or resolved_base_url, + ) + if _is_anthropic_compat_endpoint(resolved_provider, retry_base): + retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"]) + return _validate_llm_response( + retry_client.chat.completions.create(**retry_kwargs), task, + ) + + +async def _retry_same_provider_async( + *, + task: Optional[str], + resolved_provider: str, + resolved_model: Optional[str], + resolved_base_url: Optional[str], + resolved_api_key: Optional[str], + resolved_api_mode: Optional[str], + final_model: Optional[str], + messages: list, + temperature: Optional[float], + max_tokens: Optional[int], + tools: Optional[list], + effective_timeout: float, + effective_extra_body: dict, +) -> Any: + if task == "vision": + _, retry_client, retry_model = resolve_vision_provider_client( + provider=resolved_provider, + model=final_model, + base_url=resolved_base_url, + api_key=resolved_api_key, + async_mode=True, + ) + else: + retry_client, retry_model = _get_cached_client( + resolved_provider, + resolved_model, + async_mode=True, + base_url=resolved_base_url, + api_key=resolved_api_key, + api_mode=resolved_api_mode, + ) + if retry_client is None: + raise RuntimeError( + f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery" + ) + + retry_base = str(getattr(retry_client, "base_url", "") or "") + retry_kwargs = _build_call_kwargs( + resolved_provider, + retry_model or final_model, + messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + timeout=effective_timeout, + extra_body=effective_extra_body, + base_url=retry_base or resolved_base_url, + ) + if _is_anthropic_compat_endpoint(resolved_provider, retry_base): + retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"]) + return _validate_llm_response( + await retry_client.chat.completions.create(**retry_kwargs), task, + ) + + def _refresh_provider_credentials(provider: str) -> bool: """Refresh short-lived credentials for OAuth-backed auxiliary providers.""" normalized = _normalize_aux_provider(provider) @@ -3033,7 +3261,8 @@ def _client_cache_key( ) -> tuple: runtime = _normalize_main_runtime(main_runtime) runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else () - return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision) + pool_hint = _pool_cache_hint(provider, main_runtime=main_runtime) + return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision, pool_hint) def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None: @@ -3821,39 +4050,56 @@ def call_llm( "Auxiliary %s: refreshed %s credentials after auth error, retrying", task or "call", resolved_provider, ) - retry_client, retry_model = ( - resolve_vision_provider_client( - provider=resolved_provider, - model=final_model, - async_mode=False, - )[1:] - if task == "vision" - else _get_cached_client( - resolved_provider, - resolved_model, - base_url=resolved_base_url, - api_key=resolved_api_key, - api_mode=resolved_api_mode, - main_runtime=main_runtime, - ) + return _retry_same_provider_sync( + task=task, + resolved_provider=resolved_provider, + resolved_model=resolved_model, + resolved_base_url=resolved_base_url, + resolved_api_key=resolved_api_key, + resolved_api_mode=resolved_api_mode, + main_runtime=main_runtime, + final_model=final_model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + effective_timeout=effective_timeout, + effective_extra_body=effective_extra_body, ) - if retry_client is not None: - retry_kwargs = _build_call_kwargs( - resolved_provider, - retry_model or final_model, - messages, - temperature=temperature, - max_tokens=max_tokens, - tools=tools, - timeout=effective_timeout, - extra_body=effective_extra_body, - base_url=resolved_base_url, - ) - _retry_base = str(getattr(retry_client, "base_url", "") or "") - if _is_anthropic_compat_endpoint(resolved_provider, _retry_base): - retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"]) + + # ── Same-provider credential-pool recovery ───────────────────── + pool_provider = _recoverable_pool_provider(resolved_provider, client) + if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)): + recovery_err = first_err + if _is_rate_limit_error(first_err): + try: return _validate_llm_response( - retry_client.chat.completions.create(**retry_kwargs), task) + client.chat.completions.create(**kwargs), task) + except Exception as retry_err: + if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)): + raise + recovery_err = retry_err + if _recover_provider_pool(pool_provider, recovery_err): + logger.info( + "Auxiliary %s: recovered %s via credential-pool rotation after %s", + task or "call", pool_provider, type(recovery_err).__name__, + ) + return _retry_same_provider_sync( + task=task, + resolved_provider=resolved_provider, + resolved_model=resolved_model, + resolved_base_url=resolved_base_url, + resolved_api_key=resolved_api_key, + resolved_api_mode=resolved_api_mode, + main_runtime=main_runtime, + final_model=final_model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + effective_timeout=effective_timeout, + effective_extra_body=effective_extra_body, + ) # ── Payment / credit exhaustion fallback ────────────────────── # When the resolved provider returns 402 or a credit-related error, @@ -4136,38 +4382,54 @@ async def async_call_llm( "Auxiliary %s (async): refreshed %s credentials after auth error, retrying", task or "call", resolved_provider, ) - if task == "vision": - _, retry_client, retry_model = resolve_vision_provider_client( - provider=resolved_provider, - model=final_model, - async_mode=True, - ) - else: - retry_client, retry_model = _get_cached_client( - resolved_provider, - resolved_model, - async_mode=True, - base_url=resolved_base_url, - api_key=resolved_api_key, - api_mode=resolved_api_mode, - ) - if retry_client is not None: - retry_kwargs = _build_call_kwargs( - resolved_provider, - retry_model or final_model, - messages, - temperature=temperature, - max_tokens=max_tokens, - tools=tools, - timeout=effective_timeout, - extra_body=effective_extra_body, - base_url=resolved_base_url, - ) - _retry_base = str(getattr(retry_client, "base_url", "") or "") - if _is_anthropic_compat_endpoint(resolved_provider, _retry_base): - retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"]) + return await _retry_same_provider_async( + task=task, + resolved_provider=resolved_provider, + resolved_model=resolved_model, + resolved_base_url=resolved_base_url, + resolved_api_key=resolved_api_key, + resolved_api_mode=resolved_api_mode, + final_model=final_model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + effective_timeout=effective_timeout, + effective_extra_body=effective_extra_body, + ) + + # ── Same-provider credential-pool recovery (mirrors sync) ───── + pool_provider = _recoverable_pool_provider(resolved_provider, client) + if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)): + recovery_err = first_err + if _is_rate_limit_error(first_err): + try: return _validate_llm_response( - await retry_client.chat.completions.create(**retry_kwargs), task) + await client.chat.completions.create(**kwargs), task) + except Exception as retry_err: + if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)): + raise + recovery_err = retry_err + if _recover_provider_pool(pool_provider, recovery_err): + logger.info( + "Auxiliary %s (async): recovered %s via credential-pool rotation after %s", + task or "call", pool_provider, type(recovery_err).__name__, + ) + return await _retry_same_provider_async( + task=task, + resolved_provider=resolved_provider, + resolved_model=resolved_model, + resolved_base_url=resolved_base_url, + resolved_api_key=resolved_api_key, + resolved_api_mode=resolved_api_mode, + final_model=final_model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + effective_timeout=effective_timeout, + effective_extra_body=effective_extra_body, + ) # ── Payment / connection / rate-limit fallback (mirrors sync call_llm) ── should_fallback = ( diff --git a/tests/agent/test_auxiliary_client.py b/tests/agent/test_auxiliary_client.py index 6437c872ce8..5f49f74a2be 100644 --- a/tests/agent/test_auxiliary_client.py +++ b/tests/agent/test_auxiliary_client.py @@ -301,6 +301,52 @@ class TestBuildCodexClient: assert client is None assert model is None + def test_cached_codex_client_rebuilds_when_pool_entry_changes(self): + import agent.auxiliary_client as aux + + class _Entry: + def __init__(self, entry_id, token): + self.id = entry_id + self.runtime_api_key = token + self.runtime_base_url = "https://chatgpt.com/backend-api/codex" + + class _Pool: + def __init__(self): + self.entry = _Entry("cred-a", "tok-a") + + def has_credentials(self): + return True + + def current(self): + return self.entry + + def peek(self): + return self.entry + + def select(self): + return self.entry + + pool = _Pool() + client_a = MagicMock(name="codex-client-a") + client_b = MagicMock(name="codex-client-b") + + with ( + patch("agent.auxiliary_client.load_pool", return_value=pool), + patch("agent.auxiliary_client.OpenAI", side_effect=[client_a, client_b]) as mock_openai, + ): + aux.shutdown_cached_clients() + try: + first_client, first_model = aux._get_cached_client("openai-codex", "gpt-5.4") + pool.entry = _Entry("cred-b", "tok-b") + second_client, second_model = aux._get_cached_client("openai-codex", "gpt-5.4") + finally: + aux.shutdown_cached_clients() + + assert first_client is not second_client + assert first_model == "gpt-5.4" + assert second_model == "gpt-5.4" + assert mock_openai.call_count == 2 + class TestExpiredCodexFallback: """Test that expired Codex tokens don't block the auto chain.""" @@ -1632,6 +1678,107 @@ class TestAuxiliaryAuthRefreshRetry: assert fresh_client.chat.completions.create.await_count == 1 +class TestAuxiliaryPoolRotationRetry: + def test_call_llm_rotates_explicit_codex_pool_on_429(self): + rate_err = Exception("usage limit reached") + rate_err.status_code = 429 + + stale_client = MagicMock() + stale_client.base_url = "https://chatgpt.com/backend-api/codex" + stale_client.chat.completions.create.side_effect = [rate_err, rate_err] + + fresh_client = MagicMock() + fresh_client.base_url = "https://chatgpt.com/backend-api/codex" + fresh_client.chat.completions.create.return_value = _DummyResponse("rotated-sync") + + class _Pool: + def __init__(self): + self.rotate_calls = [] + + def has_credentials(self): + return True + + def try_refresh_current(self): + return None + + def mark_exhausted_and_rotate(self, **kwargs): + self.rotate_calls.append(kwargs) + return SimpleNamespace(id="cred-b") + + pool = _Pool() + + with ( + patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)), + patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]), + patch("agent.auxiliary_client._refresh_provider_credentials", return_value=False), + patch("agent.auxiliary_client.load_pool", return_value=pool), + patch("agent.auxiliary_client._try_payment_fallback") as mock_fallback, + ): + resp = call_llm( + task="compression", + provider="openai-codex", + model="gpt-5.4", + messages=[{"role": "user", "content": "hi"}], + ) + + assert resp.choices[0].message.content == "rotated-sync" + assert stale_client.chat.completions.create.call_count == 2 + assert fresh_client.chat.completions.create.call_count == 1 + assert len(pool.rotate_calls) == 1 + assert pool.rotate_calls[0]["status_code"] == 429 + mock_fallback.assert_not_called() + + @pytest.mark.asyncio + async def test_async_call_llm_rotates_explicit_codex_pool_on_429(self): + rate_err = Exception("usage limit reached") + rate_err.status_code = 429 + + stale_client = MagicMock() + stale_client.base_url = "https://chatgpt.com/backend-api/codex" + stale_client.chat.completions.create = AsyncMock(side_effect=[rate_err, rate_err]) + + fresh_client = MagicMock() + fresh_client.base_url = "https://chatgpt.com/backend-api/codex" + fresh_client.chat.completions.create = AsyncMock(return_value=_DummyResponse("rotated-async")) + + class _Pool: + def __init__(self): + self.rotate_calls = [] + + def has_credentials(self): + return True + + def try_refresh_current(self): + return None + + def mark_exhausted_and_rotate(self, **kwargs): + self.rotate_calls.append(kwargs) + return SimpleNamespace(id="cred-b") + + pool = _Pool() + + with ( + patch("agent.auxiliary_client._resolve_task_provider_model", return_value=("openai-codex", "gpt-5.4", None, None, None)), + patch("agent.auxiliary_client._get_cached_client", side_effect=[(stale_client, "gpt-5.4"), (fresh_client, "gpt-5.4")]), + patch("agent.auxiliary_client._refresh_provider_credentials", return_value=False), + patch("agent.auxiliary_client.load_pool", return_value=pool), + patch("agent.auxiliary_client._try_payment_fallback") as mock_fallback, + ): + resp = await async_call_llm( + task="compression", + provider="openai-codex", + model="gpt-5.4", + messages=[{"role": "user", "content": "hi"}], + ) + + assert resp.choices[0].message.content == "rotated-async" + assert stale_client.chat.completions.create.await_count == 2 + assert fresh_client.chat.completions.create.await_count == 1 + assert len(pool.rotate_calls) == 1 + assert pool.rotate_calls[0]["status_code"] == 429 + mock_fallback.assert_not_called() + + class TestCodexAdapterReasoningTranslation: """Verify _CodexCompletionsAdapter translates extra_body.reasoning into the Responses API's top-level reasoning + include fields, matching