diff --git a/run_agent.py b/run_agent.py index 4e9b95567..d22543f85 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1149,6 +1149,9 @@ class AIAgent: except (TypeError, ValueError): _config_context_length = None + # Store for reuse in switch_model (so config override persists across model switches) + self._config_context_length = _config_context_length + # Check custom_providers per-model context_length if _config_context_length is None: _custom_providers = _agent_cfg.get("custom_providers") @@ -1386,6 +1389,7 @@ class AIAgent: base_url=self.base_url, api_key=self.api_key, provider=self.provider, + config_context_length=getattr(self, "_config_context_length", None), ) self.context_compressor.model = self.model self.context_compressor.base_url = self.base_url diff --git a/tests/run_agent/test_switch_model_context.py b/tests/run_agent/test_switch_model_context.py new file mode 100644 index 000000000..8b04a7326 --- /dev/null +++ b/tests/run_agent/test_switch_model_context.py @@ -0,0 +1,74 @@ +"""Tests that switch_model preserves config_context_length.""" + +from unittest.mock import MagicMock, patch + +from run_agent import AIAgent +from agent.context_compressor import ContextCompressor + + +def _make_agent_with_compressor(config_context_length=None) -> AIAgent: + """Build a minimal AIAgent with a context_compressor, skipping __init__.""" + agent = AIAgent.__new__(AIAgent) + + # Primary model settings + agent.model = "primary-model" + agent.provider = "openrouter" + agent.base_url = "https://openrouter.ai/api/v1" + agent.api_key = "sk-primary" + agent.api_mode = "chat_completions" + agent.client = MagicMock() + agent.quiet_mode = True + + # Store config_context_length for later use in switch_model + agent._config_context_length = config_context_length + + # Context compressor with primary model values + compressor = ContextCompressor( + model="primary-model", + threshold_percent=0.50, + base_url="https://openrouter.ai/api/v1", + api_key="sk-primary", + provider="openrouter", + quiet_mode=True, + config_context_length=config_context_length, + ) + agent.context_compressor = compressor + + # For switch_model + agent._primary_runtime = {} + + return agent + + +@patch("agent.model_metadata.get_model_context_length", return_value=131_072) +def test_switch_model_preserves_config_context_length(mock_ctx_len): + """When switching models, config_context_length should be passed to get_model_context_length.""" + agent = _make_agent_with_compressor(config_context_length=32_768) + + assert agent.context_compressor.model == "primary-model" + assert agent.context_compressor.context_length == 32_768 # From config override + + # Switch model + agent.switch_model("new-model", "openrouter", api_key="sk-new", base_url="https://openrouter.ai/api/v1") + + # Verify get_model_context_length was called with config_context_length + mock_ctx_len.assert_called_once() + call_kwargs = mock_ctx_len.call_args.kwargs + assert call_kwargs.get("config_context_length") == 32_768 + + # Verify compressor was updated + assert agent.context_compressor.model == "new-model" + + +def test_switch_model_without_config_context_length(): + """When switching models without config override, config_context_length should be None.""" + agent = _make_agent_with_compressor(config_context_length=None) + + with patch("agent.model_metadata.get_model_context_length", return_value=128_000) as mock_ctx_len: + # Switch model + agent.switch_model("new-model", "openrouter", api_key="sk-new", base_url="https://openrouter.ai/api/v1") + + # Verify get_model_context_length was called with None + mock_ctx_len.assert_called_once() + call_kwargs = mock_ctx_len.call_args.kwargs + assert call_kwargs.get("config_context_length") is None