feat(agent): add jittered retry backoff

Adds agent/retry_utils.py with jittered_backoff() — exponential backoff
with additive jitter to prevent thundering-herd retry spikes when
multiple gateway sessions hit the same rate-limited provider.

Replaces fixed exponential backoff at 4 call sites:
- run_agent.py: None-choices retry path (5s base, 120s cap)
- run_agent.py: API error retry path (2s base, 60s cap)
- trajectory_compressor.py: sync + async summarization retries

Thread-safe jitter counter with overflow guards ensures unique seeds
across concurrent retries.

Trimmed from original PR to keep only wired-in functionality.

Co-authored-by: martinp09 <martinp09@users.noreply.github.com>
This commit is contained in:
zocomputer 2026-04-07 22:49:31 -07:00 committed by Teknium
parent fff237e111
commit e1befe5077
4 changed files with 181 additions and 4 deletions

57
agent/retry_utils.py Normal file
View file

@ -0,0 +1,57 @@
"""Retry utilities — jittered backoff for decorrelated retries.
Replaces fixed exponential backoff with jittered delays to prevent
thundering-herd retry spikes when multiple sessions hit the same
rate-limited provider concurrently.
"""
import random
import threading
import time
# Monotonic counter for jitter seed uniqueness within the same process.
# Protected by a lock to avoid race conditions in concurrent retry paths
# (e.g. multiple gateway sessions retrying simultaneously).
_jitter_counter = 0
_jitter_lock = threading.Lock()
def jittered_backoff(
attempt: int,
*,
base_delay: float = 5.0,
max_delay: float = 120.0,
jitter_ratio: float = 0.5,
) -> float:
"""Compute a jittered exponential backoff delay.
Args:
attempt: 1-based retry attempt number.
base_delay: Base delay in seconds for attempt 1.
max_delay: Maximum delay cap in seconds.
jitter_ratio: Fraction of computed delay to use as random jitter
range. 0.5 means jitter is uniform in [0, 0.5 * delay].
Returns:
Delay in seconds: min(base * 2^(attempt-1), max_delay) + jitter.
The jitter decorrelates concurrent retries so multiple sessions
hitting the same provider don't all retry at the same instant.
"""
global _jitter_counter
with _jitter_lock:
_jitter_counter += 1
tick = _jitter_counter
exponent = max(0, attempt - 1)
if exponent >= 63 or base_delay <= 0:
delay = max_delay
else:
delay = min(base_delay * (2 ** exponent), max_delay)
# Seed from time + counter for decorrelation even with coarse clocks.
seed = (time.time_ns() ^ (tick * 0x9E3779B9)) & 0xFFFFFFFF
rng = random.Random(seed)
jitter = rng.uniform(0, jitter_ratio * delay)
return delay + jitter

View file

@ -75,6 +75,7 @@ from hermes_constants import OPENROUTER_BASE_URL
# Agent internals extracted to agent/ package for modularity # Agent internals extracted to agent/ package for modularity
from agent.memory_manager import build_memory_context_block from agent.memory_manager import build_memory_context_block
from agent.retry_utils import jittered_backoff
from agent.prompt_builder import ( from agent.prompt_builder import (
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
@ -7541,7 +7542,8 @@ class AIAgent:
} }
# Longer backoff for rate limiting (likely cause of None choices) # Longer backoff for rate limiting (likely cause of None choices)
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s # Jittered exponential: 5s base, 120s cap + random jitter
wait_time = jittered_backoff(retry_count, base_delay=5.0, max_delay=120.0)
self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True) self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True)
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
@ -8398,7 +8400,7 @@ class AIAgent:
_retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes _retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes
except (TypeError, ValueError): except (TypeError, ValueError):
pass pass
wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60) wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
if is_rate_limited: if is_rate_limited:
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...") self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
else: else:

117
tests/test_retry_utils.py Normal file
View file

@ -0,0 +1,117 @@
"""Tests for agent.retry_utils jittered backoff."""
import threading
import agent.retry_utils as retry_utils
from agent.retry_utils import jittered_backoff
def test_backoff_is_exponential():
"""Base delay should double each attempt (before jitter)."""
for attempt in (1, 2, 3, 4):
delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)]
expected = min(5.0 * (2 ** (attempt - 1)), 120.0)
mean = sum(delays) / len(delays)
assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}"
def test_backoff_respects_max_delay():
"""Even with high attempt numbers, delay should not exceed max_delay."""
for attempt in (10, 20, 100):
delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0)
assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s"
def test_backoff_adds_jitter():
"""With jitter enabled, delays should vary across calls."""
delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)]
assert min(delays) != max(delays), "jitter should produce varying delays"
assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay"
assert all(d <= 15.0 for d in delays), "jittered delay should be bounded"
def test_backoff_attempt_1_is_base():
"""First attempt delay should equal base_delay (with no jitter)."""
delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0)
assert delay == 3.0
def test_backoff_with_zero_base_delay_returns_max():
"""base_delay=0 should return max_delay (guard against busy-wait)."""
delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0)
assert delay == 60.0
def test_backoff_with_extreme_attempt_returns_max():
"""Very large attempt numbers should not overflow and should return max_delay."""
delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0)
assert delay == 120.0
def test_backoff_negative_attempt_treated_as_one():
"""Negative attempt should not crash and behaves like attempt=1."""
delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0)
assert delay == 10.0
def test_backoff_thread_safety():
"""Concurrent calls should generally produce different delays."""
results = []
barrier = threading.Barrier(8)
def _call_backoff():
barrier.wait()
results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5))
threads = [threading.Thread(target=_call_backoff) for _ in range(8)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
assert len(results) == 8
unique = len(set(results))
assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique"
def test_backoff_uses_locked_tick_for_seed(monkeypatch):
"""Seed derivation should use per-call tick captured under lock."""
import time
monkeypatch.setattr(retry_utils, "_jitter_counter", 0)
recorded_seeds = []
class _RecordingRandom:
def __init__(self, seed):
recorded_seeds.append(seed)
def uniform(self, a, b):
return 0.0
monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom)
fixed_time_ns = 123456789
def _time_ns_wait_for_two_ticks():
deadline = time.time() + 2.0
while retry_utils._jitter_counter < 2 and time.time() < deadline:
time.sleep(0.001)
return fixed_time_ns
monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks)
barrier = threading.Barrier(2)
def _call():
barrier.wait()
jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)
threads = [threading.Thread(target=_call) for _ in range(2)]
for t in threads:
t.start()
for t in threads:
t.join(timeout=5)
assert len(recorded_seeds) == 2
assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}"

View file

@ -44,6 +44,7 @@ import fire
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn
from rich.console import Console from rich.console import Console
from hermes_constants import OPENROUTER_BASE_URL from hermes_constants import OPENROUTER_BASE_URL
from agent.retry_utils import jittered_backoff
# Load environment variables # Load environment variables
from dotenv import load_dotenv from dotenv import load_dotenv
@ -585,7 +586,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
if attempt < self.config.max_retries - 1: if attempt < self.config.max_retries - 1:
time.sleep(self.config.retry_delay * (attempt + 1)) time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
else: else:
# Fallback: create a basic summary # Fallback: create a basic summary
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
@ -647,7 +648,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
if attempt < self.config.max_retries - 1: if attempt < self.config.max_retries - 1:
await asyncio.sleep(self.config.retry_delay * (attempt + 1)) await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
else: else:
# Fallback: create a basic summary # Fallback: create a basic summary
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"