From 04730f32e7e836fb3b227caed3fcbea7e2985083 Mon Sep 17 00:00:00 2001 From: Tuna Dev Date: Sat, 20 Jun 2026 15:32:43 +0800 Subject: [PATCH] fix(cli): warn when in-session model switch will preflight-compress MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds hermes_cli/context_switch_guard.py mirroring the model_cost_guard pattern. When a user switches models mid-session (Herm TUI picker, CLI, or /model on Telegram/Discord), the warning surfaces on the existing ModelSwitchResult.warning_message path used by the expensive-model guard if the new model's compression threshold is below the current session size. Partial fix for #23767 — addresses only the 'user-facing guardrail when switching from a high-context provider to a substantially lower-context provider' slice. The other proposed fixes from that issue (hard preflight token guard, metadata cache invalidation on switch, compression safety invariant, oversized tool-output handling) are out of scope for this PR. --- cli.py | 26 +++ gateway/slash_commands.py | 34 ++++ hermes_cli/context_switch_guard.py | 169 ++++++++++++++++++ tests/hermes_cli/test_context_switch_guard.py | 105 +++++++++++ tui_gateway/server.py | 24 ++- website/docs/user-guide/configuring-models.md | 4 + 6 files changed, 361 insertions(+), 1 deletion(-) create mode 100644 hermes_cli/context_switch_guard.py create mode 100644 tests/hermes_cli/test_context_switch_guard.py diff --git a/cli.py b/cli.py index 794bf65763f..159f3486052 100644 --- a/cli.py +++ b/cli.py @@ -6936,6 +6936,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin): _cprint(f" ✗ {result.error_message}") return + if self.agent is not None: + try: + from hermes_cli.context_switch_guard import merge_preflight_compression_warning + + merge_preflight_compression_warning( + result, + agent=self.agent, + messages=list(self.conversation_history or []), + config_context_length=getattr(self.agent, "_config_context_length", None), + ) + except Exception: + pass + old_model = self.model self.model = result.new_model self.provider = result.target_provider @@ -7202,6 +7215,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin): _cprint(f" ✗ {result.error_message}") return + if self.agent is not None: + try: + from hermes_cli.context_switch_guard import merge_preflight_compression_warning + + merge_preflight_compression_warning( + result, + agent=self.agent, + messages=list(self.conversation_history or []), + config_context_length=getattr(self.agent, "_config_context_length", None), + ) + except Exception: + pass + if not self._confirm_expensive_model_switch(result): _cprint(" Model switch cancelled.") return diff --git a/gateway/slash_commands.py b/gateway/slash_commands.py index dbfd778daf9..b222b62ff1e 100644 --- a/gateway/slash_commands.py +++ b/gateway/slash_commands.py @@ -1160,6 +1160,22 @@ class GatewaySlashCommandsMixin: if not result.success: return t("gateway.model.error_prefix", error=result.error_message) + try: + from hermes_cli.context_switch_guard import ( + enrich_model_switch_warnings_for_gateway, + ) + + enrich_model_switch_warnings_for_gateway( + result, + _self, + session_key=_session_key, + source=event.source, + custom_providers=custom_provs, + load_gateway_config=_load_gateway_config, + ) + except Exception: + pass + # Update cached agent in-place cached_entry = None _cache_lock = getattr(_self, "_agent_cache_lock", None) @@ -1279,6 +1295,8 @@ class GatewaySlashCommandsMixin: if mi.has_cost_data(): lines.append(t("gateway.model.cost_label", cost=mi.format_cost())) lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities())) + if result.warning_message: + lines.append(t("gateway.model.warning_prefix", warning=result.warning_message)) if persist_global: lines.append(t("gateway.model.saved_global")) else: @@ -1345,6 +1363,22 @@ class GatewaySlashCommandsMixin: if not result.success: return t("gateway.model.error_prefix", error=result.error_message) + try: + from hermes_cli.context_switch_guard import ( + enrich_model_switch_warnings_for_gateway, + ) + + enrich_model_switch_warnings_for_gateway( + result, + self, + session_key=session_key, + source=source, + custom_providers=custom_provs, + load_gateway_config=_load_gateway_config, + ) + except Exception: + pass + async def _finish_switch() -> str: """Apply the resolved switch (agent, session, config) and build the reply.""" # If there's a cached agent, update it in-place diff --git a/hermes_cli/context_switch_guard.py b/hermes_cli/context_switch_guard.py new file mode 100644 index 00000000000..f0cb55bc73d --- /dev/null +++ b/hermes_cli/context_switch_guard.py @@ -0,0 +1,169 @@ +"""Warn when an in-session model switch will trigger preflight compression on the next turn. + +Addresses part of #23767 ("user-facing guardrail when switching from a +high-context provider to a substantially lower-context provider"). The other +proposed fixes from that issue (hard preflight token guard, metadata cache +invalidation on switch, compression safety invariant, oversized tool-output +handling) are tracked separately. + +Mirrors the expensive-model guard pattern: merge into ``ModelSwitchResult.warning_message`` +so Herm TUI, CLI, and gateway surfaces that already show switch warnings pick it up. +""" + +from __future__ import annotations + +from typing import Any, Callable, List, Optional + +from agent.model_metadata import MINIMUM_CONTEXT_LENGTH +from hermes_cli.model_switch import ModelSwitchResult, resolve_display_context_length + + +def _append_warning(result: ModelSwitchResult, text: str) -> None: + if result.warning_message: + result.warning_message = f"{result.warning_message} | {text}" + else: + result.warning_message = text + + +def _threshold_tokens(context_length: int, threshold_percent: float) -> int: + return max(int(context_length * threshold_percent), MINIMUM_CONTEXT_LENGTH) + + +def _estimate_tokens(agent: Any, messages: Optional[List[dict]]) -> Optional[int]: + cc = getattr(agent, "context_compressor", None) + if cc is None: + return None + + if messages is not None: + protect = int(getattr(cc, "protect_first_n", 3)) + int( + getattr(cc, "protect_last_n", 20) + ) + 1 + if len(messages) <= protect: + return None + try: + from agent.model_metadata import estimate_request_tokens_rough + + system_prompt = getattr(agent, "_cached_system_prompt", None) or "" + tools = getattr(agent, "tools", None) + return int( + estimate_request_tokens_rough( + messages, + system_prompt=system_prompt, + tools=tools or None, + ) + ) + except Exception: + pass + + last = int(getattr(cc, "last_prompt_tokens", 0) or 0) + if last > 0: + return last + session_prompt = int(getattr(agent, "session_prompt_tokens", 0) or 0) + return session_prompt if session_prompt > 0 else None + + +def merge_preflight_compression_warning( + result: ModelSwitchResult, + *, + agent: Any = None, + messages: Optional[List[dict]] = None, + custom_providers: list | None = None, + config_context_length: int | None = None, +) -> None: + """If the next user message will likely preflight-compress, append a warning.""" + if not result.success or agent is None: + return + if not getattr(agent, "compression_enabled", True): + return + + cc = getattr(agent, "context_compressor", None) + if cc is None: + return + + old_ctx = int(getattr(cc, "context_length", 0) or 0) + new_ctx = resolve_display_context_length( + result.new_model, + result.target_provider, + base_url=result.base_url or getattr(agent, "base_url", "") or "", + api_key=result.api_key or getattr(agent, "api_key", "") or "", + model_info=result.model_info, + custom_providers=custom_providers, + config_context_length=config_context_length, + ) + if not new_ctx: + return + + estimate = _estimate_tokens(agent, messages) + if estimate is None: + return + + pct = float(getattr(cc, "threshold_percent", 0.5)) + new_threshold = _threshold_tokens(new_ctx, pct) + if estimate < new_threshold: + return + + if int(getattr(cc, "_ineffective_compression_count", 0) or 0) >= 2: + return + + parts: list[str] = [] + if old_ctx and new_ctx < old_ctx: + parts.append( + f"Context window shrinks ({old_ctx:,} → {new_ctx:,}). " + ) + parts.append( + f"Session is ~{estimate:,} tokens; " + f"{result.new_model} allows {new_ctx:,} " + f"(auto-compress at ~{new_threshold:,}). " + f"Your next message will run preflight compression before the model replies." + ) + _append_warning(result, "".join(parts)) + + +def enrich_model_switch_warnings_for_gateway( + result: ModelSwitchResult, + runner: Any, + *, + session_key: str, + source: Any, + custom_providers: list | None = None, + load_gateway_config: Callable[[], dict] | None = None, +) -> None: + """Gateway helper: cached agent + session DB messages.""" + lock = getattr(runner, "_agent_cache_lock", None) + cache = getattr(runner, "_agent_cache", None) + agent = None + if lock is not None and cache is not None: + with lock: + entry = cache.get(session_key) + if entry and entry[0] is not None: + agent = entry[0] + if agent is None: + return + + cfg_ctx = None + if load_gateway_config is not None: + try: + cfg = load_gateway_config() + model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {} + if isinstance(model_cfg, dict) and model_cfg.get("context_length") is not None: + cfg_ctx = int(model_cfg["context_length"]) + except Exception: + pass + + messages = None + db = getattr(runner, "_session_db", None) + store = getattr(runner, "session_store", None) + if db is not None and store is not None: + try: + entry = store.get_or_create_session(source) + messages = db.get_messages_as_conversation(entry.session_id) + except Exception: + pass + + merge_preflight_compression_warning( + result, + agent=agent, + messages=messages, + custom_providers=custom_providers, + config_context_length=cfg_ctx, + ) \ No newline at end of file diff --git a/tests/hermes_cli/test_context_switch_guard.py b/tests/hermes_cli/test_context_switch_guard.py new file mode 100644 index 00000000000..ec61074444a --- /dev/null +++ b/tests/hermes_cli/test_context_switch_guard.py @@ -0,0 +1,105 @@ +"""Tests for hermes_cli.context_switch_guard.""" + +from __future__ import annotations + +from types import SimpleNamespace + +from hermes_cli.context_switch_guard import merge_preflight_compression_warning +from hermes_cli.model_switch import ModelSwitchResult + + +def _result(*, model: str = "small-model") -> ModelSwitchResult: + return ModelSwitchResult( + success=True, + new_model=model, + target_provider="openrouter", + provider_changed=False, + api_key="k", + base_url="https://example.com/v1", + api_mode="chat_completions", + provider_label="openrouter", + model_info={"context_length": 32_000}, + ) + + +def _compressor(monkeypatch, *, context_length: int = 200_000): + from agent.context_compressor import ContextCompressor + + monkeypatch.setattr( + "agent.context_compressor.get_model_context_length", + lambda *a, **k: context_length, + ) + return ContextCompressor( + model="big-model", + threshold_percent=0.5, + protect_first_n=3, + protect_last_n=20, + quiet_mode=True, + config_context_length=context_length, + ) + + +def test_no_warning_when_below_new_threshold(monkeypatch): + monkeypatch.setattr( + "hermes_cli.context_switch_guard.resolve_display_context_length", + lambda *a, **k: 32_000, + ) + cc = _compressor(monkeypatch) + cc.last_prompt_tokens = 10_000 + agent = SimpleNamespace( + context_compressor=cc, + compression_enabled=True, + conversation_history=[], + base_url="", + api_key="", + ) + result = _result() + merge_preflight_compression_warning(result, agent=agent) + assert not result.warning_message + + +def test_warns_when_estimate_exceeds_new_threshold(monkeypatch): + monkeypatch.setattr( + "hermes_cli.context_switch_guard.resolve_display_context_length", + lambda *a, **k: 32_000, + ) + monkeypatch.setattr( + "hermes_cli.context_switch_guard._estimate_tokens", + lambda *a, **k: 90_000, + ) + cc = _compressor(monkeypatch) + agent = SimpleNamespace( + context_compressor=cc, + compression_enabled=True, + conversation_history=[], + base_url="", + api_key="", + ) + result = _result() + merge_preflight_compression_warning(result, agent=agent) + assert result.warning_message + assert "preflight compression" in result.warning_message + assert "shrinks" in result.warning_message + + +def test_merge_appends_to_existing_warning(monkeypatch): + monkeypatch.setattr( + "hermes_cli.context_switch_guard._estimate_tokens", + lambda *a, **k: 90_000, + ) + monkeypatch.setattr( + "hermes_cli.context_switch_guard.resolve_display_context_length", + lambda *a, **k: 32_000, + ) + cc = _compressor(monkeypatch) + agent = SimpleNamespace( + context_compressor=cc, + compression_enabled=True, + base_url="", + api_key="", + ) + result = _result() + result.warning_message = "expensive" + merge_preflight_compression_warning(result, agent=agent) + assert "expensive" in result.warning_message + assert "preflight compression" in result.warning_message \ No newline at end of file diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 76a10c61206..81df58ca66b 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -2248,6 +2248,25 @@ def _apply_model_switch( if not result.success: raise ValueError(result.error_message or "model switch failed") + if agent: + try: + from hermes_cli.context_switch_guard import merge_preflight_compression_warning + + _cfg_ctx = None + if isinstance(cfg, dict): + _mc = cfg.get("model", {}) + if isinstance(_mc, dict) and _mc.get("context_length") is not None: + _cfg_ctx = int(_mc["context_length"]) + merge_preflight_compression_warning( + result, + agent=agent, + messages=list(session.get("history", [])), + custom_providers=custom_provs, + config_context_length=_cfg_ctx, + ) + except Exception: + pass + if not confirm_expensive_model: try: from hermes_cli.model_cost_guard import expensive_model_warning @@ -2262,11 +2281,14 @@ def _apply_model_switch( except Exception: warning = None if warning is not None: + confirm_msg = warning.message + if result.warning_message: + confirm_msg = f"{confirm_msg}\n\n{result.warning_message}" return { "value": result.new_model, "warning": warning.message, "confirm_required": True, - "confirm_message": warning.message, + "confirm_message": confirm_msg, } if agent: diff --git a/website/docs/user-guide/configuring-models.md b/website/docs/user-guide/configuring-models.md index 8d749e15143..f73d2b28769 100644 --- a/website/docs/user-guide/configuring-models.md +++ b/website/docs/user-guide/configuring-models.md @@ -47,6 +47,10 @@ Type in the filter box to narrow by provider name, slug, or model ID. Pick a model, hit **Switch**, and Hermes writes it to `~/.hermes/config.yaml` under the `model` section. **This applies to new sessions only** — any chat tab you already have open keeps running whatever model it started with. To hot-swap the current chat, use the `/model` slash command inside it. +### Mid-session switches and context warnings + +When you switch models **inside an active session** (Herm TUI model picker, `hermes` CLI, or `/model` on Telegram/Discord), Hermes estimates whether your **next message** will run **preflight context compression** against the new model's window. If the session is already near or above that model's compression threshold (see [Context Compression](./configuration.md#context-compression)), the switch reply includes a warning — the same `warning_message` path used for expensive-model notices. The switch still applies immediately; compression runs on the **first user message after the switch**, before the model answers. + ## Setting auxiliary models Click **Show auxiliary** to reveal the 11 task slots: