From a78e622dfe5504dd7d08c5243f60ed00f6a1f08f Mon Sep 17 00:00:00 2001 From: LeonSGP43 Date: Mon, 4 May 2026 09:36:43 +0800 Subject: [PATCH] fix(agent): honor configured model max tokens --- gateway/run.py | 1 + run_agent.py | 29 +++++++++++++++++- tests/gateway/test_agent_cache.py | 24 ++++++++++++++- tests/run_agent/test_run_agent.py | 50 +++++++++++++++++++++++++++++++ 4 files changed, 102 insertions(+), 2 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index de80262710..f96d77b3c0 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -12251,6 +12251,7 @@ class GatewayRunner: # Add more here as new baked-at-construction config settings are added. _CACHE_BUSTING_CONFIG_KEYS: tuple = ( ("model", "context_length"), + ("model", "max_tokens"), ("compression", "enabled"), ("compression", "threshold"), ("compression", "target_ratio"), diff --git a/run_agent.py b/run_agent.py index 3e1f2772a9..185431671b 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1901,8 +1901,35 @@ class AIAgent: _aux_context_config = None self._aux_compression_context_length_config = _aux_context_config - # Read explicit context_length override from model config + # Read explicit model output-token override from config when the + # caller did not pass one directly. _model_cfg = _agent_cfg.get("model", {}) + if self.max_tokens is None and isinstance(_model_cfg, dict): + _config_max_tokens = _model_cfg.get("max_tokens") + if _config_max_tokens is not None: + try: + if isinstance(_config_max_tokens, bool): + raise ValueError + _parsed_max_tokens = int(_config_max_tokens) + if _parsed_max_tokens <= 0: + raise ValueError + self.max_tokens = _parsed_max_tokens + except (TypeError, ValueError): + logger.warning( + "Invalid model.max_tokens in config.yaml: %r — " + "must be a positive integer (e.g. 4096). " + "Falling back to provider default.", + _config_max_tokens, + ) + print( + f"\n⚠ Invalid model.max_tokens in config.yaml: {_config_max_tokens!r}\n" + f" Must be a positive integer (e.g. 4096).\n" + f" Falling back to provider default.\n", + file=sys.stderr, + ) + self._session_init_model_config["max_tokens"] = self.max_tokens + + # Read explicit context_length override from model config if isinstance(_model_cfg, dict): _config_context_length = _model_cfg.get("context_length") else: diff --git a/tests/gateway/test_agent_cache.py b/tests/gateway/test_agent_cache.py index abf0ce3481..fad7e6c1cf 100644 --- a/tests/gateway/test_agent_cache.py +++ b/tests/gateway/test_agent_cache.py @@ -127,6 +127,21 @@ class TestAgentConfigSignature: ) assert sig1 != sig2 + def test_max_tokens_change_busts_cache(self): + """Editing model.max_tokens in config must produce a new signature.""" + from gateway.run import GatewayRunner + + runtime = {"api_key": "k", "base_url": "u", "provider": "p"} + sig1 = GatewayRunner._agent_config_signature( + "m", runtime, [], "", + cache_keys={"model.max_tokens": 4096}, + ) + sig2 = GatewayRunner._agent_config_signature( + "m", runtime, [], "", + cache_keys={"model.max_tokens": 8192}, + ) + assert sig1 != sig2 + def test_compression_threshold_change_busts_cache(self): from gateway.run import GatewayRunner @@ -195,9 +210,16 @@ class TestExtractCacheBustingConfig: from gateway.run import GatewayRunner out = GatewayRunner._extract_cache_busting_config( - {"model": {"context_length": 272_000, "provider": "openrouter"}} + { + "model": { + "context_length": 272_000, + "max_tokens": 4096, + "provider": "openrouter", + } + } ) assert out["model.context_length"] == 272_000 + assert out["model.max_tokens"] == 4096 def test_reads_compression_subkeys(self): from gateway.run import GatewayRunner diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index cbce772d3a..7c5973617b 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -724,6 +724,56 @@ class TestInit: ) assert a._cache_ttl == "1h" + def test_model_max_tokens_from_config(self): + """model.max_tokens config populates the chat-completions request cap.""" + with ( + patch("run_agent.get_tool_definitions", return_value=_make_tool_defs("terminal")), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch( + "hermes_cli.config.load_config", + return_value={"model": {"max_tokens": 4096}}, + ), + ): + a = AIAgent( + api_key="test-k...7890", + provider="custom", + model="claude-opus-4-6-thinking", + base_url="http://proxy.example/v1", + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + + kwargs = a._build_api_kwargs([{"role": "user", "content": "Hi"}]) + + assert a.max_tokens == 4096 + assert kwargs["max_tokens"] == 4096 + + def test_constructor_max_tokens_wins_over_config(self): + """Explicit constructor max_tokens keeps programmatic callers stable.""" + with ( + patch("run_agent.get_tool_definitions", return_value=[]), + patch("run_agent.check_toolset_requirements", return_value={}), + patch("run_agent.OpenAI"), + patch( + "hermes_cli.config.load_config", + return_value={"model": {"max_tokens": 4096}}, + ), + ): + a = AIAgent( + api_key="test-k...7890", + provider="custom", + model="claude-opus-4-6-thinking", + base_url="http://proxy.example/v1", + max_tokens=8192, + quiet_mode=True, + skip_context_files=True, + skip_memory=True, + ) + + assert a.max_tokens == 8192 + def test_prompt_caching_cache_ttl_invalid_falls_back(self): """Non-Anthropic TTL values keep default 5m without raising.""" with (