fix(model): cover typed gateway /model path + async-safe pricing lookups

Follow-ups on top of #26016's expensive-model guard:

- gateway/slash_commands.py: typed '/model <name>' now routes through the
  expensive-model confirmation gate (slash-confirm buttons / text fallback)
  instead of bypassing the guard the pickers enforce. Cancel leaves the
  session override and --global config untouched.
- telegram/discord/web_server: run expensive_model_warning() via
  asyncio.to_thread — it can hit models.dev or a /models endpoint on a
  cache miss, which would otherwise block the event loop.
- telegram: picker callback no longer toasts 'Model switched!' when the
  switch callback raised (both mm: and mc: paths).
- tests: new tests/gateway/test_model_command_expensive_confirm.py pins
  the typed-path gate (prompt, confirm-once, cancel, cheap-model no-op).
This commit is contained in:
Teknium 2026-06-10 00:08:53 -07:00
parent af978ecb17
commit 243cada157
5 changed files with 390 additions and 139 deletions

View file

@ -3136,11 +3136,13 @@ class TelegramAdapter(BasePlatformAdapter):
await query.answer(text="Picker expired.")
return
switch_failed = False
try:
result_text = await callback(chat_id, model_id, provider_slug)
except Exception as exc:
logger.error("Model picker switch failed: %s", exc)
result_text = f"Error switching model: {exc}"
switch_failed = True
try:
await query.edit_message_text(
@ -3157,7 +3159,9 @@ class TelegramAdapter(BasePlatformAdapter):
)
except Exception:
pass
await query.answer(text="Model switched!")
await query.answer(
text="Switch failed." if switch_failed else "Model switched!"
)
self._model_picker_state.pop(chat_id, None)
elif data.startswith("mm:"):
@ -3184,7 +3188,10 @@ class TelegramAdapter(BasePlatformAdapter):
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
# Pricing lookup can hit models.dev / a /models endpoint on a
# cache miss — keep it off the event loop.
warning = await asyncio.to_thread(
expensive_model_warning,
model_id,
provider=provider_slug,
)
@ -3208,11 +3215,13 @@ class TelegramAdapter(BasePlatformAdapter):
await query.answer(text="Confirm expensive model")
return
switch_failed = False
try:
result_text = await callback(chat_id, model_id, provider_slug)
except Exception as exc:
logger.error("Model picker switch failed: %s", exc)
result_text = f"Error switching model: {exc}"
switch_failed = True
# Edit message to show confirmation, remove buttons
try:
@ -3231,7 +3240,9 @@ class TelegramAdapter(BasePlatformAdapter):
)
except Exception:
pass
await query.answer(text="Model switched!")
await query.answer(
text="Switch failed." if switch_failed else "Model switched!"
)
# Clean up state
self._model_picker_state.pop(chat_id, None)

View file

@ -1146,149 +1146,197 @@ class GatewaySlashCommandsMixin:
if not result.success:
return t("gateway.model.error_prefix", error=result.error_message)
# If there's a cached agent, update it in-place
cached_entry = None
_cache_lock = getattr(self, "_agent_cache_lock", None)
_cache = getattr(self, "_agent_cache", None)
if _cache_lock and _cache is not None:
with _cache_lock:
cached_entry = _cache.get(session_key)
async def _finish_switch() -> str:
"""Apply the resolved switch (agent, session, config) and build the reply."""
# If there's a cached agent, update it in-place
cached_entry = None
_cache_lock = getattr(self, "_agent_cache_lock", None)
_cache = getattr(self, "_agent_cache", None)
if _cache_lock and _cache is not None:
with _cache_lock:
cached_entry = _cache.get(session_key)
if cached_entry and cached_entry[0] is not None:
if cached_entry and cached_entry[0] is not None:
try:
cached_entry[0].switch_model(
new_model=result.new_model,
new_provider=result.target_provider,
api_key=result.api_key,
base_url=result.base_url,
api_mode=result.api_mode,
)
except Exception as exc:
logger.warning("In-place model switch failed for cached agent: %s", exc)
# Persist the new model to the session DB so the dashboard
# shows the updated model (#34850).
_sess_db = getattr(self, "_session_db", None)
if _sess_db is not None:
try:
_sess_entry = self.session_store.get_or_create_session(source)
_sess_db.update_session_model(
_sess_entry.session_id, result.new_model
)
except Exception as exc:
logger.debug(
"Failed to persist model switch to DB: %s", exc
)
# Store a note to prepend to the next user message so the model
# knows about the switch (avoids system messages mid-history).
if not hasattr(self, "_pending_model_notes"):
self._pending_model_notes = {}
self._pending_model_notes[session_key] = (
f"[Note: model was just switched from {current_model} to {result.new_model} "
f"via {result.provider_label or result.target_provider}. "
f"Adjust your self-identification accordingly.]"
)
# Store session override so next agent creation uses the new model
self._session_model_overrides[session_key] = {
"model": result.new_model,
"provider": result.target_provider,
"api_key": result.api_key,
"base_url": result.base_url,
"api_mode": result.api_mode,
}
# Evict cached agent so the next turn creates a fresh agent from the
# override rather than relying on cache signature mismatch detection.
self._evict_cached_agent(session_key)
# Persist to config if --global
if persist_global:
try:
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
else:
cfg = {}
# Coerce scalar/None ``model:`` into a dict before mutation —
# otherwise ``cfg.setdefault("model", {})`` returns the existing
# scalar and the next assignment raises
# ``TypeError: 'str' object does not support item assignment``.
# Reproduces when ``config.yaml`` has ``model: <name>`` (flat
# string) instead of the proper nested ``model: {default: ...}``.
raw_model = cfg.get("model")
if isinstance(raw_model, dict):
model_cfg = raw_model
elif isinstance(raw_model, str) and raw_model.strip():
model_cfg = {"default": raw_model.strip()}
cfg["model"] = model_cfg
else:
model_cfg = {}
cfg["model"] = model_cfg
model_cfg["default"] = result.new_model
model_cfg["provider"] = result.target_provider
if result.base_url:
model_cfg["base_url"] = result.base_url
from hermes_cli.config import save_config
save_config(cfg)
except Exception as e:
logger.warning("Failed to persist model switch: %s", e)
# Build confirmation message with full metadata
provider_label = result.provider_label or result.target_provider
lines = [t("gateway.model.switched", model=result.new_model)]
lines.append(t("gateway.model.provider_label", provider=provider_label))
# Context: always resolve via the provider-aware chain so Codex OAuth,
# Copilot, and Nous-enforced caps win over the raw models.dev entry.
mi = result.model_info
from hermes_cli.model_switch import resolve_display_context_length
_sw2_config_ctx = None
try:
cached_entry[0].switch_model(
new_model=result.new_model,
new_provider=result.target_provider,
api_key=result.api_key,
base_url=result.base_url,
api_mode=result.api_mode,
)
except Exception as exc:
logger.warning("In-place model switch failed for cached agent: %s", exc)
_sw2_cfg = _load_gateway_config()
_sw2_model_cfg = _sw2_cfg.get("model", {})
if isinstance(_sw2_model_cfg, dict):
_sw2_raw = _sw2_model_cfg.get("context_length")
if _sw2_raw is not None:
_sw2_config_ctx = int(_sw2_raw)
except Exception:
pass
ctx = resolve_display_context_length(
result.new_model,
result.target_provider,
base_url=result.base_url or current_base_url or "",
api_key=result.api_key or current_api_key or "",
model_info=mi,
custom_providers=custom_provs,
config_context_length=_sw2_config_ctx,
)
if ctx:
lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}"))
if mi:
if mi.max_output:
lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}"))
if mi.has_cost_data():
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
# Persist the new model to the session DB so the dashboard
# shows the updated model (#34850).
_sess_db = getattr(self, "_session_db", None)
if _sess_db is not None:
try:
_sess_entry = self.session_store.get_or_create_session(source)
_sess_db.update_session_model(
_sess_entry.session_id, result.new_model
)
except Exception as exc:
logger.debug(
"Failed to persist model switch to DB: %s", exc
)
# Cache notice
cache_enabled = (
(base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower())
or result.api_mode == "anthropic_messages"
)
if cache_enabled:
lines.append(t("gateway.model.prompt_caching_enabled"))
# Store a note to prepend to the next user message so the model
# knows about the switch (avoids system messages mid-history).
if not hasattr(self, "_pending_model_notes"):
self._pending_model_notes = {}
self._pending_model_notes[session_key] = (
f"[Note: model was just switched from {current_model} to {result.new_model} "
f"via {result.provider_label or result.target_provider}. "
f"Adjust your self-identification accordingly.]"
)
if result.warning_message:
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
# Store session override so next agent creation uses the new model
self._session_model_overrides[session_key] = {
"model": result.new_model,
"provider": result.target_provider,
"api_key": result.api_key,
"base_url": result.base_url,
"api_mode": result.api_mode,
}
if persist_global:
lines.append(t("gateway.model.saved_global"))
else:
lines.append(t("gateway.model.session_only_hint"))
# Evict cached agent so the next turn creates a fresh agent from the
# override rather than relying on cache signature mismatch detection.
self._evict_cached_agent(session_key)
return "\n".join(lines)
# Persist to config if --global
if persist_global:
try:
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
cfg = yaml.safe_load(f) or {}
else:
cfg = {}
# Coerce scalar/None ``model:`` into a dict before mutation —
# otherwise ``cfg.setdefault("model", {})`` returns the existing
# scalar and the next assignment raises
# ``TypeError: 'str' object does not support item assignment``.
# Reproduces when ``config.yaml`` has ``model: <name>`` (flat
# string) instead of the proper nested ``model: {default: ...}``.
raw_model = cfg.get("model")
if isinstance(raw_model, dict):
model_cfg = raw_model
elif isinstance(raw_model, str) and raw_model.strip():
model_cfg = {"default": raw_model.strip()}
cfg["model"] = model_cfg
else:
model_cfg = {}
cfg["model"] = model_cfg
model_cfg["default"] = result.new_model
model_cfg["provider"] = result.target_provider
if result.base_url:
model_cfg["base_url"] = result.base_url
from hermes_cli.config import save_config
save_config(cfg)
except Exception as e:
logger.warning("Failed to persist model switch: %s", e)
# Build confirmation message with full metadata
provider_label = result.provider_label or result.target_provider
lines = [t("gateway.model.switched", model=result.new_model)]
lines.append(t("gateway.model.provider_label", provider=provider_label))
# Context: always resolve via the provider-aware chain so Codex OAuth,
# Copilot, and Nous-enforced caps win over the raw models.dev entry.
mi = result.model_info
from hermes_cli.model_switch import resolve_display_context_length
_sw2_config_ctx = None
# Expensive-model confirmation gate (typed /model <name> path).
# The pickers (Telegram/Discord inline keyboards, TUI, dashboard)
# already confirm via their own UI affordances; this covers the
# direct text command, which previously bypassed the guard.
# expensive_model_warning() may hit models.dev or a /models endpoint
# on a cache miss, so run it off the event loop.
_cost_warning = None
try:
_sw2_cfg = _load_gateway_config()
_sw2_model_cfg = _sw2_cfg.get("model", {})
if isinstance(_sw2_model_cfg, dict):
_sw2_raw = _sw2_model_cfg.get("context_length")
if _sw2_raw is not None:
_sw2_config_ctx = int(_sw2_raw)
from hermes_cli.model_cost_guard import expensive_model_warning
_cost_warning = await asyncio.to_thread(
expensive_model_warning,
result.new_model,
provider=result.target_provider,
base_url=result.base_url or current_base_url or "",
api_key=result.api_key or current_api_key or "",
model_info=result.model_info,
)
except Exception:
pass
ctx = resolve_display_context_length(
result.new_model,
result.target_provider,
base_url=result.base_url or current_base_url or "",
api_key=result.api_key or current_api_key or "",
model_info=mi,
custom_providers=custom_provs,
config_context_length=_sw2_config_ctx,
)
if ctx:
lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}"))
if mi:
if mi.max_output:
lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}"))
if mi.has_cost_data():
lines.append(t("gateway.model.cost_label", cost=mi.format_cost()))
lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities()))
_cost_warning = None
if _cost_warning is not None:
async def _on_cost_confirm(choice: str) -> str:
if choice == "cancel":
return (
f"🟡 Model switch cancelled. Current model unchanged "
f"({current_model or 'unknown'})."
)
# "once" and "always" both proceed — there is no persistent
# opt-out for the cost guard (each expensive switch should be
# an explicit decision).
return await _finish_switch()
# Cache notice
cache_enabled = (
(base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower())
or result.api_mode == "anthropic_messages"
)
if cache_enabled:
lines.append(t("gateway.model.prompt_caching_enabled"))
return await self._request_slash_confirm(
event=event,
command="model",
title="Expensive Model Warning",
message=(
f"⚠️ **Expensive Model Warning**\n\n{_cost_warning.message}\n\n"
"_Text fallback: reply `/approve` to switch or `/cancel` to keep "
"the current model._"
),
handler=_on_cost_confirm,
)
if result.warning_message:
lines.append(t("gateway.model.warning_prefix", warning=result.warning_message))
if persist_global:
lines.append(t("gateway.model.saved_global"))
else:
lines.append(t("gateway.model.session_only_hint"))
return "\n".join(lines)
return await _finish_switch()
async def _handle_codex_runtime_command(self, event: MessageEvent) -> str:
"""Handle /codex-runtime command in the gateway.

View file

@ -2460,7 +2460,10 @@ async def set_model_assignment(body: ModelAssignment):
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
# Pricing lookup can hit models.dev / a /models endpoint on a
# cache miss — keep it off the event loop.
warning = await asyncio.to_thread(
expensive_model_warning,
model,
provider=provider,
base_url=base_url,

View file

@ -5742,11 +5742,14 @@ def _define_discord_view_classes() -> None:
cancel_btn.callback = self._on_cancel
self.add_item(cancel_btn)
def _expensive_warning_for(self, model_id: str):
async def _expensive_warning_for(self, model_id: str):
try:
from hermes_cli.model_cost_guard import expensive_model_warning
return expensive_model_warning(
# Pricing lookup can hit models.dev / a /models endpoint on a
# cache miss — keep it off the event loop.
return await asyncio.to_thread(
expensive_model_warning,
model_id,
provider=self._selected_provider,
)
@ -5840,7 +5843,7 @@ def _define_discord_view_classes() -> None:
return
model_id = interaction.data["values"][0]
warning = self._expensive_warning_for(model_id)
warning = await self._expensive_warning_for(model_id)
if warning is not None:
self._build_expensive_confirm(model_id)
await interaction.response.edit_message(

View file

@ -0,0 +1,186 @@
"""Gateway typed ``/model <name>`` must route through the expensive-model
confirmation gate.
The pickers (Telegram/Discord inline keyboards, TUI, dashboard) confirm
expensive models via their own UI affordances; the typed text command
previously bypassed the guard entirely a user typing
``/model openai/gpt-5.5-pro`` switched silently while the picker warned.
These tests pin the typed path:
- warning fires handler returns the slash-confirm prompt, switch NOT applied
- confirm ("once") switch applies (session override set)
- cancel switch not applied, current model unchanged
- no warning (cheap model) switch applies immediately, no prompt
"""
from types import SimpleNamespace
import pytest
import yaml
from gateway.config import Platform
from gateway.platforms.base import MessageEvent, MessageType
from gateway.run import GatewayRunner
from gateway.session import SessionSource
def _make_runner():
runner = object.__new__(GatewayRunner)
runner.adapters = {}
runner._voice_mode = {}
runner._session_model_overrides = {}
runner._running_agents = {}
return runner
def _make_event(text):
return MessageEvent(
text=text,
message_type=MessageType.TEXT,
source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"),
)
def _fake_switch_result():
from hermes_cli.model_switch import ModelSwitchResult
return ModelSwitchResult(
success=True,
new_model="openai/gpt-5.5-pro",
target_provider="openrouter",
provider_changed=False,
api_key="sk-test",
base_url="https://openrouter.ai/api/v1",
api_mode="chat_completions",
provider_label="OpenRouter",
)
def _fake_warning():
return SimpleNamespace(
message=(
"!!! EXPENSIVE MODEL WARNING !!!\n"
"openai/gpt-5.5-pro has known pricing above Hermes' safety threshold.\n"
"did you mean to select openai/gpt-5.5?"
),
)
def _setup_isolated_home(tmp_path, monkeypatch, *, warn):
import gateway.run as gateway_run
hermes_home = tmp_path / ".hermes"
hermes_home.mkdir()
cfg_path = hermes_home / "config.yaml"
cfg_path.write_text(
yaml.safe_dump({"model": {"default": "old-model", "provider": "openrouter"}, "providers": {}}),
encoding="utf-8",
)
monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home)
monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {})
monkeypatch.setattr(
"hermes_cli.model_switch.switch_model",
lambda **kw: _fake_switch_result(),
)
monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: hermes_home)
monkeypatch.setattr("hermes_cli.config.get_hermes_home", lambda: hermes_home)
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
(lambda *a, **kw: _fake_warning()) if warn else (lambda *a, **kw: None),
)
return cfg_path
@pytest.mark.asyncio
async def test_typed_model_expensive_prompts_instead_of_switching(tmp_path, monkeypatch):
"""Expensive model typed directly → confirm prompt, no switch applied."""
_setup_isolated_home(tmp_path, monkeypatch, warn=True)
runner = _make_runner()
captured = {}
async def _fake_request_slash_confirm(**kwargs):
captured.update(kwargs)
return kwargs["message"]
runner._request_slash_confirm = _fake_request_slash_confirm
result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
assert result is not None
assert "EXPENSIVE MODEL WARNING" in result
# The switch must NOT have been applied yet.
assert runner._session_model_overrides == {}
assert captured["command"] == "model"
@pytest.mark.asyncio
async def test_typed_model_expensive_confirm_once_applies_switch(tmp_path, monkeypatch):
"""Resolving the confirm with "once" applies the switch."""
_setup_isolated_home(tmp_path, monkeypatch, warn=True)
runner = _make_runner()
runner._evict_cached_agent = lambda session_key: None
captured = {}
async def _fake_request_slash_confirm(**kwargs):
captured.update(kwargs)
return None # buttons rendered
runner._request_slash_confirm = _fake_request_slash_confirm
await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
assert runner._session_model_overrides == {}
reply = await captured["handler"]("once")
assert "gpt-5.5-pro" in reply
overrides = list(runner._session_model_overrides.values())
assert len(overrides) == 1
assert overrides[0]["model"] == "openai/gpt-5.5-pro"
@pytest.mark.asyncio
async def test_typed_model_expensive_cancel_keeps_current_model(tmp_path, monkeypatch):
"""Resolving the confirm with "cancel" leaves everything unchanged."""
cfg_path = _setup_isolated_home(tmp_path, monkeypatch, warn=True)
runner = _make_runner()
captured = {}
async def _fake_request_slash_confirm(**kwargs):
captured.update(kwargs)
return None
runner._request_slash_confirm = _fake_request_slash_confirm
await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro --global"))
reply = await captured["handler"]("cancel")
assert "cancelled" in reply.lower()
assert runner._session_model_overrides == {}
# --global must not have persisted the cancelled switch.
written = yaml.safe_load(cfg_path.read_text(encoding="utf-8"))
assert written["model"]["default"] == "old-model"
@pytest.mark.asyncio
async def test_typed_model_cheap_switches_without_prompt(tmp_path, monkeypatch):
"""No warning → switch applies immediately; confirm primitive never invoked."""
_setup_isolated_home(tmp_path, monkeypatch, warn=False)
runner = _make_runner()
runner._evict_cached_agent = lambda session_key: None
async def _fail_request_slash_confirm(**kwargs): # pragma: no cover
raise AssertionError("confirm should not be requested for cheap models")
runner._request_slash_confirm = _fail_request_slash_confirm
result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro"))
assert result is not None
assert "gpt-5.5-pro" in result
overrides = list(runner._session_model_overrides.values())
assert len(overrides) == 1