diff --git a/cron/scheduler.py b/cron/scheduler.py index 0355d2475..3dbb54c7d 100644 --- a/cron/scheduler.py +++ b/cron/scheduler.py @@ -895,6 +895,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: resolve_runtime_provider, format_runtime_provider_error, ) + from hermes_cli.auth import AuthError try: runtime_kwargs = { "requested": job.get("provider") or os.getenv("HERMES_INFERENCE_PROVIDER"), @@ -902,6 +903,28 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]: if job.get("base_url"): runtime_kwargs["explicit_base_url"] = job.get("base_url") runtime = resolve_runtime_provider(**runtime_kwargs) + except AuthError as auth_exc: + # Primary provider auth failed — try fallback chain before giving up. + logger.warning("Job '%s': primary auth failed (%s), trying fallback", job_id, auth_exc) + fb = _cfg.get("fallback_providers") or _cfg.get("fallback_model") + fb_list = (fb if isinstance(fb, list) else [fb]) if fb else [] + runtime = None + for entry in fb_list: + if not isinstance(entry, dict): + continue + try: + fb_kwargs = {"requested": entry.get("provider")} + if entry.get("base_url"): + fb_kwargs["explicit_base_url"] = entry["base_url"] + if entry.get("api_key"): + fb_kwargs["explicit_api_key"] = entry["api_key"] + runtime = resolve_runtime_provider(**fb_kwargs) + logger.info("Job '%s': fallback resolved to %s", job_id, runtime.get("provider")) + break + except Exception as fb_exc: + logger.debug("Job '%s': fallback %s failed: %s", job_id, entry.get("provider"), fb_exc) + if runtime is None: + raise RuntimeError(format_runtime_provider_error(auth_exc)) from auth_exc except Exception as exc: message = format_runtime_provider_error(exc) raise RuntimeError(message) from exc diff --git a/gateway/run.py b/gateway/run.py index 0dad9af10..75db1972a 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -350,16 +350,30 @@ _AGENT_PENDING_SENTINEL = object() def _resolve_runtime_agent_kwargs() -> dict: - """Resolve provider credentials for gateway-created AIAgent instances.""" + """Resolve provider credentials for gateway-created AIAgent instances. + + If the primary provider fails with an authentication error, attempt to + resolve credentials using the fallback provider chain from config.yaml + before giving up. + """ from hermes_cli.runtime_provider import ( resolve_runtime_provider, format_runtime_provider_error, ) + from hermes_cli.auth import AuthError try: runtime = resolve_runtime_provider( requested=os.getenv("HERMES_INFERENCE_PROVIDER"), ) + except AuthError as auth_exc: + # Primary provider auth failed (expired token, revoked key, etc.). + # Try the fallback provider chain before raising. + logger.warning("Primary provider auth failed: %s — trying fallback", auth_exc) + fb_config = _try_resolve_fallback_provider() + if fb_config is not None: + return fb_config + raise RuntimeError(format_runtime_provider_error(auth_exc)) from auth_exc except Exception as exc: raise RuntimeError(format_runtime_provider_error(exc)) from exc @@ -374,6 +388,48 @@ def _resolve_runtime_agent_kwargs() -> dict: } +def _try_resolve_fallback_provider() -> dict | None: + """Attempt to resolve credentials from the fallback_model/fallback_providers config.""" + from hermes_cli.runtime_provider import resolve_runtime_provider + try: + import yaml as _y + cfg_path = _hermes_home / "config.yaml" + if not cfg_path.exists(): + return None + with open(cfg_path, encoding="utf-8") as _f: + cfg = _y.safe_load(_f) or {} + fb = cfg.get("fallback_providers") or cfg.get("fallback_model") + if not fb: + return None + # Normalize to list + fb_list = fb if isinstance(fb, list) else [fb] + for entry in fb_list: + if not isinstance(entry, dict): + continue + try: + runtime = resolve_runtime_provider( + requested=entry.get("provider"), + explicit_base_url=entry.get("base_url"), + explicit_api_key=entry.get("api_key"), + ) + logger.info("Fallback provider resolved: %s", runtime.get("provider")) + return { + "api_key": runtime.get("api_key"), + "base_url": runtime.get("base_url"), + "provider": runtime.get("provider"), + "api_mode": runtime.get("api_mode"), + "command": runtime.get("command"), + "args": list(runtime.get("args") or []), + "credential_pool": runtime.get("credential_pool"), + } + except Exception as fb_exc: + logger.debug("Fallback entry %s failed: %s", entry.get("provider"), fb_exc) + continue + except Exception: + pass + return None + + def _build_media_placeholder(event) -> str: """Build a text placeholder for media-only events so they aren't dropped. diff --git a/tests/gateway/test_auth_fallback.py b/tests/gateway/test_auth_fallback.py new file mode 100644 index 000000000..3edb8b1ee --- /dev/null +++ b/tests/gateway/test_auth_fallback.py @@ -0,0 +1,73 @@ +"""Test that AuthError triggers fallback provider resolution (#7230).""" + +import os +from unittest.mock import patch, MagicMock + +import pytest + + +class TestResolveRuntimeAgentKwargsAuthFallback: + """_resolve_runtime_agent_kwargs should try fallback on AuthError.""" + + def test_auth_error_tries_fallback(self, tmp_path, monkeypatch): + """When primary provider raises AuthError, fallback is attempted.""" + from hermes_cli.auth import AuthError + + # Create a config with fallback + config_path = tmp_path / "config.yaml" + config_path.write_text( + "model:\n provider: openai-codex\n" + "fallback_model:\n provider: openrouter\n" + " model: meta-llama/llama-4-maverick\n" + ) + + monkeypatch.setattr("gateway.run._hermes_home", tmp_path) + + call_count = {"n": 0} + + def _mock_resolve(**kwargs): + call_count["n"] += 1 + requested = kwargs.get("requested", "") + if requested and "codex" in str(requested).lower(): + raise AuthError("Codex token refresh failed with status 401") + return { + "api_key": "fallback-key", + "base_url": "https://openrouter.ai/api/v1", + "provider": "openrouter", + "api_mode": "openai_chat", + "command": None, + "args": None, + "credential_pool": None, + } + + monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex") + + with patch( + "hermes_cli.runtime_provider.resolve_runtime_provider", + side_effect=_mock_resolve, + ): + from gateway.run import _resolve_runtime_agent_kwargs + result = _resolve_runtime_agent_kwargs() + + assert result["provider"] == "openrouter" + assert result["api_key"] == "fallback-key" + # Should have been called at least twice (primary + fallback) + assert call_count["n"] >= 2 + + def test_auth_error_no_fallback_raises(self, tmp_path, monkeypatch): + """When primary fails and no fallback configured, RuntimeError is raised.""" + from hermes_cli.auth import AuthError + + config_path = tmp_path / "config.yaml" + config_path.write_text("model:\n provider: openai-codex\n") + + monkeypatch.setattr("gateway.run._hermes_home", tmp_path) + monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex") + + with patch( + "hermes_cli.runtime_provider.resolve_runtime_provider", + side_effect=AuthError("token expired"), + ): + from gateway.run import _resolve_runtime_agent_kwargs + with pytest.raises(RuntimeError): + _resolve_runtime_agent_kwargs()