Merge pull request #50115 from NousResearch/salvage/model-switch-preflight-warning

fix(cli): warn when in-session model switch will preflight-compress
This commit is contained in:
kshitij 2026-06-21 16:41:44 +05:30 committed by GitHub
commit 44d552ea5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 362 additions and 2 deletions

26
cli.py
View file

@ -6936,6 +6936,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
_cprint(f"{result.error_message}")
return
if self.agent is not None:
try:
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
merge_preflight_compression_warning(
result,
agent=self.agent,
messages=list(self.conversation_history or []),
config_context_length=getattr(self.agent, "_config_context_length", None),
)
except Exception as exc:
logger.debug("preflight-compression switch warning failed: %s", exc)
old_model = self.model
self.model = result.new_model
self.provider = result.target_provider
@ -7202,6 +7215,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
_cprint(f"{result.error_message}")
return
if self.agent is not None:
try:
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
merge_preflight_compression_warning(
result,
agent=self.agent,
messages=list(self.conversation_history or []),
config_context_length=getattr(self.agent, "_config_context_length", None),
)
except Exception as exc:
logger.debug("preflight-compression switch warning failed: %s", exc)
if not self._confirm_expensive_model_switch(result):
_cprint(" Model switch cancelled.")
return

View file

@ -1160,6 +1160,22 @@ class GatewaySlashCommandsMixin:
if not result.success:
return t("gateway.model.error_prefix", error=result.error_message)
try:
from hermes_cli.context_switch_guard import (
enrich_model_switch_warnings_for_gateway,
)
enrich_model_switch_warnings_for_gateway(
result,
_self,
session_key=_session_key,
source=event.source,
custom_providers=custom_provs,
load_gateway_config=_load_gateway_config,
)
except Exception as exc:
logger.debug("preflight-compression switch warning failed: %s", exc)
# Update cached agent in-place
cached_entry = None
_cache_lock = getattr(_self, "_agent_cache_lock", None)
@ -1279,6 +1295,8 @@ class GatewaySlashCommandsMixin:
if mi.has_cost_data():
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
if result.warning_message:
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
if persist_global:
lines.append(t("gateway.model.saved_global"))
else:
@ -1345,6 +1363,22 @@ class GatewaySlashCommandsMixin:
if not result.success:
return t("gateway.model.error_prefix", error=result.error_message)
try:
from hermes_cli.context_switch_guard import (
enrich_model_switch_warnings_for_gateway,
)
enrich_model_switch_warnings_for_gateway(
result,
self,
session_key=session_key,
source=source,
custom_providers=custom_provs,
load_gateway_config=_load_gateway_config,
)
except Exception as exc:
logger.debug("preflight-compression switch warning failed: %s", exc)
async def _finish_switch() -> str:
"""Apply the resolved switch (agent, session, config) and build the reply."""
# If there's a cached agent, update it in-place

View file

@ -0,0 +1,169 @@
"""Warn when an in-session model switch will trigger preflight compression on the next turn.
Addresses part of #23767 ("user-facing guardrail when switching from a
high-context provider to a substantially lower-context provider"). The other
proposed fixes from that issue (hard preflight token guard, metadata cache
invalidation on switch, compression safety invariant, oversized tool-output
handling) are tracked separately.
Mirrors the expensive-model guard pattern: merge into ``ModelSwitchResult.warning_message``
so Herm TUI, CLI, and gateway surfaces that already show switch warnings pick it up.
"""
from __future__ import annotations
from typing import Any, Callable, List, Optional
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
from hermes_cli.model_switch import ModelSwitchResult, resolve_display_context_length
def _append_warning(result: ModelSwitchResult, text: str) -> None:
if result.warning_message:
result.warning_message = f"{result.warning_message} | {text}"
else:
result.warning_message = text
def _threshold_tokens(context_length: int, threshold_percent: float) -> int:
return max(int(context_length * threshold_percent), MINIMUM_CONTEXT_LENGTH)
def _estimate_tokens(agent: Any, messages: Optional[List[dict]]) -> Optional[int]:
cc = getattr(agent, "context_compressor", None)
if cc is None:
return None
if messages is not None:
protect = int(getattr(cc, "protect_first_n", 3)) + int(
getattr(cc, "protect_last_n", 20)
) + 1
if len(messages) <= protect:
return None
try:
from agent.model_metadata import estimate_request_tokens_rough
system_prompt = getattr(agent, "_cached_system_prompt", None) or ""
tools = getattr(agent, "tools", None)
return int(
estimate_request_tokens_rough(
messages,
system_prompt=system_prompt,
tools=tools or None,
)
)
except Exception:
pass
last = int(getattr(cc, "last_prompt_tokens", 0) or 0)
if last > 0:
return last
session_prompt = int(getattr(agent, "session_prompt_tokens", 0) or 0)
return session_prompt if session_prompt > 0 else None
def merge_preflight_compression_warning(
result: ModelSwitchResult,
*,
agent: Any = None,
messages: Optional[List[dict]] = None,
custom_providers: list | None = None,
config_context_length: int | None = None,
) -> None:
"""If the next user message will likely preflight-compress, append a warning."""
if not result.success or agent is None:
return
if not getattr(agent, "compression_enabled", True):
return
cc = getattr(agent, "context_compressor", None)
if cc is None:
return
old_ctx = int(getattr(cc, "context_length", 0) or 0)
new_ctx = resolve_display_context_length(
result.new_model,
result.target_provider,
base_url=result.base_url or getattr(agent, "base_url", "") or "",
api_key=result.api_key or getattr(agent, "api_key", "") or "",
model_info=result.model_info,
custom_providers=custom_providers,
config_context_length=config_context_length,
)
if not new_ctx:
return
estimate = _estimate_tokens(agent, messages)
if estimate is None:
return
pct = float(getattr(cc, "threshold_percent", 0.5))
new_threshold = _threshold_tokens(new_ctx, pct)
if estimate < new_threshold:
return
if int(getattr(cc, "_ineffective_compression_count", 0) or 0) >= 2:
return
parts: list[str] = []
if old_ctx and new_ctx < old_ctx:
parts.append(
f"Context window shrinks ({old_ctx:,}{new_ctx:,}). "
)
parts.append(
f"Session is ~{estimate:,} tokens; "
f"{result.new_model} allows {new_ctx:,} "
f"(auto-compress at ~{new_threshold:,}). "
f"Your next message will run preflight compression before the model replies."
)
_append_warning(result, "".join(parts))
def enrich_model_switch_warnings_for_gateway(
result: ModelSwitchResult,
runner: Any,
*,
session_key: str,
source: Any,
custom_providers: list | None = None,
load_gateway_config: Callable[[], dict] | None = None,
) -> None:
"""Gateway helper: cached agent + session DB messages."""
lock = getattr(runner, "_agent_cache_lock", None)
cache = getattr(runner, "_agent_cache", None)
agent = None
if lock is not None and cache is not None:
with lock:
entry = cache.get(session_key)
if entry and entry[0] is not None:
agent = entry[0]
if agent is None:
return
cfg_ctx = None
if load_gateway_config is not None:
try:
cfg = load_gateway_config()
model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
if isinstance(model_cfg, dict) and model_cfg.get("context_length") is not None:
cfg_ctx = int(model_cfg["context_length"])
except Exception:
pass
messages = None
db = getattr(runner, "_session_db", None)
store = getattr(runner, "session_store", None)
if db is not None and store is not None:
try:
entry = store.get_or_create_session(source)
messages = db.get_messages_as_conversation(entry.session_id)
except Exception:
pass
merge_preflight_compression_warning(
result,
agent=agent,
messages=messages,
custom_providers=custom_providers,
config_context_length=cfg_ctx,
)

View file

@ -0,0 +1,105 @@
"""Tests for hermes_cli.context_switch_guard."""
from __future__ import annotations
from types import SimpleNamespace
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
from hermes_cli.model_switch import ModelSwitchResult
def _result(*, model: str = "small-model") -> ModelSwitchResult:
return ModelSwitchResult(
success=True,
new_model=model,
target_provider="openrouter",
provider_changed=False,
api_key="k",
base_url="https://example.com/v1",
api_mode="chat_completions",
provider_label="openrouter",
model_info={"context_length": 32_000},
)
def _compressor(monkeypatch, *, context_length: int = 200_000):
from agent.context_compressor import ContextCompressor
monkeypatch.setattr(
"agent.context_compressor.get_model_context_length",
lambda *a, **k: context_length,
)
return ContextCompressor(
model="big-model",
threshold_percent=0.5,
protect_first_n=3,
protect_last_n=20,
quiet_mode=True,
config_context_length=context_length,
)
def test_no_warning_when_below_new_threshold(monkeypatch):
monkeypatch.setattr(
"hermes_cli.context_switch_guard.resolve_display_context_length",
lambda *a, **k: 32_000,
)
cc = _compressor(monkeypatch)
cc.last_prompt_tokens = 10_000
agent = SimpleNamespace(
context_compressor=cc,
compression_enabled=True,
conversation_history=[],
base_url="",
api_key="",
)
result = _result()
merge_preflight_compression_warning(result, agent=agent)
assert not result.warning_message
def test_warns_when_estimate_exceeds_new_threshold(monkeypatch):
monkeypatch.setattr(
"hermes_cli.context_switch_guard.resolve_display_context_length",
lambda *a, **k: 32_000,
)
monkeypatch.setattr(
"hermes_cli.context_switch_guard._estimate_tokens",
lambda *a, **k: 90_000,
)
cc = _compressor(monkeypatch)
agent = SimpleNamespace(
context_compressor=cc,
compression_enabled=True,
conversation_history=[],
base_url="",
api_key="",
)
result = _result()
merge_preflight_compression_warning(result, agent=agent)
assert result.warning_message
assert "preflight compression" in result.warning_message
assert "shrinks" in result.warning_message
def test_merge_appends_to_existing_warning(monkeypatch):
monkeypatch.setattr(
"hermes_cli.context_switch_guard._estimate_tokens",
lambda *a, **k: 90_000,
)
monkeypatch.setattr(
"hermes_cli.context_switch_guard.resolve_display_context_length",
lambda *a, **k: 32_000,
)
cc = _compressor(monkeypatch)
agent = SimpleNamespace(
context_compressor=cc,
compression_enabled=True,
base_url="",
api_key="",
)
result = _result()
result.warning_message = "expensive"
merge_preflight_compression_warning(result, agent=agent)
assert "expensive" in result.warning_message
assert "preflight compression" in result.warning_message

View file

@ -2248,6 +2248,25 @@ def _apply_model_switch(
if not result.success:
raise ValueError(result.error_message or "model switch failed")
if agent:
try:
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
_cfg_ctx = None
if isinstance(cfg, dict):
_mc = cfg.get("model", {})
if isinstance(_mc, dict) and _mc.get("context_length") is not None:
_cfg_ctx = int(_mc["context_length"])
merge_preflight_compression_warning(
result,
agent=agent,
messages=list(session.get("history", [])),
custom_providers=custom_provs,
config_context_length=_cfg_ctx,
)
except Exception as exc:
logger.debug("preflight-compression switch warning failed: %s", exc)
if not confirm_expensive_model:
try:
from hermes_cli.model_cost_guard import expensive_model_warning
@ -2262,11 +2281,14 @@ def _apply_model_switch(
except Exception:
warning = None
if warning is not None:
confirm_msg = warning.message
if result.warning_message:
confirm_msg = f"{confirm_msg}\n\n{result.warning_message}"
return {
"value": result.new_model,
"warning": warning.message,
"warning": confirm_msg,
"confirm_required": True,
"confirm_message": warning.message,
"confirm_message": confirm_msg,
}
if agent:

View file

@ -47,6 +47,10 @@ Type in the filter box to narrow by provider name, slug, or model ID.
Pick a model, hit **Switch**, and Hermes writes it to `~/.hermes/config.yaml` under the `model` section. **This applies to new sessions only** — any chat tab you already have open keeps running whatever model it started with. To hot-swap the current chat, use the `/model` slash command inside it.
### Mid-session switches and context warnings
When you switch models **inside an active session** (Herm TUI model picker, `hermes` CLI, or `/model` on Telegram/Discord), Hermes estimates whether your **next message** will run **preflight context compression** against the new model's window. If the session is already near or above that model's compression threshold (see [Context Compression](./configuration.md#context-compression)), the switch reply includes a warning — the same `warning_message` path used for expensive-model notices. The switch still applies immediately; compression runs on the **first user message after the switch**, before the model answers.
## Setting auxiliary models
Click **Show auxiliary** to reveal the 11 task slots: