From f44415e71a26bd51e59c951947722845e05cdfb3 Mon Sep 17 00:00:00 2001 From: liuhao1024 Date: Wed, 17 Jun 2026 12:53:47 +0800 Subject: [PATCH] fix(gateway): add init-time provider fallback to _make_agent MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When the primary provider raises AuthError (e.g. expired OAuth token), _make_agent now walks the configured fallback_providers/fallback_model chain before giving up — matching the behavior that cron/scheduler.py and cli_agent_setup_mixin.py already have. Fixes #47627 --- tests/test_tui_gateway_server.py | 140 +++++++++++++++++++++++++++++++ tui_gateway/server.py | 62 +++++++++++--- 2 files changed, 191 insertions(+), 11 deletions(-) diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 698b4c0d45f..07d9765eea6 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -8173,3 +8173,143 @@ def test_persist_model_switch_clears_stale_base_url(tmp_path, monkeypatch): assert saved["model"]["provider"] == "anthropic" # Stale custom base_url must be cleared (null coalesces to absent on read). assert not saved["model"].get("base_url"), saved["model"].get("base_url") + + +# --------------------------------------------------------------------------- +# _resolve_runtime_with_fallback — init-time provider fallback +# --------------------------------------------------------------------------- + +class TestResolveRuntimeWithFallback: + """Tests for _resolve_runtime_with_fallback(): init-time provider + fallback when the primary provider raises AuthError.""" + + def test_primary_success_returns_runtime(self, monkeypatch): + """When primary resolve succeeds, return its result directly.""" + expected = {"provider": "openai", "api_key": "tok"} + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + lambda **kw: expected, + ) + result = server._resolve_runtime_with_fallback({"requested": "openai"}) + assert result == expected + + def test_auth_error_tries_fallback_chain(self, monkeypatch): + """On AuthError from primary, walk fallback_providers chain.""" + from hermes_cli.auth import AuthError + + fallback_runtime = {"provider": "deepseek", "api_key": "fb-tok"} + + def fake_resolve(**kwargs): + if kwargs.get("requested") == "openai-codex": + raise AuthError("No Codex credentials stored") + return fallback_runtime + + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve, + ) + monkeypatch.setattr( + server, + "_load_fallback_model", + lambda: [{"provider": "deepseek", "model": "deepseek-v4-pro"}], + ) + result = server._resolve_runtime_with_fallback( + {"requested": "openai-codex"}, + ) + assert result == fallback_runtime + + def test_auth_error_all_fallbacks_fail_raises(self, monkeypatch): + """When all fallbacks also fail, re-raise the original AuthError.""" + from hermes_cli.auth import AuthError + + def fake_resolve(**kwargs): + raise AuthError("No credentials for " + str(kwargs.get("requested"))) + + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve, + ) + monkeypatch.setattr( + server, + "_load_fallback_model", + lambda: [{"provider": "deepseek", "model": "deepseek-v4-pro"}], + ) + import pytest + + with pytest.raises(AuthError, match="No credentials for openai-codex"): + server._resolve_runtime_with_fallback( + {"requested": "openai-codex"}, + ) + + def test_auth_error_skips_non_dict_entries(self, monkeypatch): + """Fallback chain entries that are not dicts are skipped.""" + from hermes_cli.auth import AuthError + + fallback_runtime = {"provider": "anthropic", "api_key": "ant-tok"} + + def fake_resolve(**kwargs): + if kwargs.get("requested") == "openai-codex": + raise AuthError("No Codex credentials stored") + return fallback_runtime + + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve, + ) + monkeypatch.setattr( + server, + "_load_fallback_model", + lambda: [ + "invalid-string-entry", + {"provider": "anthropic", "model": "claude-sonnet-4-6"}, + ], + ) + result = server._resolve_runtime_with_fallback( + {"requested": "openai-codex"}, + ) + assert result == fallback_runtime + + def test_make_agent_uses_fallback_on_auth_error(self, monkeypatch): + """Integration: _make_agent falls back to configured fallback + provider when the primary provider raises AuthError.""" + import types + + from hermes_cli.auth import AuthError + + captured = {} + fallback_runtime = {"provider": "deepseek", "api_key": "fb-tok"} + + def fake_resolve(**kwargs): + if kwargs.get("requested") == "openai-codex": + raise AuthError("No Codex credentials stored") + return fallback_runtime + + def fake_agent(**kwargs): + captured.update(kwargs) + return types.SimpleNamespace(model=kwargs.get("model")) + + monkeypatch.delenv("HERMES_MODEL", raising=False) + monkeypatch.delenv("HERMES_INFERENCE_MODEL", raising=False) + monkeypatch.delenv("HERMES_TUI_PROVIDER", raising=False) + monkeypatch.setattr( + server, + "_load_cfg", + lambda: { + "model": {"default": "gpt-5.5", "provider": "openai-codex"}, + "fallback_providers": [ + {"provider": "deepseek", "model": "deepseek-v4-pro"}, + ], + }, + ) + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve, + ) + monkeypatch.setattr("run_agent.AIAgent", fake_agent) + monkeypatch.setattr(server, "_load_enabled_toolsets", lambda: ["file"]) + monkeypatch.setattr(server, "_get_db", lambda: None) + + agent = server._make_agent("sid", "session-key") + + assert agent.model == "gpt-5.5" + assert captured["provider"] == "deepseek" diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 0d572065d44..8171caa9cde 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -4081,7 +4081,6 @@ def _schedule_mcp_late_refresh(sid: str, agent) -> None: info = _session_info(agent, session) # Emit outside the lock — write_json must not block under _sessions_lock. _emit("session.info", sid, info) - threading.Thread( target=_wait_then_refresh, name=f"tui-mcp-late-refresh-{sid}", @@ -4089,6 +4088,50 @@ def _schedule_mcp_late_refresh(sid: str, agent) -> None: ).start() +def _resolve_runtime_with_fallback( + resolve_kwargs: dict | None = None, +) -> dict: + """Resolve runtime provider with init-time fallback on auth failure. + + Mirrors the fallback pattern in ``cron/scheduler.py`` and + ``hermes_cli/cli_agent_setup_mixin.py``: when the primary provider + raises ``AuthError``, walk the configured ``fallback_providers`` / + ``fallback_model`` chain before giving up. + """ + from hermes_cli.auth import AuthError + from hermes_cli.runtime_provider import resolve_runtime_provider + + kwargs = resolve_kwargs or {} + try: + return resolve_runtime_provider(**kwargs) + except AuthError as primary_exc: + fb_chain = _load_fallback_model() or [] + for entry in fb_chain: + if not isinstance(entry, dict): + continue + fb_provider = (entry.get("provider") or "").strip() + if not fb_provider: + continue + try: + fb_kwargs: dict = {"requested": fb_provider} + if entry.get("base_url"): + fb_kwargs["explicit_base_url"] = entry["base_url"] + if entry.get("api_key"): + fb_kwargs["explicit_api_key"] = entry["api_key"] + runtime = resolve_runtime_provider(**fb_kwargs) + import logging + + logging.getLogger(__name__).warning( + "Primary auth failed (%s), falling back to %s", + primary_exc, + fb_provider, + ) + return runtime + except Exception: + continue + raise + + def _make_agent( sid: str, key: str, @@ -4100,7 +4143,6 @@ def _make_agent( service_tier_override: str | None = None, ): from run_agent import AIAgent - from hermes_cli.runtime_provider import resolve_runtime_provider # MCP tool discovery runs in a background daemon thread at startup so a # dead server can't freeze the shell. The agent snapshots its tool list @@ -4169,11 +4211,9 @@ def _make_agent( # Failing identity recovery, still hand the base_url to the # direct-alias branch so pool/env credentials resolve for it. resolve_kwargs["explicit_base_url"] = override_base_url - runtime = resolve_runtime_provider( - requested=requested_provider, - target_model=model or None, - **resolve_kwargs, - ) + resolve_kwargs["requested"] = requested_provider + resolve_kwargs["target_model"] = model or None + runtime = _resolve_runtime_with_fallback(resolve_kwargs) # The switch already resolved concrete credentials/endpoint; honor them # so a custom/named endpoint survives the rebuild even if global # resolution would pick a different one. @@ -4189,10 +4229,10 @@ def _make_agent( model = model_override if provider_override: requested_provider = provider_override - runtime = resolve_runtime_provider( - requested=requested_provider, - target_model=model or None, - ) + runtime = _resolve_runtime_with_fallback({ + "requested": requested_provider, + "target_model": model or None, + }) _pr = _load_provider_routing() return AIAgent( model=model,