refactor(agent): consolidate inner-retry-loop recovery flags into TurnRetryState (god-file Phase 1b)

run_conversation's inner retry loop tracked recovery state in ~15 scattered
bare booleans (per-provider OAuth refresh guards, format-recovery guards,
restart signals). They are now fields on a single TurnRetryState dataclass the
loop mutates in place (_retry.<flag>), giving the recovery bookkeeping a named,
testable home.

Loop-control vars (retry_count, max_retries, max_compression_attempts) stay as
plain locals — they're while-mechanics, not recovery bookkeeping.

Behavior-neutral: pure local→attribute rewrite of 42 references; kwarg NAMES
preserved (e.g. has_retried_429=_retry.has_retried_429). Live simple + tool
turns OK.

Validation: tests/run_agent/ 1615 passed / 0 failed under per-file process
isolation; new test_turn_retry_state.py pins the field contract.
This commit is contained in:
teknium1 2026-06-07 22:31:25 -07:00 committed by Teknium
parent 4d926f248d
commit 524453dab5
3 changed files with 175 additions and 56 deletions

View file

@ -32,6 +32,7 @@ from agent.display import KawaiiSpinner
from agent.error_classifier import FailoverReason, classify_api_error
from agent.iteration_budget import IterationBudget
from agent.turn_context import build_turn_context
from agent.turn_retry_state import TurnRetryState
from agent.memory_manager import build_memory_context_block
from agent.message_sanitization import (
_repair_tool_call_arguments,
@ -798,22 +799,8 @@ def run_conversation(
api_start_time = time.time()
retry_count = 0
max_retries = agent._api_max_retries
primary_recovery_attempted = False
_retry = TurnRetryState()
max_compression_attempts = 3
codex_auth_retry_attempted=False
anthropic_auth_retry_attempted=False
nous_auth_retry_attempted=False
nous_paid_entitlement_refresh_attempted=False
copilot_auth_retry_attempted=False
thinking_sig_retry_attempted = False
invalid_encrypted_content_retry_attempted = False
image_shrink_retry_attempted = False
multimodal_tool_content_retry_attempted = False
oauth_1m_beta_retry_attempted = False
llama_cpp_grammar_retry_attempted = False
has_retried_429 = False
restart_with_compressed_messages = False
restart_with_length_continuation = False
finish_reason = "stop"
response = None # Guard against UnboundLocalError if all retries fail
@ -846,7 +833,7 @@ def run_conversation(
if agent._try_activate_fallback():
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
# No fallback available — surface buffered context
# so user sees the rate-limit message that led here.
@ -1171,7 +1158,7 @@ def run_conversation(
if agent._try_activate_fallback():
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
# Check for error field in response (some providers include this)
@ -1242,7 +1229,7 @@ def run_conversation(
if agent._try_activate_fallback():
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
# Terminal — flush buffered retry trace so user sees what happened.
agent._flush_status_buffer()
@ -1466,7 +1453,7 @@ def run_conversation(
}
messages.append(continue_msg)
agent._session_messages = messages
restart_with_length_continuation = True
_retry.restart_with_length_continuation = True
break
partial_response = agent._strip_think_blocks("".join(truncated_response_parts)).strip()
@ -1715,7 +1702,7 @@ def run_conversation(
f"({hit_pct:.0f}% hit, {written:,} written)"
)
has_retried_429 = False # Reset on success
_retry.has_retried_429 = False # Reset on success
# Note: don't clear the retry buffer here — an "API call
# success" only means we got bytes back, not that we got
# usable content. Empty responses still loop through the
@ -2045,9 +2032,9 @@ def run_conversation(
getattr(agent, "provider", "") or "",
getattr(agent, "base_url", "") or "",
)
and not nous_paid_entitlement_refresh_attempted
and not _retry.nous_paid_entitlement_refresh_attempted
):
nous_paid_entitlement_refresh_attempted = True
_retry.nous_paid_entitlement_refresh_attempted = True
if _try_refresh_nous_paid_entitlement_credentials(agent):
agent._vprint(
f"{agent.log_prefix}🔐 Nous paid access verified — "
@ -2056,9 +2043,9 @@ def run_conversation(
)
continue
recovered_with_pool, has_retried_429 = agent._recover_with_credential_pool(
recovered_with_pool, _retry.has_retried_429 = agent._recover_with_credential_pool(
status_code=status_code,
has_retried_429=has_retried_429,
has_retried_429=_retry.has_retried_429,
classified_reason=classified.reason,
error_context=error_context,
)
@ -2073,9 +2060,9 @@ def run_conversation(
# fails, fall through to normal error handling.
if (
classified.reason == FailoverReason.image_too_large
and not image_shrink_retry_attempted
and not _retry.image_shrink_retry_attempted
):
image_shrink_retry_attempted = True
_retry.image_shrink_retry_attempted = True
if agent._try_shrink_image_parts_in_messages(api_messages):
agent._vprint(
f"{agent.log_prefix}📐 Image(s) exceeded provider size limit — "
@ -2098,9 +2085,9 @@ def run_conversation(
# downgrade, and retry once. See issue #27344.
if (
classified.reason == FailoverReason.multimodal_tool_content_unsupported
and not multimodal_tool_content_retry_attempted
and not _retry.multimodal_tool_content_retry_attempted
):
multimodal_tool_content_retry_attempted = True
_retry.multimodal_tool_content_retry_attempted = True
if agent._try_strip_image_parts_from_tool_messages(api_messages):
agent._vprint(
f"{agent.log_prefix}📐 Provider rejected list-type tool content — "
@ -2127,9 +2114,9 @@ def run_conversation(
classified.reason == FailoverReason.oauth_long_context_beta_forbidden
and agent.api_mode == "anthropic_messages"
and agent._is_anthropic_oauth
and not oauth_1m_beta_retry_attempted
and not _retry.oauth_1m_beta_retry_attempted
):
oauth_1m_beta_retry_attempted = True
_retry.oauth_1m_beta_retry_attempted = True
if not getattr(agent, "_oauth_1m_beta_disabled", False):
agent._oauth_1m_beta_disabled = True
try:
@ -2148,9 +2135,9 @@ def run_conversation(
agent.api_mode == "codex_responses"
and agent.provider in {"openai-codex", "xai-oauth"}
and status_code == 401
and not codex_auth_retry_attempted
and not _retry.codex_auth_retry_attempted
):
codex_auth_retry_attempted = True
_retry.codex_auth_retry_attempted = True
if agent._try_refresh_codex_client_credentials(force=True):
_label = "xAI OAuth" if agent.provider == "xai-oauth" else "Codex"
agent._buffer_vprint(f"🔐 {_label} auth refreshed after 401. Retrying request...")
@ -2159,9 +2146,9 @@ def run_conversation(
agent.api_mode == "chat_completions"
and agent.provider == "nous"
and status_code == 401
and not nous_auth_retry_attempted
and not _retry.nous_auth_retry_attempted
):
nous_auth_retry_attempted = True
_retry.nous_auth_retry_attempted = True
if agent._try_refresh_nous_client_credentials(force=True):
print(f"{agent.log_prefix}🔐 Nous agent key refreshed after 401. Retrying request...")
continue
@ -2190,9 +2177,9 @@ def run_conversation(
if (
agent.provider == "copilot"
and status_code == 401
and not copilot_auth_retry_attempted
and not _retry.copilot_auth_retry_attempted
):
copilot_auth_retry_attempted = True
_retry.copilot_auth_retry_attempted = True
if agent._try_refresh_copilot_client_credentials():
agent._buffer_vprint(f"🔐 Copilot credentials refreshed after 401. Retrying request...")
continue
@ -2200,9 +2187,9 @@ def run_conversation(
agent.api_mode == "anthropic_messages"
and status_code == 401
and hasattr(agent, '_anthropic_api_key')
and not anthropic_auth_retry_attempted
and not _retry.anthropic_auth_retry_attempted
):
anthropic_auth_retry_attempted = True
_retry.anthropic_auth_retry_attempted = True
from agent.anthropic_adapter import _is_oauth_token
from agent.azure_identity_adapter import is_token_provider
if agent._try_refresh_anthropic_client_credentials():
@ -2243,9 +2230,9 @@ def run_conversation(
# blocks at all. One-shot — don't retry infinitely.
if (
classified.reason == FailoverReason.thinking_signature
and not thinking_sig_retry_attempted
and not _retry.thinking_sig_retry_attempted
):
thinking_sig_retry_attempted = True
_retry.thinking_sig_retry_attempted = True
for _m in messages:
if isinstance(_m, dict):
_m.pop("reasoning_details", None)
@ -2277,7 +2264,7 @@ def run_conversation(
# handles it (the provider is rejecting something else).
if (
classified.reason == FailoverReason.invalid_encrypted_content
and not invalid_encrypted_content_retry_attempted
and not _retry.invalid_encrypted_content_retry_attempted
and agent.api_mode == "codex_responses"
and bool(getattr(agent, "_codex_reasoning_replay_enabled", True))
and any(
@ -2288,7 +2275,7 @@ def run_conversation(
for _m in messages
)
):
invalid_encrypted_content_retry_attempted = True
_retry.invalid_encrypted_content_retry_attempted = True
replay_stats = agent._disable_codex_reasoning_replay(messages)
agent._vprint(
f"{agent.log_prefix}⚠️ Encrypted reasoning replay was rejected by the provider — "
@ -2315,9 +2302,9 @@ def run_conversation(
# fires only for users on llama.cpp's OAI server.
if (
classified.reason == FailoverReason.llama_cpp_grammar_pattern
and not llama_cpp_grammar_retry_attempted
and not _retry.llama_cpp_grammar_retry_attempted
):
llama_cpp_grammar_retry_attempted = True
_retry.llama_cpp_grammar_retry_attempted = True
try:
from tools.schema_sanitizer import strip_pattern_and_format
_, _stripped = strip_pattern_and_format(agent.tools)
@ -2528,7 +2515,7 @@ def run_conversation(
f"(was {old_ctx:,}), retrying..."
)
time.sleep(2)
restart_with_compressed_messages = True
_retry.restart_with_compressed_messages = True
break
# Fall through to normal error handling if compression
# is exhausted or didn't help.
@ -2561,7 +2548,7 @@ def run_conversation(
if agent._try_activate_fallback(reason=classified.reason):
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
# ── Nous Portal: record rate limit & skip retries ─────
@ -2699,7 +2686,7 @@ def run_conversation(
if len(messages) < original_len:
agent._buffer_status(f"🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
restart_with_compressed_messages = True
_retry.restart_with_compressed_messages = True
break
else:
# Terminal — surface buffered context so the user
@ -2771,7 +2758,7 @@ def run_conversation(
"failed": True,
"compression_exhausted": True,
}
restart_with_compressed_messages = True
_retry.restart_with_compressed_messages = True
break
# Error is about the INPUT being too large. Only reduce
@ -2856,7 +2843,7 @@ def run_conversation(
if len(messages) < original_len:
agent._buffer_status(f"🗜️ Compressed {original_len}{len(messages)} messages, retrying...")
time.sleep(2) # Brief pause between compression retries
restart_with_compressed_messages = True
_retry.restart_with_compressed_messages = True
break
else:
# Can't compress further and already at minimum tier
@ -2961,7 +2948,7 @@ def run_conversation(
if agent._try_activate_fallback():
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
if api_kwargs is not None:
agent._dump_api_request_debug(
@ -3093,10 +3080,10 @@ def run_conversation(
# client once for transient transport errors (stale
# connection pool, TCP reset). Only attempted once
# per API call block.
if not primary_recovery_attempted and agent._try_recover_primary_transport(
if not _retry.primary_recovery_attempted and agent._try_recover_primary_transport(
api_error, retry_count=retry_count, max_retries=max_retries,
):
primary_recovery_attempted = True
_retry.primary_recovery_attempted = True
retry_count = 0
continue
# Try fallback before giving up entirely
@ -3105,7 +3092,7 @@ def run_conversation(
if agent._try_activate_fallback():
retry_count = 0
compression_attempts = 0
primary_recovery_attempted = False
_retry.primary_recovery_attempted = False
continue
# Terminal — flush buffered retry/fallback trace.
agent._flush_status_buffer()
@ -3256,17 +3243,17 @@ def run_conversation(
_turn_exit_reason = "interrupted_during_api_call"
break
if restart_with_compressed_messages:
if _retry.restart_with_compressed_messages:
api_call_count -= 1
agent.iteration_budget.refund()
# Count compression restarts toward the retry limit to prevent
# infinite loops when compression reduces messages but not enough
# to fit the context window.
retry_count += 1
restart_with_compressed_messages = False
_retry.restart_with_compressed_messages = False
continue
if restart_with_length_continuation:
if _retry.restart_with_length_continuation:
# Progressively boost the output token budget on each retry.
# Retry 1 → 2× base, retry 2 → 3× base, capped at 32 768.
# Applies to all providers via _ephemeral_max_output_tokens.

68
agent/turn_retry_state.py Normal file
View file

@ -0,0 +1,68 @@
"""Per-attempt recovery bookkeeping for the conversation turn loop.
The inner retry loop in ``run_conversation`` (``while retry_count <
max_retries``) makes several distinct recovery attempts on a single model API
call: a credential-pool 429 retry, a per-provider OAuth refresh (codex,
anthropic, nous, copilot), a long-context compression restart, a length-
continuation restart, and a handful of format-recovery branches (thinking-
signature stripping, multimodal-tool-content stripping, llama.cpp grammar
fallback, image shrink, invalid-encrypted-content, 1M-beta header).
Each of those branches is guarded by a one-shot boolean so it fires at most
once per attempt. They used to be ~16 bare ``*_attempted`` / ``has_retried_*``
/ ``restart_with_*`` locals declared inline before the loop and threaded
through its 2,400-line body. ``TurnRetryState`` collapses them into one object
the loop mutates in place (``state.codex_auth_retry_attempted = True``), giving
the recovery bookkeeping a single named, testable home.
Loop-control variables (``retry_count``, ``max_retries``,
``max_compression_attempts``) intentionally stay as plain locals they are the
``while`` mechanics, not recovery bookkeeping, and putting them on the object
would add indirection without clarifying anything.
This module is dependency-free so it can be unit-tested in isolation and
imported by the turn loop without an import cycle.
"""
from __future__ import annotations
from dataclasses import dataclass, fields
@dataclass
class TurnRetryState:
"""One-shot recovery guards + restart signals for a single API-call attempt.
A fresh instance is created for each iteration of the outer turn loop
(once per ``api_call_count``). Each guard fires its recovery branch at most
once; the ``restart_with_*`` signals are read by the loop after the attempt
to decide whether to rebuild the request and retry.
"""
# ── Per-provider OAuth / credential refresh guards ───────────────────
codex_auth_retry_attempted: bool = False
anthropic_auth_retry_attempted: bool = False
nous_auth_retry_attempted: bool = False
nous_paid_entitlement_refresh_attempted: bool = False
copilot_auth_retry_attempted: bool = False
# ── Format / payload recovery guards ─────────────────────────────────
thinking_sig_retry_attempted: bool = False
invalid_encrypted_content_retry_attempted: bool = False
image_shrink_retry_attempted: bool = False
multimodal_tool_content_retry_attempted: bool = False
oauth_1m_beta_retry_attempted: bool = False
llama_cpp_grammar_retry_attempted: bool = False
# ── Transport / rate-limit recovery ──────────────────────────────────
primary_recovery_attempted: bool = False
has_retried_429: bool = False
# ── Restart signals (read by the outer loop after the attempt) ───────
restart_with_compressed_messages: bool = False
restart_with_length_continuation: bool = False
def __iter__(self):
# Convenience for debugging / tests: iterate (name, value) pairs.
for f in fields(self):
yield f.name, getattr(self, f.name)

View file

@ -0,0 +1,64 @@
"""Unit tests for TurnRetryState (god-file Phase 1b).
The dataclass holds the inner-retry-loop's one-shot recovery guards + restart
signals. These tests pin its shape and default semantics the behavioral
guarantee for the loop itself is the existing recovery-branch tests in
tests/run_agent/ which now exercise these fields via `_retry.<flag>`.
"""
from __future__ import annotations
from dataclasses import fields
from agent.turn_retry_state import TurnRetryState
EXPECTED_FIELDS = {
"codex_auth_retry_attempted",
"anthropic_auth_retry_attempted",
"nous_auth_retry_attempted",
"nous_paid_entitlement_refresh_attempted",
"copilot_auth_retry_attempted",
"thinking_sig_retry_attempted",
"invalid_encrypted_content_retry_attempted",
"image_shrink_retry_attempted",
"multimodal_tool_content_retry_attempted",
"oauth_1m_beta_retry_attempted",
"llama_cpp_grammar_retry_attempted",
"primary_recovery_attempted",
"has_retried_429",
"restart_with_compressed_messages",
"restart_with_length_continuation",
}
def test_all_guards_default_false():
s = TurnRetryState()
for name, value in s:
assert value is False, f"{name} should default to False"
def test_field_set_matches_contract():
names = {f.name for f in fields(TurnRetryState)}
assert names == EXPECTED_FIELDS, (
f"unexpected drift: missing={EXPECTED_FIELDS - names} extra={names - EXPECTED_FIELDS}"
)
def test_loop_control_vars_are_not_on_state():
# retry_count / max_retries / max_compression_attempts stay as loop locals,
# NOT on the state object (they are while-mechanics, not recovery bookkeeping).
names = {f.name for f in fields(TurnRetryState)}
for loop_local in ("retry_count", "max_retries", "max_compression_attempts"):
assert loop_local not in names
def test_guards_are_independently_mutable():
s = TurnRetryState()
s.codex_auth_retry_attempted = True
s.restart_with_compressed_messages = True
assert s.codex_auth_retry_attempted is True
assert s.restart_with_compressed_messages is True
# untouched guards stay False
assert s.has_retried_429 is False
assert s.anthropic_auth_retry_attempted is False