fix(gateway): apply /model session overrides so switch persists across messages

The gateway /model command stored session overrides in
_session_model_overrides but run_sync() never consulted them when
resolving the model and runtime for the next message.  It always read
from config.yaml, so the switch was lost as soon as a new agent was
created.

Two fixes:

1. In run_sync(), apply _session_model_overrides after resolving from
   config.yaml/env — the override takes precedence for model, provider,
   api_key, base_url, and api_mode.

2. In post-run fallback detection, check whether the model mismatch
   (agent.model != config_model) is due to an intentional /model switch
   before evicting the cached agent.  Without this, the first message
   after /model would work (cached agent reused) but the fallback
   detector would evict it, causing the next message to revert.

Affects all gateway platforms (Telegram, Discord, Slack, WhatsApp,
Signal, Matrix, BlueBubbles, HomeAssistant) since they all share
GatewayRunner._run_agent().

Fixes #6213
This commit is contained in:
kshitijk4poor 2026-04-09 22:26:32 +05:30 committed by Teknium
parent a04854800f
commit 51d826f889
2 changed files with 279 additions and 2 deletions

View file

@ -6332,6 +6332,32 @@ class GatewayRunner:
) )
return hashlib.sha256(blob.encode()).hexdigest()[:16] return hashlib.sha256(blob.encode()).hexdigest()[:16]
def _apply_session_model_override(
self, session_key: str, model: str, runtime_kwargs: dict
) -> tuple:
"""Apply /model session overrides if present, returning (model, runtime_kwargs).
The gateway /model command stores per-session overrides in
``_session_model_overrides``. These must take precedence over
config.yaml defaults so the switched model is actually used for
subsequent messages. Fields with ``None`` values are skipped so
partial overrides don't clobber valid config defaults.
"""
override = self._session_model_overrides.get(session_key)
if not override:
return model, runtime_kwargs
model = override.get("model", model)
for key in ("provider", "api_key", "base_url", "api_mode"):
val = override.get(key)
if val is not None:
runtime_kwargs[key] = val
return model, runtime_kwargs
def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool:
"""Return True if *agent_model* matches an active /model session override."""
override = self._session_model_overrides.get(session_key)
return override is not None and override.get("model") == agent_model
def _evict_cached_agent(self, session_key: str) -> None: def _evict_cached_agent(self, session_key: str) -> None:
"""Remove a cached agent for a session (called on /new, /model, etc).""" """Remove a cached agent for a session (called on /new, /model, etc)."""
_lock = getattr(self, "_agent_cache_lock", None) _lock = getattr(self, "_agent_cache_lock", None)
@ -6709,6 +6735,11 @@ class GatewayRunner:
"tools": [], "tools": [],
} }
# /model overrides take precedence over config.yaml defaults.
model, runtime_kwargs = self._apply_session_model_override(
session_key, model, runtime_kwargs
)
pr = self._provider_routing pr = self._provider_routing
reasoning_config = self._load_reasoning_config() reasoning_config = self._load_reasoning_config()
self._reasoning_config = reasoning_config self._reasoning_config = reasoning_config
@ -7328,14 +7359,15 @@ class GatewayRunner:
_agent = agent_holder[0] _agent = agent_holder[0]
if _agent is not None and hasattr(_agent, 'model'): if _agent is not None and hasattr(_agent, 'model'):
_cfg_model = _resolve_gateway_model() _cfg_model = _resolve_gateway_model()
if _agent.model != _cfg_model: if _agent.model != _cfg_model and not self._is_intentional_model_switch(session_key, _agent.model):
self._effective_model = _agent.model self._effective_model = _agent.model
self._effective_provider = getattr(_agent, 'provider', None) self._effective_provider = getattr(_agent, 'provider', None)
# Fallback activated — evict cached agent so the next # Fallback activated — evict cached agent so the next
# message starts fresh and retries the primary model. # message starts fresh and retries the primary model.
self._evict_cached_agent(session_key) self._evict_cached_agent(session_key)
else: else:
# Primary model worked — clear any stale fallback state # Primary model worked (or intentional /model switch)
# — clear any stale fallback state.
self._effective_model = None self._effective_model = None
self._effective_provider = None self._effective_provider = None

View file

@ -0,0 +1,245 @@
"""Tests that gateway /model switch persists across messages.
The gateway /model command stores session overrides in
``_session_model_overrides``. These must:
1. Be applied in ``run_sync()`` so the next agent uses the switched model.
2. Not be mistaken for fallback activation (which evicts the cached agent).
3. Survive across multiple messages until /reset clears them.
Tests exercise the real ``_apply_session_model_override()`` and
``_is_intentional_model_switch()`` methods on ``GatewayRunner``.
"""
from datetime import datetime
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.session import SessionEntry, SessionSource, build_session_key
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_source() -> SessionSource:
return SessionSource(
platform=Platform.TELEGRAM,
user_id="u1",
chat_id="c1",
user_name="tester",
chat_type="dm",
)
def _make_runner():
"""Create a minimal GatewayRunner with stubbed internals."""
from gateway.run import GatewayRunner
runner = object.__new__(GatewayRunner)
runner.config = GatewayConfig(
platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")}
)
adapter = MagicMock()
adapter.send = AsyncMock()
runner.adapters = {Platform.TELEGRAM: adapter}
runner._voice_mode = {}
runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False)
runner._session_model_overrides = {}
runner._pending_model_notes = {}
runner._background_tasks = set()
runner._running_agents = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = None
runner._agent_cache = {}
runner._agent_cache_lock = None
runner._effective_model = None
runner._effective_provider = None
runner.session_store = MagicMock()
session_key = build_session_key(_make_source())
session_entry = SessionEntry(
session_key=session_key,
session_id="sess-1",
created_at=datetime.now(),
updated_at=datetime.now(),
platform=Platform.TELEGRAM,
chat_type="dm",
)
runner.session_store.get_or_create_session.return_value = session_entry
runner.session_store._entries = {session_key: session_entry}
return runner
# ---------------------------------------------------------------------------
# Tests: _apply_session_model_override
# ---------------------------------------------------------------------------
class TestApplySessionModelOverride:
"""Verify _apply_session_model_override replaces config defaults."""
def test_override_replaces_all_fields(self):
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides[sk] = {
"model": "gpt-5.4-turbo",
"provider": "openrouter",
"api_key": "or-key-123",
"base_url": "https://openrouter.ai/api/v1",
"api_mode": "chat_completions",
}
model, rt = runner._apply_session_model_override(
sk,
"anthropic/claude-sonnet-4",
{"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"},
)
assert model == "gpt-5.4-turbo"
assert rt["provider"] == "openrouter"
assert rt["api_key"] == "or-key-123"
assert rt["base_url"] == "https://openrouter.ai/api/v1"
assert rt["api_mode"] == "chat_completions"
def test_no_override_returns_originals(self):
runner = _make_runner()
sk = build_session_key(_make_source())
orig_model = "anthropic/claude-sonnet-4"
orig_rt = {"provider": "anthropic", "api_key": "key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"}
model, rt = runner._apply_session_model_override(sk, orig_model, dict(orig_rt))
assert model == orig_model
assert rt == orig_rt
def test_none_values_do_not_overwrite(self):
"""Override with None api_key/base_url should preserve config defaults."""
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides[sk] = {
"model": "gpt-5.4",
"provider": "openai",
"api_key": None,
"base_url": None,
"api_mode": "chat_completions",
}
model, rt = runner._apply_session_model_override(
sk,
"anthropic/claude-sonnet-4",
{"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"},
)
assert model == "gpt-5.4"
assert rt["provider"] == "openai"
assert rt["api_key"] == "ant-key" # preserved — None didn't overwrite
assert rt["base_url"] == "https://api.anthropic.com" # preserved
assert rt["api_mode"] == "chat_completions" # overwritten (not None)
def test_empty_string_overwrites(self):
"""Empty string is not None — it should overwrite the config value."""
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides[sk] = {
"model": "local-model",
"provider": "custom",
"api_key": "local-key",
"base_url": "",
"api_mode": "chat_completions",
}
_, rt = runner._apply_session_model_override(
sk,
"anthropic/claude-sonnet-4",
{"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"},
)
assert rt["base_url"] == "" # empty string overwrites
def test_different_session_key_not_affected(self):
runner = _make_runner()
sk = build_session_key(_make_source())
other_sk = "other_session"
runner._session_model_overrides[other_sk] = {
"model": "gpt-5.4",
"provider": "openai",
"api_key": "key",
"base_url": "",
"api_mode": "chat_completions",
}
model, rt = runner._apply_session_model_override(
sk,
"anthropic/claude-sonnet-4",
{"provider": "anthropic", "api_key": "ant-key", "base_url": "url", "api_mode": "anthropic_messages"},
)
assert model == "anthropic/claude-sonnet-4" # unchanged — wrong session key
# ---------------------------------------------------------------------------
# Tests: _is_intentional_model_switch
# ---------------------------------------------------------------------------
class TestIsIntentionalModelSwitch:
"""Verify fallback detection respects intentional /model overrides."""
def test_matches_override(self):
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides[sk] = {
"model": "gpt-5.4",
"provider": "openai",
"api_key": "key",
"base_url": "",
"api_mode": "chat_completions",
}
assert runner._is_intentional_model_switch(sk, "gpt-5.4") is True
def test_no_override_returns_false(self):
runner = _make_runner()
sk = build_session_key(_make_source())
assert runner._is_intentional_model_switch(sk, "gpt-5.4") is False
def test_different_model_returns_false(self):
"""Agent fell back to a different model than the override."""
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides[sk] = {
"model": "gpt-5.4",
"provider": "openai",
"api_key": "key",
"base_url": "",
"api_mode": "chat_completions",
}
assert runner._is_intentional_model_switch(sk, "gpt-5.4-mini") is False
def test_wrong_session_key(self):
runner = _make_runner()
sk = build_session_key(_make_source())
runner._session_model_overrides["other_session"] = {
"model": "gpt-5.4",
"provider": "openai",
"api_key": "key",
"base_url": "",
"api_mode": "chat_completions",
}
assert runner._is_intentional_model_switch(sk, "gpt-5.4") is False