diff --git a/agent/retry_utils.py b/agent/retry_utils.py new file mode 100644 index 000000000..71d6963f7 --- /dev/null +++ b/agent/retry_utils.py @@ -0,0 +1,57 @@ +"""Retry utilities — jittered backoff for decorrelated retries. + +Replaces fixed exponential backoff with jittered delays to prevent +thundering-herd retry spikes when multiple sessions hit the same +rate-limited provider concurrently. +""" + +import random +import threading +import time + +# Monotonic counter for jitter seed uniqueness within the same process. +# Protected by a lock to avoid race conditions in concurrent retry paths +# (e.g. multiple gateway sessions retrying simultaneously). +_jitter_counter = 0 +_jitter_lock = threading.Lock() + + +def jittered_backoff( + attempt: int, + *, + base_delay: float = 5.0, + max_delay: float = 120.0, + jitter_ratio: float = 0.5, +) -> float: + """Compute a jittered exponential backoff delay. + + Args: + attempt: 1-based retry attempt number. + base_delay: Base delay in seconds for attempt 1. + max_delay: Maximum delay cap in seconds. + jitter_ratio: Fraction of computed delay to use as random jitter + range. 0.5 means jitter is uniform in [0, 0.5 * delay]. + + Returns: + Delay in seconds: min(base * 2^(attempt-1), max_delay) + jitter. + + The jitter decorrelates concurrent retries so multiple sessions + hitting the same provider don't all retry at the same instant. + """ + global _jitter_counter + with _jitter_lock: + _jitter_counter += 1 + tick = _jitter_counter + + exponent = max(0, attempt - 1) + if exponent >= 63 or base_delay <= 0: + delay = max_delay + else: + delay = min(base_delay * (2 ** exponent), max_delay) + + # Seed from time + counter for decorrelation even with coarse clocks. + seed = (time.time_ns() ^ (tick * 0x9E3779B9)) & 0xFFFFFFFF + rng = random.Random(seed) + jitter = rng.uniform(0, jitter_ratio * delay) + + return delay + jitter diff --git a/run_agent.py b/run_agent.py index 343110ecc..22928bb18 100644 --- a/run_agent.py +++ b/run_agent.py @@ -75,6 +75,7 @@ from hermes_constants import OPENROUTER_BASE_URL # Agent internals extracted to agent/ package for modularity from agent.memory_manager import build_memory_context_block +from agent.retry_utils import jittered_backoff from agent.prompt_builder import ( DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS, MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE, @@ -7541,7 +7542,8 @@ class AIAgent: } # Longer backoff for rate limiting (likely cause of None choices) - wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s + # Jittered exponential: 5s base, 120s cap + random jitter + wait_time = jittered_backoff(retry_count, base_delay=5.0, max_delay=120.0) self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True) logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}") @@ -8398,7 +8400,7 @@ class AIAgent: _retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes except (TypeError, ValueError): pass - wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60) + wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0) if is_rate_limited: self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...") else: diff --git a/tests/test_retry_utils.py b/tests/test_retry_utils.py new file mode 100644 index 000000000..f39c3142d --- /dev/null +++ b/tests/test_retry_utils.py @@ -0,0 +1,117 @@ +"""Tests for agent.retry_utils jittered backoff.""" + +import threading + +import agent.retry_utils as retry_utils +from agent.retry_utils import jittered_backoff + + +def test_backoff_is_exponential(): + """Base delay should double each attempt (before jitter).""" + for attempt in (1, 2, 3, 4): + delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)] + expected = min(5.0 * (2 ** (attempt - 1)), 120.0) + mean = sum(delays) / len(delays) + assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}" + + +def test_backoff_respects_max_delay(): + """Even with high attempt numbers, delay should not exceed max_delay.""" + for attempt in (10, 20, 100): + delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0) + assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s" + + +def test_backoff_adds_jitter(): + """With jitter enabled, delays should vary across calls.""" + delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)] + assert min(delays) != max(delays), "jitter should produce varying delays" + assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay" + assert all(d <= 15.0 for d in delays), "jittered delay should be bounded" + + +def test_backoff_attempt_1_is_base(): + """First attempt delay should equal base_delay (with no jitter).""" + delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0) + assert delay == 3.0 + + +def test_backoff_with_zero_base_delay_returns_max(): + """base_delay=0 should return max_delay (guard against busy-wait).""" + delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0) + assert delay == 60.0 + + +def test_backoff_with_extreme_attempt_returns_max(): + """Very large attempt numbers should not overflow and should return max_delay.""" + delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) + assert delay == 120.0 + + +def test_backoff_negative_attempt_treated_as_one(): + """Negative attempt should not crash and behaves like attempt=1.""" + delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0) + assert delay == 10.0 + + +def test_backoff_thread_safety(): + """Concurrent calls should generally produce different delays.""" + results = [] + barrier = threading.Barrier(8) + + def _call_backoff(): + barrier.wait() + results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)) + + threads = [threading.Thread(target=_call_backoff) for _ in range(8)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + assert len(results) == 8 + unique = len(set(results)) + assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique" + + +def test_backoff_uses_locked_tick_for_seed(monkeypatch): + """Seed derivation should use per-call tick captured under lock.""" + import time + + monkeypatch.setattr(retry_utils, "_jitter_counter", 0) + + recorded_seeds = [] + + class _RecordingRandom: + def __init__(self, seed): + recorded_seeds.append(seed) + + def uniform(self, a, b): + return 0.0 + + monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom) + + fixed_time_ns = 123456789 + + def _time_ns_wait_for_two_ticks(): + deadline = time.time() + 2.0 + while retry_utils._jitter_counter < 2 and time.time() < deadline: + time.sleep(0.001) + return fixed_time_ns + + monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks) + + barrier = threading.Barrier(2) + + def _call(): + barrier.wait() + jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) + + threads = [threading.Thread(target=_call) for _ in range(2)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=5) + + assert len(recorded_seeds) == 2 + assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}" diff --git a/trajectory_compressor.py b/trajectory_compressor.py index e4faf97a3..24c1f722a 100644 --- a/trajectory_compressor.py +++ b/trajectory_compressor.py @@ -44,6 +44,7 @@ import fire from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn from rich.console import Console from hermes_constants import OPENROUTER_BASE_URL +from agent.retry_utils import jittered_backoff # Load environment variables from dotenv import load_dotenv @@ -585,7 +586,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") if attempt < self.config.max_retries - 1: - time.sleep(self.config.retry_delay * (attempt + 1)) + time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0)) else: # Fallback: create a basic summary return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]" @@ -647,7 +648,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix.""" self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}") if attempt < self.config.max_retries - 1: - await asyncio.sleep(self.config.retry_delay * (attempt + 1)) + await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0)) else: # Fallback: create a basic summary return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"