mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
fix(cli): warn when in-session model switch will preflight-compress
Adds hermes_cli/context_switch_guard.py mirroring the model_cost_guard pattern. When a user switches models mid-session (Herm TUI picker, CLI, or /model on Telegram/Discord), the warning surfaces on the existing ModelSwitchResult.warning_message path used by the expensive-model guard if the new model's compression threshold is below the current session size. Partial fix for #23767 — addresses only the 'user-facing guardrail when switching from a high-context provider to a substantially lower-context provider' slice. The other proposed fixes from that issue (hard preflight token guard, metadata cache invalidation on switch, compression safety invariant, oversized tool-output handling) are out of scope for this PR.
This commit is contained in:
parent
7b9a0b315b
commit
04730f32e7
6 changed files with 361 additions and 1 deletions
26
cli.py
26
cli.py
|
|
@ -6936,6 +6936,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
|
|||
_cprint(f" ✗ {result.error_message}")
|
||||
return
|
||||
|
||||
if self.agent is not None:
|
||||
try:
|
||||
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
|
||||
|
||||
merge_preflight_compression_warning(
|
||||
result,
|
||||
agent=self.agent,
|
||||
messages=list(self.conversation_history or []),
|
||||
config_context_length=getattr(self.agent, "_config_context_length", None),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
old_model = self.model
|
||||
self.model = result.new_model
|
||||
self.provider = result.target_provider
|
||||
|
|
@ -7202,6 +7215,19 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
|
|||
_cprint(f" ✗ {result.error_message}")
|
||||
return
|
||||
|
||||
if self.agent is not None:
|
||||
try:
|
||||
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
|
||||
|
||||
merge_preflight_compression_warning(
|
||||
result,
|
||||
agent=self.agent,
|
||||
messages=list(self.conversation_history or []),
|
||||
config_context_length=getattr(self.agent, "_config_context_length", None),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not self._confirm_expensive_model_switch(result):
|
||||
_cprint(" Model switch cancelled.")
|
||||
return
|
||||
|
|
|
|||
|
|
@ -1160,6 +1160,22 @@ class GatewaySlashCommandsMixin:
|
|||
if not result.success:
|
||||
return t("gateway.model.error_prefix", error=result.error_message)
|
||||
|
||||
try:
|
||||
from hermes_cli.context_switch_guard import (
|
||||
enrich_model_switch_warnings_for_gateway,
|
||||
)
|
||||
|
||||
enrich_model_switch_warnings_for_gateway(
|
||||
result,
|
||||
_self,
|
||||
session_key=_session_key,
|
||||
source=event.source,
|
||||
custom_providers=custom_provs,
|
||||
load_gateway_config=_load_gateway_config,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Update cached agent in-place
|
||||
cached_entry = None
|
||||
_cache_lock = getattr(_self, "_agent_cache_lock", None)
|
||||
|
|
@ -1279,6 +1295,8 @@ class GatewaySlashCommandsMixin:
|
|||
if mi.has_cost_data():
|
||||
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
|
||||
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
|
||||
if result.warning_message:
|
||||
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
|
||||
if persist_global:
|
||||
lines.append(t("gateway.model.saved_global"))
|
||||
else:
|
||||
|
|
@ -1345,6 +1363,22 @@ class GatewaySlashCommandsMixin:
|
|||
if not result.success:
|
||||
return t("gateway.model.error_prefix", error=result.error_message)
|
||||
|
||||
try:
|
||||
from hermes_cli.context_switch_guard import (
|
||||
enrich_model_switch_warnings_for_gateway,
|
||||
)
|
||||
|
||||
enrich_model_switch_warnings_for_gateway(
|
||||
result,
|
||||
self,
|
||||
session_key=session_key,
|
||||
source=source,
|
||||
custom_providers=custom_provs,
|
||||
load_gateway_config=_load_gateway_config,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _finish_switch() -> str:
|
||||
"""Apply the resolved switch (agent, session, config) and build the reply."""
|
||||
# If there's a cached agent, update it in-place
|
||||
|
|
|
|||
169
hermes_cli/context_switch_guard.py
Normal file
169
hermes_cli/context_switch_guard.py
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
"""Warn when an in-session model switch will trigger preflight compression on the next turn.
|
||||
|
||||
Addresses part of #23767 ("user-facing guardrail when switching from a
|
||||
high-context provider to a substantially lower-context provider"). The other
|
||||
proposed fixes from that issue (hard preflight token guard, metadata cache
|
||||
invalidation on switch, compression safety invariant, oversized tool-output
|
||||
handling) are tracked separately.
|
||||
|
||||
Mirrors the expensive-model guard pattern: merge into ``ModelSwitchResult.warning_message``
|
||||
so Herm TUI, CLI, and gateway surfaces that already show switch warnings pick it up.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
|
||||
from hermes_cli.model_switch import ModelSwitchResult, resolve_display_context_length
|
||||
|
||||
|
||||
def _append_warning(result: ModelSwitchResult, text: str) -> None:
|
||||
if result.warning_message:
|
||||
result.warning_message = f"{result.warning_message} | {text}"
|
||||
else:
|
||||
result.warning_message = text
|
||||
|
||||
|
||||
def _threshold_tokens(context_length: int, threshold_percent: float) -> int:
|
||||
return max(int(context_length * threshold_percent), MINIMUM_CONTEXT_LENGTH)
|
||||
|
||||
|
||||
def _estimate_tokens(agent: Any, messages: Optional[List[dict]]) -> Optional[int]:
|
||||
cc = getattr(agent, "context_compressor", None)
|
||||
if cc is None:
|
||||
return None
|
||||
|
||||
if messages is not None:
|
||||
protect = int(getattr(cc, "protect_first_n", 3)) + int(
|
||||
getattr(cc, "protect_last_n", 20)
|
||||
) + 1
|
||||
if len(messages) <= protect:
|
||||
return None
|
||||
try:
|
||||
from agent.model_metadata import estimate_request_tokens_rough
|
||||
|
||||
system_prompt = getattr(agent, "_cached_system_prompt", None) or ""
|
||||
tools = getattr(agent, "tools", None)
|
||||
return int(
|
||||
estimate_request_tokens_rough(
|
||||
messages,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools or None,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
last = int(getattr(cc, "last_prompt_tokens", 0) or 0)
|
||||
if last > 0:
|
||||
return last
|
||||
session_prompt = int(getattr(agent, "session_prompt_tokens", 0) or 0)
|
||||
return session_prompt if session_prompt > 0 else None
|
||||
|
||||
|
||||
def merge_preflight_compression_warning(
|
||||
result: ModelSwitchResult,
|
||||
*,
|
||||
agent: Any = None,
|
||||
messages: Optional[List[dict]] = None,
|
||||
custom_providers: list | None = None,
|
||||
config_context_length: int | None = None,
|
||||
) -> None:
|
||||
"""If the next user message will likely preflight-compress, append a warning."""
|
||||
if not result.success or agent is None:
|
||||
return
|
||||
if not getattr(agent, "compression_enabled", True):
|
||||
return
|
||||
|
||||
cc = getattr(agent, "context_compressor", None)
|
||||
if cc is None:
|
||||
return
|
||||
|
||||
old_ctx = int(getattr(cc, "context_length", 0) or 0)
|
||||
new_ctx = resolve_display_context_length(
|
||||
result.new_model,
|
||||
result.target_provider,
|
||||
base_url=result.base_url or getattr(agent, "base_url", "") or "",
|
||||
api_key=result.api_key or getattr(agent, "api_key", "") or "",
|
||||
model_info=result.model_info,
|
||||
custom_providers=custom_providers,
|
||||
config_context_length=config_context_length,
|
||||
)
|
||||
if not new_ctx:
|
||||
return
|
||||
|
||||
estimate = _estimate_tokens(agent, messages)
|
||||
if estimate is None:
|
||||
return
|
||||
|
||||
pct = float(getattr(cc, "threshold_percent", 0.5))
|
||||
new_threshold = _threshold_tokens(new_ctx, pct)
|
||||
if estimate < new_threshold:
|
||||
return
|
||||
|
||||
if int(getattr(cc, "_ineffective_compression_count", 0) or 0) >= 2:
|
||||
return
|
||||
|
||||
parts: list[str] = []
|
||||
if old_ctx and new_ctx < old_ctx:
|
||||
parts.append(
|
||||
f"Context window shrinks ({old_ctx:,} → {new_ctx:,}). "
|
||||
)
|
||||
parts.append(
|
||||
f"Session is ~{estimate:,} tokens; "
|
||||
f"{result.new_model} allows {new_ctx:,} "
|
||||
f"(auto-compress at ~{new_threshold:,}). "
|
||||
f"Your next message will run preflight compression before the model replies."
|
||||
)
|
||||
_append_warning(result, "".join(parts))
|
||||
|
||||
|
||||
def enrich_model_switch_warnings_for_gateway(
|
||||
result: ModelSwitchResult,
|
||||
runner: Any,
|
||||
*,
|
||||
session_key: str,
|
||||
source: Any,
|
||||
custom_providers: list | None = None,
|
||||
load_gateway_config: Callable[[], dict] | None = None,
|
||||
) -> None:
|
||||
"""Gateway helper: cached agent + session DB messages."""
|
||||
lock = getattr(runner, "_agent_cache_lock", None)
|
||||
cache = getattr(runner, "_agent_cache", None)
|
||||
agent = None
|
||||
if lock is not None and cache is not None:
|
||||
with lock:
|
||||
entry = cache.get(session_key)
|
||||
if entry and entry[0] is not None:
|
||||
agent = entry[0]
|
||||
if agent is None:
|
||||
return
|
||||
|
||||
cfg_ctx = None
|
||||
if load_gateway_config is not None:
|
||||
try:
|
||||
cfg = load_gateway_config()
|
||||
model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
|
||||
if isinstance(model_cfg, dict) and model_cfg.get("context_length") is not None:
|
||||
cfg_ctx = int(model_cfg["context_length"])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
messages = None
|
||||
db = getattr(runner, "_session_db", None)
|
||||
store = getattr(runner, "session_store", None)
|
||||
if db is not None and store is not None:
|
||||
try:
|
||||
entry = store.get_or_create_session(source)
|
||||
messages = db.get_messages_as_conversation(entry.session_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
merge_preflight_compression_warning(
|
||||
result,
|
||||
agent=agent,
|
||||
messages=messages,
|
||||
custom_providers=custom_providers,
|
||||
config_context_length=cfg_ctx,
|
||||
)
|
||||
105
tests/hermes_cli/test_context_switch_guard.py
Normal file
105
tests/hermes_cli/test_context_switch_guard.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
"""Tests for hermes_cli.context_switch_guard."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
|
||||
from hermes_cli.model_switch import ModelSwitchResult
|
||||
|
||||
|
||||
def _result(*, model: str = "small-model") -> ModelSwitchResult:
|
||||
return ModelSwitchResult(
|
||||
success=True,
|
||||
new_model=model,
|
||||
target_provider="openrouter",
|
||||
provider_changed=False,
|
||||
api_key="k",
|
||||
base_url="https://example.com/v1",
|
||||
api_mode="chat_completions",
|
||||
provider_label="openrouter",
|
||||
model_info={"context_length": 32_000},
|
||||
)
|
||||
|
||||
|
||||
def _compressor(monkeypatch, *, context_length: int = 200_000):
|
||||
from agent.context_compressor import ContextCompressor
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.context_compressor.get_model_context_length",
|
||||
lambda *a, **k: context_length,
|
||||
)
|
||||
return ContextCompressor(
|
||||
model="big-model",
|
||||
threshold_percent=0.5,
|
||||
protect_first_n=3,
|
||||
protect_last_n=20,
|
||||
quiet_mode=True,
|
||||
config_context_length=context_length,
|
||||
)
|
||||
|
||||
|
||||
def test_no_warning_when_below_new_threshold(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.context_switch_guard.resolve_display_context_length",
|
||||
lambda *a, **k: 32_000,
|
||||
)
|
||||
cc = _compressor(monkeypatch)
|
||||
cc.last_prompt_tokens = 10_000
|
||||
agent = SimpleNamespace(
|
||||
context_compressor=cc,
|
||||
compression_enabled=True,
|
||||
conversation_history=[],
|
||||
base_url="",
|
||||
api_key="",
|
||||
)
|
||||
result = _result()
|
||||
merge_preflight_compression_warning(result, agent=agent)
|
||||
assert not result.warning_message
|
||||
|
||||
|
||||
def test_warns_when_estimate_exceeds_new_threshold(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.context_switch_guard.resolve_display_context_length",
|
||||
lambda *a, **k: 32_000,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.context_switch_guard._estimate_tokens",
|
||||
lambda *a, **k: 90_000,
|
||||
)
|
||||
cc = _compressor(monkeypatch)
|
||||
agent = SimpleNamespace(
|
||||
context_compressor=cc,
|
||||
compression_enabled=True,
|
||||
conversation_history=[],
|
||||
base_url="",
|
||||
api_key="",
|
||||
)
|
||||
result = _result()
|
||||
merge_preflight_compression_warning(result, agent=agent)
|
||||
assert result.warning_message
|
||||
assert "preflight compression" in result.warning_message
|
||||
assert "shrinks" in result.warning_message
|
||||
|
||||
|
||||
def test_merge_appends_to_existing_warning(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.context_switch_guard._estimate_tokens",
|
||||
lambda *a, **k: 90_000,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.context_switch_guard.resolve_display_context_length",
|
||||
lambda *a, **k: 32_000,
|
||||
)
|
||||
cc = _compressor(monkeypatch)
|
||||
agent = SimpleNamespace(
|
||||
context_compressor=cc,
|
||||
compression_enabled=True,
|
||||
base_url="",
|
||||
api_key="",
|
||||
)
|
||||
result = _result()
|
||||
result.warning_message = "expensive"
|
||||
merge_preflight_compression_warning(result, agent=agent)
|
||||
assert "expensive" in result.warning_message
|
||||
assert "preflight compression" in result.warning_message
|
||||
|
|
@ -2248,6 +2248,25 @@ def _apply_model_switch(
|
|||
if not result.success:
|
||||
raise ValueError(result.error_message or "model switch failed")
|
||||
|
||||
if agent:
|
||||
try:
|
||||
from hermes_cli.context_switch_guard import merge_preflight_compression_warning
|
||||
|
||||
_cfg_ctx = None
|
||||
if isinstance(cfg, dict):
|
||||
_mc = cfg.get("model", {})
|
||||
if isinstance(_mc, dict) and _mc.get("context_length") is not None:
|
||||
_cfg_ctx = int(_mc["context_length"])
|
||||
merge_preflight_compression_warning(
|
||||
result,
|
||||
agent=agent,
|
||||
messages=list(session.get("history", [])),
|
||||
custom_providers=custom_provs,
|
||||
config_context_length=_cfg_ctx,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not confirm_expensive_model:
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
|
@ -2262,11 +2281,14 @@ def _apply_model_switch(
|
|||
except Exception:
|
||||
warning = None
|
||||
if warning is not None:
|
||||
confirm_msg = warning.message
|
||||
if result.warning_message:
|
||||
confirm_msg = f"{confirm_msg}\n\n{result.warning_message}"
|
||||
return {
|
||||
"value": result.new_model,
|
||||
"warning": warning.message,
|
||||
"confirm_required": True,
|
||||
"confirm_message": warning.message,
|
||||
"confirm_message": confirm_msg,
|
||||
}
|
||||
|
||||
if agent:
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ Type in the filter box to narrow by provider name, slug, or model ID.
|
|||
|
||||
Pick a model, hit **Switch**, and Hermes writes it to `~/.hermes/config.yaml` under the `model` section. **This applies to new sessions only** — any chat tab you already have open keeps running whatever model it started with. To hot-swap the current chat, use the `/model` slash command inside it.
|
||||
|
||||
### Mid-session switches and context warnings
|
||||
|
||||
When you switch models **inside an active session** (Herm TUI model picker, `hermes` CLI, or `/model` on Telegram/Discord), Hermes estimates whether your **next message** will run **preflight context compression** against the new model's window. If the session is already near or above that model's compression threshold (see [Context Compression](./configuration.md#context-compression)), the switch reply includes a warning — the same `warning_message` path used for expensive-model notices. The switch still applies immediately; compression runs on the **first user message after the switch**, before the model answers.
|
||||
|
||||
## Setting auxiliary models
|
||||
|
||||
Click **Show auxiliary** to reveal the 11 task slots:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue