diff --git a/gateway/run.py b/gateway/run.py index ca1e48946..4838ce212 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3264,6 +3264,10 @@ class GatewayRunner: except Exception: pass + # Clear any session-scoped model override so the next agent picks up + # the configured default instead of the previously switched model. + self._session_model_overrides.pop(session_key, None) + # Reset the session new_entry = self.session_store.reset_session(session_key) diff --git a/tests/gateway/test_session_model_reset.py b/tests/gateway/test_session_model_reset.py new file mode 100644 index 000000000..6529f3a11 --- /dev/null +++ b/tests/gateway/test_session_model_reset.py @@ -0,0 +1,126 @@ +"""Tests that /new (and its /reset alias) clears the session-scoped model override.""" +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from gateway.config import GatewayConfig, Platform, PlatformConfig +from gateway.platforms.base import MessageEvent +from gateway.session import SessionEntry, SessionSource, build_session_key + + +def _make_source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + user_id="u1", + chat_id="c1", + user_name="tester", + chat_type="dm", + ) + + +def _make_event(text: str) -> MessageEvent: + return MessageEvent(text=text, source=_make_source(), message_id="m1") + + +def _make_runner(): + from gateway.run import GatewayRunner + + runner = object.__new__(GatewayRunner) + runner.config = GatewayConfig( + platforms={Platform.TELEGRAM: PlatformConfig(enabled=True, token="***")} + ) + adapter = MagicMock() + adapter.send = AsyncMock() + runner.adapters = {Platform.TELEGRAM: adapter} + runner._voice_mode = {} + runner.hooks = SimpleNamespace(emit=AsyncMock(), loaded_hooks=False) + runner._session_model_overrides = {} + runner._pending_model_notes = {} + runner._background_tasks = set() + + session_key = build_session_key(_make_source()) + session_entry = SessionEntry( + session_key=session_key, + session_id="sess-1", + created_at=datetime.now(), + updated_at=datetime.now(), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + runner.session_store = MagicMock() + runner.session_store.get_or_create_session.return_value = session_entry + runner.session_store.reset_session.return_value = session_entry + runner.session_store._entries = {session_key: session_entry} + runner.session_store._generate_session_key.return_value = session_key + runner._running_agents = {} + runner._pending_messages = {} + runner._pending_approvals = {} + runner._session_db = None + runner._agent_cache_lock = None # disables _evict_cached_agent lock path + runner._is_user_authorized = lambda _source: True + runner._format_session_info = lambda: "" + + return runner + + +@pytest.mark.asyncio +async def test_new_command_clears_session_model_override(): + """/new must remove the session-scoped model override for that session.""" + runner = _make_runner() + session_key = build_session_key(_make_source()) + + # Simulate a prior /model switch stored as a session override + runner._session_model_overrides[session_key] = { + "model": "gpt-4o", + "provider": "openai", + "api_key": "sk-test", + "base_url": "", + "api_mode": "openai", + } + + await runner._handle_reset_command(_make_event("/new")) + + assert session_key not in runner._session_model_overrides + + +@pytest.mark.asyncio +async def test_new_command_no_override_is_noop(): + """/new with no prior model override must not raise.""" + runner = _make_runner() + session_key = build_session_key(_make_source()) + + assert session_key not in runner._session_model_overrides + + await runner._handle_reset_command(_make_event("/new")) + + assert session_key not in runner._session_model_overrides + + +@pytest.mark.asyncio +async def test_new_command_only_clears_own_session(): + """/new must only clear the override for the session that triggered it.""" + runner = _make_runner() + session_key = build_session_key(_make_source()) + other_key = "other_session_key" + + runner._session_model_overrides[session_key] = { + "model": "gpt-4o", + "provider": "openai", + "api_key": "sk-test", + "base_url": "", + "api_mode": "openai", + } + runner._session_model_overrides[other_key] = { + "model": "claude-sonnet-4-6", + "provider": "anthropic", + "api_key": "sk-ant-test", + "base_url": "", + "api_mode": "anthropic", + } + + await runner._handle_reset_command(_make_event("/new")) + + assert session_key not in runner._session_model_overrides + assert other_key in runner._session_model_overrides