diff --git a/agent/credential_pool.py b/agent/credential_pool.py index ad4dbcfc1..003a5a8e7 100644 --- a/agent/credential_pool.py +++ b/agent/credential_pool.py @@ -267,6 +267,10 @@ class CredentialPool: def has_credentials(self) -> bool: return bool(self._entries) + def has_available(self) -> bool: + """True if at least one entry is not currently in exhaustion cooldown.""" + return bool(self._available_entries()) + def entries(self) -> List[PooledCredential]: return list(self._entries) diff --git a/agent/smart_model_routing.py b/agent/smart_model_routing.py index d57cd1b83..ada865af0 100644 --- a/agent/smart_model_routing.py +++ b/agent/smart_model_routing.py @@ -127,6 +127,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any "api_mode": primary.get("api_mode"), "command": primary.get("command"), "args": list(primary.get("args") or []), + "credential_pool": primary.get("credential_pool"), }, "label": None, "signature": ( @@ -162,6 +163,7 @@ def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any "api_mode": primary.get("api_mode"), "command": primary.get("command"), "args": list(primary.get("args") or []), + "credential_pool": primary.get("credential_pool"), }, "label": None, "signature": ( diff --git a/cli.py b/cli.py index b18e53077..151ae4615 100644 --- a/cli.py +++ b/cli.py @@ -2024,6 +2024,7 @@ class HermesCLI: "api_mode": self.api_mode, "command": self.acp_command, "args": list(self.acp_args or []), + "credential_pool": getattr(self, "_credential_pool", None), }, ) diff --git a/gateway/run.py b/gateway/run.py index cc1a6666f..49135ce5a 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -788,6 +788,7 @@ class GatewayRunner: "api_mode": runtime_kwargs.get("api_mode"), "command": runtime_kwargs.get("command"), "args": list(runtime_kwargs.get("args") or []), + "credential_pool": runtime_kwargs.get("credential_pool"), } return resolve_turn_route(user_message, getattr(self, "_smart_model_routing", {}), primary) diff --git a/run_agent.py b/run_agent.py index 5ed40500b..558a89457 100644 --- a/run_agent.py +++ b/run_agent.py @@ -7178,10 +7178,17 @@ class AIAgent: or "quota" in error_msg ) if is_rate_limited and self._fallback_index < len(self._fallback_chain): - self._emit_status("⚠️ Rate limited — switching to fallback provider...") - if self._try_activate_fallback(): - retry_count = 0 - continue + # Don't eagerly fallback if credential pool rotation may + # still recover. The pool's retry-then-rotate cycle needs + # at least one more attempt to fire — jumping to a fallback + # provider here short-circuits it. + pool = self._credential_pool + pool_may_recover = pool is not None and pool.has_available() + if not pool_may_recover: + self._emit_status("⚠️ Rate limited — switching to fallback provider...") + if self._try_activate_fallback(): + retry_count = 0 + continue is_payload_too_large = ( status_code == 413 diff --git a/tests/test_credential_pool_routing.py b/tests/test_credential_pool_routing.py new file mode 100644 index 000000000..f4006a236 --- /dev/null +++ b/tests/test_credential_pool_routing.py @@ -0,0 +1,350 @@ +"""Tests for credential pool preservation through smart routing and 429 recovery. + +Covers: +1. credential_pool flows through resolve_turn_route (no-route and fallback paths) +2. CLI _resolve_turn_agent_config passes credential_pool to primary dict +3. Gateway _resolve_turn_agent_config passes credential_pool to primary dict +4. Eager fallback deferred when credential pool has credentials +5. Eager fallback fires when no credential pool exists +6. Full 429 rotation cycle: retry-same → rotate → exhaust → fallback +""" + +import os +import time +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, PropertyMock + +import pytest + + +# --------------------------------------------------------------------------- +# 1. smart_model_routing: credential_pool preserved in no-route path +# --------------------------------------------------------------------------- + +class TestSmartRoutingPoolPreservation: + def test_no_route_preserves_credential_pool(self): + from agent.smart_model_routing import resolve_turn_route + + fake_pool = MagicMock(name="CredentialPool") + primary = { + "model": "gpt-5.4", + "api_key": "sk-test", + "base_url": None, + "provider": "openai-codex", + "api_mode": "codex_responses", + "command": None, + "args": [], + "credential_pool": fake_pool, + } + # routing disabled + result = resolve_turn_route("hello", None, primary) + assert result["runtime"]["credential_pool"] is fake_pool + + def test_no_route_none_pool(self): + from agent.smart_model_routing import resolve_turn_route + + primary = { + "model": "gpt-5.4", + "api_key": "sk-test", + "base_url": None, + "provider": "openai-codex", + "api_mode": "codex_responses", + "command": None, + "args": [], + } + result = resolve_turn_route("hello", None, primary) + assert result["runtime"]["credential_pool"] is None + + def test_routing_disabled_preserves_pool(self): + from agent.smart_model_routing import resolve_turn_route + + fake_pool = MagicMock(name="CredentialPool") + primary = { + "model": "gpt-5.4", + "api_key": "sk-test", + "base_url": None, + "provider": "openai-codex", + "api_mode": "codex_responses", + "command": None, + "args": [], + "credential_pool": fake_pool, + } + # routing explicitly disabled + result = resolve_turn_route("hello", {"enabled": False}, primary) + assert result["runtime"]["credential_pool"] is fake_pool + + def test_route_fallback_on_resolve_error_preserves_pool(self, monkeypatch): + """When smart routing picks a cheap model but resolve_runtime_provider + fails, the fallback to primary must still include credential_pool.""" + from agent.smart_model_routing import resolve_turn_route + + fake_pool = MagicMock(name="CredentialPool") + primary = { + "model": "gpt-5.4", + "api_key": "sk-test", + "base_url": None, + "provider": "openai-codex", + "api_mode": "codex_responses", + "command": None, + "args": [], + "credential_pool": fake_pool, + } + routing_config = { + "enabled": True, + "cheap_model": "openai/gpt-4.1-mini", + "cheap_provider": "openrouter", + "max_tokens": 200, + "patterns": ["^(hi|hello|hey)"], + } + # Force resolve_runtime_provider to fail so it falls back to primary + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + MagicMock(side_effect=RuntimeError("no credentials")), + ) + result = resolve_turn_route("hi", routing_config, primary) + assert result["runtime"]["credential_pool"] is fake_pool + + +# --------------------------------------------------------------------------- +# 2 & 3. CLI and Gateway _resolve_turn_agent_config include credential_pool +# --------------------------------------------------------------------------- + +class TestCliTurnRoutePool: + def test_resolve_turn_includes_pool(self, monkeypatch, tmp_path): + """CLI's _resolve_turn_agent_config must pass credential_pool to primary.""" + from agent.smart_model_routing import resolve_turn_route + captured = {} + + def spy_resolve(user_message, routing_config, primary): + captured["primary"] = primary + return resolve_turn_route(user_message, routing_config, primary) + + monkeypatch.setattr( + "agent.smart_model_routing.resolve_turn_route", spy_resolve + ) + + # Build a minimal HermesCLI-like object with the method + shell = SimpleNamespace( + model="gpt-5.4", + api_key="sk-test", + base_url=None, + provider="openai-codex", + api_mode="codex_responses", + acp_command=None, + acp_args=[], + _credential_pool=MagicMock(name="FakePool"), + _smart_model_routing={"enabled": False}, + ) + + # Import and bind the real method + from cli import HermesCLI + bound = HermesCLI._resolve_turn_agent_config.__get__(shell) + bound("test message") + + assert "credential_pool" in captured["primary"] + assert captured["primary"]["credential_pool"] is shell._credential_pool + + +class TestGatewayTurnRoutePool: + def test_resolve_turn_includes_pool(self, monkeypatch): + """Gateway's _resolve_turn_agent_config must pass credential_pool.""" + from agent.smart_model_routing import resolve_turn_route + captured = {} + + def spy_resolve(user_message, routing_config, primary): + captured["primary"] = primary + return resolve_turn_route(user_message, routing_config, primary) + + monkeypatch.setattr( + "agent.smart_model_routing.resolve_turn_route", spy_resolve + ) + + from gateway.run import GatewayRunner + + runner = SimpleNamespace( + _smart_model_routing={"enabled": False}, + ) + + runtime_kwargs = { + "api_key": "sk-test", + "base_url": None, + "provider": "openai-codex", + "api_mode": "codex_responses", + "command": None, + "args": [], + "credential_pool": MagicMock(name="FakePool"), + } + + bound = GatewayRunner._resolve_turn_agent_config.__get__(runner) + bound("test message", "gpt-5.4", runtime_kwargs) + + assert "credential_pool" in captured["primary"] + assert captured["primary"]["credential_pool"] is runtime_kwargs["credential_pool"] + + +# --------------------------------------------------------------------------- +# 4 & 5. Eager fallback deferred/fires based on credential pool +# --------------------------------------------------------------------------- + +class TestEagerFallbackWithPool: + """Test the eager fallback guard in run_agent.py's error handling loop.""" + + def _make_agent(self, has_pool=True, pool_has_creds=True, has_fallback=True): + """Create a minimal AIAgent mock with the fields needed.""" + from run_agent import AIAgent + + with patch.object(AIAgent, "__init__", lambda self, **kw: None): + agent = AIAgent() + + agent._credential_pool = None + if has_pool: + pool = MagicMock() + pool.has_available.return_value = pool_has_creds + agent._credential_pool = pool + + agent._fallback_chain = [{"model": "fallback/model"}] if has_fallback else [] + agent._fallback_index = 0 + agent._try_activate_fallback = MagicMock(return_value=True) + agent._emit_status = MagicMock() + + return agent + + def test_eager_fallback_deferred_when_pool_has_credentials(self): + """429 with active pool should NOT trigger eager fallback.""" + agent = self._make_agent(has_pool=True, pool_has_creds=True, has_fallback=True) + + # Simulate the check from run_agent.py lines 7180-7191 + is_rate_limited = True + if is_rate_limited and agent._fallback_index < len(agent._fallback_chain): + pool = agent._credential_pool + pool_may_recover = pool is not None and pool.has_available() + if not pool_may_recover: + agent._try_activate_fallback() + + agent._try_activate_fallback.assert_not_called() + + def test_eager_fallback_fires_when_no_pool(self): + """429 without pool should trigger eager fallback.""" + agent = self._make_agent(has_pool=False, has_fallback=True) + + is_rate_limited = True + if is_rate_limited and agent._fallback_index < len(agent._fallback_chain): + pool = agent._credential_pool + pool_may_recover = pool is not None and pool.has_available() + if not pool_may_recover: + agent._try_activate_fallback() + + agent._try_activate_fallback.assert_called_once() + + def test_eager_fallback_fires_when_pool_exhausted(self): + """429 with exhausted pool should trigger eager fallback.""" + agent = self._make_agent(has_pool=True, pool_has_creds=False, has_fallback=True) + + is_rate_limited = True + if is_rate_limited and agent._fallback_index < len(agent._fallback_chain): + pool = agent._credential_pool + pool_may_recover = pool is not None and pool.has_available() + if not pool_may_recover: + agent._try_activate_fallback() + + agent._try_activate_fallback.assert_called_once() + + +# --------------------------------------------------------------------------- +# 6. Full 429 rotation cycle via _recover_with_credential_pool +# --------------------------------------------------------------------------- + +class TestPoolRotationCycle: + """Verify the retry-same → rotate → exhaust flow in _recover_with_credential_pool.""" + + def _make_agent_with_pool(self, pool_entries=3): + from run_agent import AIAgent + + with patch.object(AIAgent, "__init__", lambda self, **kw: None): + agent = AIAgent() + + entries = [] + for i in range(pool_entries): + e = MagicMock(name=f"entry_{i}") + e.id = f"cred-{i}" + entries.append(e) + + pool = MagicMock() + pool.has_credentials.return_value = True + + # mark_exhausted_and_rotate returns next entry until exhausted + self._rotation_index = 0 + + def rotate(status_code=None): + self._rotation_index += 1 + if self._rotation_index < pool_entries: + return entries[self._rotation_index] + pool.has_credentials.return_value = False + return None + + pool.mark_exhausted_and_rotate = MagicMock(side_effect=rotate) + agent._credential_pool = pool + agent._swap_credential = MagicMock() + agent.log_prefix = "" + + return agent, pool, entries + + def test_first_429_sets_retry_flag_no_rotation(self): + """First 429 should just set has_retried_429=True, no rotation.""" + agent, pool, _ = self._make_agent_with_pool(3) + recovered, has_retried = agent._recover_with_credential_pool( + status_code=429, has_retried_429=False + ) + assert recovered is False + assert has_retried is True + pool.mark_exhausted_and_rotate.assert_not_called() + + def test_second_429_rotates_to_next(self): + """Second consecutive 429 should rotate to next credential.""" + agent, pool, entries = self._make_agent_with_pool(3) + recovered, has_retried = agent._recover_with_credential_pool( + status_code=429, has_retried_429=True + ) + assert recovered is True + assert has_retried is False # reset after rotation + pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=429) + agent._swap_credential.assert_called_once_with(entries[1]) + + def test_pool_exhaustion_returns_false(self): + """When all credentials exhausted, recovery should return False.""" + agent, pool, _ = self._make_agent_with_pool(1) + # First 429 sets flag + _, has_retried = agent._recover_with_credential_pool( + status_code=429, has_retried_429=False + ) + assert has_retried is True + + # Second 429 tries to rotate but pool is exhausted (only 1 entry) + recovered, _ = agent._recover_with_credential_pool( + status_code=429, has_retried_429=True + ) + assert recovered is False + + def test_402_immediate_rotation(self): + """402 (billing) should immediately rotate, no retry-first.""" + agent, pool, entries = self._make_agent_with_pool(3) + recovered, has_retried = agent._recover_with_credential_pool( + status_code=402, has_retried_429=False + ) + assert recovered is True + assert has_retried is False + pool.mark_exhausted_and_rotate.assert_called_once_with(status_code=402) + + def test_no_pool_returns_false(self): + """No pool should return (False, unchanged).""" + from run_agent import AIAgent + + with patch.object(AIAgent, "__init__", lambda self, **kw: None): + agent = AIAgent() + agent._credential_pool = None + + recovered, has_retried = agent._recover_with_credential_pool( + status_code=429, has_retried_429=False + ) + assert recovered is False + assert has_retried is False