fix(auxiliary): rotate pooled auth after quota failures

This commit is contained in:
Maxim Esipov 2026-05-09 22:44:12 +03:00 committed by Teknium
parent 775c0e22cf
commit 17d8914850
2 changed files with 472 additions and 63 deletions

View file

@ -490,6 +490,29 @@ def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]:
return True, None
def _peek_pool_entry(provider: str) -> Optional[Any]:
"""Best-effort current/next pool entry without mutating selection order."""
try:
pool = load_pool(provider)
except Exception as exc:
logger.debug("Auxiliary client: could not load pool for %s (peek): %s", provider, exc)
return None
if not pool or not pool.has_credentials():
return None
try:
current_fn = getattr(pool, "current", None)
if callable(current_fn):
current = current_fn()
if current is not None:
return current
peek_fn = getattr(pool, "peek", None)
if callable(peek_fn):
return peek_fn()
except Exception as exc:
logger.debug("Auxiliary client: could not peek pool entry for %s: %s", provider, exc)
return None
def _pool_runtime_api_key(entry: Any) -> str:
if entry is None:
return ""
@ -1908,6 +1931,211 @@ def _evict_cached_clients(provider: str) -> None:
_client_cache.pop(key, None)
def _pool_cache_hint(
provider: str,
*,
main_runtime: Optional[Dict[str, Any]] = None,
) -> str:
"""Return a stable cache discriminator for pooled providers."""
normalized = _normalize_aux_provider(provider)
if normalized == "auto":
runtime = _normalize_main_runtime(main_runtime)
normalized = _normalize_aux_provider(runtime.get("provider") or _read_main_provider())
if normalized in ("", "auto", "custom"):
return ""
entry = _peek_pool_entry(normalized)
if entry is None:
return ""
entry_id = str(getattr(entry, "id", "") or "").strip()
if not entry_id:
return ""
return f"{normalized}:{entry_id}"
def _pool_error_context(exc: Exception) -> Dict[str, Any]:
status = getattr(exc, "status_code", None)
payload: Dict[str, Any] = {"message": str(exc)}
if status is not None:
payload["status_code"] = status
return payload
def _recoverable_pool_provider(resolved_provider: str, client: Any) -> Optional[str]:
"""Infer which provider pool can recover the current auxiliary client."""
normalized = _normalize_aux_provider(resolved_provider)
if normalized not in ("", "auto", "custom"):
return normalized
base = str(getattr(client, "base_url", "") or "")
if base_url_host_matches(base, "chatgpt.com"):
return "openai-codex"
if base_url_host_matches(base, "openrouter.ai"):
return "openrouter"
if base_url_host_matches(base, "inference-api.nousresearch.com"):
return "nous"
if base_url_host_matches(base, "api.anthropic.com"):
return "anthropic"
if base_url_host_matches(base, "api.githubcopilot.com"):
return "copilot"
if base_url_host_matches(base, "api.kimi.com"):
return "kimi-coding"
return None
def _recover_provider_pool(provider: str, exc: Exception) -> bool:
"""Try same-provider credential-pool recovery for auxiliary calls."""
normalized = _normalize_aux_provider(provider)
try:
pool = load_pool(normalized)
except Exception as load_exc:
logger.debug("Auxiliary client: could not load pool for %s recovery: %s", normalized, load_exc)
return False
if not pool or not pool.has_credentials():
return False
status_code = getattr(exc, "status_code", None)
error_context = _pool_error_context(exc)
if _is_auth_error(exc):
refreshed = pool.try_refresh_current()
if refreshed is not None:
_evict_cached_clients(normalized)
return True
next_entry = pool.mark_exhausted_and_rotate(
status_code=status_code if status_code is not None else 401,
error_context=error_context,
)
if next_entry is not None:
_evict_cached_clients(normalized)
return True
return False
if _is_payment_error(exc) or _is_rate_limit_error(exc):
fallback_status = 402 if _is_payment_error(exc) else 429
next_entry = pool.mark_exhausted_and_rotate(
status_code=status_code if status_code is not None else fallback_status,
error_context=error_context,
)
if next_entry is not None:
_evict_cached_clients(normalized)
return True
return False
def _retry_same_provider_sync(
*,
task: Optional[str],
resolved_provider: str,
resolved_model: Optional[str],
resolved_base_url: Optional[str],
resolved_api_key: Optional[str],
resolved_api_mode: Optional[str],
main_runtime: Optional[Dict[str, Any]],
final_model: Optional[str],
messages: list,
temperature: Optional[float],
max_tokens: Optional[int],
tools: Optional[list],
effective_timeout: float,
effective_extra_body: dict,
) -> Any:
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
async_mode=False,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
)
if retry_client is None:
raise RuntimeError(
f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery"
)
retry_base = str(getattr(retry_client, "base_url", "") or "")
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=retry_base or resolved_base_url,
)
if _is_anthropic_compat_endpoint(resolved_provider, retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return _validate_llm_response(
retry_client.chat.completions.create(**retry_kwargs), task,
)
async def _retry_same_provider_async(
*,
task: Optional[str],
resolved_provider: str,
resolved_model: Optional[str],
resolved_base_url: Optional[str],
resolved_api_key: Optional[str],
resolved_api_mode: Optional[str],
final_model: Optional[str],
messages: list,
temperature: Optional[float],
max_tokens: Optional[int],
tools: Optional[list],
effective_timeout: float,
effective_extra_body: dict,
) -> Any:
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
async_mode=True,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
async_mode=True,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
)
if retry_client is None:
raise RuntimeError(
f"Auxiliary {task or 'call'}: provider {resolved_provider} could not be rebuilt after recovery"
)
retry_base = str(getattr(retry_client, "base_url", "") or "")
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=retry_base or resolved_base_url,
)
if _is_anthropic_compat_endpoint(resolved_provider, retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return _validate_llm_response(
await retry_client.chat.completions.create(**retry_kwargs), task,
)
def _refresh_provider_credentials(provider: str) -> bool:
"""Refresh short-lived credentials for OAuth-backed auxiliary providers."""
normalized = _normalize_aux_provider(provider)
@ -3033,7 +3261,8 @@ def _client_cache_key(
) -> tuple:
runtime = _normalize_main_runtime(main_runtime)
runtime_key = tuple(runtime.get(field, "") for field in _MAIN_RUNTIME_FIELDS) if provider == "auto" else ()
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision)
pool_hint = _pool_cache_hint(provider, main_runtime=main_runtime)
return (provider, async_mode, base_url or "", api_key or "", api_mode or "", runtime_key, is_vision, pool_hint)
def _store_cached_client(cache_key: tuple, client: Any, default_model: Optional[str], *, bound_loop: Any = None) -> None:
@ -3821,39 +4050,56 @@ def call_llm(
"Auxiliary %s: refreshed %s credentials after auth error, retrying",
task or "call", resolved_provider,
)
retry_client, retry_model = (
resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
async_mode=False,
)[1:]
if task == "vision"
else _get_cached_client(
resolved_provider,
resolved_model,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
main_runtime=main_runtime,
)
return _retry_same_provider_sync(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
main_runtime=main_runtime,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
if retry_client is not None:
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=resolved_base_url,
)
_retry_base = str(getattr(retry_client, "base_url", "") or "")
if _is_anthropic_compat_endpoint(resolved_provider, _retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
# ── Same-provider credential-pool recovery ─────────────────────
pool_provider = _recoverable_pool_provider(resolved_provider, client)
if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)):
recovery_err = first_err
if _is_rate_limit_error(first_err):
try:
return _validate_llm_response(
retry_client.chat.completions.create(**retry_kwargs), task)
client.chat.completions.create(**kwargs), task)
except Exception as retry_err:
if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)):
raise
recovery_err = retry_err
if _recover_provider_pool(pool_provider, recovery_err):
logger.info(
"Auxiliary %s: recovered %s via credential-pool rotation after %s",
task or "call", pool_provider, type(recovery_err).__name__,
)
return _retry_same_provider_sync(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
main_runtime=main_runtime,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Payment / credit exhaustion fallback ──────────────────────
# When the resolved provider returns 402 or a credit-related error,
@ -4136,38 +4382,54 @@ async def async_call_llm(
"Auxiliary %s (async): refreshed %s credentials after auth error, retrying",
task or "call", resolved_provider,
)
if task == "vision":
_, retry_client, retry_model = resolve_vision_provider_client(
provider=resolved_provider,
model=final_model,
async_mode=True,
)
else:
retry_client, retry_model = _get_cached_client(
resolved_provider,
resolved_model,
async_mode=True,
base_url=resolved_base_url,
api_key=resolved_api_key,
api_mode=resolved_api_mode,
)
if retry_client is not None:
retry_kwargs = _build_call_kwargs(
resolved_provider,
retry_model or final_model,
messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
timeout=effective_timeout,
extra_body=effective_extra_body,
base_url=resolved_base_url,
)
_retry_base = str(getattr(retry_client, "base_url", "") or "")
if _is_anthropic_compat_endpoint(resolved_provider, _retry_base):
retry_kwargs["messages"] = _convert_openai_images_to_anthropic(retry_kwargs["messages"])
return await _retry_same_provider_async(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Same-provider credential-pool recovery (mirrors sync) ─────
pool_provider = _recoverable_pool_provider(resolved_provider, client)
if pool_provider and (_is_auth_error(first_err) or _is_payment_error(first_err) or _is_rate_limit_error(first_err)):
recovery_err = first_err
if _is_rate_limit_error(first_err):
try:
return _validate_llm_response(
await retry_client.chat.completions.create(**retry_kwargs), task)
await client.chat.completions.create(**kwargs), task)
except Exception as retry_err:
if not (_is_auth_error(retry_err) or _is_payment_error(retry_err) or _is_rate_limit_error(retry_err)):
raise
recovery_err = retry_err
if _recover_provider_pool(pool_provider, recovery_err):
logger.info(
"Auxiliary %s (async): recovered %s via credential-pool rotation after %s",
task or "call", pool_provider, type(recovery_err).__name__,
)
return await _retry_same_provider_async(
task=task,
resolved_provider=resolved_provider,
resolved_model=resolved_model,
resolved_base_url=resolved_base_url,
resolved_api_key=resolved_api_key,
resolved_api_mode=resolved_api_mode,
final_model=final_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
tools=tools,
effective_timeout=effective_timeout,
effective_extra_body=effective_extra_body,
)
# ── Payment / connection / rate-limit fallback (mirrors sync call_llm) ──
should_fallback = (