diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 850e166626..29d5e1e89b 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -106,9 +106,11 @@ _endpoint_model_metadata_cache_time: Dict[str, float] = {} _ENDPOINT_MODEL_CACHE_TTL = 300 # Descending tiers for context length probing when the model is unknown. -# We start at 128K (a safe default for most modern models) and step down -# on context-length errors until one works. +# We start at 256K (covers GPT-5.x, many current large-context models) and +# step down on context-length errors until one works. Tier[0] is also the +# default fallback when no detection method succeeds. CONTEXT_PROBE_TIERS = [ + 256_000, 128_000, 64_000, 32_000, @@ -1193,6 +1195,7 @@ def get_model_context_length( api_key: str = "", config_context_length: int | None = None, provider: str = "", + custom_providers: list | None = None, ) -> int: """Get the context length for a model. @@ -1213,6 +1216,23 @@ def get_model_context_length( if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0: return config_context_length + # 0b. custom_providers per-model override — check before any probe. + # This closes the gap where /model switch and display paths used to fall + # back to 128K despite the user having a per-model context_length set. + # See #15779. + if custom_providers and base_url and model: + try: + from hermes_cli.config import get_custom_provider_context_length + cp_ctx = get_custom_provider_context_length( + model=model, + base_url=base_url, + custom_providers=custom_providers, + ) + if cp_ctx: + return cp_ctx + except Exception: + pass # fall through to probing + # Normalise provider-prefixed model names (e.g. "local:model-name" → # "model-name") so cache lookups and server queries use the bare ID that # local servers actually know about. Ollama "model:tag" colons are preserved. @@ -1352,7 +1372,7 @@ def get_model_context_length( # 6. OpenRouter live API metadata (provider-unaware fallback) metadata = fetch_model_metadata() if model in metadata: - return metadata[model].get("context_length", 128000) + return metadata[model].get("context_length", DEFAULT_FALLBACK_CONTEXT) # 8. Hardcoded defaults (fuzzy match — longest key first for specificity) # Only check `default_model in model` (is the key a substring of the input). diff --git a/gateway/run.py b/gateway/run.py index 4c82a9274b..05578fa0d8 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4891,6 +4891,7 @@ class GatewayRunner: provider = None base_url = None api_key = None + custom_provs = None try: cfg_path = _hermes_home / "config.yaml" @@ -4908,6 +4909,11 @@ class GatewayRunner: pass provider = model_cfg.get("provider") or None base_url = model_cfg.get("base_url") or None + try: + from hermes_cli.config import get_compatible_custom_providers + custom_provs = get_compatible_custom_providers(data) + except Exception: + custom_provs = data.get("custom_providers") except Exception: pass @@ -4926,6 +4932,7 @@ class GatewayRunner: api_key=api_key or "", config_context_length=config_context_length, provider=provider or "", + custom_providers=custom_provs, ) # Format context source hint @@ -5601,6 +5608,7 @@ class GatewayRunner: base_url=result.base_url or current_base_url or "", api_key=result.api_key or current_api_key or "", model_info=mi, + custom_providers=custom_provs, ) if ctx: lines.append(f"Context: {ctx:,} tokens") @@ -5748,6 +5756,7 @@ class GatewayRunner: base_url=result.base_url or current_base_url or "", api_key=result.api_key or current_api_key or "", model_info=mi, + custom_providers=custom_provs, ) if ctx: lines.append(f"Context: {ctx:,} tokens") diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 4b7ff9fba7..a32213c5f8 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -2206,6 +2206,71 @@ def get_compatible_custom_providers( return compatible +def get_custom_provider_context_length( + model: str, + base_url: str, + custom_providers: Optional[List[Dict[str, Any]]] = None, + config: Optional[Dict[str, Any]] = None, +) -> Optional[int]: + """Look up a per-model ``context_length`` override from ``custom_providers``. + + Matches any entry whose ``base_url`` equals ``base_url`` (trailing-slash + insensitive) and returns ``custom_providers[i].models..context_length`` + if present and valid. Returns ``None`` when no override applies. + + This is the single source of truth for custom-provider context overrides, + used by: + * ``AIAgent.__init__`` (startup resolution) + * ``AIAgent.switch_model`` (mid-session ``/model`` switch) + * ``hermes_cli.model_switch.resolve_display_context_length`` (``/model`` confirmation display) + * ``gateway.run._format_session_info`` (``/info`` display) + * ``agent.model_metadata.get_model_context_length`` (when custom_providers is threaded through) + + Before this helper existed, the lookup was duplicated in ``run_agent.py``'s + startup path only; every other path (notably ``/model`` switch) fell back + to the 128K default. See #15779. + """ + if not model or not base_url: + return None + if custom_providers is None: + try: + custom_providers = get_compatible_custom_providers(config) + except Exception: + if config is None: + return None + raw = config.get("custom_providers") + custom_providers = raw if isinstance(raw, list) else [] + if not isinstance(custom_providers, list): + return None + + target_url = (base_url or "").rstrip("/") + if not target_url: + return None + + for entry in custom_providers: + if not isinstance(entry, dict): + continue + entry_url = (entry.get("base_url") or "").rstrip("/") + if not entry_url or entry_url != target_url: + continue + models = entry.get("models") + if not isinstance(models, dict): + continue + model_cfg = models.get(model) + if not isinstance(model_cfg, dict): + continue + raw_ctx = model_cfg.get("context_length") + if raw_ctx is None: + continue + try: + ctx = int(raw_ctx) + except (TypeError, ValueError): + continue + if ctx > 0: + return ctx + return None + + def check_config_version() -> Tuple[int, int]: """ Check config version. diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index b91bed67a3..d9e1b04183 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -533,6 +533,7 @@ def resolve_display_context_length( base_url: str = "", api_key: str = "", model_info: Optional[ModelInfo] = None, + custom_providers: list | None = None, ) -> Optional[int]: """Resolve the context length to show in /model output. @@ -543,6 +544,11 @@ def resolve_display_context_length( about Codex OAuth, Copilot, Nous, and falls back to models.dev for the rest. + When ``custom_providers`` is provided, per-model ``context_length`` + overrides from ``custom_providers[].models..context_length`` are + honored — this closes #15779 where ``/model`` switch ignored user-set + overrides. + Prefer the provider-aware value; fall back to ``model_info.context_window`` only if the resolver returns nothing. """ @@ -553,6 +559,7 @@ def resolve_display_context_length( base_url=base_url or "", api_key=api_key or "", provider=provider or None, + custom_providers=custom_providers, ) if ctx: return int(ctx) diff --git a/run_agent.py b/run_agent.py index 370ad97822..70d9a84b06 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1765,43 +1765,64 @@ class AIAgent: # Store for reuse in switch_model (so config override persists across model switches) self._config_context_length = _config_context_length + # Resolve custom_providers list once for reuse below (startup + # context-length override and plugin context-engine init). + try: + from hermes_cli.config import get_compatible_custom_providers + _custom_providers = get_compatible_custom_providers(_agent_cfg) + except Exception: + _custom_providers = _agent_cfg.get("custom_providers") + if not isinstance(_custom_providers, list): + _custom_providers = [] + # Check custom_providers per-model context_length - if _config_context_length is None: + if _config_context_length is None and _custom_providers: try: - from hermes_cli.config import get_compatible_custom_providers - _custom_providers = get_compatible_custom_providers(_agent_cfg) + from hermes_cli.config import get_custom_provider_context_length + _cp_ctx_resolved = get_custom_provider_context_length( + model=self.model, + base_url=self.base_url, + custom_providers=_custom_providers, + ) + if _cp_ctx_resolved: + _config_context_length = int(_cp_ctx_resolved) except Exception: - _custom_providers = _agent_cfg.get("custom_providers") - if not isinstance(_custom_providers, list): - _custom_providers = [] - for _cp_entry in _custom_providers: - if not isinstance(_cp_entry, dict): - continue - _cp_url = (_cp_entry.get("base_url") or "").rstrip("/") - if _cp_url and _cp_url == self.base_url.rstrip("/"): - _cp_models = _cp_entry.get("models", {}) - if isinstance(_cp_models, dict): - _cp_model_cfg = _cp_models.get(self.model, {}) - if isinstance(_cp_model_cfg, dict): - _cp_ctx = _cp_model_cfg.get("context_length") - if _cp_ctx is not None: - try: - _config_context_length = int(_cp_ctx) - except (TypeError, ValueError): - logger.warning( - "Invalid context_length for model %r in " - "custom_providers: %r — must be a plain " - "integer (e.g. 256000, not '256K'). " - "Falling back to auto-detection.", - self.model, _cp_ctx, - ) - print( - f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n" - f" Must be a plain integer (e.g. 256000, not '256K').\n" - f" Falling back to auto-detected context window.\n", - file=sys.stderr, - ) - break + _cp_ctx_resolved = None + + # Surface a clear warning if the user set a context_length but it + # wasn't a valid positive int — the helper silently skips those. + if _config_context_length is None: + _target = self.base_url.rstrip("/") if self.base_url else "" + for _cp_entry in _custom_providers: + if not isinstance(_cp_entry, dict): + continue + _cp_url = (_cp_entry.get("base_url") or "").rstrip("/") + if _target and _cp_url == _target: + _cp_models = _cp_entry.get("models", {}) + if isinstance(_cp_models, dict): + _cp_model_cfg = _cp_models.get(self.model, {}) + if isinstance(_cp_model_cfg, dict): + _cp_ctx = _cp_model_cfg.get("context_length") + if _cp_ctx is not None: + try: + _parsed = int(_cp_ctx) + if _parsed <= 0: + raise ValueError + except (TypeError, ValueError): + logger.warning( + "Invalid context_length for model %r in " + "custom_providers: %r — must be a positive " + "integer (e.g. 256000, not '256K'). " + "Falling back to auto-detection.", + self.model, _cp_ctx, + ) + print( + f"\n⚠ Invalid context_length for model {self.model!r} in custom_providers: {_cp_ctx!r}\n" + f" Must be a positive integer (e.g. 256000, not '256K').\n" + f" Falling back to auto-detected context window.\n", + file=sys.stderr, + ) + break # Select context engine: config-driven (like memory providers). # 1. Check config.yaml context.engine setting @@ -1851,6 +1872,7 @@ class AIAgent: api_key=getattr(self, "api_key", ""), config_context_length=_config_context_length, provider=self.provider, + custom_providers=_custom_providers, ) self.context_compressor.update_model( model=self.model, @@ -2141,12 +2163,23 @@ class AIAgent: # ── Update context compressor ── if hasattr(self, "context_compressor") and self.context_compressor: from agent.model_metadata import get_model_context_length + # Re-read custom_providers from live config so per-model + # context_length overrides are honored when switching to a + # custom provider mid-session (closes #15779). + _sm_custom_providers = None + try: + from hermes_cli.config import load_config, get_compatible_custom_providers + _sm_cfg = load_config() + _sm_custom_providers = get_compatible_custom_providers(_sm_cfg) + except Exception: + _sm_custom_providers = None new_context_length = get_model_context_length( self.model, base_url=self.base_url, api_key=self.api_key, provider=self.provider, config_context_length=getattr(self, "_config_context_length", None), + custom_providers=_sm_custom_providers, ) self.context_compressor.update_model( model=self.model, diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index fc4ed0bf5b..42ec0a464f 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -459,9 +459,10 @@ class TestGetModelContextLength: @patch("agent.model_metadata.fetch_model_metadata") def test_api_missing_context_length_key(self, mock_fetch): - """Model in API but without context_length → defaults to 128000.""" + """Model in API but without context_length → defaults to the top + probe tier (currently 256K).""" mock_fetch.return_value = {"test/model": {"name": "Test"}} - assert get_model_context_length("test/model") == 128000 + assert get_model_context_length("test/model") == CONTEXT_PROBE_TIERS[0] @patch("agent.model_metadata.fetch_model_metadata") def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path): @@ -814,14 +815,17 @@ class TestContextProbeTiers: for i in range(len(CONTEXT_PROBE_TIERS) - 1): assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1] - def test_first_tier_is_128k(self): - assert CONTEXT_PROBE_TIERS[0] == 128_000 + def test_first_tier_is_256k(self): + assert CONTEXT_PROBE_TIERS[0] == 256_000 def test_last_tier_is_8k(self): assert CONTEXT_PROBE_TIERS[-1] == 8_000 class TestGetNextProbeTier: + def test_from_256k(self): + assert get_next_probe_tier(256_000) == 128_000 + def test_from_128k(self): assert get_next_probe_tier(128_000) == 64_000 @@ -841,8 +845,8 @@ class TestGetNextProbeTier: assert get_next_probe_tier(100_000) == 64_000 def test_above_max_tier(self): - """Value above 128K should return 128K.""" - assert get_next_probe_tier(500_000) == 128_000 + """Value above 256K should return 256K.""" + assert get_next_probe_tier(500_000) == 256_000 def test_zero_returns_none(self): assert get_next_probe_tier(0) is None diff --git a/tests/gateway/test_session_info.py b/tests/gateway/test_session_info.py index 5f04b1a482..d8c65305f7 100644 --- a/tests/gateway/test_session_info.py +++ b/tests/gateway/test_session_info.py @@ -58,7 +58,7 @@ class TestFormatSessionInfo: {"provider": "", "base_url": "", "api_key": ""}) with p1, p2, p3: info = runner._format_session_info() - assert "128K" in info + assert "256K" in info assert "model.context_length" in info def test_local_endpoint_shown(self, runner, tmp_path): diff --git a/tests/hermes_cli/test_custom_provider_context_length.py b/tests/hermes_cli/test_custom_provider_context_length.py new file mode 100644 index 0000000000..70e7760e7e --- /dev/null +++ b/tests/hermes_cli/test_custom_provider_context_length.py @@ -0,0 +1,240 @@ +"""Regression tests for custom_providers per-model context_length resolution. + +Covers the fix for #15779 — mid-session /model switch to a named custom +provider must honor ``custom_providers[].models..context_length`` the +same way startup already does. +""" +from __future__ import annotations + +from unittest.mock import patch + +from hermes_cli.config import get_custom_provider_context_length + + +class TestGetCustomProviderContextLength: + def test_returns_override_for_matching_entry(self): + custom = [ + { + "name": "my-endpoint", + "base_url": "https://example.invalid/v1", + "models": {"gpt-5.5": {"context_length": 1_050_000}}, + } + ] + assert ( + get_custom_provider_context_length( + "gpt-5.5", "https://example.invalid/v1", custom + ) + == 1_050_000 + ) + + def test_trailing_slash_insensitive(self): + custom = [ + { + "base_url": "https://example.invalid/v1/", + "models": {"m": {"context_length": 500_000}}, + } + ] + # config has trailing slash, runtime doesn't — must match + assert ( + get_custom_provider_context_length( + "m", "https://example.invalid/v1", custom + ) + == 500_000 + ) + # and the reverse + custom2 = [ + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": 500_000}}, + } + ] + assert ( + get_custom_provider_context_length( + "m", "https://example.invalid/v1/", custom2 + ) + == 500_000 + ) + + def test_returns_none_when_url_does_not_match(self): + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": 400_000}}, + } + ] + assert ( + get_custom_provider_context_length( + "m", "https://other.invalid/v1", custom + ) + is None + ) + + def test_returns_none_when_model_does_not_match(self): + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"gpt-5.5": {"context_length": 400_000}}, + } + ] + assert ( + get_custom_provider_context_length( + "different-model", "https://example.invalid/v1", custom + ) + is None + ) + + def test_returns_none_for_string_value(self): + """'256K' string is not a valid int — skip silently. + + (The inline startup path still emits a user-visible warning; the + helper itself returns None so downstream fallbacks can run.) + """ + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": "256K"}}, + } + ] + assert ( + get_custom_provider_context_length( + "m", "https://example.invalid/v1", custom + ) + is None + ) + + def test_returns_none_for_zero_or_negative(self): + for bad in (0, -1, -100): + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": bad}}, + } + ] + assert ( + get_custom_provider_context_length( + "m", "https://example.invalid/v1", custom + ) + is None + ), f"value {bad!r} should be rejected" + + def test_empty_inputs_return_none(self): + assert get_custom_provider_context_length("", "http://x", [{"base_url": "http://x", "models": {"": {"context_length": 1}}}]) is None + assert get_custom_provider_context_length("m", "", [{"base_url": "", "models": {"m": {"context_length": 1}}}]) is None + assert get_custom_provider_context_length("m", "http://x", None) is None + assert get_custom_provider_context_length("m", "http://x", []) is None + + def test_ignores_non_dict_entries(self): + """Malformed entries must not crash the lookup.""" + custom = [ + "not a dict", + None, + {"base_url": "https://example.invalid/v1", "models": "not a dict"}, + {"base_url": "https://example.invalid/v1", "models": {"m": "not a dict"}}, + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": 400_000}}, + }, + ] + assert ( + get_custom_provider_context_length( + "m", "https://example.invalid/v1", custom + ) + == 400_000 + ) + + +class TestGetModelContextLengthHonorsOverride: + """agent.model_metadata.get_model_context_length must honor the + custom_providers override at step 0b — before any probe, cache hit, + or models.dev lookup can override it. + """ + + def _mock_all_probes(self): + """Context manager that disables every downstream resolution step.""" + from agent import model_metadata as _mm + return [ + patch.object(_mm, "get_cached_context_length", return_value=None), + patch.object(_mm, "fetch_endpoint_model_metadata", return_value={}), + patch.object(_mm, "fetch_model_metadata", return_value={}), + patch.object(_mm, "is_local_endpoint", return_value=False), + patch.object(_mm, "_is_known_provider_base_url", return_value=False), + ] + + def test_custom_providers_override_wins_over_default_fallback(self): + from agent.model_metadata import get_model_context_length + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"gpt-5.5": {"context_length": 1_050_000}}, + } + ] + patches = self._mock_all_probes() + for p in patches: + p.start() + try: + ctx = get_model_context_length( + "gpt-5.5", + base_url="https://example.invalid/v1", + provider="custom", + custom_providers=custom, + ) + finally: + for p in patches: + p.stop() + assert ctx == 1_050_000 + + def test_explicit_config_context_length_still_wins(self): + """Top-level model.context_length (step 0) outranks custom_providers (step 0b). + + Users who set both should see the top-level value — that's the + documented precedence and matches the long-standing step-0 behavior. + """ + from agent.model_metadata import get_model_context_length + custom = [ + { + "base_url": "https://example.invalid/v1", + "models": {"m": {"context_length": 1_050_000}}, + } + ] + ctx = get_model_context_length( + "m", + base_url="https://example.invalid/v1", + provider="custom", + config_context_length=500_000, # explicit top-level wins + custom_providers=custom, + ) + assert ctx == 500_000 + + def test_no_override_falls_through_to_default(self): + """With custom_providers=None and all probes disabled, resolver + returns DEFAULT_FALLBACK_CONTEXT (256K after the stepdown bump). + """ + from agent.model_metadata import get_model_context_length, DEFAULT_FALLBACK_CONTEXT + patches = self._mock_all_probes() + for p in patches: + p.start() + try: + ctx = get_model_context_length( + "unknown-model", + base_url="https://example.invalid/v1", + provider="custom", + custom_providers=None, + ) + finally: + for p in patches: + p.stop() + assert ctx == DEFAULT_FALLBACK_CONTEXT + + +class TestContextProbeTiers: + def test_256k_is_top_tier_and_default(self): + """The stepdown probe starts at 256K and 256K is the new default.""" + from agent.model_metadata import CONTEXT_PROBE_TIERS, DEFAULT_FALLBACK_CONTEXT + + assert CONTEXT_PROBE_TIERS[0] == 256_000 + assert DEFAULT_FALLBACK_CONTEXT == 256_000 + # Tiers still descend monotonically + for a, b in zip(CONTEXT_PROBE_TIERS, CONTEXT_PROBE_TIERS[1:]): + assert a > b, f"tiers must strictly descend, got {a} then {b}" + # 128K is still a tier (users relying on it probe-down get there) + assert 128_000 in CONTEXT_PROBE_TIERS diff --git a/tests/hermes_cli/test_model_switch_context_display.py b/tests/hermes_cli/test_model_switch_context_display.py index e30c5a3c6c..cb6275af09 100644 --- a/tests/hermes_cli/test_model_switch_context_display.py +++ b/tests/hermes_cli/test_model_switch_context_display.py @@ -88,3 +88,61 @@ class TestResolveDisplayContextLength: model_info=fake_mi, ) assert ctx == 128_000 + + def test_custom_providers_override_honored(self): + """Regression for #15779: /model switch onto a custom provider must + surface the configured per-model context_length, not the 128K/256K + fallback. + """ + custom_provs = [ + { + "name": "my-custom-endpoint", + "base_url": "https://example.invalid/v1", + "models": {"gpt-5.5": {"context_length": 1_050_000}}, + } + ] + # Real resolver call — no mock — so the override path is exercised + # through agent.model_metadata.get_model_context_length. + from unittest.mock import patch as _p + from agent import model_metadata as _mm + with _p.object(_mm, "get_cached_context_length", return_value=None), \ + _p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \ + _p.object(_mm, "fetch_model_metadata", return_value={}), \ + _p.object(_mm, "is_local_endpoint", return_value=False), \ + _p.object(_mm, "_is_known_provider_base_url", return_value=False): + ctx = resolve_display_context_length( + "gpt-5.5", + "custom", + base_url="https://example.invalid/v1", + api_key="k", + custom_providers=custom_provs, + ) + assert ctx == 1_050_000, ( + "custom_providers[].models.gpt-5.5.context_length=1.05M must win " + "over probe-down fallback" + ) + + def test_custom_providers_trailing_slash_insensitive(self): + """Base URL comparison must tolerate trailing-slash differences + between config.yaml and the runtime value. + """ + custom_provs = [ + { + "base_url": "https://example.invalid/v1/", + "models": {"m": {"context_length": 400_000}}, + } + ] + from unittest.mock import patch as _p + from agent import model_metadata as _mm + with _p.object(_mm, "get_cached_context_length", return_value=None), \ + _p.object(_mm, "fetch_endpoint_model_metadata", return_value={}), \ + _p.object(_mm, "fetch_model_metadata", return_value={}), \ + _p.object(_mm, "is_local_endpoint", return_value=False), \ + _p.object(_mm, "_is_known_provider_base_url", return_value=False): + ctx = resolve_display_context_length( + "m", + "custom", + base_url="https://example.invalid/v1", # no trailing slash + custom_providers=custom_provs, + ) + assert ctx == 400_000