mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-24 10:52:21 +00:00
Follow-up to the salvaged preflight-compression warning: - Replace silent `except Exception: pass` at all 5 guard call sites (cli.py x2, gateway/slash_commands.py x2, tui_gateway/server.py) with `logger.debug(...)` so signature drift in the guard helper isn't hidden. - tui_gateway/server.py: set the confirm dict's `warning` field to the merged message (was bare expensive-model text) so it matches `confirm_message` for any future consumer reading `warning`. - Add trailing newlines to the two new files.
169 lines
5.6 KiB
Python
169 lines
5.6 KiB
Python
"""Warn when an in-session model switch will trigger preflight compression on the next turn.
|
|
|
|
Addresses part of #23767 ("user-facing guardrail when switching from a
|
|
high-context provider to a substantially lower-context provider"). The other
|
|
proposed fixes from that issue (hard preflight token guard, metadata cache
|
|
invalidation on switch, compression safety invariant, oversized tool-output
|
|
handling) are tracked separately.
|
|
|
|
Mirrors the expensive-model guard pattern: merge into ``ModelSwitchResult.warning_message``
|
|
so Herm TUI, CLI, and gateway surfaces that already show switch warnings pick it up.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, List, Optional
|
|
|
|
from agent.model_metadata import MINIMUM_CONTEXT_LENGTH
|
|
from hermes_cli.model_switch import ModelSwitchResult, resolve_display_context_length
|
|
|
|
|
|
def _append_warning(result: ModelSwitchResult, text: str) -> None:
|
|
if result.warning_message:
|
|
result.warning_message = f"{result.warning_message} | {text}"
|
|
else:
|
|
result.warning_message = text
|
|
|
|
|
|
def _threshold_tokens(context_length: int, threshold_percent: float) -> int:
|
|
return max(int(context_length * threshold_percent), MINIMUM_CONTEXT_LENGTH)
|
|
|
|
|
|
def _estimate_tokens(agent: Any, messages: Optional[List[dict]]) -> Optional[int]:
|
|
cc = getattr(agent, "context_compressor", None)
|
|
if cc is None:
|
|
return None
|
|
|
|
if messages is not None:
|
|
protect = int(getattr(cc, "protect_first_n", 3)) + int(
|
|
getattr(cc, "protect_last_n", 20)
|
|
) + 1
|
|
if len(messages) <= protect:
|
|
return None
|
|
try:
|
|
from agent.model_metadata import estimate_request_tokens_rough
|
|
|
|
system_prompt = getattr(agent, "_cached_system_prompt", None) or ""
|
|
tools = getattr(agent, "tools", None)
|
|
return int(
|
|
estimate_request_tokens_rough(
|
|
messages,
|
|
system_prompt=system_prompt,
|
|
tools=tools or None,
|
|
)
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
last = int(getattr(cc, "last_prompt_tokens", 0) or 0)
|
|
if last > 0:
|
|
return last
|
|
session_prompt = int(getattr(agent, "session_prompt_tokens", 0) or 0)
|
|
return session_prompt if session_prompt > 0 else None
|
|
|
|
|
|
def merge_preflight_compression_warning(
|
|
result: ModelSwitchResult,
|
|
*,
|
|
agent: Any = None,
|
|
messages: Optional[List[dict]] = None,
|
|
custom_providers: list | None = None,
|
|
config_context_length: int | None = None,
|
|
) -> None:
|
|
"""If the next user message will likely preflight-compress, append a warning."""
|
|
if not result.success or agent is None:
|
|
return
|
|
if not getattr(agent, "compression_enabled", True):
|
|
return
|
|
|
|
cc = getattr(agent, "context_compressor", None)
|
|
if cc is None:
|
|
return
|
|
|
|
old_ctx = int(getattr(cc, "context_length", 0) or 0)
|
|
new_ctx = resolve_display_context_length(
|
|
result.new_model,
|
|
result.target_provider,
|
|
base_url=result.base_url or getattr(agent, "base_url", "") or "",
|
|
api_key=result.api_key or getattr(agent, "api_key", "") or "",
|
|
model_info=result.model_info,
|
|
custom_providers=custom_providers,
|
|
config_context_length=config_context_length,
|
|
)
|
|
if not new_ctx:
|
|
return
|
|
|
|
estimate = _estimate_tokens(agent, messages)
|
|
if estimate is None:
|
|
return
|
|
|
|
pct = float(getattr(cc, "threshold_percent", 0.5))
|
|
new_threshold = _threshold_tokens(new_ctx, pct)
|
|
if estimate < new_threshold:
|
|
return
|
|
|
|
if int(getattr(cc, "_ineffective_compression_count", 0) or 0) >= 2:
|
|
return
|
|
|
|
parts: list[str] = []
|
|
if old_ctx and new_ctx < old_ctx:
|
|
parts.append(
|
|
f"Context window shrinks ({old_ctx:,} → {new_ctx:,}). "
|
|
)
|
|
parts.append(
|
|
f"Session is ~{estimate:,} tokens; "
|
|
f"{result.new_model} allows {new_ctx:,} "
|
|
f"(auto-compress at ~{new_threshold:,}). "
|
|
f"Your next message will run preflight compression before the model replies."
|
|
)
|
|
_append_warning(result, "".join(parts))
|
|
|
|
|
|
def enrich_model_switch_warnings_for_gateway(
|
|
result: ModelSwitchResult,
|
|
runner: Any,
|
|
*,
|
|
session_key: str,
|
|
source: Any,
|
|
custom_providers: list | None = None,
|
|
load_gateway_config: Callable[[], dict] | None = None,
|
|
) -> None:
|
|
"""Gateway helper: cached agent + session DB messages."""
|
|
lock = getattr(runner, "_agent_cache_lock", None)
|
|
cache = getattr(runner, "_agent_cache", None)
|
|
agent = None
|
|
if lock is not None and cache is not None:
|
|
with lock:
|
|
entry = cache.get(session_key)
|
|
if entry and entry[0] is not None:
|
|
agent = entry[0]
|
|
if agent is None:
|
|
return
|
|
|
|
cfg_ctx = None
|
|
if load_gateway_config is not None:
|
|
try:
|
|
cfg = load_gateway_config()
|
|
model_cfg = cfg.get("model", {}) if isinstance(cfg, dict) else {}
|
|
if isinstance(model_cfg, dict) and model_cfg.get("context_length") is not None:
|
|
cfg_ctx = int(model_cfg["context_length"])
|
|
except Exception:
|
|
pass
|
|
|
|
messages = None
|
|
db = getattr(runner, "_session_db", None)
|
|
store = getattr(runner, "session_store", None)
|
|
if db is not None and store is not None:
|
|
try:
|
|
entry = store.get_or_create_session(source)
|
|
messages = db.get_messages_as_conversation(entry.session_id)
|
|
except Exception:
|
|
pass
|
|
|
|
merge_preflight_compression_warning(
|
|
result,
|
|
agent=agent,
|
|
messages=messages,
|
|
custom_providers=custom_providers,
|
|
config_context_length=cfg_ctx,
|
|
)
|