diff --git a/gateway/run.py b/gateway/run.py index 662e089413..5aa42cf53a 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -6332,6 +6332,32 @@ class GatewayRunner: ) return hashlib.sha256(blob.encode()).hexdigest()[:16] + def _apply_session_model_override( + self, session_key: str, model: str, runtime_kwargs: dict + ) -> tuple: + """Apply /model session overrides if present, returning (model, runtime_kwargs). + + The gateway /model command stores per-session overrides in + ``_session_model_overrides``. These must take precedence over + config.yaml defaults so the switched model is actually used for + subsequent messages. Fields with ``None`` values are skipped so + partial overrides don't clobber valid config defaults. + """ + override = self._session_model_overrides.get(session_key) + if not override: + return model, runtime_kwargs + model = override.get("model", model) + for key in ("provider", "api_key", "base_url", "api_mode"): + val = override.get(key) + if val is not None: + runtime_kwargs[key] = val + return model, runtime_kwargs + + def _is_intentional_model_switch(self, session_key: str, agent_model: str) -> bool: + """Return True if *agent_model* matches an active /model session override.""" + override = self._session_model_overrides.get(session_key) + return override is not None and override.get("model") == agent_model + def _evict_cached_agent(self, session_key: str) -> None: """Remove a cached agent for a session (called on /new, /model, etc).""" _lock = getattr(self, "_agent_cache_lock", None) @@ -6709,6 +6735,11 @@ class GatewayRunner: "tools": [], } + # /model overrides take precedence over config.yaml defaults. + model, runtime_kwargs = self._apply_session_model_override( + session_key, model, runtime_kwargs + ) + pr = self._provider_routing reasoning_config = self._load_reasoning_config() self._reasoning_config = reasoning_config @@ -7328,14 +7359,15 @@ class GatewayRunner: _agent = agent_holder[0] if _agent is not None and hasattr(_agent, 'model'): _cfg_model = _resolve_gateway_model() - if _agent.model != _cfg_model: + if _agent.model != _cfg_model and not self._is_intentional_model_switch(session_key, _agent.model): self._effective_model = _agent.model self._effective_provider = getattr(_agent, 'provider', None) # Fallback activated — evict cached agent so the next # message starts fresh and retries the primary model. self._evict_cached_agent(session_key) else: - # Primary model worked — clear any stale fallback state + # Primary model worked (or intentional /model switch) + # — clear any stale fallback state. self._effective_model = None self._effective_provider = None diff --git a/tests/gateway/test_model_switch_persistence.py b/tests/gateway/test_model_switch_persistence.py new file mode 100644 index 0000000000..07fa5d5f43 --- /dev/null +++ b/tests/gateway/test_model_switch_persistence.py @@ -0,0 +1,245 @@ +"""Tests that gateway /model switch persists across messages. + +The gateway /model command stores session overrides in +``_session_model_overrides``. These must: + +1. Be applied in ``run_sync()`` so the next agent uses the switched model. +2. Not be mistaken for fallback activation (which evicts the cached agent). +3. Survive across multiple messages until /reset clears them. + +Tests exercise the real ``_apply_session_model_override()`` and +``_is_intentional_model_switch()`` methods on ``GatewayRunner``. +""" + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.session import SessionEntry, SessionSource, build_session_key + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_runner(): + """Create a minimal GatewayRunner with stubbed internals.""" + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="tok")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner._session_model_overrides = {} + runner._pending_model_notes = {} + runner._background_tasks = set() + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._agent_cache = {} + runner._agent_cache_lock = None + runner._effective_model = None + runner._effective_provider = None + runner.session_store = MagicMock() + session_key = build_session_key(_make_source()) + session_entry = SessionEntry( + session_key=session_key, + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store._entries = {session_key: session_entry} + return runner + + +# --------------------------------------------------------------------------- +# Tests: _apply_session_model_override +# --------------------------------------------------------------------------- + + +class TestApplySessionModelOverride: + """Verify _apply_session_model_override replaces config defaults.""" + + def test_override_replaces_all_fields(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides[sk] = { + "model": "gpt-5.4-turbo", + "provider": "openrouter", + "api_key": "or-key-123", + "base_url": "https://openrouter.ai/api/v1", + "api_mode": "chat_completions", + } + + model, rt = runner._apply_session_model_override( + sk, + "anthropic/claude-sonnet-4", + {"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"}, + ) + + assert model == "gpt-5.4-turbo" + assert rt["provider"] == "openrouter" + assert rt["api_key"] == "or-key-123" + assert rt["base_url"] == "https://openrouter.ai/api/v1" + assert rt["api_mode"] == "chat_completions" + + def test_no_override_returns_originals(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + + orig_model = "anthropic/claude-sonnet-4" + orig_rt = {"provider": "anthropic", "api_key": "key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"} + + model, rt = runner._apply_session_model_override(sk, orig_model, dict(orig_rt)) + + assert model == orig_model + assert rt == orig_rt + + def test_none_values_do_not_overwrite(self): + """Override with None api_key/base_url should preserve config defaults.""" + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides[sk] = { + "model": "gpt-5.4", + "provider": "openai", + "api_key": None, + "base_url": None, + "api_mode": "chat_completions", + } + + model, rt = runner._apply_session_model_override( + sk, + "anthropic/claude-sonnet-4", + {"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"}, + ) + + assert model == "gpt-5.4" + assert rt["provider"] == "openai" + assert rt["api_key"] == "ant-key" # preserved — None didn't overwrite + assert rt["base_url"] == "https://api.anthropic.com" # preserved + assert rt["api_mode"] == "chat_completions" # overwritten (not None) + + def test_empty_string_overwrites(self): + """Empty string is not None — it should overwrite the config value.""" + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides[sk] = { + "model": "local-model", + "provider": "custom", + "api_key": "local-key", + "base_url": "", + "api_mode": "chat_completions", + } + + _, rt = runner._apply_session_model_override( + sk, + "anthropic/claude-sonnet-4", + {"provider": "anthropic", "api_key": "ant-key", "base_url": "https://api.anthropic.com", "api_mode": "anthropic_messages"}, + ) + + assert rt["base_url"] == "" # empty string overwrites + + def test_different_session_key_not_affected(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + other_sk = "other_session" + + runner._session_model_overrides[other_sk] = { + "model": "gpt-5.4", + "provider": "openai", + "api_key": "key", + "base_url": "", + "api_mode": "chat_completions", + } + + model, rt = runner._apply_session_model_override( + sk, + "anthropic/claude-sonnet-4", + {"provider": "anthropic", "api_key": "ant-key", "base_url": "url", "api_mode": "anthropic_messages"}, + ) + + assert model == "anthropic/claude-sonnet-4" # unchanged — wrong session key + + +# --------------------------------------------------------------------------- +# Tests: _is_intentional_model_switch +# --------------------------------------------------------------------------- + + +class TestIsIntentionalModelSwitch: + """Verify fallback detection respects intentional /model overrides.""" + + def test_matches_override(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides[sk] = { + "model": "gpt-5.4", + "provider": "openai", + "api_key": "key", + "base_url": "", + "api_mode": "chat_completions", + } + + assert runner._is_intentional_model_switch(sk, "gpt-5.4") is True + + def test_no_override_returns_false(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + + assert runner._is_intentional_model_switch(sk, "gpt-5.4") is False + + def test_different_model_returns_false(self): + """Agent fell back to a different model than the override.""" + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides[sk] = { + "model": "gpt-5.4", + "provider": "openai", + "api_key": "key", + "base_url": "", + "api_mode": "chat_completions", + } + + assert runner._is_intentional_model_switch(sk, "gpt-5.4-mini") is False + + def test_wrong_session_key(self): + runner = _make_runner() + sk = build_session_key(_make_source()) + + runner._session_model_overrides["other_session"] = { + "model": "gpt-5.4", + "provider": "openai", + "api_key": "key", + "base_url": "", + "api_mode": "chat_completions", + } + + assert runner._is_intentional_model_switch(sk, "gpt-5.4") is False