mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(agent): add jittered retry backoff
Adds agent/retry_utils.py with jittered_backoff() — exponential backoff with additive jitter to prevent thundering-herd retry spikes when multiple gateway sessions hit the same rate-limited provider. Replaces fixed exponential backoff at 4 call sites: - run_agent.py: None-choices retry path (5s base, 120s cap) - run_agent.py: API error retry path (2s base, 60s cap) - trajectory_compressor.py: sync + async summarization retries Thread-safe jitter counter with overflow guards ensures unique seeds across concurrent retries. Trimmed from original PR to keep only wired-in functionality. Co-authored-by: martinp09 <martinp09@users.noreply.github.com>
This commit is contained in:
parent
fff237e111
commit
e1befe5077
4 changed files with 181 additions and 4 deletions
57
agent/retry_utils.py
Normal file
57
agent/retry_utils.py
Normal file
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""Retry utilities — jittered backoff for decorrelated retries.
|
||||||
|
|
||||||
|
Replaces fixed exponential backoff with jittered delays to prevent
|
||||||
|
thundering-herd retry spikes when multiple sessions hit the same
|
||||||
|
rate-limited provider concurrently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Monotonic counter for jitter seed uniqueness within the same process.
|
||||||
|
# Protected by a lock to avoid race conditions in concurrent retry paths
|
||||||
|
# (e.g. multiple gateway sessions retrying simultaneously).
|
||||||
|
_jitter_counter = 0
|
||||||
|
_jitter_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def jittered_backoff(
|
||||||
|
attempt: int,
|
||||||
|
*,
|
||||||
|
base_delay: float = 5.0,
|
||||||
|
max_delay: float = 120.0,
|
||||||
|
jitter_ratio: float = 0.5,
|
||||||
|
) -> float:
|
||||||
|
"""Compute a jittered exponential backoff delay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attempt: 1-based retry attempt number.
|
||||||
|
base_delay: Base delay in seconds for attempt 1.
|
||||||
|
max_delay: Maximum delay cap in seconds.
|
||||||
|
jitter_ratio: Fraction of computed delay to use as random jitter
|
||||||
|
range. 0.5 means jitter is uniform in [0, 0.5 * delay].
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Delay in seconds: min(base * 2^(attempt-1), max_delay) + jitter.
|
||||||
|
|
||||||
|
The jitter decorrelates concurrent retries so multiple sessions
|
||||||
|
hitting the same provider don't all retry at the same instant.
|
||||||
|
"""
|
||||||
|
global _jitter_counter
|
||||||
|
with _jitter_lock:
|
||||||
|
_jitter_counter += 1
|
||||||
|
tick = _jitter_counter
|
||||||
|
|
||||||
|
exponent = max(0, attempt - 1)
|
||||||
|
if exponent >= 63 or base_delay <= 0:
|
||||||
|
delay = max_delay
|
||||||
|
else:
|
||||||
|
delay = min(base_delay * (2 ** exponent), max_delay)
|
||||||
|
|
||||||
|
# Seed from time + counter for decorrelation even with coarse clocks.
|
||||||
|
seed = (time.time_ns() ^ (tick * 0x9E3779B9)) & 0xFFFFFFFF
|
||||||
|
rng = random.Random(seed)
|
||||||
|
jitter = rng.uniform(0, jitter_ratio * delay)
|
||||||
|
|
||||||
|
return delay + jitter
|
||||||
|
|
@ -75,6 +75,7 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||||
|
|
||||||
# Agent internals extracted to agent/ package for modularity
|
# Agent internals extracted to agent/ package for modularity
|
||||||
from agent.memory_manager import build_memory_context_block
|
from agent.memory_manager import build_memory_context_block
|
||||||
|
from agent.retry_utils import jittered_backoff
|
||||||
from agent.prompt_builder import (
|
from agent.prompt_builder import (
|
||||||
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
DEFAULT_AGENT_IDENTITY, PLATFORM_HINTS,
|
||||||
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
MEMORY_GUIDANCE, SESSION_SEARCH_GUIDANCE, SKILLS_GUIDANCE,
|
||||||
|
|
@ -7541,7 +7542,8 @@ class AIAgent:
|
||||||
}
|
}
|
||||||
|
|
||||||
# Longer backoff for rate limiting (likely cause of None choices)
|
# Longer backoff for rate limiting (likely cause of None choices)
|
||||||
wait_time = min(5 * (2 ** (retry_count - 1)), 120) # 5s, 10s, 20s, 40s, 80s, 120s
|
# Jittered exponential: 5s base, 120s cap + random jitter
|
||||||
|
wait_time = jittered_backoff(retry_count, base_delay=5.0, max_delay=120.0)
|
||||||
self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True)
|
self._vprint(f"{self.log_prefix}⏳ Retrying in {wait_time}s (extended backoff for possible rate limit)...", force=True)
|
||||||
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
logging.warning(f"Invalid API response (retry {retry_count}/{max_retries}): {', '.join(error_details)} | Provider: {provider_name}")
|
||||||
|
|
||||||
|
|
@ -8398,7 +8400,7 @@ class AIAgent:
|
||||||
_retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes
|
_retry_after = min(int(_ra_raw), 120) # Cap at 2 minutes
|
||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
pass
|
pass
|
||||||
wait_time = _retry_after if _retry_after else min(2 ** retry_count, 60)
|
wait_time = _retry_after if _retry_after else jittered_backoff(retry_count, base_delay=2.0, max_delay=60.0)
|
||||||
if is_rate_limited:
|
if is_rate_limited:
|
||||||
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
self._emit_status(f"⏱️ Rate limit reached. Waiting {wait_time}s before retry (attempt {retry_count + 1}/{max_retries})...")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
117
tests/test_retry_utils.py
Normal file
117
tests/test_retry_utils.py
Normal file
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""Tests for agent.retry_utils jittered backoff."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import agent.retry_utils as retry_utils
|
||||||
|
from agent.retry_utils import jittered_backoff
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_is_exponential():
|
||||||
|
"""Base delay should double each attempt (before jitter)."""
|
||||||
|
for attempt in (1, 2, 3, 4):
|
||||||
|
delays = [jittered_backoff(attempt, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0) for _ in range(100)]
|
||||||
|
expected = min(5.0 * (2 ** (attempt - 1)), 120.0)
|
||||||
|
mean = sum(delays) / len(delays)
|
||||||
|
assert abs(mean - expected) < 0.01, f"attempt {attempt}: expected {expected}, got {mean}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_respects_max_delay():
|
||||||
|
"""Even with high attempt numbers, delay should not exceed max_delay."""
|
||||||
|
for attempt in (10, 20, 100):
|
||||||
|
delay = jittered_backoff(attempt, base_delay=5.0, max_delay=60.0, jitter_ratio=0.0)
|
||||||
|
assert delay <= 60.0, f"attempt {attempt}: delay {delay} exceeds max 60s"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_adds_jitter():
|
||||||
|
"""With jitter enabled, delays should vary across calls."""
|
||||||
|
delays = [jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5) for _ in range(50)]
|
||||||
|
assert min(delays) != max(delays), "jitter should produce varying delays"
|
||||||
|
assert all(d >= 10.0 for d in delays), "jittered delay should be >= base delay"
|
||||||
|
assert all(d <= 15.0 for d in delays), "jittered delay should be bounded"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_attempt_1_is_base():
|
||||||
|
"""First attempt delay should equal base_delay (with no jitter)."""
|
||||||
|
delay = jittered_backoff(1, base_delay=3.0, max_delay=120.0, jitter_ratio=0.0)
|
||||||
|
assert delay == 3.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_with_zero_base_delay_returns_max():
|
||||||
|
"""base_delay=0 should return max_delay (guard against busy-wait)."""
|
||||||
|
delay = jittered_backoff(1, base_delay=0.0, max_delay=60.0, jitter_ratio=0.0)
|
||||||
|
assert delay == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_with_extreme_attempt_returns_max():
|
||||||
|
"""Very large attempt numbers should not overflow and should return max_delay."""
|
||||||
|
delay = jittered_backoff(999, base_delay=5.0, max_delay=120.0, jitter_ratio=0.0)
|
||||||
|
assert delay == 120.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_negative_attempt_treated_as_one():
|
||||||
|
"""Negative attempt should not crash and behaves like attempt=1."""
|
||||||
|
delay = jittered_backoff(-5, base_delay=10.0, max_delay=120.0, jitter_ratio=0.0)
|
||||||
|
assert delay == 10.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_thread_safety():
|
||||||
|
"""Concurrent calls should generally produce different delays."""
|
||||||
|
results = []
|
||||||
|
barrier = threading.Barrier(8)
|
||||||
|
|
||||||
|
def _call_backoff():
|
||||||
|
barrier.wait()
|
||||||
|
results.append(jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5))
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=_call_backoff) for _ in range(8)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(results) == 8
|
||||||
|
unique = len(set(results))
|
||||||
|
assert unique >= 6, f"Expected mostly unique delays, got {unique}/8 unique"
|
||||||
|
|
||||||
|
|
||||||
|
def test_backoff_uses_locked_tick_for_seed(monkeypatch):
|
||||||
|
"""Seed derivation should use per-call tick captured under lock."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
monkeypatch.setattr(retry_utils, "_jitter_counter", 0)
|
||||||
|
|
||||||
|
recorded_seeds = []
|
||||||
|
|
||||||
|
class _RecordingRandom:
|
||||||
|
def __init__(self, seed):
|
||||||
|
recorded_seeds.append(seed)
|
||||||
|
|
||||||
|
def uniform(self, a, b):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
monkeypatch.setattr(retry_utils.random, "Random", _RecordingRandom)
|
||||||
|
|
||||||
|
fixed_time_ns = 123456789
|
||||||
|
|
||||||
|
def _time_ns_wait_for_two_ticks():
|
||||||
|
deadline = time.time() + 2.0
|
||||||
|
while retry_utils._jitter_counter < 2 and time.time() < deadline:
|
||||||
|
time.sleep(0.001)
|
||||||
|
return fixed_time_ns
|
||||||
|
|
||||||
|
monkeypatch.setattr(retry_utils.time, "time_ns", _time_ns_wait_for_two_ticks)
|
||||||
|
|
||||||
|
barrier = threading.Barrier(2)
|
||||||
|
|
||||||
|
def _call():
|
||||||
|
barrier.wait()
|
||||||
|
jittered_backoff(1, base_delay=10.0, max_delay=120.0, jitter_ratio=0.5)
|
||||||
|
|
||||||
|
threads = [threading.Thread(target=_call) for _ in range(2)]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
assert len(recorded_seeds) == 2
|
||||||
|
assert len(set(recorded_seeds)) == 2, f"Expected unique seeds, got {recorded_seeds}"
|
||||||
|
|
@ -44,6 +44,7 @@ import fire
|
||||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn, TimeElapsedColumn, TimeRemainingColumn
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from hermes_constants import OPENROUTER_BASE_URL
|
from hermes_constants import OPENROUTER_BASE_URL
|
||||||
|
from agent.retry_utils import jittered_backoff
|
||||||
|
|
||||||
# Load environment variables
|
# Load environment variables
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
@ -585,7 +586,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||||
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
||||||
|
|
||||||
if attempt < self.config.max_retries - 1:
|
if attempt < self.config.max_retries - 1:
|
||||||
time.sleep(self.config.retry_delay * (attempt + 1))
|
time.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
|
||||||
else:
|
else:
|
||||||
# Fallback: create a basic summary
|
# Fallback: create a basic summary
|
||||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||||
|
|
@ -647,7 +648,7 @@ Write only the summary, starting with "[CONTEXT SUMMARY]:" prefix."""
|
||||||
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
self.logger.warning(f"Summarization attempt {attempt + 1} failed: {e}")
|
||||||
|
|
||||||
if attempt < self.config.max_retries - 1:
|
if attempt < self.config.max_retries - 1:
|
||||||
await asyncio.sleep(self.config.retry_delay * (attempt + 1))
|
await asyncio.sleep(jittered_backoff(attempt + 1, base_delay=self.config.retry_delay, max_delay=30.0))
|
||||||
else:
|
else:
|
||||||
# Fallback: create a basic summary
|
# Fallback: create a basic summary
|
||||||
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
return "[CONTEXT SUMMARY]: [Summary generation failed - previous turns contained tool calls and responses that have been compressed to save context space.]"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue