diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 7aec7f99a0a..fa896db9d3a 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -3136,11 +3136,13 @@ class TelegramAdapter(BasePlatformAdapter): await query.answer(text="Picker expired.") return + switch_failed = False try: result_text = await callback(chat_id, model_id, provider_slug) except Exception as exc: logger.error("Model picker switch failed: %s", exc) result_text = f"Error switching model: {exc}" + switch_failed = True try: await query.edit_message_text( @@ -3157,7 +3159,9 @@ class TelegramAdapter(BasePlatformAdapter): ) except Exception: pass - await query.answer(text="Model switched!") + await query.answer( + text="Switch failed." if switch_failed else "Model switched!" + ) self._model_picker_state.pop(chat_id, None) elif data.startswith("mm:"): @@ -3184,7 +3188,10 @@ class TelegramAdapter(BasePlatformAdapter): try: from hermes_cli.model_cost_guard import expensive_model_warning - warning = expensive_model_warning( + # Pricing lookup can hit models.dev / a /models endpoint on a + # cache miss — keep it off the event loop. + warning = await asyncio.to_thread( + expensive_model_warning, model_id, provider=provider_slug, ) @@ -3208,11 +3215,13 @@ class TelegramAdapter(BasePlatformAdapter): await query.answer(text="Confirm expensive model") return + switch_failed = False try: result_text = await callback(chat_id, model_id, provider_slug) except Exception as exc: logger.error("Model picker switch failed: %s", exc) result_text = f"Error switching model: {exc}" + switch_failed = True # Edit message to show confirmation, remove buttons try: @@ -3231,7 +3240,9 @@ class TelegramAdapter(BasePlatformAdapter): ) except Exception: pass - await query.answer(text="Model switched!") + await query.answer( + text="Switch failed." if switch_failed else "Model switched!" + ) # Clean up state self._model_picker_state.pop(chat_id, None) diff --git a/gateway/slash_commands.py b/gateway/slash_commands.py index 3c7436498e1..ac210d367de 100644 --- a/gateway/slash_commands.py +++ b/gateway/slash_commands.py @@ -1146,149 +1146,197 @@ class GatewaySlashCommandsMixin: if not result.success: return t("gateway.model.error_prefix", error=result.error_message) - # If there's a cached agent, update it in-place - cached_entry = None - _cache_lock = getattr(self, "_agent_cache_lock", None) - _cache = getattr(self, "_agent_cache", None) - if _cache_lock and _cache is not None: - with _cache_lock: - cached_entry = _cache.get(session_key) + async def _finish_switch() -> str: + """Apply the resolved switch (agent, session, config) and build the reply.""" + # If there's a cached agent, update it in-place + cached_entry = None + _cache_lock = getattr(self, "_agent_cache_lock", None) + _cache = getattr(self, "_agent_cache", None) + if _cache_lock and _cache is not None: + with _cache_lock: + cached_entry = _cache.get(session_key) - if cached_entry and cached_entry[0] is not None: + if cached_entry and cached_entry[0] is not None: + try: + cached_entry[0].switch_model( + new_model=result.new_model, + new_provider=result.target_provider, + api_key=result.api_key, + base_url=result.base_url, + api_mode=result.api_mode, + ) + except Exception as exc: + logger.warning("In-place model switch failed for cached agent: %s", exc) + + # Persist the new model to the session DB so the dashboard + # shows the updated model (#34850). + _sess_db = getattr(self, "_session_db", None) + if _sess_db is not None: + try: + _sess_entry = self.session_store.get_or_create_session(source) + _sess_db.update_session_model( + _sess_entry.session_id, result.new_model + ) + except Exception as exc: + logger.debug( + "Failed to persist model switch to DB: %s", exc + ) + + # Store a note to prepend to the next user message so the model + # knows about the switch (avoids system messages mid-history). + if not hasattr(self, "_pending_model_notes"): + self._pending_model_notes = {} + self._pending_model_notes[session_key] = ( + f"[Note: model was just switched from {current_model} to {result.new_model} " + f"via {result.provider_label or result.target_provider}. " + f"Adjust your self-identification accordingly.]" + ) + + # Store session override so next agent creation uses the new model + self._session_model_overrides[session_key] = { + "model": result.new_model, + "provider": result.target_provider, + "api_key": result.api_key, + "base_url": result.base_url, + "api_mode": result.api_mode, + } + + # Evict cached agent so the next turn creates a fresh agent from the + # override rather than relying on cache signature mismatch detection. + self._evict_cached_agent(session_key) + + # Persist to config if --global + if persist_global: + try: + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + else: + cfg = {} + # Coerce scalar/None ``model:`` into a dict before mutation — + # otherwise ``cfg.setdefault("model", {})`` returns the existing + # scalar and the next assignment raises + # ``TypeError: 'str' object does not support item assignment``. + # Reproduces when ``config.yaml`` has ``model: `` (flat + # string) instead of the proper nested ``model: {default: ...}``. + raw_model = cfg.get("model") + if isinstance(raw_model, dict): + model_cfg = raw_model + elif isinstance(raw_model, str) and raw_model.strip(): + model_cfg = {"default": raw_model.strip()} + cfg["model"] = model_cfg + else: + model_cfg = {} + cfg["model"] = model_cfg + model_cfg["default"] = result.new_model + model_cfg["provider"] = result.target_provider + if result.base_url: + model_cfg["base_url"] = result.base_url + from hermes_cli.config import save_config + save_config(cfg) + except Exception as e: + logger.warning("Failed to persist model switch: %s", e) + + # Build confirmation message with full metadata + provider_label = result.provider_label or result.target_provider + lines = [t("gateway.model.switched", model=result.new_model)] + lines.append(t("gateway.model.provider_label", provider=provider_label)) + + # Context: always resolve via the provider-aware chain so Codex OAuth, + # Copilot, and Nous-enforced caps win over the raw models.dev entry. + mi = result.model_info + from hermes_cli.model_switch import resolve_display_context_length + _sw2_config_ctx = None try: - cached_entry[0].switch_model( - new_model=result.new_model, - new_provider=result.target_provider, - api_key=result.api_key, - base_url=result.base_url, - api_mode=result.api_mode, - ) - except Exception as exc: - logger.warning("In-place model switch failed for cached agent: %s", exc) + _sw2_cfg = _load_gateway_config() + _sw2_model_cfg = _sw2_cfg.get("model", {}) + if isinstance(_sw2_model_cfg, dict): + _sw2_raw = _sw2_model_cfg.get("context_length") + if _sw2_raw is not None: + _sw2_config_ctx = int(_sw2_raw) + except Exception: + pass + ctx = resolve_display_context_length( + result.new_model, + result.target_provider, + base_url=result.base_url or current_base_url or "", + api_key=result.api_key or current_api_key or "", + model_info=mi, + custom_providers=custom_provs, + config_context_length=_sw2_config_ctx, + ) + if ctx: + lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}")) + if mi: + if mi.max_output: + lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}")) + if mi.has_cost_data(): + lines.append(t("gateway.model.cost_label", cost=mi.format_cost())) + lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities())) - # Persist the new model to the session DB so the dashboard - # shows the updated model (#34850). - _sess_db = getattr(self, "_session_db", None) - if _sess_db is not None: - try: - _sess_entry = self.session_store.get_or_create_session(source) - _sess_db.update_session_model( - _sess_entry.session_id, result.new_model - ) - except Exception as exc: - logger.debug( - "Failed to persist model switch to DB: %s", exc - ) + # Cache notice + cache_enabled = ( + (base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower()) + or result.api_mode == "anthropic_messages" + ) + if cache_enabled: + lines.append(t("gateway.model.prompt_caching_enabled")) - # Store a note to prepend to the next user message so the model - # knows about the switch (avoids system messages mid-history). - if not hasattr(self, "_pending_model_notes"): - self._pending_model_notes = {} - self._pending_model_notes[session_key] = ( - f"[Note: model was just switched from {current_model} to {result.new_model} " - f"via {result.provider_label or result.target_provider}. " - f"Adjust your self-identification accordingly.]" - ) + if result.warning_message: + lines.append(t("gateway.model.warning_prefix", warning=result.warning_message)) - # Store session override so next agent creation uses the new model - self._session_model_overrides[session_key] = { - "model": result.new_model, - "provider": result.target_provider, - "api_key": result.api_key, - "base_url": result.base_url, - "api_mode": result.api_mode, - } + if persist_global: + lines.append(t("gateway.model.saved_global")) + else: + lines.append(t("gateway.model.session_only_hint")) - # Evict cached agent so the next turn creates a fresh agent from the - # override rather than relying on cache signature mismatch detection. - self._evict_cached_agent(session_key) + return "\n".join(lines) - # Persist to config if --global - if persist_global: - try: - if config_path.exists(): - with open(config_path, encoding="utf-8") as f: - cfg = yaml.safe_load(f) or {} - else: - cfg = {} - # Coerce scalar/None ``model:`` into a dict before mutation — - # otherwise ``cfg.setdefault("model", {})`` returns the existing - # scalar and the next assignment raises - # ``TypeError: 'str' object does not support item assignment``. - # Reproduces when ``config.yaml`` has ``model: `` (flat - # string) instead of the proper nested ``model: {default: ...}``. - raw_model = cfg.get("model") - if isinstance(raw_model, dict): - model_cfg = raw_model - elif isinstance(raw_model, str) and raw_model.strip(): - model_cfg = {"default": raw_model.strip()} - cfg["model"] = model_cfg - else: - model_cfg = {} - cfg["model"] = model_cfg - model_cfg["default"] = result.new_model - model_cfg["provider"] = result.target_provider - if result.base_url: - model_cfg["base_url"] = result.base_url - from hermes_cli.config import save_config - save_config(cfg) - except Exception as e: - logger.warning("Failed to persist model switch: %s", e) - - # Build confirmation message with full metadata - provider_label = result.provider_label or result.target_provider - lines = [t("gateway.model.switched", model=result.new_model)] - lines.append(t("gateway.model.provider_label", provider=provider_label)) - - # Context: always resolve via the provider-aware chain so Codex OAuth, - # Copilot, and Nous-enforced caps win over the raw models.dev entry. - mi = result.model_info - from hermes_cli.model_switch import resolve_display_context_length - _sw2_config_ctx = None + # Expensive-model confirmation gate (typed /model path). + # The pickers (Telegram/Discord inline keyboards, TUI, dashboard) + # already confirm via their own UI affordances; this covers the + # direct text command, which previously bypassed the guard. + # expensive_model_warning() may hit models.dev or a /models endpoint + # on a cache miss, so run it off the event loop. + _cost_warning = None try: - _sw2_cfg = _load_gateway_config() - _sw2_model_cfg = _sw2_cfg.get("model", {}) - if isinstance(_sw2_model_cfg, dict): - _sw2_raw = _sw2_model_cfg.get("context_length") - if _sw2_raw is not None: - _sw2_config_ctx = int(_sw2_raw) + from hermes_cli.model_cost_guard import expensive_model_warning + + _cost_warning = await asyncio.to_thread( + expensive_model_warning, + result.new_model, + provider=result.target_provider, + base_url=result.base_url or current_base_url or "", + api_key=result.api_key or current_api_key or "", + model_info=result.model_info, + ) except Exception: - pass - ctx = resolve_display_context_length( - result.new_model, - result.target_provider, - base_url=result.base_url or current_base_url or "", - api_key=result.api_key or current_api_key or "", - model_info=mi, - custom_providers=custom_provs, - config_context_length=_sw2_config_ctx, - ) - if ctx: - lines.append(t("gateway.model.context_label", tokens=f"{ctx:,}")) - if mi: - if mi.max_output: - lines.append(t("gateway.model.max_output_label", tokens=f"{mi.max_output:,}")) - if mi.has_cost_data(): - lines.append(t("gateway.model.cost_label", cost=mi.format_cost())) - lines.append(t("gateway.model.capabilities_label", capabilities=mi.format_capabilities())) + _cost_warning = None + if _cost_warning is not None: + async def _on_cost_confirm(choice: str) -> str: + if choice == "cancel": + return ( + f"🟡 Model switch cancelled. Current model unchanged " + f"({current_model or 'unknown'})." + ) + # "once" and "always" both proceed — there is no persistent + # opt-out for the cost guard (each expensive switch should be + # an explicit decision). + return await _finish_switch() - # Cache notice - cache_enabled = ( - (base_url_host_matches(result.base_url or "", "openrouter.ai") and "claude" in result.new_model.lower()) - or result.api_mode == "anthropic_messages" - ) - if cache_enabled: - lines.append(t("gateway.model.prompt_caching_enabled")) + return await self._request_slash_confirm( + event=event, + command="model", + title="Expensive Model Warning", + message=( + f"⚠️ **Expensive Model Warning**\n\n{_cost_warning.message}\n\n" + "_Text fallback: reply `/approve` to switch or `/cancel` to keep " + "the current model._" + ), + handler=_on_cost_confirm, + ) - if result.warning_message: - lines.append(t("gateway.model.warning_prefix", warning=result.warning_message)) - - if persist_global: - lines.append(t("gateway.model.saved_global")) - else: - lines.append(t("gateway.model.session_only_hint")) - - return "\n".join(lines) + return await _finish_switch() async def _handle_codex_runtime_command(self, event: MessageEvent) -> str: """Handle /codex-runtime command in the gateway. diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 267035fe0d2..f848a4ad5c0 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -2460,7 +2460,10 @@ async def set_model_assignment(body: ModelAssignment): try: from hermes_cli.model_cost_guard import expensive_model_warning - warning = expensive_model_warning( + # Pricing lookup can hit models.dev / a /models endpoint on a + # cache miss — keep it off the event loop. + warning = await asyncio.to_thread( + expensive_model_warning, model, provider=provider, base_url=base_url, diff --git a/plugins/platforms/discord/adapter.py b/plugins/platforms/discord/adapter.py index 06357c2b547..46544cd1f44 100644 --- a/plugins/platforms/discord/adapter.py +++ b/plugins/platforms/discord/adapter.py @@ -5742,11 +5742,14 @@ def _define_discord_view_classes() -> None: cancel_btn.callback = self._on_cancel self.add_item(cancel_btn) - def _expensive_warning_for(self, model_id: str): + async def _expensive_warning_for(self, model_id: str): try: from hermes_cli.model_cost_guard import expensive_model_warning - return expensive_model_warning( + # Pricing lookup can hit models.dev / a /models endpoint on a + # cache miss — keep it off the event loop. + return await asyncio.to_thread( + expensive_model_warning, model_id, provider=self._selected_provider, ) @@ -5840,7 +5843,7 @@ def _define_discord_view_classes() -> None: return model_id = interaction.data["values"][0] - warning = self._expensive_warning_for(model_id) + warning = await self._expensive_warning_for(model_id) if warning is not None: self._build_expensive_confirm(model_id) await interaction.response.edit_message( diff --git a/tests/gateway/test_model_command_expensive_confirm.py b/tests/gateway/test_model_command_expensive_confirm.py new file mode 100644 index 00000000000..c78ae3818af --- /dev/null +++ b/tests/gateway/test_model_command_expensive_confirm.py @@ -0,0 +1,186 @@ +"""Gateway typed ``/model `` must route through the expensive-model +confirmation gate. + +The pickers (Telegram/Discord inline keyboards, TUI, dashboard) confirm +expensive models via their own UI affordances; the typed text command +previously bypassed the guard entirely — a user typing +``/model openai/gpt-5.5-pro`` switched silently while the picker warned. +These tests pin the typed path: + +- warning fires → handler returns the slash-confirm prompt, switch NOT applied +- confirm ("once") → switch applies (session override set) +- cancel → switch not applied, current model unchanged +- no warning (cheap model) → switch applies immediately, no prompt +""" + +from types import SimpleNamespace + +import pytest +import yaml + +from gateway.config import Platform +from gateway.platforms.base import MessageEvent, MessageType +from gateway.run import GatewayRunner +from gateway.session import SessionSource + + +def _make_runner(): + runner = object.__new__(GatewayRunner) + runner.adapters = {} + runner._voice_mode = {} + runner._session_model_overrides = {} + runner._running_agents = {} + return runner + + +def _make_event(text): + return MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=SessionSource(platform=Platform.TELEGRAM, chat_id="12345", chat_type="dm"), + ) + + +def _fake_switch_result(): + from hermes_cli.model_switch import ModelSwitchResult + + return ModelSwitchResult( + success=True, + new_model="openai/gpt-5.5-pro", + target_provider="openrouter", + provider_changed=False, + api_key="sk-test", + base_url="https://openrouter.ai/api/v1", + api_mode="chat_completions", + provider_label="OpenRouter", + ) + + +def _fake_warning(): + return SimpleNamespace( + message=( + "!!! EXPENSIVE MODEL WARNING !!!\n" + "openai/gpt-5.5-pro has known pricing above Hermes' safety threshold.\n" + "did you mean to select openai/gpt-5.5?" + ), + ) + + +def _setup_isolated_home(tmp_path, monkeypatch, *, warn): + import gateway.run as gateway_run + + hermes_home = tmp_path / ".hermes" + hermes_home.mkdir() + cfg_path = hermes_home / "config.yaml" + cfg_path.write_text( + yaml.safe_dump({"model": {"default": "old-model", "provider": "openrouter"}, "providers": {}}), + encoding="utf-8", + ) + + monkeypatch.setattr(gateway_run, "_hermes_home", hermes_home) + monkeypatch.setattr("agent.models_dev.fetch_models_dev", lambda: {}) + monkeypatch.setattr( + "hermes_cli.model_switch.switch_model", + lambda **kw: _fake_switch_result(), + ) + monkeypatch.setattr("hermes_constants.get_hermes_home", lambda: hermes_home) + monkeypatch.setattr("hermes_cli.config.get_hermes_home", lambda: hermes_home) + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + (lambda *a, **kw: _fake_warning()) if warn else (lambda *a, **kw: None), + ) + return cfg_path + + +@pytest.mark.asyncio +async def test_typed_model_expensive_prompts_instead_of_switching(tmp_path, monkeypatch): + """Expensive model typed directly → confirm prompt, no switch applied.""" + _setup_isolated_home(tmp_path, monkeypatch, warn=True) + runner = _make_runner() + + captured = {} + + async def _fake_request_slash_confirm(**kwargs): + captured.update(kwargs) + return kwargs["message"] + + runner._request_slash_confirm = _fake_request_slash_confirm + + result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro")) + + assert result is not None + assert "EXPENSIVE MODEL WARNING" in result + # The switch must NOT have been applied yet. + assert runner._session_model_overrides == {} + assert captured["command"] == "model" + + +@pytest.mark.asyncio +async def test_typed_model_expensive_confirm_once_applies_switch(tmp_path, monkeypatch): + """Resolving the confirm with "once" applies the switch.""" + _setup_isolated_home(tmp_path, monkeypatch, warn=True) + runner = _make_runner() + runner._evict_cached_agent = lambda session_key: None + + captured = {} + + async def _fake_request_slash_confirm(**kwargs): + captured.update(kwargs) + return None # buttons rendered + + runner._request_slash_confirm = _fake_request_slash_confirm + + await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro")) + assert runner._session_model_overrides == {} + + reply = await captured["handler"]("once") + + assert "gpt-5.5-pro" in reply + overrides = list(runner._session_model_overrides.values()) + assert len(overrides) == 1 + assert overrides[0]["model"] == "openai/gpt-5.5-pro" + + +@pytest.mark.asyncio +async def test_typed_model_expensive_cancel_keeps_current_model(tmp_path, monkeypatch): + """Resolving the confirm with "cancel" leaves everything unchanged.""" + cfg_path = _setup_isolated_home(tmp_path, monkeypatch, warn=True) + runner = _make_runner() + + captured = {} + + async def _fake_request_slash_confirm(**kwargs): + captured.update(kwargs) + return None + + runner._request_slash_confirm = _fake_request_slash_confirm + + await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro --global")) + + reply = await captured["handler"]("cancel") + + assert "cancelled" in reply.lower() + assert runner._session_model_overrides == {} + # --global must not have persisted the cancelled switch. + written = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) + assert written["model"]["default"] == "old-model" + + +@pytest.mark.asyncio +async def test_typed_model_cheap_switches_without_prompt(tmp_path, monkeypatch): + """No warning → switch applies immediately; confirm primitive never invoked.""" + _setup_isolated_home(tmp_path, monkeypatch, warn=False) + runner = _make_runner() + runner._evict_cached_agent = lambda session_key: None + + async def _fail_request_slash_confirm(**kwargs): # pragma: no cover + raise AssertionError("confirm should not be requested for cheap models") + + runner._request_slash_confirm = _fail_request_slash_confirm + + result = await runner._handle_model_command(_make_event("/model openai/gpt-5.5-pro")) + + assert result is not None + assert "gpt-5.5-pro" in result + overrides = list(runner._session_model_overrides.values()) + assert len(overrides) == 1