mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway,cron): activate fallback_model when primary provider auth fails
When the primary provider raises AuthError (expired OAuth token, revoked API key), the error was re-raised before AIAgent was created, so fallback_model was never consulted. Now both gateway/run.py and cron/scheduler.py catch AuthError specifically and attempt to resolve credentials from the fallback_providers/fallback_model config chain before propagating the error. Closes #7230
This commit is contained in:
parent
f7f7588893
commit
ee83a710f0
3 changed files with 153 additions and 1 deletions
|
|
@ -895,6 +895,7 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||||
resolve_runtime_provider,
|
resolve_runtime_provider,
|
||||||
format_runtime_provider_error,
|
format_runtime_provider_error,
|
||||||
)
|
)
|
||||||
|
from hermes_cli.auth import AuthError
|
||||||
try:
|
try:
|
||||||
runtime_kwargs = {
|
runtime_kwargs = {
|
||||||
"requested": job.get("provider") or os.getenv("HERMES_INFERENCE_PROVIDER"),
|
"requested": job.get("provider") or os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||||
|
|
@ -902,6 +903,28 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||||
if job.get("base_url"):
|
if job.get("base_url"):
|
||||||
runtime_kwargs["explicit_base_url"] = job.get("base_url")
|
runtime_kwargs["explicit_base_url"] = job.get("base_url")
|
||||||
runtime = resolve_runtime_provider(**runtime_kwargs)
|
runtime = resolve_runtime_provider(**runtime_kwargs)
|
||||||
|
except AuthError as auth_exc:
|
||||||
|
# Primary provider auth failed — try fallback chain before giving up.
|
||||||
|
logger.warning("Job '%s': primary auth failed (%s), trying fallback", job_id, auth_exc)
|
||||||
|
fb = _cfg.get("fallback_providers") or _cfg.get("fallback_model")
|
||||||
|
fb_list = (fb if isinstance(fb, list) else [fb]) if fb else []
|
||||||
|
runtime = None
|
||||||
|
for entry in fb_list:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
fb_kwargs = {"requested": entry.get("provider")}
|
||||||
|
if entry.get("base_url"):
|
||||||
|
fb_kwargs["explicit_base_url"] = entry["base_url"]
|
||||||
|
if entry.get("api_key"):
|
||||||
|
fb_kwargs["explicit_api_key"] = entry["api_key"]
|
||||||
|
runtime = resolve_runtime_provider(**fb_kwargs)
|
||||||
|
logger.info("Job '%s': fallback resolved to %s", job_id, runtime.get("provider"))
|
||||||
|
break
|
||||||
|
except Exception as fb_exc:
|
||||||
|
logger.debug("Job '%s': fallback %s failed: %s", job_id, entry.get("provider"), fb_exc)
|
||||||
|
if runtime is None:
|
||||||
|
raise RuntimeError(format_runtime_provider_error(auth_exc)) from auth_exc
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
message = format_runtime_provider_error(exc)
|
message = format_runtime_provider_error(exc)
|
||||||
raise RuntimeError(message) from exc
|
raise RuntimeError(message) from exc
|
||||||
|
|
|
||||||
|
|
@ -350,16 +350,30 @@ _AGENT_PENDING_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
def _resolve_runtime_agent_kwargs() -> dict:
|
def _resolve_runtime_agent_kwargs() -> dict:
|
||||||
"""Resolve provider credentials for gateway-created AIAgent instances."""
|
"""Resolve provider credentials for gateway-created AIAgent instances.
|
||||||
|
|
||||||
|
If the primary provider fails with an authentication error, attempt to
|
||||||
|
resolve credentials using the fallback provider chain from config.yaml
|
||||||
|
before giving up.
|
||||||
|
"""
|
||||||
from hermes_cli.runtime_provider import (
|
from hermes_cli.runtime_provider import (
|
||||||
resolve_runtime_provider,
|
resolve_runtime_provider,
|
||||||
format_runtime_provider_error,
|
format_runtime_provider_error,
|
||||||
)
|
)
|
||||||
|
from hermes_cli.auth import AuthError
|
||||||
|
|
||||||
try:
|
try:
|
||||||
runtime = resolve_runtime_provider(
|
runtime = resolve_runtime_provider(
|
||||||
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
requested=os.getenv("HERMES_INFERENCE_PROVIDER"),
|
||||||
)
|
)
|
||||||
|
except AuthError as auth_exc:
|
||||||
|
# Primary provider auth failed (expired token, revoked key, etc.).
|
||||||
|
# Try the fallback provider chain before raising.
|
||||||
|
logger.warning("Primary provider auth failed: %s — trying fallback", auth_exc)
|
||||||
|
fb_config = _try_resolve_fallback_provider()
|
||||||
|
if fb_config is not None:
|
||||||
|
return fb_config
|
||||||
|
raise RuntimeError(format_runtime_provider_error(auth_exc)) from auth_exc
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise RuntimeError(format_runtime_provider_error(exc)) from exc
|
raise RuntimeError(format_runtime_provider_error(exc)) from exc
|
||||||
|
|
||||||
|
|
@ -374,6 +388,48 @@ def _resolve_runtime_agent_kwargs() -> dict:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _try_resolve_fallback_provider() -> dict | None:
|
||||||
|
"""Attempt to resolve credentials from the fallback_model/fallback_providers config."""
|
||||||
|
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||||
|
try:
|
||||||
|
import yaml as _y
|
||||||
|
cfg_path = _hermes_home / "config.yaml"
|
||||||
|
if not cfg_path.exists():
|
||||||
|
return None
|
||||||
|
with open(cfg_path, encoding="utf-8") as _f:
|
||||||
|
cfg = _y.safe_load(_f) or {}
|
||||||
|
fb = cfg.get("fallback_providers") or cfg.get("fallback_model")
|
||||||
|
if not fb:
|
||||||
|
return None
|
||||||
|
# Normalize to list
|
||||||
|
fb_list = fb if isinstance(fb, list) else [fb]
|
||||||
|
for entry in fb_list:
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
runtime = resolve_runtime_provider(
|
||||||
|
requested=entry.get("provider"),
|
||||||
|
explicit_base_url=entry.get("base_url"),
|
||||||
|
explicit_api_key=entry.get("api_key"),
|
||||||
|
)
|
||||||
|
logger.info("Fallback provider resolved: %s", runtime.get("provider"))
|
||||||
|
return {
|
||||||
|
"api_key": runtime.get("api_key"),
|
||||||
|
"base_url": runtime.get("base_url"),
|
||||||
|
"provider": runtime.get("provider"),
|
||||||
|
"api_mode": runtime.get("api_mode"),
|
||||||
|
"command": runtime.get("command"),
|
||||||
|
"args": list(runtime.get("args") or []),
|
||||||
|
"credential_pool": runtime.get("credential_pool"),
|
||||||
|
}
|
||||||
|
except Exception as fb_exc:
|
||||||
|
logger.debug("Fallback entry %s failed: %s", entry.get("provider"), fb_exc)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _build_media_placeholder(event) -> str:
|
def _build_media_placeholder(event) -> str:
|
||||||
"""Build a text placeholder for media-only events so they aren't dropped.
|
"""Build a text placeholder for media-only events so they aren't dropped.
|
||||||
|
|
||||||
|
|
|
||||||
73
tests/gateway/test_auth_fallback.py
Normal file
73
tests/gateway/test_auth_fallback.py
Normal file
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""Test that AuthError triggers fallback provider resolution (#7230)."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveRuntimeAgentKwargsAuthFallback:
|
||||||
|
"""_resolve_runtime_agent_kwargs should try fallback on AuthError."""
|
||||||
|
|
||||||
|
def test_auth_error_tries_fallback(self, tmp_path, monkeypatch):
|
||||||
|
"""When primary provider raises AuthError, fallback is attempted."""
|
||||||
|
from hermes_cli.auth import AuthError
|
||||||
|
|
||||||
|
# Create a config with fallback
|
||||||
|
config_path = tmp_path / "config.yaml"
|
||||||
|
config_path.write_text(
|
||||||
|
"model:\n provider: openai-codex\n"
|
||||||
|
"fallback_model:\n provider: openrouter\n"
|
||||||
|
" model: meta-llama/llama-4-maverick\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||||
|
|
||||||
|
call_count = {"n": 0}
|
||||||
|
|
||||||
|
def _mock_resolve(**kwargs):
|
||||||
|
call_count["n"] += 1
|
||||||
|
requested = kwargs.get("requested", "")
|
||||||
|
if requested and "codex" in str(requested).lower():
|
||||||
|
raise AuthError("Codex token refresh failed with status 401")
|
||||||
|
return {
|
||||||
|
"api_key": "fallback-key",
|
||||||
|
"base_url": "https://openrouter.ai/api/v1",
|
||||||
|
"provider": "openrouter",
|
||||||
|
"api_mode": "openai_chat",
|
||||||
|
"command": None,
|
||||||
|
"args": None,
|
||||||
|
"credential_pool": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||||
|
side_effect=_mock_resolve,
|
||||||
|
):
|
||||||
|
from gateway.run import _resolve_runtime_agent_kwargs
|
||||||
|
result = _resolve_runtime_agent_kwargs()
|
||||||
|
|
||||||
|
assert result["provider"] == "openrouter"
|
||||||
|
assert result["api_key"] == "fallback-key"
|
||||||
|
# Should have been called at least twice (primary + fallback)
|
||||||
|
assert call_count["n"] >= 2
|
||||||
|
|
||||||
|
def test_auth_error_no_fallback_raises(self, tmp_path, monkeypatch):
|
||||||
|
"""When primary fails and no fallback configured, RuntimeError is raised."""
|
||||||
|
from hermes_cli.auth import AuthError
|
||||||
|
|
||||||
|
config_path = tmp_path / "config.yaml"
|
||||||
|
config_path.write_text("model:\n provider: openai-codex\n")
|
||||||
|
|
||||||
|
monkeypatch.setattr("gateway.run._hermes_home", tmp_path)
|
||||||
|
monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "openai-codex")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"hermes_cli.runtime_provider.resolve_runtime_provider",
|
||||||
|
side_effect=AuthError("token expired"),
|
||||||
|
):
|
||||||
|
from gateway.run import _resolve_runtime_agent_kwargs
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
_resolve_runtime_agent_kwargs()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue