diff --git a/agent/error_classifier.py b/agent/error_classifier.py new file mode 100644 index 0000000000..b227932ad7 --- /dev/null +++ b/agent/error_classifier.py @@ -0,0 +1,789 @@ +"""API error classification for smart failover and recovery. + +Provides a structured taxonomy of API errors and a priority-ordered +classification pipeline that determines the correct recovery action +(retry, rotate credential, fallback to another provider, compress +context, or abort). + +Replaces scattered inline string-matching with a centralized classifier +that the main retry loop in run_agent.py consults for every API failure. +""" + +from __future__ import annotations + +import enum +import logging +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional + +logger = logging.getLogger(__name__) + + +# ── Error taxonomy ────────────────────────────────────────────────────── + +class FailoverReason(enum.Enum): + """Why an API call failed — determines recovery strategy.""" + + # Authentication / authorization + auth = "auth" # Transient auth (401/403) — refresh/rotate + auth_permanent = "auth_permanent" # Auth failed after refresh — abort + + # Billing / quota + billing = "billing" # 402 or confirmed credit exhaustion — rotate immediately + rate_limit = "rate_limit" # 429 or quota-based throttling — backoff then rotate + + # Server-side + overloaded = "overloaded" # 503/529 — provider overloaded, backoff + server_error = "server_error" # 500/502 — internal server error, retry + + # Transport + timeout = "timeout" # Connection/read timeout — rebuild client + retry + + # Context / payload + context_overflow = "context_overflow" # Context too large — compress, not failover + payload_too_large = "payload_too_large" # 413 — compress payload + + # Model + model_not_found = "model_not_found" # 404 or invalid model — fallback to different model + + # Request format + format_error = "format_error" # 400 bad request — abort or strip + retry + + # Provider-specific + thinking_signature = "thinking_signature" # Anthropic thinking block sig invalid + long_context_tier = "long_context_tier" # Anthropic "extra usage" tier gate + + # Catch-all + unknown = "unknown" # Unclassifiable — retry with backoff + + +# ── Classification result ─────────────────────────────────────────────── + +@dataclass +class ClassifiedError: + """Structured classification of an API error with recovery hints.""" + + reason: FailoverReason + status_code: Optional[int] = None + provider: Optional[str] = None + model: Optional[str] = None + message: str = "" + error_context: Dict[str, Any] = field(default_factory=dict) + + # Recovery action hints — the retry loop checks these instead of + # re-classifying the error itself. + retryable: bool = True + should_compress: bool = False + should_rotate_credential: bool = False + should_fallback: bool = False + + @property + def is_auth(self) -> bool: + return self.reason in (FailoverReason.auth, FailoverReason.auth_permanent) + + @property + def is_transient(self) -> bool: + """Error is expected to resolve on retry (with or without backoff).""" + return self.reason in ( + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ) + + +# ── Provider-specific patterns ────────────────────────────────────────── + +# Patterns that indicate billing exhaustion (not transient rate limit) +_BILLING_PATTERNS = [ + "insufficient credits", + "insufficient_quota", + "credit balance", + "credits have been exhausted", + "top up your credits", + "payment required", + "billing hard limit", + "exceeded your current quota", + "account is deactivated", + "plan does not include", +] + +# Patterns that indicate rate limiting (transient, will resolve) +_RATE_LIMIT_PATTERNS = [ + "rate limit", + "rate_limit", + "too many requests", + "throttled", + "requests per minute", + "tokens per minute", + "requests per day", + "try again in", + "please retry after", + "resource_exhausted", +] + +# Usage-limit patterns that need disambiguation (could be billing OR rate_limit) +_USAGE_LIMIT_PATTERNS = [ + "usage limit", + "quota", + "limit exceeded", + "key limit exceeded", +] + +# Patterns confirming usage limit is transient (not billing) +_USAGE_LIMIT_TRANSIENT_SIGNALS = [ + "try again", + "retry", + "resets at", + "reset in", + "wait", + "requests remaining", + "periodic", + "window", +] + +# Payload-too-large patterns detected from message text (no status_code attr). +# Proxies and some backends embed the HTTP status in the error message. +_PAYLOAD_TOO_LARGE_PATTERNS = [ + "request entity too large", + "payload too large", + "error code: 413", +] + +# Context overflow patterns +_CONTEXT_OVERFLOW_PATTERNS = [ + "context length", + "context size", + "maximum context", + "token limit", + "too many tokens", + "reduce the length", + "exceeds the limit", + "context window", + "prompt is too long", + "prompt exceeds max length", + "max_tokens", + "maximum number of tokens", + # Chinese error messages (some providers return these) + "超过最大长度", + "上下文长度", +] + +# Model not found patterns +_MODEL_NOT_FOUND_PATTERNS = [ + "is not a valid model", + "invalid model", + "model not found", + "model_not_found", + "does not exist", + "no such model", + "unknown model", + "unsupported model", +] + +# Auth patterns (non-status-code signals) +_AUTH_PATTERNS = [ + "invalid api key", + "invalid_api_key", + "authentication", + "unauthorized", + "forbidden", + "invalid token", + "token expired", + "token revoked", + "access denied", +] + +# Anthropic thinking block signature patterns +_THINKING_SIG_PATTERNS = [ + "signature", # Combined with "thinking" check +] + +# Transport error type names +_TRANSPORT_ERROR_TYPES = frozenset({ + "ReadTimeout", "ConnectTimeout", "PoolTimeout", + "ConnectError", "RemoteProtocolError", + "ConnectionError", "ConnectionResetError", + "ConnectionAbortedError", "BrokenPipeError", + "TimeoutError", "ReadError", + "ServerDisconnectedError", + # OpenAI SDK errors (not subclasses of Python builtins) + "APIConnectionError", + "APITimeoutError", +}) + +# Server disconnect patterns (no status code, but transport-level) +_SERVER_DISCONNECT_PATTERNS = [ + "server disconnected", + "peer closed connection", + "connection reset by peer", + "connection was closed", + "network connection lost", + "unexpected eof", + "incomplete chunked read", +] + + +# ── Classification pipeline ───────────────────────────────────────────── + +def classify_api_error( + error: Exception, + *, + provider: str = "", + model: str = "", + approx_tokens: int = 0, + context_length: int = 200000, + num_messages: int = 0, +) -> ClassifiedError: + """Classify an API error into a structured recovery recommendation. + + Priority-ordered pipeline: + 1. Special-case provider-specific patterns (thinking sigs, tier gates) + 2. HTTP status code + message-aware refinement + 3. Error code classification (from body) + 4. Message pattern matching (billing vs rate_limit vs context vs auth) + 5. Transport error heuristics + 6. Server disconnect + large session → context overflow + 7. Fallback: unknown (retryable with backoff) + + Args: + error: The exception from the API call. + provider: Current provider name (e.g. "openrouter", "anthropic"). + model: Current model slug. + approx_tokens: Approximate token count of the current context. + context_length: Maximum context length for the current model. + + Returns: + ClassifiedError with reason and recovery action hints. + """ + status_code = _extract_status_code(error) + error_type = type(error).__name__ + body = _extract_error_body(error) + error_code = _extract_error_code(body) + + # Build a comprehensive error message string for pattern matching. + # str(error) alone may not include the body message (e.g. OpenAI SDK's + # APIStatusError.__str__ returns the first arg, not the body). Append + # the body message so patterns like "try again" in 402 disambiguation + # are detected even when only present in the structured body. + # + # Also extract metadata.raw — OpenRouter wraps upstream provider errors + # inside {"error": {"message": "Provider returned error", "metadata": + # {"raw": ""}}} and the real error message (e.g. + # "context length exceeded") is only in the inner JSON. + _raw_msg = str(error).lower() + _body_msg = "" + _metadata_msg = "" + if isinstance(body, dict): + _err_obj = body.get("error", {}) + if isinstance(_err_obj, dict): + _body_msg = (_err_obj.get("message") or "").lower() + # Parse metadata.raw for wrapped provider errors + _metadata = _err_obj.get("metadata", {}) + if isinstance(_metadata, dict): + _raw_json = _metadata.get("raw") or "" + if isinstance(_raw_json, str) and _raw_json.strip(): + try: + import json + _inner = json.loads(_raw_json) + if isinstance(_inner, dict): + _inner_err = _inner.get("error", {}) + if isinstance(_inner_err, dict): + _metadata_msg = (_inner_err.get("message") or "").lower() + except (json.JSONDecodeError, TypeError): + pass + if not _body_msg: + _body_msg = (body.get("message") or "").lower() + # Combine all message sources for pattern matching + parts = [_raw_msg] + if _body_msg and _body_msg not in _raw_msg: + parts.append(_body_msg) + if _metadata_msg and _metadata_msg not in _raw_msg and _metadata_msg not in _body_msg: + parts.append(_metadata_msg) + error_msg = " ".join(parts) + provider_lower = (provider or "").strip().lower() + model_lower = (model or "").strip().lower() + + def _result(reason: FailoverReason, **overrides) -> ClassifiedError: + defaults = { + "reason": reason, + "status_code": status_code, + "provider": provider, + "model": model, + "message": _extract_message(error, body), + } + defaults.update(overrides) + return ClassifiedError(**defaults) + + # ── 1. Provider-specific patterns (highest priority) ──────────── + + # Anthropic thinking block signature invalid (400). + # Don't gate on provider — OpenRouter proxies Anthropic errors, so the + # provider may be "openrouter" even though the error is Anthropic-specific. + # The message pattern ("signature" + "thinking") is unique enough. + if ( + status_code == 400 + and "signature" in error_msg + and "thinking" in error_msg + ): + return _result( + FailoverReason.thinking_signature, + retryable=True, + should_compress=False, + ) + + # Anthropic long-context tier gate (429 "extra usage" + "long context") + if ( + status_code == 429 + and "extra usage" in error_msg + and "long context" in error_msg + ): + return _result( + FailoverReason.long_context_tier, + retryable=True, + should_compress=True, + ) + + # ── 2. HTTP status code classification ────────────────────────── + + if status_code is not None: + classified = _classify_by_status( + status_code, error_msg, error_code, body, + provider=provider_lower, model=model_lower, + approx_tokens=approx_tokens, context_length=context_length, + num_messages=num_messages, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 3. Error code classification ──────────────────────────────── + + if error_code: + classified = _classify_by_error_code(error_code, error_msg, _result) + if classified is not None: + return classified + + # ── 4. Message pattern matching (no status code) ──────────────── + + classified = _classify_by_message( + error_msg, error_type, + approx_tokens=approx_tokens, + context_length=context_length, + result_fn=_result, + ) + if classified is not None: + return classified + + # ── 5. Server disconnect + large session → context overflow ───── + # Must come BEFORE generic transport error catch — a disconnect on + # a large session is more likely context overflow than a transient + # transport hiccup. Without this ordering, RemoteProtocolError + # always maps to timeout regardless of session size. + + is_disconnect = any(p in error_msg for p in _SERVER_DISCONNECT_PATTERNS) + if is_disconnect and not status_code: + is_large = approx_tokens > context_length * 0.6 or approx_tokens > 120000 or num_messages > 200 + if is_large: + return _result( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + return _result(FailoverReason.timeout, retryable=True) + + # ── 6. Transport / timeout heuristics ─────────────────────────── + + if error_type in _TRANSPORT_ERROR_TYPES or isinstance(error, (TimeoutError, ConnectionError, OSError)): + return _result(FailoverReason.timeout, retryable=True) + + # ── 7. Fallback: unknown ──────────────────────────────────────── + + return _result(FailoverReason.unknown, retryable=True) + + +# ── Status code classification ────────────────────────────────────────── + +def _classify_by_status( + status_code: int, + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on HTTP status code with message-aware refinement.""" + + if status_code == 401: + # Not retryable on its own — credential pool rotation and + # provider-specific refresh (Codex, Anthropic, Nous) run before + # the retryability check in run_agent.py. If those succeed, the + # loop `continue`s. If they fail, retryable=False ensures we + # hit the client-error abort path (which tries fallback first). + return result_fn( + FailoverReason.auth, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 403: + # OpenRouter 403 "key limit exceeded" is actually billing + if "key limit exceeded" in error_msg or "spending limit" in error_msg: + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + return result_fn( + FailoverReason.auth, + retryable=False, + should_fallback=True, + ) + + if status_code == 402: + return _classify_402(error_msg, result_fn) + + if status_code == 404: + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + # Generic 404 — could be model or endpoint + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if status_code == 413: + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + if status_code == 429: + # Already checked long_context_tier above; this is a normal rate limit + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + if status_code == 400: + return _classify_400( + error_msg, error_code, body, + provider=provider, model=model, + approx_tokens=approx_tokens, + context_length=context_length, + num_messages=num_messages, + result_fn=result_fn, + ) + + if status_code in (500, 502): + return result_fn(FailoverReason.server_error, retryable=True) + + if status_code in (503, 529): + return result_fn(FailoverReason.overloaded, retryable=True) + + # Other 4xx — non-retryable + if 400 <= status_code < 500: + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + # Other 5xx — retryable + if 500 <= status_code < 600: + return result_fn(FailoverReason.server_error, retryable=True) + + return None + + +def _classify_402(error_msg: str, result_fn) -> ClassifiedError: + """Disambiguate 402: billing exhaustion vs transient usage limit. + + The key insight from OpenClaw: some 402s are transient rate limits + disguised as payment errors. "Usage limit, try again in 5 minutes" + is NOT a billing problem — it's a periodic quota that resets. + """ + # Check for transient usage-limit signals first + has_usage_limit = any(p in error_msg for p in _USAGE_LIMIT_PATTERNS) + has_transient_signal = any(p in error_msg for p in _USAGE_LIMIT_TRANSIENT_SIGNALS) + + if has_usage_limit and has_transient_signal: + # Transient quota — treat as rate limit, not billing + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Confirmed billing exhaustion + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + +def _classify_400( + error_msg: str, + error_code: str, + body: dict, + *, + provider: str, + model: str, + approx_tokens: int, + context_length: int, + num_messages: int = 0, + result_fn, +) -> ClassifiedError: + """Classify 400 Bad Request — context overflow, format error, or generic.""" + + # Context overflow from 400 + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Some providers return model-not-found as 400 instead of 404 (e.g. OpenRouter). + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + # Some providers return rate limit / billing errors as 400 instead of 429/402. + # Check these patterns before falling through to format_error. + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Generic 400 + large session → probable context overflow + # Anthropic sometimes returns a bare "Error" message when context is too large + err_body_msg = "" + if isinstance(body, dict): + err_obj = body.get("error", {}) + if isinstance(err_obj, dict): + err_body_msg = (err_obj.get("message") or "").strip().lower() + is_generic = len(err_body_msg) < 30 or err_body_msg in ("error", "") + is_large = approx_tokens > context_length * 0.4 or approx_tokens > 80000 or num_messages > 80 + + if is_generic and is_large: + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Non-retryable format error + return result_fn( + FailoverReason.format_error, + retryable=False, + should_fallback=True, + ) + + +# ── Error code classification ─────────────────────────────────────────── + +def _classify_by_error_code( + error_code: str, error_msg: str, result_fn, +) -> Optional[ClassifiedError]: + """Classify by structured error codes from the response body.""" + code_lower = error_code.lower() + + if code_lower in ("resource_exhausted", "throttled", "rate_limit_exceeded"): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + ) + + if code_lower in ("insufficient_quota", "billing_not_active", "payment_required"): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + if code_lower in ("model_not_found", "model_not_available", "invalid_model"): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + if code_lower in ("context_length_exceeded", "max_tokens_exceeded"): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + return None + + +# ── Message pattern classification ────────────────────────────────────── + +def _classify_by_message( + error_msg: str, + error_type: str, + *, + approx_tokens: int, + context_length: int, + result_fn, +) -> Optional[ClassifiedError]: + """Classify based on error message patterns when no status code is available.""" + + # Payload-too-large patterns (from message text when no status_code) + if any(p in error_msg for p in _PAYLOAD_TOO_LARGE_PATTERNS): + return result_fn( + FailoverReason.payload_too_large, + retryable=True, + should_compress=True, + ) + + # Billing patterns + if any(p in error_msg for p in _BILLING_PATTERNS): + return result_fn( + FailoverReason.billing, + retryable=False, + should_rotate_credential=True, + should_fallback=True, + ) + + # Rate limit patterns + if any(p in error_msg for p in _RATE_LIMIT_PATTERNS): + return result_fn( + FailoverReason.rate_limit, + retryable=True, + should_rotate_credential=True, + should_fallback=True, + ) + + # Context overflow patterns + if any(p in error_msg for p in _CONTEXT_OVERFLOW_PATTERNS): + return result_fn( + FailoverReason.context_overflow, + retryable=True, + should_compress=True, + ) + + # Auth patterns + if any(p in error_msg for p in _AUTH_PATTERNS): + return result_fn( + FailoverReason.auth, + retryable=True, + should_rotate_credential=True, + ) + + # Model not found patterns + if any(p in error_msg for p in _MODEL_NOT_FOUND_PATTERNS): + return result_fn( + FailoverReason.model_not_found, + retryable=False, + should_fallback=True, + ) + + return None + + +# ── Helpers ───────────────────────────────────────────────────────────── + +def _extract_status_code(error: Exception) -> Optional[int]: + """Walk the error and its cause chain to find an HTTP status code.""" + current = error + for _ in range(5): # Max depth to prevent infinite loops + code = getattr(current, "status_code", None) + if isinstance(code, int): + return code + # Some SDKs use .status instead of .status_code + code = getattr(current, "status", None) + if isinstance(code, int) and 100 <= code < 600: + return code + # Walk cause chain + cause = getattr(current, "__cause__", None) or getattr(current, "__context__", None) + if cause is None or cause is current: + break + current = cause + return None + + +def _extract_error_body(error: Exception) -> dict: + """Extract the structured error body from an SDK exception.""" + body = getattr(error, "body", None) + if isinstance(body, dict): + return body + # Some errors have .response.json() + response = getattr(error, "response", None) + if response is not None: + try: + json_body = response.json() + if isinstance(json_body, dict): + return json_body + except Exception: + pass + return {} + + +def _extract_error_code(body: dict) -> str: + """Extract an error code string from the response body.""" + if not body: + return "" + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + code = error_obj.get("code") or error_obj.get("type") or "" + if isinstance(code, str) and code.strip(): + return code.strip() + # Top-level code + code = body.get("code") or body.get("error_code") or "" + if isinstance(code, (str, int)): + return str(code).strip() + return "" + + +def _extract_message(error: Exception, body: dict) -> str: + """Extract the most informative error message.""" + # Try structured body first + if body: + error_obj = body.get("error", {}) + if isinstance(error_obj, dict): + msg = error_obj.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + msg = body.get("message", "") + if isinstance(msg, str) and msg.strip(): + return msg.strip()[:500] + # Fallback to str(error) + return str(error)[:500] diff --git a/run_agent.py b/run_agent.py index ecd0be656f..8f60b8f012 100644 --- a/run_agent.py +++ b/run_agent.py @@ -77,6 +77,7 @@ from hermes_constants import OPENROUTER_BASE_URL # Agent internals extracted to agent/ package for modularity from agent.memory_manager import build_memory_context_block from agent.retry_utils import jittered_backoff +from agent.error_classifier import classify_api_error, FailoverReason from agent.prompt_builder import ( DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, @@ -8017,6 +8018,25 @@ class AIAgent: status_code = getattr(api_error, "status_code", None) error_context = self._extract_api_error_context(api_error) + + # ── Classify the error for structured recovery decisions ── + _compressor = getattr(self, "context_compressor", None) + _ctx_len = getattr(_compressor, "context_length", 200000) if _compressor else 200000 + classified = classify_api_error( + api_error, + provider=getattr(self, "provider", "") or "", + model=getattr(self, "model", "") or "", + approx_tokens=approx_tokens, + context_length=_ctx_len, + num_messages=len(api_messages) if api_messages else 0, + ) + logger.debug( + "Error classified: reason=%s status=%s retryable=%s compress=%s rotate=%s fallback=%s", + classified.reason.value, classified.status_code, + classified.retryable, classified.should_compress, + classified.should_rotate_credential, classified.should_fallback, + ) + recovered_with_pool, has_retried_429 = self._recover_with_credential_pool( status_code=status_code, has_retried_429=has_retried_429, @@ -8079,27 +8099,24 @@ class AIAgent: # from all messages so the next retry sends no thinking # blocks at all. One-shot — don't retry infinitely. if ( - self.api_mode == "anthropic_messages" - and status_code == 400 + classified.reason == FailoverReason.thinking_signature and not thinking_sig_retry_attempted ): - _err_msg_lower = str(api_error).lower() - if "signature" in _err_msg_lower and "thinking" in _err_msg_lower: - thinking_sig_retry_attempted = True - for _m in messages: - if isinstance(_m, dict): - _m.pop("reasoning_details", None) - self._vprint( - f"{self.log_prefix}⚠️ Thinking block signature invalid — " - f"stripped all thinking blocks, retrying...", - force=True, - ) - logging.warning( - "%sThinking block signature recovery: stripped " - "reasoning_details from %d messages", - self.log_prefix, len(messages), - ) - continue + thinking_sig_retry_attempted = True + for _m in messages: + if isinstance(_m, dict): + _m.pop("reasoning_details", None) + self._vprint( + f"{self.log_prefix}⚠️ Thinking block signature invalid — " + f"stripped all thinking blocks, retrying...", + force=True, + ) + logging.warning( + "%sThinking block signature recovery: stripped " + "reasoning_details from %d messages", + self.log_prefix, len(messages), + ) + continue retry_count += 1 elapsed_time = time.time() - api_start_time @@ -8156,14 +8173,7 @@ class AIAgent: # is NOT a transient rate limit — retrying or switching # credentials won't help. Reduce context to 200k (the # standard tier) and compress. - # Only applies to Sonnet — Opus 1M is general access. - _is_long_context_tier_error = ( - status_code == 429 - and "extra usage" in error_msg - and "long context" in error_msg - and "sonnet" in self.model.lower() - ) - if _is_long_context_tier_error: + if classified.reason == FailoverReason.long_context_tier: _reduced_ctx = 200000 compressor = self.context_compressor old_ctx = compressor.context_length @@ -8208,13 +8218,9 @@ class AIAgent: # When a fallback model is configured, switch immediately instead # of burning through retries with exponential backoff -- the # primary provider won't recover within the retry window. - is_rate_limited = ( - status_code == 429 - or "rate limit" in error_msg - or "too many requests" in error_msg - or "rate_limit" in error_msg - or "usage limit" in error_msg - or "quota" in error_msg + is_rate_limited = classified.reason in ( + FailoverReason.rate_limit, + FailoverReason.billing, ) if is_rate_limited and self._fallback_index < len(self._fallback_chain): # Don't eagerly fallback if credential pool rotation may @@ -8230,10 +8236,7 @@ class AIAgent: continue is_payload_too_large = ( - status_code == 413 - or 'request entity too large' in error_msg - or 'payload too large' in error_msg - or 'error code: 413' in error_msg + classified.reason == FailoverReason.payload_too_large ) if is_payload_too_large: @@ -8277,64 +8280,12 @@ class AIAgent: } # Check for context-length errors BEFORE generic 4xx handler. - # Local backends (LM Studio, Ollama, llama.cpp) often return - # HTTP 400 with messages like "Context size has been exceeded" - # which must trigger compression, not an immediate abort. - is_context_length_error = any(phrase in error_msg for phrase in [ - 'context length', 'context size', 'maximum context', - 'token limit', 'too many tokens', 'reduce the length', - 'exceeds the limit', 'context window', - 'request entity too large', # OpenRouter/Nous 413 safety net - 'prompt is too long', # Anthropic: "prompt is too long: N tokens > M maximum" - 'prompt exceeds max length', # Z.AI / GLM: generic 400 overflow wording - ]) - - # Fallback heuristic: Anthropic sometimes returns a generic - # 400 invalid_request_error with just "Error" as the message - # when the context is too large. If the error message is very - # short/generic AND the session is large, treat it as a - # probable context-length error and attempt compression rather - # than aborting. This prevents an infinite failure loop where - # each failed message gets persisted, making the session even - # larger. (#1630) - if not is_context_length_error and status_code == 400: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - is_large_session = approx_tokens > ctx_len * 0.4 or len(api_messages) > 80 - is_generic_error = len(error_msg.strip()) < 30 # e.g. just "error" - if is_large_session and is_generic_error: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Generic 400 with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as probable context overflow.", - force=True, - ) - - # Server disconnects on large sessions are often caused by - # the request exceeding the provider's context/payload limit - # without a proper HTTP error response. Treat these as - # context-length errors to trigger compression rather than - # burning through retries that will all fail the same way. - # This breaks the death spiral: disconnect → no token data - # → no compression → bigger session → more disconnects. - # (#2153) - if not is_context_length_error and not status_code: - _is_server_disconnect = ( - 'server disconnected' in error_msg - or 'peer closed connection' in error_msg - or error_type in ('ReadError', 'RemoteProtocolError', 'ServerDisconnectedError') - ) - if _is_server_disconnect: - ctx_len = getattr(getattr(self, 'context_compressor', None), 'context_length', 200000) - _is_large = approx_tokens > ctx_len * 0.6 or len(api_messages) > 200 - if _is_large: - is_context_length_error = True - self._vprint( - f"{self.log_prefix}⚠️ Server disconnected with large session " - f"(~{approx_tokens:,} tokens, {len(api_messages)} msgs) — " - f"treating as context-length error, attempting compression.", - force=True, - ) + # The classifier detects context overflow from: explicit error + # messages, generic 400 + large session heuristic (#1630), and + # server disconnect + large session pattern (#2153). + is_context_length_error = ( + classified.reason == FailoverReason.context_overflow + ) if is_context_length_error: compressor = self.context_compressor @@ -8406,35 +8357,30 @@ class AIAgent: "partial": True } - # Check for non-retryable client errors (4xx HTTP status codes). - # These indicate a problem with the request itself (bad model ID, - # invalid API key, forbidden, etc.) and will never succeed on retry. - # Note: 413 and context-length errors are excluded — handled above. - # 429 (rate limit) is transient and MUST be retried with backoff. - # 529 (Anthropic overloaded) is also transient. - # Also catch local validation errors (ValueError, TypeError) — these - # are programming bugs, not transient failures. - # Exclude UnicodeEncodeError — it's a ValueError subclass but is - # handled separately by the surrogate sanitization path above. - _RETRYABLE_STATUS_CODES = {413, 429, 529} + # Check for non-retryable client errors. The classifier + # already accounts for 413, 429, 529 (transient), context + # overflow, and generic-400 heuristics. Local validation + # errors (ValueError, TypeError) are programming bugs. is_local_validation_error = ( isinstance(api_error, (ValueError, TypeError)) and not isinstance(api_error, UnicodeEncodeError) ) - # Detect generic 400s from Anthropic OAuth (transient server-side failures). - # Real invalid_request_error responses include a descriptive message; - # transient ones contain only "Error" or are empty. (ref: issue #1608) - _err_body = getattr(api_error, "body", None) or {} - _err_message = (_err_body.get("error", {}).get("message", "") if isinstance(_err_body, dict) else "") - _is_generic_400 = (status_code == 400 and _err_message.strip().lower() in ("error", "")) - is_client_status_error = isinstance(status_code, int) and 400 <= status_code < 500 and status_code not in _RETRYABLE_STATUS_CODES and not _is_generic_400 - is_client_error = (is_local_validation_error or is_client_status_error or any(phrase in error_msg for phrase in [ - 'error code: 401', 'error code: 403', - 'error code: 404', 'error code: 422', - 'is not a valid model', 'invalid model', 'model not found', - 'invalid api key', 'invalid_api_key', 'authentication', - 'unauthorized', 'forbidden', 'not found', - ])) and not is_context_length_error + is_client_error = ( + is_local_validation_error + or ( + not classified.retryable + and not classified.should_compress + and classified.reason not in ( + FailoverReason.rate_limit, + FailoverReason.billing, + FailoverReason.overloaded, + FailoverReason.context_overflow, + FailoverReason.payload_too_large, + FailoverReason.long_context_tier, + FailoverReason.thinking_signature, + ) + ) + ) and not is_context_length_error if is_client_error: # Try fallback before aborting — a different provider @@ -8454,7 +8400,7 @@ class AIAgent: self._vprint(f"{self.log_prefix} 🔌 Provider: {_provider} Model: {_model}", force=True) self._vprint(f"{self.log_prefix} 🌐 Endpoint: {_base}", force=True) # Actionable guidance for common auth errors - if status_code in (401, 403) or "unauthorized" in error_msg or "forbidden" in error_msg or "permission" in error_msg: + if classified.is_auth or classified.reason == FailoverReason.billing: if _provider == "openai-codex" and status_code == 401: self._vprint(f"{self.log_prefix} 💡 Codex OAuth token was rejected (HTTP 401). Your token may have been", force=True) self._vprint(f"{self.log_prefix} refreshed by another client (Codex CLI, VS Code). To fix:", force=True) diff --git a/tests/agent/test_error_classifier.py b/tests/agent/test_error_classifier.py new file mode 100644 index 0000000000..da248f8218 --- /dev/null +++ b/tests/agent/test_error_classifier.py @@ -0,0 +1,750 @@ +"""Tests for agent.error_classifier — structured API error classification.""" + +import pytest +from agent.error_classifier import ( + ClassifiedError, + FailoverReason, + classify_api_error, + _extract_status_code, + _extract_error_body, + _extract_error_code, + _classify_402, +) + + +# ── Helper: mock API errors ──────────────────────────────────────────── + +class MockAPIError(Exception): + """Simulates an OpenAI SDK APIStatusError.""" + def __init__(self, message, status_code=None, body=None): + super().__init__(message) + self.status_code = status_code + self.body = body or {} + + +class MockTransportError(Exception): + """Simulates a transport-level error with a specific type name.""" + pass + + +class ReadTimeout(MockTransportError): + pass + + +class ConnectError(MockTransportError): + pass + + +class RemoteProtocolError(MockTransportError): + pass + + +class ServerDisconnectedError(MockTransportError): + pass + + +# ── Test: FailoverReason enum ────────────────────────────────────────── + +class TestFailoverReason: + def test_all_reasons_have_string_values(self): + for reason in FailoverReason: + assert isinstance(reason.value, str) + + def test_enum_members_exist(self): + expected = { + "auth", "auth_permanent", "billing", "rate_limit", + "overloaded", "server_error", "timeout", + "context_overflow", "payload_too_large", + "model_not_found", "format_error", + "thinking_signature", "long_context_tier", "unknown", + } + actual = {r.value for r in FailoverReason} + assert expected == actual + + +# ── Test: ClassifiedError ────────────────────────────────────────────── + +class TestClassifiedError: + def test_is_auth_property(self): + e1 = ClassifiedError(reason=FailoverReason.auth) + assert e1.is_auth is True + + e2 = ClassifiedError(reason=FailoverReason.auth_permanent) + assert e2.is_auth is True + + e3 = ClassifiedError(reason=FailoverReason.billing) + assert e3.is_auth is False + + def test_is_transient_property(self): + transient_reasons = [ + FailoverReason.rate_limit, + FailoverReason.overloaded, + FailoverReason.server_error, + FailoverReason.timeout, + FailoverReason.unknown, + ] + for reason in transient_reasons: + e = ClassifiedError(reason=reason) + assert e.is_transient is True, f"{reason} should be transient" + + non_transient = [ + FailoverReason.auth, + FailoverReason.billing, + FailoverReason.model_not_found, + FailoverReason.format_error, + ] + for reason in non_transient: + e = ClassifiedError(reason=reason) + assert e.is_transient is False, f"{reason} should NOT be transient" + + def test_defaults(self): + e = ClassifiedError(reason=FailoverReason.unknown) + assert e.retryable is True + assert e.should_compress is False + assert e.should_rotate_credential is False + assert e.should_fallback is False + assert e.status_code is None + assert e.message == "" + + +# ── Test: Status code extraction ─────────────────────────────────────── + +class TestExtractStatusCode: + def test_from_status_code_attr(self): + e = MockAPIError("fail", status_code=429) + assert _extract_status_code(e) == 429 + + def test_from_status_attr(self): + class ErrWithStatus(Exception): + status = 503 + assert _extract_status_code(ErrWithStatus()) == 503 + + def test_from_cause_chain(self): + inner = MockAPIError("inner", status_code=401) + outer = Exception("outer") + outer.__cause__ = inner + assert _extract_status_code(outer) == 401 + + def test_none_when_missing(self): + assert _extract_status_code(Exception("generic")) is None + + def test_rejects_non_http_status(self): + """Integers outside 100-599 on .status should be ignored.""" + class ErrWeirdStatus(Exception): + status = 42 + assert _extract_status_code(ErrWeirdStatus()) is None + + +# ── Test: Error body extraction ──────────────────────────────────────── + +class TestExtractErrorBody: + def test_from_body_attr(self): + e = MockAPIError("fail", body={"error": {"message": "bad"}}) + assert _extract_error_body(e) == {"error": {"message": "bad"}} + + def test_empty_when_no_body(self): + assert _extract_error_body(Exception("generic")) == {} + + +# ── Test: Error code extraction ──────────────────────────────────────── + +class TestExtractErrorCode: + def test_from_nested_error_code(self): + body = {"error": {"code": "rate_limit_exceeded"}} + assert _extract_error_code(body) == "rate_limit_exceeded" + + def test_from_nested_error_type(self): + body = {"error": {"type": "invalid_request_error"}} + assert _extract_error_code(body) == "invalid_request_error" + + def test_from_top_level_code(self): + body = {"code": "model_not_found"} + assert _extract_error_code(body) == "model_not_found" + + def test_empty_when_no_code(self): + assert _extract_error_code({}) == "" + assert _extract_error_code({"error": {"message": "oops"}}) == "" + + +# ── Test: 402 disambiguation ─────────────────────────────────────────── + +class TestClassify402: + """The critical 402 billing vs rate_limit disambiguation.""" + + def test_billing_exhaustion(self): + """Plain 402 = billing.""" + result = _classify_402( + "payment required", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + + def test_transient_usage_limit(self): + """402 with 'usage limit' + 'try again' = rate limit, not billing.""" + result = _classify_402( + "usage limit exceeded. try again in 5 minutes", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + assert result.should_rotate_credential is True + + def test_quota_with_retry(self): + """402 with 'quota' + 'retry' = rate limit.""" + result = _classify_402( + "quota exceeded, please retry after the window resets", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.rate_limit + + def test_quota_without_retry(self): + """402 with just 'quota' but no transient signal = billing.""" + result = _classify_402( + "quota exceeded", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + def test_insufficient_credits(self): + result = _classify_402( + "insufficient credits to complete request", + lambda reason, **kw: ClassifiedError(reason=reason, **kw), + ) + assert result.reason == FailoverReason.billing + + +# ── Test: Full classification pipeline ───────────────────────────────── + +class TestClassifyApiError: + """End-to-end classification tests.""" + + # ── Auth errors ── + + def test_401_classified_as_auth(self): + e = MockAPIError("Unauthorized", status_code=401) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.auth + assert result.should_rotate_credential is True + # 401 is non-retryable on its own — credential rotation runs + # before the retryability check in the agent loop. + assert result.retryable is False + assert result.should_fallback is True + + def test_403_classified_as_auth(self): + e = MockAPIError("Forbidden", status_code=403) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.auth + assert result.should_fallback is True + + def test_403_key_limit_classified_as_billing(self): + """OpenRouter 403 'key limit exceeded' is billing, not auth.""" + e = MockAPIError("Key limit exceeded for this key", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + assert result.should_rotate_credential is True + assert result.should_fallback is True + + def test_403_spending_limit_classified_as_billing(self): + e = MockAPIError("spending limit reached", status_code=403) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.billing + + # ── Billing ── + + def test_402_plain_billing(self): + e = MockAPIError("Payment Required", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + assert result.retryable is False + + def test_402_transient_usage_limit(self): + e = MockAPIError("usage limit exceeded, try again later", status_code=402) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.retryable is True + + # ── Rate limit ── + + def test_429_rate_limit(self): + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + assert result.should_fallback is True + + # ── Server errors ── + + def test_500_server_error(self): + e = MockAPIError("Internal Server Error", status_code=500) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + assert result.retryable is True + + def test_502_server_error(self): + e = MockAPIError("Bad Gateway", status_code=502) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_503_overloaded(self): + e = MockAPIError("Service Unavailable", status_code=503) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + def test_529_anthropic_overloaded(self): + e = MockAPIError("Overloaded", status_code=529) + result = classify_api_error(e) + assert result.reason == FailoverReason.overloaded + + # ── Model not found ── + + def test_404_model_not_found(self): + e = MockAPIError("model not found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + assert result.should_fallback is True + assert result.retryable is False + + def test_404_generic(self): + e = MockAPIError("Not Found", status_code=404) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + # ── Payload too large ── + + def test_413_payload_too_large(self): + e = MockAPIError("Request Entity Too Large", status_code=413) + result = classify_api_error(e) + assert result.reason == FailoverReason.payload_too_large + assert result.should_compress is True + + # ── Context overflow ── + + def test_400_context_length(self): + e = MockAPIError("context length exceeded: 250000 > 200000", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_400_too_many_tokens(self): + e = MockAPIError("This model's maximum context is 128000 tokens, too many tokens", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_prompt_too_long(self): + e = MockAPIError("prompt is too long: 300000 tokens > 200000 maximum", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_large_session(self): + """Generic 400 with large session → context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=100000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + def test_400_generic_small_session_is_format_error(self): + """Generic 400 with small session → format error, not context overflow.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + result = classify_api_error(e, approx_tokens=1000, context_length=200000) + assert result.reason == FailoverReason.format_error + + # ── Server disconnect + large session ── + + def test_disconnect_large_session_context_overflow(self): + """Server disconnect with large session → context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_disconnect_small_session_timeout(self): + """Server disconnect with small session → timeout.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error(e, approx_tokens=5000, context_length=200000) + assert result.reason == FailoverReason.timeout + + # ── Provider-specific: Anthropic thinking signature ── + + def test_anthropic_thinking_signature(self): + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.thinking_signature + assert result.retryable is True + + def test_non_anthropic_400_with_signature_not_classified_as_thinking(self): + """400 with 'signature' but from non-Anthropic → format error.""" + e = MockAPIError("invalid signature", status_code=400) + result = classify_api_error(e, provider="openrouter", approx_tokens=0) + # Without "thinking" in the message, it shouldn't be thinking_signature + assert result.reason != FailoverReason.thinking_signature + + # ── Provider-specific: Anthropic long-context tier ── + + def test_anthropic_long_context_tier(self): + e = MockAPIError( + "Extra usage is required for long context requests over 200k tokens", + status_code=429, + ) + result = classify_api_error(e, provider="anthropic", model="claude-sonnet-4") + assert result.reason == FailoverReason.long_context_tier + assert result.should_compress is True + + def test_normal_429_not_long_context(self): + """Normal 429 without 'extra usage' + 'long context' → rate_limit.""" + e = MockAPIError("Too Many Requests", status_code=429) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.rate_limit + + # ── Transport errors ── + + def test_read_timeout(self): + e = ReadTimeout("Read timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + assert result.retryable is True + + def test_connect_error(self): + e = ConnectError("Connection refused") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_connection_error_builtin(self): + e = ConnectionError("Connection reset by peer") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + def test_timeout_error_builtin(self): + e = TimeoutError("timed out") + result = classify_api_error(e) + assert result.reason == FailoverReason.timeout + + # ── Error code classification ── + + def test_error_code_resource_exhausted(self): + e = MockAPIError( + "Resource exhausted", + body={"error": {"code": "resource_exhausted", "message": "Too many requests"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_error_code_model_not_found(self): + e = MockAPIError( + "Model not available", + body={"error": {"code": "model_not_found"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_error_code_context_length_exceeded(self): + e = MockAPIError( + "Context too large", + body={"error": {"code": "context_length_exceeded"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Message-only patterns (no status code) ── + + def test_message_billing_pattern(self): + e = Exception("insufficient credits to complete this request") + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_message_rate_limit_pattern(self): + e = Exception("rate limit reached for this model") + result = classify_api_error(e) + assert result.reason == FailoverReason.rate_limit + + def test_message_auth_pattern(self): + e = Exception("invalid api key provided") + result = classify_api_error(e) + assert result.reason == FailoverReason.auth + + def test_message_model_not_found_pattern(self): + e = Exception("gpt-99 is not a valid model") + result = classify_api_error(e) + assert result.reason == FailoverReason.model_not_found + + def test_message_context_overflow_pattern(self): + e = Exception("maximum context length exceeded") + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Unknown / fallback ── + + def test_generic_exception_is_unknown(self): + e = Exception("something weird happened") + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + # ── Format error ── + + def test_400_descriptive_format_error(self): + """400 with descriptive message (not context overflow) → format error.""" + e = MockAPIError( + "Invalid value for parameter 'temperature': must be between 0 and 2", + status_code=400, + body={"error": {"message": "Invalid value for parameter 'temperature': must be between 0 and 2"}}, + ) + result = classify_api_error(e, approx_tokens=1000) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + def test_422_format_error(self): + e = MockAPIError("Unprocessable Entity", status_code=422) + result = classify_api_error(e) + assert result.reason == FailoverReason.format_error + assert result.retryable is False + + # ── Peer closed + large session ── + + def test_peer_closed_large_session(self): + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=130000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + + # ── Chinese error messages ── + + def test_chinese_context_overflow(self): + e = MockAPIError("超过最大长度限制", status_code=400) + result = classify_api_error(e) + assert result.reason == FailoverReason.context_overflow + + # ── Result metadata ── + + def test_provider_and_model_in_result(self): + e = MockAPIError("fail", status_code=500) + result = classify_api_error(e, provider="openrouter", model="gpt-5") + assert result.provider == "openrouter" + assert result.model == "gpt-5" + assert result.status_code == 500 + + def test_message_extracted(self): + e = MockAPIError( + "outer", + status_code=500, + body={"error": {"message": "Internal server error occurred"}}, + ) + result = classify_api_error(e) + assert result.message == "Internal server error occurred" + + +# ── Test: Adversarial / edge cases (from live testing) ───────────────── + +class TestAdversarialEdgeCases: + """Edge cases discovered during live testing with real SDK objects.""" + + def test_empty_exception_message(self): + result = classify_api_error(Exception("")) + assert result.reason == FailoverReason.unknown + assert result.retryable is True + + def test_500_with_none_body(self): + e = MockAPIError("fail", status_code=500, body=None) + result = classify_api_error(e) + assert result.reason == FailoverReason.server_error + + def test_non_dict_body(self): + """Some providers return strings instead of JSON.""" + class StringBodyError(Exception): + status_code = 400 + body = "just a string" + result = classify_api_error(StringBodyError("bad")) + assert result.reason == FailoverReason.format_error + + def test_list_body(self): + class ListBodyError(Exception): + status_code = 500 + body = [{"error": "something"}] + result = classify_api_error(ListBodyError("server error")) + assert result.reason == FailoverReason.server_error + + def test_circular_cause_chain(self): + """Must not infinite-loop on circular __cause__.""" + e = Exception("circular") + e.__cause__ = e + result = classify_api_error(e) + assert result.reason == FailoverReason.unknown + + def test_three_level_cause_chain(self): + inner = MockAPIError("inner", status_code=429) + middle = Exception("middle") + middle.__cause__ = inner + outer = RuntimeError("outer") + outer.__cause__ = middle + result = classify_api_error(outer) + assert result.status_code == 429 + assert result.reason == FailoverReason.rate_limit + + def test_400_with_rate_limit_text(self): + """Some providers send rate limits as 400 instead of 429.""" + e = MockAPIError( + "rate limit policy", + status_code=400, + body={"error": {"message": "rate limit exceeded on this model"}}, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_400_with_billing_text(self): + """Some providers send billing errors as 400.""" + e = MockAPIError( + "billing", + status_code=400, + body={"error": {"message": "insufficient credits for this request"}}, + ) + result = classify_api_error(e) + assert result.reason == FailoverReason.billing + + def test_200_with_error_body(self): + """200 status with error in body — should be unknown, not crash.""" + class WeirdSuccess(Exception): + status_code = 200 + body = {"error": {"message": "loading"}} + result = classify_api_error(WeirdSuccess("model loading")) + assert result.reason == FailoverReason.unknown + + def test_ollama_context_size_exceeded(self): + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "context size has been exceeded"}}, + ) + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.context_overflow + + def test_connection_refused_error(self): + e = ConnectionRefusedError("Connection refused: localhost:11434") + result = classify_api_error(e, provider="ollama") + assert result.reason == FailoverReason.timeout + + def test_body_message_enrichment(self): + """Body message must be included in pattern matching even when + str(error) doesn't contain it (OpenAI SDK APIStatusError).""" + e = MockAPIError( + "Usage limit", # str(e) = "usage limit" + status_code=402, + body={"error": {"message": "Usage limit reached, try again in 5 minutes"}}, + ) + result = classify_api_error(e) + # "try again" is only in body, not in str(e) + assert result.reason == FailoverReason.rate_limit + + def test_disconnect_pattern_ordering(self): + """Disconnect + large session must beat generic transport catch.""" + class FakeRemoteProtocol(Exception): + pass + # Type name isn't in _TRANSPORT_ERROR_TYPES but message has disconnect pattern + e = Exception("peer closed connection without sending complete message") + result = classify_api_error(e, approx_tokens=150000, context_length=200000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_credit_balance_too_low(self): + e = MockAPIError( + "Credits low", + status_code=402, + body={"error": {"message": "Your credit balance is too low"}}, + ) + result = classify_api_error(e, provider="anthropic") + assert result.reason == FailoverReason.billing + + def test_deepseek_402_chinese(self): + """Chinese billing message should still match billing patterns.""" + # "余额不足" doesn't match English billing patterns, but 402 defaults to billing + e = MockAPIError("余额不足", status_code=402) + result = classify_api_error(e, provider="deepseek") + assert result.reason == FailoverReason.billing + + def test_openrouter_wrapped_context_overflow_in_metadata_raw(self): + """OpenRouter wraps provider errors in metadata.raw JSON string.""" + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "code": 400, + "metadata": { + "raw": '{"error":{"message":"context length exceeded: 50000 > 32768"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter", approx_tokens=10000) + assert result.reason == FailoverReason.context_overflow + assert result.should_compress is True + + def test_openrouter_wrapped_rate_limit_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"Rate limit exceeded. Please retry after 30s."}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.rate_limit + + def test_thinking_signature_via_openrouter(self): + """Thinking signature errors proxied through OpenRouter must be caught.""" + e = MockAPIError( + "thinking block has invalid signature", + status_code=400, + ) + # provider is openrouter, not anthropic — old code missed this + result = classify_api_error(e, provider="openrouter", model="anthropic/claude-sonnet-4") + assert result.reason == FailoverReason.thinking_signature + + def test_generic_400_large_by_message_count(self): + """Many small messages (>80) should trigger context overflow heuristic.""" + e = MockAPIError( + "Error", + status_code=400, + body={"error": {"message": "Error"}}, + ) + # Low token count but high message count + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=100, + ) + assert result.reason == FailoverReason.context_overflow + + def test_disconnect_large_by_message_count(self): + """Server disconnect with 200+ messages should trigger context overflow.""" + e = Exception("server disconnected without sending complete message") + result = classify_api_error( + e, approx_tokens=5000, context_length=200000, num_messages=250, + ) + assert result.reason == FailoverReason.context_overflow + + def test_openrouter_wrapped_model_not_found_in_metadata_raw(self): + e = MockAPIError( + "Provider returned error", + status_code=400, + body={ + "error": { + "message": "Provider returned error", + "metadata": { + "raw": '{"error":{"message":"The model gpt-99 does not exist"}}' + } + } + }, + ) + result = classify_api_error(e, provider="openrouter") + assert result.reason == FailoverReason.model_not_found