mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-18 09:51:59 +00:00
fix(model): cover typed gateway /model path + async-safe pricing lookups
Follow-ups on top of #26016's expensive-model guard: - gateway/slash_commands.py: typed '/model <name>' now routes through the expensive-model confirmation gate (slash-confirm buttons / text fallback) instead of bypassing the guard the pickers enforce. Cancel leaves the session override and --global config untouched. - telegram/discord/web_server: run expensive_model_warning() via asyncio.to_thread — it can hit models.dev or a /models endpoint on a cache miss, which would otherwise block the event loop. - telegram: picker callback no longer toasts 'Model switched!' when the switch callback raised (both mm: and mc: paths). - tests: new tests/gateway/test_model_command_expensive_confirm.py pins the typed-path gate (prompt, confirm-once, cancel, cheap-model no-op).
This commit is contained in:
parent
af978ecb17
commit
243cada157
5 changed files with 390 additions and 139 deletions
|
|
@ -3136,11 +3136,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await query.answer(text="Picker expired.")
|
||||
return
|
||||
|
||||
switch_failed = False
|
||||
try:
|
||||
result_text = await callback(chat_id, model_id, provider_slug)
|
||||
except Exception as exc:
|
||||
logger.error("Model picker switch failed: %s", exc)
|
||||
result_text = f"Error switching model: {exc}"
|
||||
switch_failed = True
|
||||
|
||||
try:
|
||||
await query.edit_message_text(
|
||||
|
|
@ -3157,7 +3159,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
await query.answer(text="Model switched!")
|
||||
await query.answer(
|
||||
text="Switch failed." if switch_failed else "Model switched!"
|
||||
)
|
||||
self._model_picker_state.pop(chat_id, None)
|
||||
|
||||
elif data.startswith("mm:"):
|
||||
|
|
@ -3184,7 +3188,10 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
# Pricing lookup can hit models.dev / a /models endpoint on a
|
||||
# cache miss — keep it off the event loop.
|
||||
warning = await asyncio.to_thread(
|
||||
expensive_model_warning,
|
||||
model_id,
|
||||
provider=provider_slug,
|
||||
)
|
||||
|
|
@ -3208,11 +3215,13 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await query.answer(text="Confirm expensive model")
|
||||
return
|
||||
|
||||
switch_failed = False
|
||||
try:
|
||||
result_text = await callback(chat_id, model_id, provider_slug)
|
||||
except Exception as exc:
|
||||
logger.error("Model picker switch failed: %s", exc)
|
||||
result_text = f"Error switching model: {exc}"
|
||||
switch_failed = True
|
||||
|
||||
# Edit message to show confirmation, remove buttons
|
||||
try:
|
||||
|
|
@ -3231,7 +3240,9 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
await query.answer(text="Model switched!")
|
||||
await query.answer(
|
||||
text="Switch failed." if switch_failed else "Model switched!"
|
||||
)
|
||||
|
||||
# Clean up state
|
||||
self._model_picker_state.pop(chat_id, None)
|
||||
|
|
|
|||
|
|
@ -1146,149 +1146,197 @@ class GatewaySlashCommandsMixin:
|
|||
if not result.success:
|
||||
return t("gateway.model.error_prefix", error=result.error_message)
|
||||
|
||||
# If there's a cached agent, update it in-place
|
||||
cached_entry = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
cached_entry = _cache.get(session_key)
|
||||
async def _finish_switch() -> str:
|
||||
"""Apply the resolved switch (agent, session, config) and build the reply."""
|
||||
# If there's a cached agent, update it in-place
|
||||
cached_entry = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock and _cache is not None:
|
||||
with _cache_lock:
|
||||
cached_entry = _cache.get(session_key)
|
||||
|
||||
if cached_entry and cached_entry[0] is not None:
|
||||
if cached_entry and cached_entry[0] is not None:
|
||||
try:
|
||||
cached_entry[0].switch_model(
|
||||
new_model=result.new_model,
|
||||
new_provider=result.target_provider,
|
||||
api_key=result.api_key,
|
||||
base_url=result.base_url,
|
||||
api_mode=result.api_mode,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("In-place model switch failed for cached agent: %s", exc)
|
||||
|
||||
# Persist the new model to the session DB so the dashboard
|
||||
# shows the updated model (#34850).
|
||||
_sess_db = getattr(self, "_session_db", None)
|
||||
if _sess_db is not None:
|
||||
try:
|
||||
_sess_entry = self.session_store.get_or_create_session(source)
|
||||
_sess_db.update_session_model(
|
||||
_sess_entry.session_id, result.new_model
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Failed to persist model switch to DB: %s", exc
|
||||
)
|
||||
|
||||
# Store a note to prepend to the next user message so the model
|
||||
# knows about the switch (avoids system messages mid-history).
|
||||
if not hasattr(self, "_pending_model_notes"):
|
||||
self._pending_model_notes = {}
|
||||
self._pending_model_notes[session_key] = (
|
||||
f"[Note: model was just switched from {current_model} to {result.new_model} "
|
||||
f"via {result.provider_label or result.target_provider}. "
|
||||
f"Adjust your self-identification accordingly.]"
|
||||
)
|
||||
|
||||
# Store session override so next agent creation uses the new model
|
||||
self._session_model_overrides[session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
"api_key": result.api_key,
|
||||
"base_url": result.base_url,
|
||||
"api_mode": result.api_mode,
|
||||
}
|
||||
|
||||
# Evict cached agent so the next turn creates a fresh agent from the
|
||||
# override rather than relying on cache signature mismatch detection.
|
||||
self._evict_cached_agent(session_key)
|
||||
|
||||
# Persist to config if --global
|
||||
if persist_global:
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
else:
|
||||
cfg = {}
|
||||
# Coerce scalar/None ``model:`` into a dict before mutation —
|
||||
# otherwise ``cfg.setdefault("model", {})`` returns the existing
|
||||
# scalar and the next assignment raises
|
||||
# ``TypeError: 'str' object does not support item assignment``.
|
||||
# Reproduces when ``config.yaml`` has ``model: <name>`` (flat
|
||||
# string) instead of the proper nested ``model: {default: ...}``.
|
||||
raw_model = cfg.get("model")
|
||||
if isinstance(raw_model, dict):
|
||||
model_cfg = raw_model
|
||||
elif isinstance(raw_model, str) and raw_model.strip():
|
||||
model_cfg = {"default": raw_model.strip()}
|
||||
cfg["model"] = model_cfg
|
||||
else:
|
||||
model_cfg = {}
|
||||
cfg["model"] = model_cfg
|
||||
model_cfg["default"] = result.new_model
|
||||
model_cfg["provider"] = result.target_provider
|
||||
if result.base_url:
|
||||
model_cfg["base_url"] = result.base_url
|
||||
from hermes_cli.config import save_config
|
||||
save_config(cfg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist model switch: %s", e)
|
||||
|
||||
# Build confirmation message with full metadata
|
||||
provider_label = result.provider_label or result.target_provider
|
||||
lines = [t("gateway.model.switched", model=result.new_model)]
|
||||
lines.append(t("gateway.model.provider_label", provider=provider_label))
|
||||
|
||||
# Context: always resolve via the provider-aware chain so Codex OAuth,
|
||||
# Copilot, and Nous-enforced caps win over the raw models.dev entry.
|
||||
mi = result.model_info
|
||||
from hermes_cli.model_switch import resolve_display_context_length
|
||||
_sw2_config_ctx = None
|
||||
try:
|
||||
cached_entry[0].switch_model(
|
||||
new_model=result.new_model,
|
||||
new_provider=result.target_provider,
|
||||
api_key=result.api_key,
|
||||
base_url=result.base_url,
|
||||
api_mode=result.api_mode,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("In-place model switch failed for cached agent: %s", exc)
|
||||
_sw2_cfg = _load_gateway_config()
|
||||
_sw2_model_cfg = _sw2_cfg.get("model", {})
|
||||
if isinstance(_sw2_model_cfg, dict):
|
||||
_sw2_raw = _sw2_model_cfg.get("context_length")
|
||||
if _sw2_raw is not None:
|
||||
_sw2_config_ctx = int(_sw2_raw)
|
||||
except Exception:
|
||||
pass
|
||||
ctx = resolve_display_context_length(
|
||||
result.new_model,
|
||||
result.target_provider,
|
||||
base_url=result.base_url or current_base_url or "",
|
||||
api_key=result.api_key or current_api_key or "",
|
||||
model_info=mi,
|
||||
custom_providers=custom_provs,
|
||||
config_context_length=_sw2_config_ctx,
|
||||
)
|
||||
if ctx:
|
||||
lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}"))
|
||||
if mi:
|
||||
if mi.max_output:
|
||||
lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}"))
|
||||
if mi.has_cost_data():
|
||||
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
|
||||
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
|
||||
|
||||
# Persist the new model to the session DB so the dashboard
|
||||
# shows the updated model (#34850).
|
||||
_sess_db = getattr(self, "_session_db", None)
|
||||
if _sess_db is not None:
|
||||
try:
|
||||
_sess_entry = self.session_store.get_or_create_session(source)
|
||||
_sess_db.update_session_model(
|
||||
_sess_entry.session_id, result.new_model
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Failed to persist model switch to DB: %s", exc
|
||||
)
|
||||
# Cache notice
|
||||
cache_enabled = (
|
||||
(base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower())
|
||||
or result.api_mode == "anthropic_messages"
|
||||
)
|
||||
if cache_enabled:
|
||||
lines.append(t("gateway.model.prompt_caching_enabled"))
|
||||
|
||||
# Store a note to prepend to the next user message so the model
|
||||
# knows about the switch (avoids system messages mid-history).
|
||||
if not hasattr(self, "_pending_model_notes"):
|
||||
self._pending_model_notes = {}
|
||||
self._pending_model_notes[session_key] = (
|
||||
f"[Note: model was just switched from {current_model} to {result.new_model} "
|
||||
f"via {result.provider_label or result.target_provider}. "
|
||||
f"Adjust your self-identification accordingly.]"
|
||||
)
|
||||
if result.warning_message:
|
||||
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
|
||||
|
||||
# Store session override so next agent creation uses the new model
|
||||
self._session_model_overrides[session_key] = {
|
||||
"model": result.new_model,
|
||||
"provider": result.target_provider,
|
||||
"api_key": result.api_key,
|
||||
"base_url": result.base_url,
|
||||
"api_mode": result.api_mode,
|
||||
}
|
||||
if persist_global:
|
||||
lines.append(t("gateway.model.saved_global"))
|
||||
else:
|
||||
lines.append(t("gateway.model.session_only_hint"))
|
||||
|
||||
# Evict cached agent so the next turn creates a fresh agent from the
|
||||
# override rather than relying on cache signature mismatch detection.
|
||||
self._evict_cached_agent(session_key)
|
||||
return "\n".join(lines)
|
||||
|
||||
# Persist to config if --global
|
||||
if persist_global:
|
||||
try:
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
else:
|
||||
cfg = {}
|
||||
# Coerce scalar/None ``model:`` into a dict before mutation —
|
||||
# otherwise ``cfg.setdefault("model", {})`` returns the existing
|
||||
# scalar and the next assignment raises
|
||||
# ``TypeError: 'str' object does not support item assignment``.
|
||||
# Reproduces when ``config.yaml`` has ``model: <name>`` (flat
|
||||
# string) instead of the proper nested ``model: {default: ...}``.
|
||||
raw_model = cfg.get("model")
|
||||
if isinstance(raw_model, dict):
|
||||
model_cfg = raw_model
|
||||
elif isinstance(raw_model, str) and raw_model.strip():
|
||||
model_cfg = {"default": raw_model.strip()}
|
||||
cfg["model"] = model_cfg
|
||||
else:
|
||||
model_cfg = {}
|
||||
cfg["model"] = model_cfg
|
||||
model_cfg["default"] = result.new_model
|
||||
model_cfg["provider"] = result.target_provider
|
||||
if result.base_url:
|
||||
model_cfg["base_url"] = result.base_url
|
||||
from hermes_cli.config import save_config
|
||||
save_config(cfg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist model switch: %s", e)
|
||||
|
||||
# Build confirmation message with full metadata
|
||||
provider_label = result.provider_label or result.target_provider
|
||||
lines = [t("gateway.model.switched", model=result.new_model)]
|
||||
lines.append(t("gateway.model.provider_label", provider=provider_label))
|
||||
|
||||
# Context: always resolve via the provider-aware chain so Codex OAuth,
|
||||
# Copilot, and Nous-enforced caps win over the raw models.dev entry.
|
||||
mi = result.model_info
|
||||
from hermes_cli.model_switch import resolve_display_context_length
|
||||
_sw2_config_ctx = None
|
||||
# Expensive-model confirmation gate (typed /model <name> path).
|
||||
# The pickers (Telegram/Discord inline keyboards, TUI, dashboard)
|
||||
# already confirm via their own UI affordances; this covers the
|
||||
# direct text command, which previously bypassed the guard.
|
||||
# expensive_model_warning() may hit models.dev or a /models endpoint
|
||||
# on a cache miss, so run it off the event loop.
|
||||
_cost_warning = None
|
||||
try:
|
||||
_sw2_cfg = _load_gateway_config()
|
||||
_sw2_model_cfg = _sw2_cfg.get("model", {})
|
||||
if isinstance(_sw2_model_cfg, dict):
|
||||
_sw2_raw = _sw2_model_cfg.get("context_length")
|
||||
if _sw2_raw is not None:
|
||||
_sw2_config_ctx = int(_sw2_raw)
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
_cost_warning = await asyncio.to_thread(
|
||||
expensive_model_warning,
|
||||
result.new_model,
|
||||
provider=result.target_provider,
|
||||
base_url=result.base_url or current_base_url or "",
|
||||
api_key=result.api_key or current_api_key or "",
|
||||
model_info=result.model_info,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
ctx = resolve_display_context_length(
|
||||
result.new_model,
|
||||
result.target_provider,
|
||||
base_url=result.base_url or current_base_url or "",
|
||||
api_key=result.api_key or current_api_key or "",
|
||||
model_info=mi,
|
||||
custom_providers=custom_provs,
|
||||
config_context_length=_sw2_config_ctx,
|
||||
)
|
||||
if ctx:
|
||||
lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}"))
|
||||
if mi:
|
||||
if mi.max_output:
|
||||
lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}"))
|
||||
if mi.has_cost_data():
|
||||
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
|
||||
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
|
||||
_cost_warning = None
|
||||
if _cost_warning is not None:
|
||||
async def _on_cost_confirm(choice: str) -> str:
|
||||
if choice == "cancel":
|
||||
return (
|
||||
f"🟡 Model switch cancelled. Current model unchanged "
|
||||
f"({current_model or 'unknown'})."
|
||||
)
|
||||
# "once" and "always" both proceed — there is no persistent
|
||||
# opt-out for the cost guard (each expensive switch should be
|
||||
# an explicit decision).
|
||||
return await _finish_switch()
|
||||
|
||||
# Cache notice
|
||||
cache_enabled = (
|
||||
(base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower())
|
||||
or result.api_mode == "anthropic_messages"
|
||||
)
|
||||
if cache_enabled:
|
||||
lines.append(t("gateway.model.prompt_caching_enabled"))
|
||||
return await self._request_slash_confirm(
|
||||
event=event,
|
||||
command="model",
|
||||
title="Expensive Model Warning",
|
||||
message=(
|
||||
f"⚠️ **Expensive Model Warning**\n\n{_cost_warning.message}\n\n"
|
||||
"_Text fallback: reply `/approve` to switch or `/cancel` to keep "
|
||||
"the current model._"
|
||||
),
|
||||
handler=_on_cost_confirm,
|
||||
)
|
||||
|
||||
if result.warning_message:
|
||||
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
|
||||
|
||||
if persist_global:
|
||||
lines.append(t("gateway.model.saved_global"))
|
||||
else:
|
||||
lines.append(t("gateway.model.session_only_hint"))
|
||||
|
||||
return "\n".join(lines)
|
||||
return await _finish_switch()
|
||||
|
||||
async def _handle_codex_runtime_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /codex-runtime command in the gateway.
|
||||
|
|
|
|||
|
|
@ -2460,7 +2460,10 @@ async def set_model_assignment(body: ModelAssignment):
|
|||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
# Pricing lookup can hit models.dev / a /models endpoint on a
|
||||
# cache miss — keep it off the event loop.
|
||||
warning = await asyncio.to_thread(
|
||||
expensive_model_warning,
|
||||
model,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
|
|
|
|||
|
|
@ -5742,11 +5742,14 @@ def _define_discord_view_classes() -> None:
|
|||
cancel_btn.callback = self._on_cancel
|
||||
self.add_item(cancel_btn)
|
||||
|
||||
def _expensive_warning_for(self, model_id: str):
|
||||
async def _expensive_warning_for(self, model_id: str):
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
return expensive_model_warning(
|
||||
# Pricing lookup can hit models.dev / a /models endpoint on a
|
||||
# cache miss — keep it off the event loop.
|
||||
return await asyncio.to_thread(
|
||||
expensive_model_warning,
|
||||
model_id,
|
||||
provider=self._selected_provider,
|
||||
)
|
||||
|
|
@ -5840,7 +5843,7 @@ def _define_discord_view_classes() -> None:
|
|||
return
|
||||
|
||||
model_id = interaction.data["values"][0]
|
||||
warning = self._expensive_warning_for(model_id)
|
||||
warning = await self._expensive_warning_for(model_id)
|
||||
if warning is not None:
|
||||
self._build_expensive_confirm(model_id)
|
||||
await interaction.response.edit_message(
|
||||
|
|
|
|||
186
tests/gateway/test_model_command_expensive_confirm.py
Normal file
186
tests/gateway/test_model_command_expensive_confirm.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""Gateway typed ``/model <name>`` must route through the expensive-model
|
||||
confirmation gate.
|
||||
|
||||
The pickers (Telegram/Discord inline keyboards, TUI, dashboard) confirm
|
||||
expensive models via their own UI affordances; the typed text command
|
||||
previously bypassed the guard entirely — a user typing
|
||||
``/model openai/gpt-5.5-pro`` switched silently while the picker warned.
|
||||
These tests pin the typed path:
|
||||
|
||||
- warning fires → handler returns the slash-confirm prompt, switch NOT applied
|
||||
- confirm ("once") → switch applies (session override set)
|
||||
- cancel → switch not applied, current model unchanged
|
||||
- no warning (cheap model) → switch applies immediately, no prompt
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.platforms.base import MessageEvent, MessageType
|
||||
from gateway.run import GatewayRunner
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner._voice_mode = {}
|
||||
runner._session_model_overrides = {}
|
||||
runner._running_agents = {}
|
||||
return runner
|
||||
|
||||
|
||||
def _make_event(text):
|
||||
return MessageEvent(
|
||||
text=text,
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"),
|
||||
)
|
||||
|
||||
|
||||
def _fake_switch_result():
|
||||
from hermes_cli.model_switch import ModelSwitchResult
|
||||
|
||||
return ModelSwitchResult(
|
||||
success=True,
|
||||
new_model="openai/gpt-5.5-pro",
|
||||
target_provider="openrouter",
|
||||
provider_changed=False,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="chat_completions",
|
||||
provider_label="OpenRouter",
|
||||
)
|
||||
|
||||
|
||||
def _fake_warning():
|
||||
return SimpleNamespace(
|
||||
message=(
|
||||
"!!! EXPENSIVE MODEL WARNING !!!\n"
|
||||
"openai/gpt-5.5-pro has known pricing above Hermes' safety threshold.\n"
|
||||
"did you mean to select openai/gpt-5.5?"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _setup_isolated_home(tmp_path, monkeypatch, *, warn):
|
||||
import gateway.run as gateway_run
|
||||
|
||||
hermes_home = tmp_path / ".hermes"
|
||||
hermes_home.mkdir()
|
||||
cfg_path = hermes_home / "config.yaml"
|
||||
cfg_path.write_text(
|
||||
yaml.safe_dump({"model": {"default": "old-model", "provider": "openrouter"}, "providers": {}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
|
||||
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_switch.switch_model",
|
||||
lambda **kw: _fake_switch_result(),
|
||||
)
|
||||
monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setattr("hermes_cli.config.get_hermes_home", lambda: hermes_home)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
(lambda *a, **kw: _fake_warning()) if warn else (lambda *a, **kw: None),
|
||||
)
|
||||
return cfg_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typed_model_expensive_prompts_instead_of_switching(tmp_path, monkeypatch):
|
||||
"""Expensive model typed directly → confirm prompt, no switch applied."""
|
||||
_setup_isolated_home(tmp_path, monkeypatch, warn=True)
|
||||
runner = _make_runner()
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _fake_request_slash_confirm(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return kwargs["message"]
|
||||
|
||||
runner._request_slash_confirm = _fake_request_slash_confirm
|
||||
|
||||
result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
|
||||
|
||||
assert result is not None
|
||||
assert "EXPENSIVE MODEL WARNING" in result
|
||||
# The switch must NOT have been applied yet.
|
||||
assert runner._session_model_overrides == {}
|
||||
assert captured["command"] == "model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typed_model_expensive_confirm_once_applies_switch(tmp_path, monkeypatch):
|
||||
"""Resolving the confirm with "once" applies the switch."""
|
||||
_setup_isolated_home(tmp_path, monkeypatch, warn=True)
|
||||
runner = _make_runner()
|
||||
runner._evict_cached_agent = lambda session_key: None
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _fake_request_slash_confirm(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return None # buttons rendered
|
||||
|
||||
runner._request_slash_confirm = _fake_request_slash_confirm
|
||||
|
||||
await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
|
||||
assert runner._session_model_overrides == {}
|
||||
|
||||
reply = await captured["handler"]("once")
|
||||
|
||||
assert "gpt-5.5-pro" in reply
|
||||
overrides = list(runner._session_model_overrides.values())
|
||||
assert len(overrides) == 1
|
||||
assert overrides[0]["model"] == "openai/gpt-5.5-pro"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typed_model_expensive_cancel_keeps_current_model(tmp_path, monkeypatch):
|
||||
"""Resolving the confirm with "cancel" leaves everything unchanged."""
|
||||
cfg_path = _setup_isolated_home(tmp_path, monkeypatch, warn=True)
|
||||
runner = _make_runner()
|
||||
|
||||
captured = {}
|
||||
|
||||
async def _fake_request_slash_confirm(**kwargs):
|
||||
captured.update(kwargs)
|
||||
return None
|
||||
|
||||
runner._request_slash_confirm = _fake_request_slash_confirm
|
||||
|
||||
await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro --global"))
|
||||
|
||||
reply = await captured["handler"]("cancel")
|
||||
|
||||
assert "cancelled" in reply.lower()
|
||||
assert runner._session_model_overrides == {}
|
||||
# --global must not have persisted the cancelled switch.
|
||||
written = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
|
||||
assert written["model"]["default"] == "old-model"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_typed_model_cheap_switches_without_prompt(tmp_path, monkeypatch):
|
||||
"""No warning → switch applies immediately; confirm primitive never invoked."""
|
||||
_setup_isolated_home(tmp_path, monkeypatch, warn=False)
|
||||
runner = _make_runner()
|
||||
runner._evict_cached_agent = lambda session_key: None
|
||||
|
||||
async def _fail_request_slash_confirm(**kwargs): # pragma: no cover
|
||||
raise AssertionError("confirm should not be requested for cheap models")
|
||||
|
||||
runner._request_slash_confirm = _fail_request_slash_confirm
|
||||
|
||||
result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
|
||||
|
||||
assert result is not None
|
||||
assert "gpt-5.5-pro" in result
|
||||
overrides = list(runner._session_model_overrides.values())
|
||||
assert len(overrides) == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue