diff --git a/run_agent.py b/run_agent.py index 685b8372f8..4ad047262a 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1271,7 +1271,10 @@ class AIAgent: _agent_cfg = _load_agent_config() except Exception: _agent_cfg = {} - self._config = _agent_cfg # stored for later use (e.g. compression feasibility check) + # Cache only the derived auxiliary compression context override that is + # needed later by the startup feasibility check. Avoid exposing a + # broad pseudo-public config object on the agent instance. + self._aux_compression_context_length_config = None # Persistent memory (MEMORY.md + USER.md) -- loaded from disk self._memory_store = None @@ -1402,6 +1405,24 @@ class AIAgent: compression_target_ratio = float(_compression_cfg.get("target_ratio", 0.20)) compression_protect_last = int(_compression_cfg.get("protect_last_n", 20)) + # Read optional explicit context_length override for the auxiliary + # compression model. Custom endpoints often cannot report this via + # /models, so the startup feasibility check needs the config hint. + try: + _aux_cfg = _agent_cfg.get("auxiliary", {}).get("compression", {}) + except Exception: + _aux_cfg = {} + if isinstance(_aux_cfg, dict): + _aux_context_config = _aux_cfg.get("context_length") + else: + _aux_context_config = None + if _aux_context_config is not None: + try: + _aux_context_config = int(_aux_context_config) + except (TypeError, ValueError): + _aux_context_config = None + self._aux_compression_context_length_config = _aux_context_config + # Read explicit context_length override from model config _model_cfg = _agent_cfg.get("model", {}) if isinstance(_model_cfg, dict): @@ -1999,24 +2020,11 @@ class AIAgent: aux_base_url = str(getattr(client, "base_url", "")) aux_api_key = str(getattr(client, "api_key", "")) - # Read user-configured context_length for the compression model. - # Custom endpoints often don't support /models API queries so - # get_model_context_length() falls through to the 128K default, - # ignoring the explicit config value. Pass it as the highest- - # priority hint so the configured value is always respected. - _aux_cfg = (self._config or {}).get("auxiliary", {}).get("compression", {}) - _aux_context_config = _aux_cfg.get("context_length") if isinstance(_aux_cfg, dict) else None - if _aux_context_config is not None: - try: - _aux_context_config = int(_aux_context_config) - except (TypeError, ValueError): - _aux_context_config = None - aux_context = get_model_context_length( aux_model, base_url=aux_base_url, api_key=aux_api_key, - config_context_length=_aux_context_config, + config_context_length=getattr(self, "_aux_compression_context_length_config", None), ) threshold = self.context_compressor.threshold_tokens diff --git a/tests/run_agent/test_compression_feasibility.py b/tests/run_agent/test_compression_feasibility.py index f0db50de4d..451eeb2f7e 100644 --- a/tests/run_agent/test_compression_feasibility.py +++ b/tests/run_agent/test_compression_feasibility.py @@ -38,7 +38,7 @@ def _make_agent( agent.status_callback = None agent.tool_progress_callback = None agent._compression_warning = None - agent._config = None + agent._aux_compression_context_length_config = None compressor = MagicMock(spec=ContextCompressor) compressor.context_length = main_context @@ -138,13 +138,7 @@ def test_feasibility_check_passes_config_context_length(mock_get_client, mock_ct get_model_context_length so custom endpoints that lack /models still report the correct context window (fixes #8499).""" agent = _make_agent(main_context=200_000, threshold_percent=0.85) - agent._config = { - "auxiliary": { - "compression": { - "context_length": 1_000_000, - }, - }, - } + agent._aux_compression_context_length_config = 1_000_000 mock_client = MagicMock() mock_client.base_url = "http://custom-endpoint:8080/v1" mock_client.api_key = "sk-custom" @@ -166,13 +160,7 @@ def test_feasibility_check_passes_config_context_length(mock_get_client, mock_ct def test_feasibility_check_ignores_invalid_context_length(mock_get_client, mock_ctx_len): """Non-integer context_length in config is silently ignored.""" agent = _make_agent(main_context=200_000, threshold_percent=0.50) - agent._config = { - "auxiliary": { - "compression": { - "context_length": "not-a-number", - }, - }, - } + agent._aux_compression_context_length_config = None mock_client = MagicMock() mock_client.base_url = "http://custom:8080/v1" mock_client.api_key = "sk-test" @@ -189,6 +177,58 @@ def test_feasibility_check_ignores_invalid_context_length(mock_get_client, mock_ ) +def test_init_feasibility_check_uses_aux_context_override_from_config(): + """Real AIAgent init should cache and forward auxiliary.compression.context_length.""" + + class _StubCompressor: + def __init__(self, *args, **kwargs): + self.context_length = 200_000 + self.threshold_tokens = 100_000 + self.threshold_percent = 0.50 + + def get_tool_schemas(self): + return [] + + def on_session_start(self, *args, **kwargs): + return None + + cfg = { + "auxiliary": { + "compression": { + "context_length": 1_000_000, + }, + }, + } + mock_client = MagicMock() + mock_client.base_url = "http://custom-endpoint:8080/v1" + mock_client.api_key = "sk-custom" + + with ( + patch("hermes_cli.config.load_config", return_value=cfg), + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch("run_agent.ContextCompressor", new=_StubCompressor), + patch("agent.auxiliary_client.get_text_auxiliary_client", return_value=(mock_client, "custom/big-model")), + patch("agent.model_metadata.get_model_context_length", return_value=1_000_000) as mock_ctx_len, + ): + agent = AIAgent( + api_key="test-key-1234567890", + base_url="https://openrouter.ai/api/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + + assert agent._aux_compression_context_length_config == 1_000_000 + mock_ctx_len.assert_called_once_with( + "custom/big-model", + base_url="http://custom-endpoint:8080/v1", + api_key="sk-custom", + config_context_length=1_000_000, + ) + + @patch("agent.auxiliary_client.get_text_auxiliary_client") def test_warns_when_no_auxiliary_provider(mock_get_client): """Warning emitted when no auxiliary provider is configured."""