fix(agent): roll back switch_model() state when client rebuild fails (#33228)

Closes #33175.

switch_model() in agent/agent_runtime_helpers.py mutated agent.model and
agent.provider before rebuilding the client, with no try/except to restore
them on failure. If the rebuild raised (bad API key, network error,
build_anthropic_client failure, etc.) the agent was left with the new
model+provider name paired with the OLD client — producing HTTP 400s like
"claude-sonnet-4-6 is not supported on openai-codex" on the next turn.

Callers in cli.py, gateway/run.py, and tui_gateway/server.py already catch
the exception and warn the user, but the warning was misleading because
the swap had partially succeeded; the agent's state was torn.

Snapshot every mutated field before the swap, wrap the swap+rebuild block
in try/except, and restore the snapshot on failure before re-raising so
the caller's warning surfaces.

Reported by @amirariff91. Tests cover both branches (chat_completions and
anthropic_messages) and the cross-branch case (anthropic -> openai).
This commit is contained in:
Teknium 2026-05-27 05:43:20 -07:00 committed by GitHub
parent 825948edab
commit f0de3cd0a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 324 additions and 72 deletions

View file

@ -1361,81 +1361,129 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo
old_model = agent.model
old_provider = agent.provider
# Clear the per-config context_length override so the new model's
# actual context window is resolved via get_model_context_length()
# instead of inheriting the stale value from the previous model.
agent._config_context_length = None
# ── Swap core runtime fields ──
agent.model = new_model
agent.provider = new_provider
# Use new base_url when provided; only fall back to current when the
# new provider genuinely has no endpoint (e.g. native SDK providers).
# Without this guard the old provider's URL (e.g. Ollama's localhost
# address) would persist silently after switching to a cloud provider
# that returns an empty base_url string.
if base_url:
agent.base_url = base_url
agent.api_mode = api_mode
# Invalidate transport cache — new api_mode may need a different transport
if hasattr(agent, "_transport_cache"):
agent._transport_cache.clear()
if api_key:
agent.api_key = api_key
# ── Build new client ──
if api_mode == "anthropic_messages":
from agent.anthropic_adapter import (
build_anthropic_client,
resolve_anthropic_token,
_is_oauth_token,
# ── Snapshot all fields the swap+rebuild can mutate ──
# If the rebuild raises (bad API key, network error, build_anthropic_client
# failure, etc.) we restore these atomically so the agent isn't left with a
# new model/provider name paired with the OLD client — that mismatch causes
# HTTP 400s like "claude-sonnet-4-6 is not supported on openai-codex" on the
# next turn. Callers in cli.py / gateway/run.py / tui_gateway/server.py
# catch the re-raised exception and show the user a warning; without this
# rollback the warning is misleading because the swap partially succeeded.
# Use a sentinel so we can distinguish "attribute was unset" from
# "attribute was None" and skip the restore for genuinely-missing
# attributes (tests construct bare agents via __new__ without all fields).
_MISSING = object()
_snapshot = {
name: getattr(agent, name, _MISSING)
for name in (
"model",
"provider",
"base_url",
"api_mode",
"api_key",
"client",
"_anthropic_client",
"_anthropic_api_key",
"_anthropic_base_url",
"_is_anthropic_oauth",
"_config_context_length",
)
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own
# API key — falling back would send Anthropic credentials to third-party endpoints.
_is_native_anthropic = new_provider == "anthropic"
effective_key = (api_key or agent.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or agent.api_key or "")
}
# _client_kwargs is a dict — snapshot a shallow copy so mutating the
# live dict doesn't poison the rollback target.
_snapshot["_client_kwargs"] = dict(getattr(agent, "_client_kwargs", {}) or {})
# MiniMax OAuth: swap static string for a per-request callable token
# provider so the rebuilt client survives 15-min token expiry. See
# the matching block in agent_init.py for the full rationale.
if new_provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key:
try:
# Clear the per-config context_length override so the new model's
# actual context window is resolved via get_model_context_length()
# instead of inheriting the stale value from the previous model.
agent._config_context_length = None
# ── Swap core runtime fields ──
agent.model = new_model
agent.provider = new_provider
# Use new base_url when provided; only fall back to current when the
# new provider genuinely has no endpoint (e.g. native SDK providers).
# Without this guard the old provider's URL (e.g. Ollama's localhost
# address) would persist silently after switching to a cloud provider
# that returns an empty base_url string.
if base_url:
agent.base_url = base_url
agent.api_mode = api_mode
# Invalidate transport cache — new api_mode may need a different transport
if hasattr(agent, "_transport_cache"):
agent._transport_cache.clear()
if api_key:
agent.api_key = api_key
# ── Build new client ──
if api_mode == "anthropic_messages":
from agent.anthropic_adapter import (
build_anthropic_client,
resolve_anthropic_token,
_is_oauth_token,
)
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own
# API key — falling back would send Anthropic credentials to third-party endpoints.
_is_native_anthropic = new_provider == "anthropic"
effective_key = (api_key or agent.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or agent.api_key or "")
# MiniMax OAuth: swap static string for a per-request callable token
# provider so the rebuilt client survives 15-min token expiry. See
# the matching block in agent_init.py for the full rationale.
if new_provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key:
try:
from hermes_cli.auth import build_minimax_oauth_token_provider
effective_key = build_minimax_oauth_token_provider()
except Exception as _mm_exc: # noqa: BLE001
import logging as _logging
_logging.getLogger(__name__).warning(
"MiniMax OAuth: failed to install per-request token provider "
"on switch (%s); using static bearer.",
_mm_exc,
)
agent.api_key = effective_key
agent._anthropic_api_key = effective_key
agent._anthropic_base_url = base_url or getattr(agent, "_anthropic_base_url", None)
agent._anthropic_client = build_anthropic_client(
effective_key, agent._anthropic_base_url,
timeout=get_provider_request_timeout(agent.provider, agent.model),
)
agent._is_anthropic_oauth = _is_oauth_token(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False
agent.client = None
agent._client_kwargs = {}
else:
effective_key = api_key or agent.api_key
effective_base = base_url or agent.base_url
agent._client_kwargs = {
"api_key": effective_key,
"base_url": effective_base,
}
_sm_timeout = get_provider_request_timeout(agent.provider, agent.model)
if _sm_timeout is not None:
agent._client_kwargs["timeout"] = _sm_timeout
agent.client = agent._create_openai_client(
dict(agent._client_kwargs),
reason="switch_model",
shared=True,
)
except Exception:
# Rollback every mutated field to the pre-swap snapshot so the agent
# is left consistent (old model + old provider + old client) and the
# caller's exception handler can surface a meaningful warning. The
# exception is re-raised; cli.py / gateway/run.py / tui_gateway catch
# it and print "Agent swap failed; change applied to next session".
for _name, _value in _snapshot.items():
if _value is _MISSING:
# Attribute did not exist before the swap — don't fabricate it.
continue
try:
from hermes_cli.auth import build_minimax_oauth_token_provider
effective_key = build_minimax_oauth_token_provider()
except Exception as _mm_exc: # noqa: BLE001
import logging as _logging
_logging.getLogger(__name__).warning(
"MiniMax OAuth: failed to install per-request token provider "
"on switch (%s); using static bearer.",
_mm_exc,
)
agent.api_key = effective_key
agent._anthropic_api_key = effective_key
agent._anthropic_base_url = base_url or getattr(agent, "_anthropic_base_url", None)
agent._anthropic_client = build_anthropic_client(
effective_key, agent._anthropic_base_url,
timeout=get_provider_request_timeout(agent.provider, agent.model),
)
agent._is_anthropic_oauth = _is_oauth_token(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False
agent.client = None
agent._client_kwargs = {}
else:
effective_key = api_key or agent.api_key
effective_base = base_url or agent.base_url
agent._client_kwargs = {
"api_key": effective_key,
"base_url": effective_base,
}
_sm_timeout = get_provider_request_timeout(agent.provider, agent.model)
if _sm_timeout is not None:
agent._client_kwargs["timeout"] = _sm_timeout
agent.client = agent._create_openai_client(
dict(agent._client_kwargs),
reason="switch_model",
shared=True,
)
setattr(agent, _name, _value)
except Exception: # noqa: BLE001
pass
raise
# ── Re-evaluate prompt caching ──
agent._use_prompt_caching, agent._use_native_cache_layout = (

View file

@ -0,0 +1,204 @@
"""Regression test for #33175: switch_model() must roll back to the pre-swap
state if the client rebuild raises.
Before the fix, ``agent.model`` and ``agent.provider`` were assigned BEFORE
the client rebuild was attempted, with no try/except to restore them on
failure. An exception during ``build_anthropic_client`` / OpenAI client
construction left the agent with the new model+provider name but the OLD
client producing HTTP 400s like "claude-sonnet-4-6 is not supported on
openai-codex" on the next turn.
These tests exercise both branches (openai_chat_completions and
anthropic_messages) and assert that every mutated field returns to its
pre-swap value when the rebuild raises.
"""
from unittest.mock import MagicMock, patch
import pytest
from run_agent import AIAgent
def _make_agent_openrouter():
"""Agent on openrouter (openai-compatible) with sentinel client + kwargs."""
agent = AIAgent.__new__(AIAgent)
agent.provider = "openrouter"
agent.model = "x-ai/grok-4"
agent.base_url = "https://openrouter.ai/api/v1"
agent.api_key = "or-key-original"
agent.api_mode = "chat_completions"
agent.client = MagicMock(name="OriginalOpenRouterClient")
agent._client_kwargs = {
"api_key": "or-key-original",
"base_url": "https://openrouter.ai/api/v1",
}
agent.context_compressor = None
agent._anthropic_api_key = ""
agent._anthropic_base_url = None
agent._anthropic_client = None
agent._is_anthropic_oauth = False
agent._cached_system_prompt = "cached"
agent._primary_runtime = {}
agent._fallback_activated = False
agent._fallback_index = 0
agent._fallback_chain = []
agent._fallback_model = None
agent._config_context_length = None
return agent
def _make_agent_anthropic():
"""Agent on native anthropic with a sentinel anthropic client."""
agent = AIAgent.__new__(AIAgent)
agent.provider = "anthropic"
agent.model = "claude-sonnet-4-5"
agent.base_url = "https://api.anthropic.com"
agent.api_key = "sk-ant-original"
agent.api_mode = "anthropic_messages"
agent.client = None
agent._client_kwargs = {}
agent.context_compressor = None
agent._anthropic_api_key = "sk-ant-original"
agent._anthropic_base_url = "https://api.anthropic.com"
agent._anthropic_client = MagicMock(name="OriginalAnthropicClient")
agent._is_anthropic_oauth = False
agent._cached_system_prompt = "cached"
agent._primary_runtime = {}
agent._fallback_activated = False
agent._fallback_index = 0
agent._fallback_chain = []
agent._fallback_model = None
agent._config_context_length = None
return agent
def test_openai_client_rebuild_failure_rolls_back_to_original_state():
"""When OpenAI client construction fails, every mutated field must restore."""
agent = _make_agent_openrouter()
original_client = agent.client
original_kwargs = dict(agent._client_kwargs)
# _create_openai_client raises mid-swap (simulates bad key / network error)
def boom(*_a, **_kw):
raise RuntimeError("simulated client build failure")
agent._create_openai_client = boom
with patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None):
with pytest.raises(RuntimeError, match="simulated client build failure"):
agent.switch_model(
new_model="openai/gpt-5",
new_provider="openai-codex",
api_key="codex-key-new",
base_url="https://chatgpt.com/backend-api/codex/responses",
api_mode="chat_completions",
)
# Core invariant: agent state is unchanged from before the call
assert agent.model == "x-ai/grok-4"
assert agent.provider == "openrouter"
assert agent.base_url == "https://openrouter.ai/api/v1"
assert agent.api_mode == "chat_completions"
assert agent.api_key == "or-key-original"
assert agent.client is original_client
assert agent._client_kwargs == original_kwargs
def test_anthropic_client_rebuild_failure_rolls_back_to_original_state():
"""When build_anthropic_client raises, every mutated field must restore."""
agent = _make_agent_anthropic()
original_anthropic_client = agent._anthropic_client
original_anthropic_key = agent._anthropic_api_key
original_anthropic_base = agent._anthropic_base_url
with (
patch(
"agent.anthropic_adapter.build_anthropic_client",
side_effect=RuntimeError("simulated anthropic build failure"),
),
patch(
"agent.anthropic_adapter.resolve_anthropic_token",
return_value="sk-ant-resolved",
),
patch("agent.anthropic_adapter._is_oauth_token", return_value=False),
patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None),
):
with pytest.raises(RuntimeError, match="simulated anthropic build failure"):
agent.switch_model(
new_model="claude-opus-4-6",
new_provider="opencode-zen",
api_key="zen-key-new",
base_url="https://opencode.example/v1",
api_mode="anthropic_messages",
)
# Anthropic-specific state restored
assert agent._anthropic_client is original_anthropic_client
assert agent._anthropic_api_key == original_anthropic_key
assert agent._anthropic_base_url == original_anthropic_base
# Core state also restored
assert agent.model == "claude-sonnet-4-5"
assert agent.provider == "anthropic"
assert agent.base_url == "https://api.anthropic.com"
assert agent.api_mode == "anthropic_messages"
assert agent.api_key == "sk-ant-original"
def test_cross_branch_anthropic_to_openai_rebuild_failure_rolls_back():
"""Switching from anthropic_messages to chat_completions: failure must
restore the anthropic state, not leave the agent half-converted."""
agent = _make_agent_anthropic()
original_anthropic_client = agent._anthropic_client
def boom(*_a, **_kw):
raise RuntimeError("openai client failed")
agent._create_openai_client = boom
with patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None):
with pytest.raises(RuntimeError, match="openai client failed"):
agent.switch_model(
new_model="x-ai/grok-4",
new_provider="openrouter",
api_key="or-key-new",
base_url="https://openrouter.ai/api/v1",
api_mode="chat_completions",
)
# Anthropic client preserved (not nulled by the openai branch)
assert agent._anthropic_client is original_anthropic_client
assert agent.model == "claude-sonnet-4-5"
assert agent.provider == "anthropic"
assert agent.api_mode == "anthropic_messages"
assert agent.base_url == "https://api.anthropic.com"
def test_successful_switch_still_works_after_rollback_refactor():
"""Sanity check: the try/except wrapper hasn't broken the happy path."""
agent = _make_agent_openrouter()
new_client = MagicMock(name="NewClient")
agent._create_openai_client = lambda *_a, **_kw: new_client
with patch("hermes_cli.timeouts.get_provider_request_timeout", return_value=None):
agent.switch_model(
new_model="openai/gpt-5",
new_provider="openrouter",
api_key="or-key-new",
base_url="https://openrouter.ai/api/v1",
api_mode="chat_completions",
)
assert agent.model == "openai/gpt-5"
assert agent.provider == "openrouter"
assert agent.api_key == "or-key-new"
assert agent.client is new_client