fix(agent): roll back switch_model() state when client rebuild fails (#33228)

Closes #33175.

switch_model() in agent/agent_runtime_helpers.py mutated agent.model and
agent.provider before rebuilding the client, with no try/except to restore
them on failure. If the rebuild raised (bad API key, network error,
build_anthropic_client failure, etc.) the agent was left with the new
model+provider name paired with the OLD client — producing HTTP 400s like
"claude-sonnet-4-6 is not supported on openai-codex" on the next turn.

Callers in cli.py, gateway/run.py, and tui_gateway/server.py already catch
the exception and warn the user, but the warning was misleading because
the swap had partially succeeded; the agent's state was torn.

Snapshot every mutated field before the swap, wrap the swap+rebuild block
in try/except, and restore the snapshot on failure before re-raising so
the caller's warning surfaces.

Reported by @amirariff91. Tests cover both branches (chat_completions and
anthropic_messages) and the cross-branch case (anthropic -> openai).
This commit is contained in:
Teknium 2026-05-27 05:43:20 -07:00 committed by GitHub
parent 825948edab
commit f0de3cd0a0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 324 additions and 72 deletions

View file

@ -1361,81 +1361,129 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo
old_model = agent.model
old_provider = agent.provider
# Clear the per-config context_length override so the new model's
# actual context window is resolved via get_model_context_length()
# instead of inheriting the stale value from the previous model.
agent._config_context_length = None
# ── Swap core runtime fields ──
agent.model = new_model
agent.provider = new_provider
# Use new base_url when provided; only fall back to current when the
# new provider genuinely has no endpoint (e.g. native SDK providers).
# Without this guard the old provider's URL (e.g. Ollama's localhost
# address) would persist silently after switching to a cloud provider
# that returns an empty base_url string.
if base_url:
agent.base_url = base_url
agent.api_mode = api_mode
# Invalidate transport cache — new api_mode may need a different transport
if hasattr(agent, "_transport_cache"):
agent._transport_cache.clear()
if api_key:
agent.api_key = api_key
# ── Build new client ──
if api_mode == "anthropic_messages":
from agent.anthropic_adapter import (
build_anthropic_client,
resolve_anthropic_token,
_is_oauth_token,
# ── Snapshot all fields the swap+rebuild can mutate ──
# If the rebuild raises (bad API key, network error, build_anthropic_client
# failure, etc.) we restore these atomically so the agent isn't left with a
# new model/provider name paired with the OLD client — that mismatch causes
# HTTP 400s like "claude-sonnet-4-6 is not supported on openai-codex" on the
# next turn. Callers in cli.py / gateway/run.py / tui_gateway/server.py
# catch the re-raised exception and show the user a warning; without this
# rollback the warning is misleading because the swap partially succeeded.
# Use a sentinel so we can distinguish "attribute was unset" from
# "attribute was None" and skip the restore for genuinely-missing
# attributes (tests construct bare agents via __new__ without all fields).
_MISSING = object()
_snapshot = {
name: getattr(agent, name, _MISSING)
for name in (
"model",
"provider",
"base_url",
"api_mode",
"api_key",
"client",
"_anthropic_client",
"_anthropic_api_key",
"_anthropic_base_url",
"_is_anthropic_oauth",
"_config_context_length",
)
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own
# API key — falling back would send Anthropic credentials to third-party endpoints.
_is_native_anthropic = new_provider == "anthropic"
effective_key = (api_key or agent.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or agent.api_key or "")
}
# _client_kwargs is a dict — snapshot a shallow copy so mutating the
# live dict doesn't poison the rollback target.
_snapshot["_client_kwargs"] = dict(getattr(agent, "_client_kwargs", {}) or {})
# MiniMax OAuth: swap static string for a per-request callable token
# provider so the rebuilt client survives 15-min token expiry. See
# the matching block in agent_init.py for the full rationale.
if new_provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key:
try:
# Clear the per-config context_length override so the new model's
# actual context window is resolved via get_model_context_length()
# instead of inheriting the stale value from the previous model.
agent._config_context_length = None
# ── Swap core runtime fields ──
agent.model = new_model
agent.provider = new_provider
# Use new base_url when provided; only fall back to current when the
# new provider genuinely has no endpoint (e.g. native SDK providers).
# Without this guard the old provider's URL (e.g. Ollama's localhost
# address) would persist silently after switching to a cloud provider
# that returns an empty base_url string.
if base_url:
agent.base_url = base_url
agent.api_mode = api_mode
# Invalidate transport cache — new api_mode may need a different transport
if hasattr(agent, "_transport_cache"):
agent._transport_cache.clear()
if api_key:
agent.api_key = api_key
# ── Build new client ──
if api_mode == "anthropic_messages":
from agent.anthropic_adapter import (
build_anthropic_client,
resolve_anthropic_token,
_is_oauth_token,
)
# Only fall back to ANTHROPIC_TOKEN when the provider is actually Anthropic.
# Other anthropic_messages providers (MiniMax, Alibaba, etc.) must use their own
# API key — falling back would send Anthropic credentials to third-party endpoints.
_is_native_anthropic = new_provider == "anthropic"
effective_key = (api_key or agent.api_key or resolve_anthropic_token() or "") if _is_native_anthropic else (api_key or agent.api_key or "")
# MiniMax OAuth: swap static string for a per-request callable token
# provider so the rebuilt client survives 15-min token expiry. See
# the matching block in agent_init.py for the full rationale.
if new_provider == "minimax-oauth" and isinstance(effective_key, str) and effective_key:
try:
from hermes_cli.auth import build_minimax_oauth_token_provider
effective_key = build_minimax_oauth_token_provider()
except Exception as _mm_exc: # noqa: BLE001
import logging as _logging
_logging.getLogger(__name__).warning(
"MiniMax OAuth: failed to install per-request token provider "
"on switch (%s); using static bearer.",
_mm_exc,
)
agent.api_key = effective_key
agent._anthropic_api_key = effective_key
agent._anthropic_base_url = base_url or getattr(agent, "_anthropic_base_url", None)
agent._anthropic_client = build_anthropic_client(
effective_key, agent._anthropic_base_url,
timeout=get_provider_request_timeout(agent.provider, agent.model),
)
agent._is_anthropic_oauth = _is_oauth_token(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False
agent.client = None
agent._client_kwargs = {}
else:
effective_key = api_key or agent.api_key
effective_base = base_url or agent.base_url
agent._client_kwargs = {
"api_key": effective_key,
"base_url": effective_base,
}
_sm_timeout = get_provider_request_timeout(agent.provider, agent.model)
if _sm_timeout is not None:
agent._client_kwargs["timeout"] = _sm_timeout
agent.client = agent._create_openai_client(
dict(agent._client_kwargs),
reason="switch_model",
shared=True,
)
except Exception:
# Rollback every mutated field to the pre-swap snapshot so the agent
# is left consistent (old model + old provider + old client) and the
# caller's exception handler can surface a meaningful warning. The
# exception is re-raised; cli.py / gateway/run.py / tui_gateway catch
# it and print "Agent swap failed; change applied to next session".
for _name, _value in _snapshot.items():
if _value is _MISSING:
# Attribute did not exist before the swap — don't fabricate it.
continue
try:
from hermes_cli.auth import build_minimax_oauth_token_provider
effective_key = build_minimax_oauth_token_provider()
except Exception as _mm_exc: # noqa: BLE001
import logging as _logging
_logging.getLogger(__name__).warning(
"MiniMax OAuth: failed to install per-request token provider "
"on switch (%s); using static bearer.",
_mm_exc,
)
agent.api_key = effective_key
agent._anthropic_api_key = effective_key
agent._anthropic_base_url = base_url or getattr(agent, "_anthropic_base_url", None)
agent._anthropic_client = build_anthropic_client(
effective_key, agent._anthropic_base_url,
timeout=get_provider_request_timeout(agent.provider, agent.model),
)
agent._is_anthropic_oauth = _is_oauth_token(effective_key) if (_is_native_anthropic and isinstance(effective_key, str)) else False
agent.client = None
agent._client_kwargs = {}
else:
effective_key = api_key or agent.api_key
effective_base = base_url or agent.base_url
agent._client_kwargs = {
"api_key": effective_key,
"base_url": effective_base,
}
_sm_timeout = get_provider_request_timeout(agent.provider, agent.model)
if _sm_timeout is not None:
agent._client_kwargs["timeout"] = _sm_timeout
agent.client = agent._create_openai_client(
dict(agent._client_kwargs),
reason="switch_model",
shared=True,
)
setattr(agent, _name, _value)
except Exception: # noqa: BLE001
pass
raise
# ── Re-evaluate prompt caching ──
agent._use_prompt_caching, agent._use_native_cache_layout = (