diff --git a/agent/usage_pricing.py b/agent/usage_pricing.py index 8d6b85cd0b8..95bb11df521 100644 --- a/agent/usage_pricing.py +++ b/agent/usage_pricing.py @@ -13,6 +13,7 @@ DEFAULT_PRICING = {"input": 0.0, "output": 0.0} _ZERO = Decimal("0") _ONE_MILLION = Decimal("1000000") +_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1" CostStatus = Literal["actual", "estimated", "included", "unknown"] CostSource = Literal[ @@ -570,6 +571,8 @@ def resolve_billing_route( return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included") if provider_name == "openrouter" or base_url_host_matches(base_url or "", "openrouter.ai"): return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api") + if provider_name == "nous" or base_url_host_matches(base_url or "", "inference-api.nousresearch.com"): + return BillingRoute(provider="nous", model=model, base_url=base_url or _NOUS_DEFAULT_BASE_URL, billing_mode="official_models_api") if provider_name == "anthropic": return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot") if provider_name == "openai": diff --git a/cli.py b/cli.py index 70c30f00730..641c200ad3d 100644 --- a/cli.py +++ b/cli.py @@ -6516,6 +6516,47 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin): } self._invalidate(min_interval=0.0) + def _confirm_expensive_model_switch(self, result) -> bool: + """Ask for explicit confirmation before applying costly model switches.""" + if not getattr(result, "success", False): + return True + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + warning = expensive_model_warning( + result.new_model, + provider=result.target_provider, + base_url=result.base_url or self.base_url or "", + api_key=result.api_key or self.api_key or "", + model_info=result.model_info, + ) + except Exception: + warning = None + if warning is None: + return True + + choices = [ + ("once", "Switch anyway", "Use this model for the current Hermes session."), + ("cancel", "Cancel", "Keep the current model."), + ] + raw = self._prompt_text_input_modal( + title="!!! Expensive Model Warning !!!", + detail=warning.message, + choices=choices, + timeout=120, + ) + choice = self._normalize_slash_confirm_choice(raw, choices) + return choice == "once" + + def _confirm_and_apply_model_switch_result(self, result, persist_global: bool) -> None: + try: + if result.success and not self._confirm_expensive_model_switch(result): + _cprint(" Model switch cancelled.") + return + self._apply_model_switch_result(result, persist_global) + except Exception as exc: + _cprint(f" ✗ Model selection failed: {exc}") + def _close_model_picker(self) -> None: self._model_picker_state = None self._restore_modal_input_snapshot() @@ -6692,7 +6733,14 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin): custom_providers=state.get("custom_provs"), ) self._close_model_picker() - self._apply_model_switch_result(result, persist_global) + if getattr(self, "_app", None): + threading.Thread( + target=self._confirm_and_apply_model_switch_result, + args=(result, persist_global), + daemon=True, + ).start() + else: + self._confirm_and_apply_model_switch_result(result, persist_global) return self._close_model_picker() @@ -6793,6 +6841,10 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin): _cprint(f" ✗ {result.error_message}") return + if not self._confirm_expensive_model_switch(result): + _cprint(" Model switch cancelled.") + return + # Apply to CLI state. # Update requested_provider so _ensure_runtime_credentials() doesn't # overwrite the switch on the next turn (it re-resolves from this). diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index fca6fa66255..7aec7f99a0a 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -3030,7 +3030,7 @@ class TelegramAdapter(BasePlatformAdapter): async def _handle_model_picker_callback( self, query, data: str, chat_id: str ) -> None: - """Handle model picker inline keyboard callbacks (mp:/mm:/mb:/mx:/mg:).""" + """Handle model picker inline keyboard callbacks (mp:/mm:/mc:/mb:/mx:/mg:).""" state = self._model_picker_state.get(chat_id) if not state: await query.answer(text="Picker expired — use /model again.") @@ -3115,6 +3115,51 @@ class TelegramAdapter(BasePlatformAdapter): ) await query.answer() + elif data.startswith("mc:"): + # --- Expensive model confirmed: perform the switch --- + try: + idx = int(data[3:]) + except ValueError: + await query.answer(text="Invalid selection.") + return + + model_list = state.get("model_list", []) + if idx < 0 or idx >= len(model_list): + await query.answer(text="Invalid model index.") + return + + model_id = model_list[idx] + provider_slug = state.get("selected_provider", "") + callback = state.get("on_model_selected") + + if not callback: + await query.answer(text="Picker expired.") + return + + try: + result_text = await callback(chat_id, model_id, provider_slug) + except Exception as exc: + logger.error("Model picker switch failed: %s", exc) + result_text = f"Error switching model: {exc}" + + try: + await query.edit_message_text( + text=self.format_message(result_text), + parse_mode=ParseMode.MARKDOWN_V2, + reply_markup=None, + ) + except Exception: + try: + await query.edit_message_text( + text=result_text, + parse_mode=None, + reply_markup=None, + ) + except Exception: + pass + await query.answer(text="Model switched!") + self._model_picker_state.pop(chat_id, None) + elif data.startswith("mm:"): # --- Model selected: perform the switch --- try: @@ -3136,6 +3181,33 @@ class TelegramAdapter(BasePlatformAdapter): await query.answer(text="Picker expired.") return + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + warning = expensive_model_warning( + model_id, + provider=provider_slug, + ) + except Exception: + warning = None + if warning is not None: + keyboard = InlineKeyboardMarkup([ + [InlineKeyboardButton("Switch anyway", callback_data=f"mc:{idx}")], + [ + InlineKeyboardButton("◀ Back", callback_data="mb"), + InlineKeyboardButton("✗ Cancel", callback_data="mx"), + ], + ]) + await query.edit_message_text( + text=self.format_message( + f"⚠ *Expensive Model Warning*\n\n{warning.message}" + ), + parse_mode=ParseMode.MARKDOWN_V2, + reply_markup=keyboard, + ) + await query.answer(text="Confirm expensive model") + return + try: result_text = await callback(chat_id, model_id, provider_slug) except Exception as exc: @@ -3260,7 +3332,7 @@ class TelegramAdapter(BasePlatformAdapter): query_user_name = getattr(query.from_user, "first_name", None) # --- Model picker callbacks --- - if data.startswith(("mp:", "mpg:", "mm:", "mb", "mx", "mg:")): + if data.startswith(("mp:", "mpg:", "mm:", "mc:", "mb", "mx", "mg:")): chat_id = str(query.message.chat_id) if query.message else None if chat_id: await self._handle_model_picker_callback(query, data, chat_id) diff --git a/hermes_cli/auth.py b/hermes_cli/auth.py index c3c48a1fd4b..a65e9ea78b8 100644 --- a/hermes_cli/auth.py +++ b/hermes_cli/auth.py @@ -6175,6 +6175,40 @@ def _reset_config_provider() -> Path: return config_path +def _confirm_expensive_model_selection( + model_id: str, + *, + provider: str = "", + base_url: str = "", + api_key: str = "", +) -> bool: + """Prompt before saving a model whose known pricing exceeds guardrails.""" + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + warning = expensive_model_warning( + model_id, + provider=provider, + base_url=base_url, + api_key=api_key, + ) + except Exception: + warning = None + if warning is None: + return True + + print() + print("=" * 72) + print(warning.message) + print("=" * 72) + try: + response = input("Switch anyway? [y/N]: ").strip().lower() + except (KeyboardInterrupt, EOFError): + print() + return False + return response in {"y", "yes"} + + def _prompt_model_selection( model_ids: List[str], current_model: str = "", @@ -6182,6 +6216,9 @@ def _prompt_model_selection( unavailable_models: Optional[List[str]] = None, portal_url: str = "", unavailable_message: str = "", + confirm_provider: str = "", + confirm_base_url: str = "", + confirm_api_key: str = "", ) -> Optional[str]: """Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None. @@ -6195,6 +6232,18 @@ def _prompt_model_selection( _unavailable = unavailable_models or [] + def _confirmed_selection(mid: str) -> Optional[str]: + if not mid: + return None + if confirm_provider and not _confirm_expensive_model_selection( + mid, + provider=confirm_provider, + base_url=confirm_base_url, + api_key=confirm_api_key, + ): + return None + return mid + # Reorder: current model first, then the rest (deduplicated) ordered = [] if current_model and current_model in model_ids: @@ -6310,13 +6359,13 @@ def _prompt_model_selection( return None print() if idx < len(ordered): - return ordered[idx] + return _confirmed_selection(ordered[idx]) elif idx == len(ordered): try: custom = input("Enter model name: ").strip() except (EOFError, KeyboardInterrupt): return None - return custom if custom else None + return _confirmed_selection(custom) if custom else None return None except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError): pass @@ -6348,10 +6397,10 @@ def _prompt_model_selection( return None idx = int(choice) if 1 <= idx <= n: - return ordered[idx - 1] + return _confirmed_selection(ordered[idx - 1]) elif idx == n + 1: custom = input("Enter model name: ").strip() - return custom if custom else None + return _confirmed_selection(custom) if custom else None elif idx == n + 2: return None print(f"Please enter 1-{n + 2}") @@ -7730,6 +7779,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None: unavailable_models=unavailable_models, portal_url=_portal, unavailable_message=unavailable_message, + confirm_provider="nous", + confirm_base_url=inference_base_url, + confirm_api_key=runtime_key, ) elif unavailable_models: _url = (_portal or DEFAULT_NOUS_PORTAL_URL).rstrip("/") diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 16ca582ef06..ad0861609ca 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -3269,6 +3269,7 @@ def _aux_flow_provider_model( model_list, current_model=current_model, pricing=pricing, + confirm_provider=provider_slug, ) if selected is None: print("No change.") diff --git a/hermes_cli/model_cost_guard.py b/hermes_cli/model_cost_guard.py new file mode 100644 index 00000000000..fd7e65b8551 --- /dev/null +++ b/hermes_cli/model_cost_guard.py @@ -0,0 +1,134 @@ +"""Expensive-model confirmation helpers for model selection surfaces.""" + +from __future__ import annotations + +from dataclasses import dataclass +from decimal import Decimal, InvalidOperation +from typing import Optional + +from agent.models_dev import ModelInfo + + +INPUT_COST_WARNING_THRESHOLD = Decimal("20") +OUTPUT_COST_WARNING_THRESHOLD = Decimal("100") +GPT55_PRO_OPENROUTER_ID = "openai/gpt-5.5-pro" +GPT55_SUGGESTION = "did you mean to select openai/gpt-5.5?" + + +@dataclass(frozen=True) +class ExpensiveModelWarning: + """Confirmation payload for models above Hermes' cost guardrail.""" + + model: str + provider: str + input_cost_per_million: Optional[Decimal] + output_cost_per_million: Optional[Decimal] + source: str + message: str + + +def _to_decimal(value: object) -> Optional[Decimal]: + if value is None: + return None + try: + return Decimal(str(value)) + except (InvalidOperation, ValueError): + return None + + +def _format_money(value: Optional[Decimal]) -> str: + if value is None: + return "unknown" + return f"${value:.2f}/M" + + +def _pricing_from_model_info( + model_info: Optional[ModelInfo], +) -> tuple[Optional[Decimal], Optional[Decimal], str]: + if model_info is None or not model_info.has_cost_data(): + return None, None, "" + return ( + _to_decimal(model_info.cost_input), + _to_decimal(model_info.cost_output), + "models.dev", + ) + + +def expensive_model_warning( + model_name: str, + *, + provider: Optional[str] = None, + base_url: Optional[str] = None, + api_key: Optional[str] = None, + model_info: Optional[ModelInfo] = None, +) -> Optional[ExpensiveModelWarning]: + """Return a warning payload when known pricing exceeds safety thresholds. + + The guard only triggers when pricing is known. Callers should use this after + model resolution so aliases and provider-specific model IDs have settled. + """ + model = (model_name or "").strip() + if not model: + return None + + input_cost, output_cost, source = _pricing_from_model_info(model_info) + if input_cost is None and output_cost is None and provider: + try: + from agent.models_dev import get_model_info + + input_cost, output_cost, source = _pricing_from_model_info( + get_model_info(provider, model) + ) + except Exception: + pass + if input_cost is None and output_cost is None: + try: + from agent.usage_pricing import get_pricing_entry + + entry = get_pricing_entry( + model, + provider=provider, + base_url=base_url, + api_key=api_key, + ) + except Exception: + entry = None + if entry is not None: + input_cost = entry.input_cost_per_million + output_cost = entry.output_cost_per_million + source = entry.source + + over_input = ( + input_cost is not None and input_cost > INPUT_COST_WARNING_THRESHOLD + ) + over_output = ( + output_cost is not None and output_cost > OUTPUT_COST_WARNING_THRESHOLD + ) + if not over_input and not over_output: + return None + + lines = [ + "!!! EXPENSIVE MODEL WARNING !!!", + "", + f"{model} has known pricing above Hermes' safety threshold.", + f"Input tokens: {_format_money(input_cost)}", + f"Output tokens: {_format_money(output_cost)}", + ( + "Threshold: more than $20/M input tokens or more than " + "$100/M output tokens." + ), + ] + if source: + lines.append(f"Pricing source: {source}.") + if model.lower() == GPT55_PRO_OPENROUTER_ID: + lines.append(GPT55_SUGGESTION) + lines.append("Confirm only if you intend to use this model.") + + return ExpensiveModelWarning( + model=model, + provider=(provider or "").strip(), + input_cost_per_million=input_cost, + output_cost_per_million=output_cost, + source=source or "unknown", + message="\n".join(lines), + ) diff --git a/hermes_cli/model_setup_flows.py b/hermes_cli/model_setup_flows.py index 2885b241dca..83e60fc20a2 100644 --- a/hermes_cli/model_setup_flows.py +++ b/hermes_cli/model_setup_flows.py @@ -102,7 +102,12 @@ def _model_flow_openrouter(config, current_model=""): pricing = get_pricing_for_provider("openrouter", force_refresh=True) selected = _prompt_model_selection( - openrouter_models, current_model=current_model, pricing=pricing + openrouter_models, + current_model=current_model, + pricing=pricing, + confirm_provider="openrouter", + confirm_base_url=OPENROUTER_BASE_URL, + confirm_api_key=_resolved or existing_key, ) if selected: _save_model_choice(selected) @@ -311,6 +316,9 @@ def _model_flow_nous(config, current_model="", args=None): unavailable_models=unavailable_models, portal_url=_nous_portal_url, unavailable_message=unavailable_message, + confirm_provider="nous", + confirm_base_url=creds.get("base_url", ""), + confirm_api_key=creds.get("api_key", ""), ) if selected: _save_model_choice(selected) @@ -416,7 +424,13 @@ def _model_flow_openai_codex(config, current_model=""): codex_models = get_codex_model_ids(access_token=_codex_token) - selected = _prompt_model_selection(codex_models, current_model=current_model) + selected = _prompt_model_selection( + codex_models, + current_model=current_model, + confirm_provider="openai-codex", + confirm_base_url=DEFAULT_CODEX_BASE_URL, + confirm_api_key=_codex_token or "", + ) if selected: _save_model_choice(selected) _update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL) @@ -546,7 +560,12 @@ def _model_flow_qwen_oauth(_config, current_model=""): models = list(_DEFAULT_QWEN_PORTAL_MODELS) default = current_model or (models[0] if models else "qwen3-coder-plus") - selected = _prompt_model_selection(models, current_model=default) + selected = _prompt_model_selection( + models, + current_model=default, + confirm_provider="qwen-oauth", + confirm_base_url=DEFAULT_QWEN_BASE_URL, + ) if selected: _save_model_choice(selected) _update_config_for_provider("qwen-oauth", DEFAULT_QWEN_BASE_URL) @@ -595,7 +614,12 @@ def _model_flow_minimax_oauth(config, current_model="", args=None): from hermes_cli.models import _PROVIDER_MODELS model_ids = _PROVIDER_MODELS.get("minimax-oauth", []) - selected = _prompt_model_selection(model_ids, current_model) + selected = _prompt_model_selection( + model_ids, + current_model, + confirm_provider="minimax-oauth", + confirm_base_url=creds["base_url"], + ) if not selected: return _save_model_choice(selected) @@ -664,7 +688,12 @@ def _model_flow_google_gemini_cli(_config, current_model=""): models = list(_PROVIDER_MODELS.get("google-gemini-cli") or []) default = current_model or (models[0] if models else "gemini-3-flash-preview") - selected = _prompt_model_selection(models, current_model=default) + selected = _prompt_model_selection( + models, + current_model=default, + confirm_provider="google-gemini-cli", + confirm_base_url=DEFAULT_GEMINI_CLOUDCODE_BASE_URL, + ) if selected: _save_model_choice(selected) _update_config_for_provider( @@ -1589,7 +1618,11 @@ def _model_flow_copilot(config, current_model=""): if model_list: selected = _prompt_model_selection( - model_list, current_model=normalized_current_model + model_list, + current_model=normalized_current_model, + confirm_provider=provider_id, + confirm_base_url=effective_base, + confirm_api_key=api_key, ) else: try: @@ -1727,6 +1760,9 @@ def _model_flow_copilot_acp(config, current_model=""): selected = _prompt_model_selection( model_list, current_model=normalized_current_model, + confirm_provider=provider_id, + confirm_base_url=effective_base, + confirm_api_key=catalog_api_key, ) else: try: @@ -1831,7 +1867,13 @@ def _model_flow_kimi(config, current_model=""): model_list = _PROVIDER_MODELS.get("moonshot", []) if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider=provider_id, + confirm_base_url=effective_base, + confirm_api_key=existing_key, + ) else: try: selected = input("Enter model name: ").strip() @@ -1939,7 +1981,13 @@ def _model_flow_stepfun(config, current_model=""): ) if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider=provider_id, + confirm_base_url=effective_base, + confirm_api_key=existing_key, + ) else: try: selected = input("Model name: ").strip() @@ -2015,7 +2063,13 @@ def _model_flow_bedrock_api_key(config, region, current_model=""): print(f" Showing {len(model_list)} curated models") if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider="custom", + confirm_base_url=mantle_base_url, + confirm_api_key=existing_key, + ) else: try: selected = input(" Model ID: ").strip() @@ -2204,7 +2258,12 @@ def _model_flow_bedrock(config, current_model=""): # 4. Model selection if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider="bedrock", + confirm_base_url=f"https://bedrock-runtime.{region}.amazonaws.com", + ) else: try: selected = input(" Model ID: ").strip() @@ -2488,7 +2547,13 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): model_list = list(dict.fromkeys(mid for mid in model_list if mid)) if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider=provider_id, + confirm_base_url=effective_base, + confirm_api_key=existing_key, + ) else: try: selected = input("Model name: ").strip() @@ -2638,7 +2703,11 @@ def _model_flow_anthropic(config, current_model=""): # Model selection model_list = _PROVIDER_MODELS.get("anthropic", []) if model_list: - selected = _prompt_model_selection(model_list, current_model=current_model) + selected = _prompt_model_selection( + model_list, + current_model=current_model, + confirm_provider="anthropic", + ) else: try: selected = input("Model name (e.g., claude-sonnet-4-20250514): ").strip() diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 061a2e46816..267035fe0d2 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -657,21 +657,6 @@ class AudioTranscriptionRequest(BaseModel): mime_type: Optional[str] = None -class ModelAssignment(BaseModel): - """Payload for POST /api/model/set — assign a provider/model to a slot. - - scope="main" → writes model.provider + model.default - scope="auxiliary" → writes auxiliary..provider + auxiliary..model - scope="auxiliary" with task="" → applied to every auxiliary.* slot - scope="auxiliary" with task="__reset__" → resets every slot to provider="auto" - """ - - scope: str - provider: str - model: str - task: str = "" - - _AUDIO_MIME_EXTENSIONS: Dict[str, str] = { "audio/aac": ".aac", "audio/flac": ".flac", @@ -713,6 +698,7 @@ class ModelAssignment(BaseModel): # reads model.base_url from config (it ignores OPENAI_BASE_URL), so this is # the path that actually wires a local endpoint into resolution. base_url: str = "" + confirm_expensive_model: bool = False def _apply_main_model_assignment( @@ -2470,6 +2456,27 @@ async def set_model_assignment(body: ModelAssignment): try: cfg = load_config() + if model and not body.confirm_expensive_model: + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + warning = expensive_model_warning( + model, + provider=provider, + base_url=base_url, + ) + except Exception: + warning = None + if warning is not None: + return { + "ok": False, + "scope": scope, + "provider": provider, + "model": model, + "confirm_required": True, + "confirm_message": warning.message, + } + if scope == "main": if not provider or not model: raise HTTPException(status_code=400, detail="provider and model required for main") diff --git a/plugins/platforms/discord/adapter.py b/plugins/platforms/discord/adapter.py index d9db208fc4f..06357c2b547 100644 --- a/plugins/platforms/discord/adapter.py +++ b/plugins/platforms/discord/adapter.py @@ -5638,6 +5638,7 @@ def _define_discord_view_classes() -> None: self.allowed_role_ids = allowed_role_ids or set() self.resolved = False self._selected_provider: str = "" + self._pending_expensive_model: str = "" self._build_provider_select() @@ -5720,6 +5721,38 @@ def _define_discord_view_classes() -> None: cancel_btn.callback = self._on_cancel self.add_item(cancel_btn) + def _build_expensive_confirm(self, model_id: str): + """Build confirmation buttons for unusually expensive models.""" + self.clear_items() + self._pending_expensive_model = model_id + + confirm_btn = discord.ui.Button( + label="Switch anyway", + style=discord.ButtonStyle.red, + custom_id="model_expensive_confirm", + ) + confirm_btn.callback = self._on_expensive_confirm + self.add_item(confirm_btn) + + cancel_btn = discord.ui.Button( + label="Cancel", + style=discord.ButtonStyle.grey, + custom_id="model_expensive_cancel", + ) + cancel_btn.callback = self._on_cancel + self.add_item(cancel_btn) + + def _expensive_warning_for(self, model_id: str): + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + return expensive_model_warning( + model_id, + provider=self._selected_provider, + ) + except Exception: + return None + async def _on_provider_selected(self, interaction: discord.Interaction): if not self._check_auth(interaction): await interaction.response.send_message( @@ -5749,7 +5782,11 @@ def _define_discord_view_classes() -> None: view=self, ) - async def _on_model_selected(self, interaction: discord.Interaction): + async def _switch_selected_model( + self, + interaction: discord.Interaction, + model_id: str, + ): if self.resolved: await interaction.response.send_message( "Already resolved~", ephemeral=True @@ -5762,7 +5799,6 @@ def _define_discord_view_classes() -> None: return self.resolved = True - model_id = interaction.data["values"][0] self.clear_items() await interaction.response.edit_message( embed=discord.Embed( @@ -5791,6 +5827,50 @@ def _define_discord_view_classes() -> None: view=None, ) + async def _on_model_selected(self, interaction: discord.Interaction): + if self.resolved: + await interaction.response.send_message( + "Already resolved~", ephemeral=True + ) + return + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized~", ephemeral=True + ) + return + + model_id = interaction.data["values"][0] + warning = self._expensive_warning_for(model_id) + if warning is not None: + self._build_expensive_confirm(model_id) + await interaction.response.edit_message( + embed=discord.Embed( + title="⚠ Expensive Model Warning", + description=warning.message, + color=discord.Color.red(), + ), + view=self, + ) + return + + await self._switch_selected_model(interaction, model_id) + + async def _on_expensive_confirm(self, interaction: discord.Interaction): + if not self._check_auth(interaction): + await interaction.response.send_message( + "You're not authorized~", ephemeral=True + ) + return + if not self._pending_expensive_model: + await interaction.response.send_message( + "Model selection expired.", ephemeral=True + ) + return + await self._switch_selected_model( + interaction, + self._pending_expensive_model, + ) + async def _on_back(self, interaction: discord.Interaction): if not self._check_auth(interaction): await interaction.response.send_message( diff --git a/tests/agent/test_usage_pricing.py b/tests/agent/test_usage_pricing.py index 3a745a60441..319a8028b3e 100644 --- a/tests/agent/test_usage_pricing.py +++ b/tests/agent/test_usage_pricing.py @@ -192,6 +192,32 @@ def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch): assert float(entry.output_cost_per_million) == 2.0 +def test_nous_portal_pricing_preserves_vendor_prefixed_model_ids(monkeypatch): + seen = {} + + def _fake_fetch_endpoint_model_metadata(base_url, api_key=None): + seen["base_url"] = base_url + return { + "openai/gpt-5.5-pro": { + "pricing": { + "prompt": "0.000025", + "completion": "0.000125", + } + } + } + + monkeypatch.setattr( + "agent.usage_pricing.fetch_endpoint_model_metadata", + _fake_fetch_endpoint_model_metadata, + ) + + entry = get_pricing_entry("openai/gpt-5.5-pro", provider="nous") + + assert seen["base_url"] == "https://inference-api.nousresearch.com/v1" + assert float(entry.input_cost_per_million) == 25.0 + assert float(entry.output_cost_per_million) == 125.0 + + def test_deepseek_v4_pro_pricing_entry_exists(): """Regression test: deepseek-v4-pro must have a pricing entry. diff --git a/tests/gateway/test_discord_model_picker.py b/tests/gateway/test_discord_model_picker.py index a07abfb21c3..86025f4e429 100644 --- a/tests/gateway/test_discord_model_picker.py +++ b/tests/gateway/test_discord_model_picker.py @@ -80,3 +80,91 @@ async def test_model_picker_clears_controls_before_running_switch_callback(): interaction.response.edit_message.assert_awaited_once() interaction.response.defer.assert_not_called() interaction.edit_original_response.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_expensive_model_requires_confirmation(monkeypatch): + events: list[object] = [] + + async def on_model_selected(chat_id: str, model_id: str, provider_slug: str) -> str: + events.append(("switch", chat_id, model_id, provider_slug)) + return "Model switched" + + async def edit_message(**kwargs): + events.append( + ( + "edit", + kwargs["embed"].title, + kwargs["embed"].description, + kwargs["view"], + ) + ) + + async def edit_original_response(**kwargs): + events.append(( + "final-edit", + kwargs["embed"].title, + kwargs["embed"].description, + kwargs["view"], + )) + + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + lambda *_args, **_kwargs: SimpleNamespace( + message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?" + ), + ) + + view = ModelPickerView( + providers=[ + { + "slug": "openrouter", + "name": "OpenRouter", + "models": ["openai/gpt-5.5-pro"], + "total_models": 1, + "is_current": True, + } + ], + current_model="openai/gpt-5.5", + current_provider="openrouter", + session_key="session-1", + on_model_selected=on_model_selected, + allowed_user_ids={"123"}, # matches the interaction user; empty = fail-closed + ) + view._selected_provider = "openrouter" + + interaction = SimpleNamespace( + user=SimpleNamespace(id=123), + channel_id=456, + data={"values": ["openai/gpt-5.5-pro"]}, + response=SimpleNamespace( + send_message=AsyncMock(), + edit_message=AsyncMock(side_effect=edit_message), + ), + edit_original_response=AsyncMock(side_effect=edit_original_response), + ) + + await view._on_model_selected(interaction) + + assert events == [ + ( + "edit", + "⚠ Expensive Model Warning", + "!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?", + view, + ), + ] + assert view.resolved is False + + await view._on_expensive_confirm(interaction) + + assert events[1:] == [ + ( + "edit", + "⚙ Switching Model", + "Switching to `openai/gpt-5.5-pro`...", + None, + ), + ("switch", "456", "openai/gpt-5.5-pro", "openrouter"), + ("final-edit", "⚙ Model Switched", "Model switched", None), + ] diff --git a/tests/gateway/test_telegram_model_picker.py b/tests/gateway/test_telegram_model_picker.py index f6c887ef3f4..7b91b92647a 100644 --- a/tests/gateway/test_telegram_model_picker.py +++ b/tests/gateway/test_telegram_model_picker.py @@ -91,10 +91,6 @@ class TestTelegramModelPicker: query.answer = AsyncMock() query.edit_message_text = AsyncMock() - update = MagicMock() - update.callback_query = query - context = MagicMock() - await adapter._handle_model_picker_callback(query, "mb", "12345") edit_kwargs = query.edit_message_text.call_args[1] @@ -133,17 +129,11 @@ class TestTelegramModelPicker: await adapter._handle_model_picker_callback(query, "mm:0", "12345") - # The callback was invoked with the selected model callback.assert_awaited_once() - # edit_message_text MUST be called on the success path (this is the - # regression we're guarding). query.edit_message_text.assert_awaited() edit_kwargs = query.edit_message_text.call_args[1] assert "MARKDOWN_V2" in repr(edit_kwargs["parse_mode"]) - # The dynamic result text was routed through format_message - # (backtick code blocks survive escaping). assert "`gpt-5`" in edit_kwargs["text"] - # State is cleaned up after a successful switch. assert "12345" not in adapter._model_picker_state @pytest.mark.asyncio @@ -184,7 +174,7 @@ class TestTelegramModelPicker: providers = [ {"slug": "minimax", "name": "MiniMax", "total_models": 2}, {"slug": "minimax-cn", "name": "MiniMax (China)", "total_models": 3}, - {"slug": "xai", "name": "xAI", "total_models": 1}, # lone group member + {"slug": "xai", "name": "xAI", "total_models": 1}, ] await adapter.send_model_picker( @@ -197,14 +187,11 @@ class TestTelegramModelPicker: metadata=None, ) - # Top-level keyboard: MiniMax family folded into one group button; - # xai (lone member) degraded to a direct provider button. assert "mpg:minimax" in built assert "mp:xai" in built assert "mp:minimax" not in built assert "mp:minimax-cn" not in built - # Drill into the MiniMax group → members appear as mp: buttons + back. built.clear() query = AsyncMock() query.message = MagicMock() @@ -216,7 +203,49 @@ class TestTelegramModelPicker: assert "mp:minimax" in built assert "mp:minimax-cn" in built - assert "mb" in built # back-to-providers button present + assert "mb" in built + + @pytest.mark.asyncio + async def test_expensive_model_requires_confirmation(self, monkeypatch): + adapter = _make_adapter() + callback = AsyncMock(return_value="Switched to `openai/gpt-5.5-pro`") + adapter._model_picker_state["12345"] = { + "providers": [ + {"slug": "openrouter", "name": "OpenRouter", "total_models": 1, "is_current": True} + ], + "current_model": "model_1", + "current_provider": "openrouter", + "session_key": "s", + "on_model_selected": callback, + "selected_provider": "openrouter", + "model_list": ["openai/gpt-5.5-pro"], + "msg_id": 42, + } + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + lambda *_args, **_kwargs: SimpleNamespace( + message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?" + ), + ) + + query = AsyncMock() + query.message = MagicMock() + query.message.chat_id = 12345 + query.answer = AsyncMock() + query.edit_message_text = AsyncMock() + + await adapter._handle_model_picker_callback(query, "mm:0", "12345") + + callback.assert_not_awaited() + assert "12345" in adapter._model_picker_state + first_edit = query.edit_message_text.call_args[1] + assert "EXPENSIVE MODEL WARNING" in first_edit["text"] + assert first_edit["reply_markup"] is not None + + await adapter._handle_model_picker_callback(query, "mc:0", "12345") + + callback.assert_awaited_once_with("12345", "openai/gpt-5.5-pro", "openrouter") + assert "12345" not in adapter._model_picker_state @pytest.mark.asyncio async def test_retries_without_thread_when_thread_not_found(self): diff --git a/tests/hermes_cli/test_codex_models.py b/tests/hermes_cli/test_codex_models.py index 7d8fa81dc91..f755fe7a320 100644 --- a/tests/hermes_cli/test_codex_models.py +++ b/tests/hermes_cli/test_codex_models.py @@ -133,7 +133,7 @@ def test_model_command_uses_runtime_access_token_for_codex_list(monkeypatch): captured["access_token"] = access_token return ["gpt-5.2-codex", "gpt-5.2"] - def _fake_prompt_model_selection(model_ids, current_model=""): + def _fake_prompt_model_selection(model_ids, current_model="", **_kwargs): captured["model_ids"] = list(model_ids) captured["current_model"] = current_model return None @@ -181,7 +181,7 @@ def test_model_command_prompts_to_reuse_or_reauthenticate_codex_session(monkeypa ) monkeypatch.setattr( "hermes_cli.auth._prompt_model_selection", - lambda model_ids, current_model="": None, + lambda model_ids, current_model="", **_kwargs: None, ) _model_flow_openai_codex({}, current_model="gpt-5.4") @@ -219,7 +219,7 @@ def test_model_command_uses_existing_codex_session_without_relogin(monkeypatch): ) monkeypatch.setattr( "hermes_cli.auth._prompt_model_selection", - lambda model_ids, current_model="": None, + lambda model_ids, current_model="", **_kwargs: None, ) monkeypatch.setattr( "hermes_cli.auth._login_openai_codex", diff --git a/tests/hermes_cli/test_model_cost_guard.py b/tests/hermes_cli/test_model_cost_guard.py new file mode 100644 index 00000000000..5e6e146e3eb --- /dev/null +++ b/tests/hermes_cli/test_model_cost_guard.py @@ -0,0 +1,97 @@ +from decimal import Decimal + +from agent.models_dev import ModelInfo +from agent.usage_pricing import PricingEntry +from hermes_cli.model_cost_guard import expensive_model_warning + + +def test_no_warning_when_known_prices_are_at_threshold(): + info = ModelInfo( + id="edge/model", + name="edge/model", + family="", + provider_id="test", + cost_input=20.0, + cost_output=100.0, + ) + + assert expensive_model_warning("edge/model", provider="test", model_info=info) is None + + +def test_warns_when_models_dev_input_price_exceeds_threshold(): + info = ModelInfo( + id="expensive/input", + name="expensive/input", + family="", + provider_id="test", + cost_input=20.01, + cost_output=1.0, + ) + + warning = expensive_model_warning( + "expensive/input", + provider="test", + model_info=info, + ) + + assert warning is not None + assert warning.input_cost_per_million == Decimal("20.01") + assert "EXPENSIVE MODEL WARNING" in warning.message + assert "$20/M input" in warning.message + + +def test_warns_when_pricing_entry_output_price_exceeds_threshold(monkeypatch): + monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + "agent.usage_pricing.get_pricing_entry", + lambda *_args, **_kwargs: PricingEntry( + input_cost_per_million=Decimal("1.00"), + output_cost_per_million=Decimal("100.01"), + source="provider_models_api", + ), + ) + + warning = expensive_model_warning("provider/expensive-output", provider="openrouter") + + assert warning is not None + assert warning.output_cost_per_million == Decimal("100.01") + assert "$100.01/M" in warning.message + + +def test_openai_gpt55_pro_adds_suggestion(monkeypatch): + monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + "agent.usage_pricing.get_pricing_entry", + lambda *_args, **_kwargs: PricingEntry( + input_cost_per_million=Decimal("25"), + output_cost_per_million=Decimal("125"), + source="provider_models_api", + ), + ) + + warning = expensive_model_warning("openai/gpt-5.5-pro", provider="openrouter") + + assert warning is not None + assert "did you mean to select openai/gpt-5.5?" in warning.message + + +def test_openai_gpt55_pro_warns_for_nous_portal_pricing(monkeypatch): + monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr( + "agent.usage_pricing.fetch_endpoint_model_metadata", + lambda base_url, api_key="": { + "openai/gpt-5.5-pro": { + "pricing": { + "prompt": "0.000025", + "completion": "0.000125", + } + } + }, + ) + + warning = expensive_model_warning("openai/gpt-5.5-pro", provider="nous") + + assert warning is not None + assert warning.input_cost_per_million == Decimal("25.000000") + assert warning.output_cost_per_million == Decimal("125.000000") + assert "did you mean to select openai/gpt-5.5?" in warning.message diff --git a/tests/hermes_cli/test_model_picker_expensive_confirm.py b/tests/hermes_cli/test_model_picker_expensive_confirm.py new file mode 100644 index 00000000000..b827be3c9e8 --- /dev/null +++ b/tests/hermes_cli/test_model_picker_expensive_confirm.py @@ -0,0 +1,64 @@ +from types import SimpleNamespace + +from hermes_cli.model_switch import ModelSwitchResult + + +def _bound(fn, instance): + return fn.__get__(instance, type(instance)) + + +def test_prompt_toolkit_model_picker_defers_confirmation_off_key_handler(monkeypatch): + import cli as cli_mod + + result = ModelSwitchResult( + success=True, + new_model="openai/gpt-5.5-pro", + target_provider="nous", + ) + monkeypatch.setattr( + "hermes_cli.model_switch.switch_model", + lambda **_kwargs: result, + ) + + captured = {} + + class _Thread: + def __init__(self, *, target, args, daemon): + captured["target"] = target + captured["args"] = args + captured["daemon"] = daemon + + def start(self): + captured["started"] = True + + monkeypatch.setattr(cli_mod.threading, "Thread", _Thread) + + self_ = SimpleNamespace( + _app=object(), + _model_picker_state={ + "stage": "model", + "provider_data": {"slug": "nous"}, + "model_list": ["openai/gpt-5.5-pro"], + "selected": 0, + "user_provs": None, + "custom_provs": None, + }, + provider="nous", + model="openai/gpt-5.5", + base_url="", + api_key="", + _restore_modal_input_snapshot=lambda: None, + _invalidate=lambda **_kwargs: None, + ) + self_._close_model_picker = _bound(cli_mod.HermesCLI._close_model_picker, self_) + self_._confirm_and_apply_model_switch_result = ( + lambda *_args: captured.setdefault("ran_inline", True) + ) + + _bound(cli_mod.HermesCLI._handle_model_picker_selection, self_)() + + assert self_._model_picker_state is None + assert captured["started"] is True + assert captured["daemon"] is True + assert captured["args"] == (result, False) + assert "ran_inline" not in captured diff --git a/tests/hermes_cli/test_terminal_menu_fallbacks.py b/tests/hermes_cli/test_terminal_menu_fallbacks.py index 626858af4ce..642f7a8c0be 100644 --- a/tests/hermes_cli/test_terminal_menu_fallbacks.py +++ b/tests/hermes_cli/test_terminal_menu_fallbacks.py @@ -2,6 +2,7 @@ cannot initialize (e.g. non-TTY, curses unavailable, terminal error).""" import subprocess +from types import SimpleNamespace from hermes_cli.config import load_config, save_config @@ -24,6 +25,46 @@ def test_prompt_model_selection_falls_back_on_menu_runtime_error(monkeypatch): assert selected == "model-b" +def test_prompt_model_selection_requires_expensive_confirmation(monkeypatch, capsys): + from hermes_cli.auth import _prompt_model_selection + + monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu) + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"), + ) + responses = iter(["1", "n"]) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses)) + + selected = _prompt_model_selection( + ["openai/gpt-5.5-pro"], + confirm_provider="nous", + ) + + out = capsys.readouterr().out + assert selected is None + assert "EXPENSIVE MODEL WARNING" in out + + +def test_prompt_model_selection_allows_confirmed_expensive_model(monkeypatch): + from hermes_cli.auth import _prompt_model_selection + + monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu) + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"), + ) + responses = iter(["1", "y"]) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses)) + + selected = _prompt_model_selection( + ["openai/gpt-5.5-pro"], + confirm_provider="nous", + ) + + assert selected == "openai/gpt-5.5-pro" + + def test_prompt_reasoning_effort_falls_back_on_menu_runtime_error(monkeypatch): from hermes_cli.main import _prompt_reasoning_effort_selection diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index 50314debfc1..1ccd4704ca1 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -4,6 +4,7 @@ import os import json import shutil from pathlib import Path +from types import SimpleNamespace from unittest.mock import patch, MagicMock import pytest @@ -1069,6 +1070,41 @@ class TestWebServerEndpoints: assert "GATEWAY_PROXY_URL" not in managed assert "GATEWAY_PROXY_URL" in _MESSAGING_KEYS_PAGE_KEYS + def test_model_set_requires_confirmation_for_expensive_model(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.model_cost_guard.expensive_model_warning", + lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"), + ) + + resp = self.client.post( + "/api/model/set", + json={ + "scope": "main", + "provider": "nous", + "model": "openai/gpt-5.5-pro", + }, + ) + + assert resp.status_code == 200 + data = resp.json() + assert data["ok"] is False + assert data["confirm_required"] is True + assert data["confirm_message"] == "EXPENSIVE MODEL WARNING" + + confirmed = self.client.post( + "/api/model/set", + json={ + "scope": "main", + "provider": "nous", + "model": "openai/gpt-5.5-pro", + "confirm_expensive_model": True, + }, + ) + + assert confirmed.status_code == 200 + assert confirmed.json()["ok"] is True + + def test_reveal_env_var(self, tmp_path): """POST /api/env/reveal should return the real unredacted value.""" from hermes_cli.config import save_env_value diff --git a/tests/test_tui_gateway_server.py b/tests/test_tui_gateway_server.py index 7998af7292c..72dc43564c0 100644 --- a/tests/test_tui_gateway_server.py +++ b/tests/test_tui_gateway_server.py @@ -2397,7 +2397,7 @@ def test_config_set_model_waits_for_lazy_agent_before_switch(monkeypatch): target["agent"] = agent agent_ready.set() - def fake_apply(sid, target, raw): + def fake_apply(sid, target, raw, **kwargs): calls.append(("apply", sid, target.get("agent"), raw)) if target.get("agent") is not agent: raise AssertionError("model switch ran before lazy agent was ready") @@ -2424,7 +2424,7 @@ def test_config_set_model_uses_live_switch_path(monkeypatch): server._sessions["sid"] = _session() seen = {} - def _fake_apply(sid, session, raw): + def _fake_apply(sid, session, raw, **_kwargs): seen["args"] = (sid, session["session_key"], raw) return {"value": "new/model", "warning": "catalog unreachable"} @@ -2442,6 +2442,74 @@ def test_config_set_model_uses_live_switch_path(monkeypatch): assert seen["args"] == ("sid", "session-key", "new/model") +def test_config_set_model_requires_confirmation_for_expensive_model(monkeypatch): + class _Agent: + provider = "openrouter" + model = "old/model" + base_url = "" + api_key = "sk-or" + switched = False + + def switch_model(self, **_kwargs): + self.switched = True + + result = types.SimpleNamespace( + success=True, + new_model="openai/gpt-5.5-pro", + target_provider="openrouter", + api_key="sk-or", + base_url="https://openrouter.ai/api/v1", + api_mode="chat_completions", + warning_message="", + model_info=types.SimpleNamespace( + has_cost_data=lambda: True, + cost_input=25.0, + cost_output=125.0, + ), + ) + + agent = _Agent() + server._sessions["sid"] = _session(agent=agent) + monkeypatch.setattr( + "hermes_cli.model_switch.switch_model", lambda **_kwargs: result + ) + monkeypatch.setattr(server, "_restart_slash_worker", lambda sid, session: None) + monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None) + + resp = server.handle_request( + { + "id": "1", + "method": "config.set", + "params": { + "session_id": "sid", + "key": "model", + "value": "openai/gpt-5.5-pro --provider openrouter", + }, + } + ) + + assert resp["result"]["confirm_required"] is True + assert "did you mean to select openai/gpt-5.5?" in resp["result"]["confirm_message"] + assert agent.switched is False + + confirmed = server.handle_request( + { + "id": "2", + "method": "config.set", + "params": { + "session_id": "sid", + "key": "model", + "value": "openai/gpt-5.5-pro --provider openrouter", + "confirm_expensive_model": True, + }, + } + ) + + assert confirmed["result"]["confirm_required"] is False + assert confirmed["result"]["value"] == "openai/gpt-5.5-pro" + assert agent.switched is True + + def test_config_set_model_global_persists(monkeypatch): class _Agent: provider = "openrouter" @@ -3944,7 +4012,7 @@ def test_config_set_model_rejects_while_running(monkeypatch): """/model via config.set must reject during an in-flight turn.""" seen = {"called": False} - def _fake_apply(sid, session, raw): + def _fake_apply(sid, session, raw, **_kwargs): seen["called"] = True return {"value": raw, "warning": ""} @@ -3978,7 +4046,7 @@ def test_config_set_model_allowed_when_idle(monkeypatch): """Regression guard: idle sessions can still switch models.""" seen = {"called": False} - def _fake_apply(sid, session, raw): + def _fake_apply(sid, session, raw, **_kwargs): seen["called"] = True return {"value": "newmodel", "warning": ""} diff --git a/tui_gateway/server.py b/tui_gateway/server.py index 12bfd502fdb..7aedc0e7813 100644 --- a/tui_gateway/server.py +++ b/tui_gateway/server.py @@ -1696,7 +1696,13 @@ def _persist_model_switch(result) -> None: save_config(cfg) -def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict: +def _apply_model_switch( + sid: str, + session: dict, + raw_input: str, + *, + confirm_expensive_model: bool = False, +) -> dict: from hermes_cli.model_switch import parse_model_flags, switch_model from hermes_cli.runtime_provider import resolve_runtime_provider @@ -1753,6 +1759,27 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict: if not result.success: raise ValueError(result.error_message or "model switch failed") + if not confirm_expensive_model: + try: + from hermes_cli.model_cost_guard import expensive_model_warning + + warning = expensive_model_warning( + result.new_model, + provider=result.target_provider, + base_url=result.base_url or current_base_url, + api_key=result.api_key or current_api_key, + model_info=result.model_info, + ) + except Exception: + warning = None + if warning is not None: + return { + "value": result.new_model, + "warning": warning.message, + "confirm_required": True, + "confirm_message": warning.message, + } + if agent: agent.switch_model( new_model=result.new_model, @@ -1787,7 +1814,11 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict: } if persist_global: _persist_model_switch(result) - return {"value": result.new_model, "warning": result.warning_message or ""} + return { + "value": result.new_model, + "warning": result.warning_message or "", + "confirm_required": False, + } def _compress_session_history( @@ -6196,13 +6227,31 @@ def _(rid, params: dict) -> dict: if session.get("agent") is None: return _err(rid, 5032, "agent initialization failed") result = _apply_model_switch( - params.get("session_id", ""), session, value + params.get("session_id", ""), + session, + value, + confirm_expensive_model=bool( + params.get("confirm_expensive_model", False) + ), ) else: - result = _apply_model_switch("", {"agent": None}, value) + result = _apply_model_switch( + "", + {"agent": None}, + value, + confirm_expensive_model=bool( + params.get("confirm_expensive_model", False) + ), + ) return _ok( rid, - {"key": key, "value": result["value"], "warning": result["warning"]}, + { + "key": key, + "value": result["value"], + "warning": result["warning"], + "confirm_required": result.get("confirm_required", False), + "confirm_message": result.get("confirm_message", ""), + }, ) except Exception as e: return _err(rid, 5001, str(e)) diff --git a/ui-tui/src/__tests__/createSlashHandler.test.ts b/ui-tui/src/__tests__/createSlashHandler.test.ts index 9a68c5b2fbd..a671063e5e9 100644 --- a/ui-tui/src/__tests__/createSlashHandler.test.ts +++ b/ui-tui/src/__tests__/createSlashHandler.test.ts @@ -108,6 +108,7 @@ describe('createSlashHandler', () => { expect(createSlashHandler(ctx)('/model x-model')).toBe(true) expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', { + confirm_expensive_model: false, key: 'model', session_id: 'sid-abc', value: 'x-model' @@ -128,6 +129,7 @@ describe('createSlashHandler', () => { createSlashHandler(ctx)(`/model anthropic/claude-sonnet-4.6 --provider openrouter ${TUI_SESSION_MODEL_FLAG}`) ).toBe(true) expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', { + confirm_expensive_model: false, key: 'model', session_id: 'sid-abc', value: 'anthropic/claude-sonnet-4.6 --provider openrouter' @@ -140,6 +142,7 @@ describe('createSlashHandler', () => { createSlashHandler(ctx)('/model x-model --global') expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', { + confirm_expensive_model: false, key: 'model', session_id: 'sid-abc', value: 'x-model --global' diff --git a/ui-tui/src/app/slash/commands/session.ts b/ui-tui/src/app/slash/commands/session.ts index 2456e718f68..b716504b353 100644 --- a/ui-tui/src/app/slash/commands/session.ts +++ b/ui-tui/src/app/slash/commands/session.ts @@ -72,10 +72,25 @@ export const sessionCommands: SlashCommand[] = [ return patchOverlayState({ modelPicker: true }) } - ctx.gateway - .rpc('config.set', { key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) }) + const switchModel = (confirmExpensiveModel = false) => ctx.gateway + .rpc('config.set', { confirm_expensive_model: confirmExpensiveModel, key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) }) .then( ctx.guarded(r => { + if (r.confirm_required) { + patchOverlayState({ + confirm: { + cancelLabel: 'Cancel', + confirmLabel: 'Switch anyway', + danger: true, + detail: r.confirm_message || r.warning || 'This model has unusually high known pricing.', + onConfirm: () => switchModel(true), + title: 'Expensive model selection' + } + }) + + return + } + if (!r.value) { return ctx.transcript.sys('error: invalid response: model switch') } @@ -89,6 +104,8 @@ export const sessionCommands: SlashCommand[] = [ })) }) ) + + switchModel() } }, diff --git a/ui-tui/src/gatewayTypes.ts b/ui-tui/src/gatewayTypes.ts index 0f1e898415c..531f08feec8 100644 --- a/ui-tui/src/gatewayTypes.ts +++ b/ui-tui/src/gatewayTypes.ts @@ -103,6 +103,8 @@ export interface ConfigGetValueResponse { } export interface ConfigSetResponse { + confirm_message?: string + confirm_required?: boolean credential_warning?: string history_reset?: boolean info?: SessionInfo diff --git a/web/src/components/ChatSidebar.tsx b/web/src/components/ChatSidebar.tsx index e24ddfa5b10..66b15b95f92 100644 --- a/web/src/components/ChatSidebar.tsx +++ b/web/src/components/ChatSidebar.tsx @@ -292,24 +292,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) { setVersion((v) => v + 1); }, []); - // Picker hands us a fully-formed slash command (e.g. "/model anthropic/..."). - // Fire-and-forget through `slash.exec`; the TUI pane will render the result - // via PTY, so the sidebar doesn't need to surface output of its own. - const onModelSubmit = useCallback( - (slashCommand: string) => { - if (!sessionId) { - return; - } - - void gw.request("slash.exec", { - session_id: sessionId, - command: slashCommand, - }); - setModelOpen(false); - }, - [gw, sessionId], - ); - const canPickModel = state === "open" && !!sessionId; const modelLabel = (info.model ?? "—").split("/").slice(-1)[0] ?? "—"; const banner = error ?? info.credential_warning ?? null; @@ -390,7 +372,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) { gw={gw} sessionId={sessionId} onClose={() => setModelOpen(false)} - onSubmit={onModelSubmit} /> )} diff --git a/web/src/components/ConfirmDialog.tsx b/web/src/components/ConfirmDialog.tsx new file mode 100644 index 00000000000..9c257729b45 --- /dev/null +++ b/web/src/components/ConfirmDialog.tsx @@ -0,0 +1,122 @@ +import { Button } from "@nous-research/ui/ui/components/button"; +import { AlertTriangle } from "lucide-react"; +import { useEffect, useRef } from "react"; +import { createPortal } from "react-dom"; +import { cn, themedBody } from "@/lib/utils"; + +interface ConfirmDialogProps { + cancelLabel?: string; + confirmLabel?: string; + description?: string; + destructive?: boolean; + loading?: boolean; + onCancel: () => void; + onConfirm: () => void; + open: boolean; + title: string; +} + +export function ConfirmDialog({ + cancelLabel = "Cancel", + confirmLabel = "Confirm", + description, + destructive = false, + loading = false, + onCancel, + onConfirm, + open, + title, +}: ConfirmDialogProps) { + const dialogRef = useRef(null); + + useEffect(() => { + if (!open) return; + + const prevActive = document.activeElement as HTMLElement | null; + dialogRef.current + ?.querySelector("[data-confirm]") + ?.focus(); + + const onKey = (e: KeyboardEvent) => { + if (e.key === "Escape") { + e.preventDefault(); + onCancel(); + } + }; + + document.addEventListener("keydown", onKey); + const prevOverflow = document.body.style.overflow; + document.body.style.overflow = "hidden"; + + return () => { + document.removeEventListener("keydown", onKey); + document.body.style.overflow = prevOverflow; + prevActive?.focus?.(); + }; + }, [open, onCancel]); + + if (!open) return null; + + return createPortal( +
{ + if (e.target === e.currentTarget) onCancel(); + }} + className="fixed inset-0 z-[200] flex items-center justify-center bg-background/85 backdrop-blur-sm p-4" + > +
+
+ {destructive && ( +
+ +
+ )} + +
+

+ {title} +

+ + {description && ( +

+ {description} +

+ )} +
+
+ +
+ + +
+
+
, + document.body, + ); +} diff --git a/web/src/components/ModelPickerDialog.tsx b/web/src/components/ModelPickerDialog.tsx index 54489dd1f05..96b40ae68b0 100644 --- a/web/src/components/ModelPickerDialog.tsx +++ b/web/src/components/ModelPickerDialog.tsx @@ -4,6 +4,7 @@ import { ListItem } from "@nous-research/ui/ui/components/list-item"; import { Spinner } from "@nous-research/ui/ui/components/spinner"; import { Input } from "@nous-research/ui/ui/components/input"; import { Label } from "@nous-research/ui/ui/components/label"; +import { ConfirmDialog } from "@/components/ConfirmDialog"; import type { GatewayClient } from "@/lib/gatewayClient"; import { Check, Search, X } from "lucide-react"; import { useEffect, useMemo, useRef, useState } from "react"; @@ -21,9 +22,8 @@ import { fuzzyRank } from "@/lib/fuzzy"; * Two invocation modes: * * 1. Chat-session mode (ChatSidebar) — pass `gw` + `sessionId`. The picker - * loads options via `model.options` JSON-RPC and emits the result as a - * slash command string (`/model --provider [--global]`) - * through `onSubmit`, which the ChatPage pipes to `slashExec`. + * loads options via `model.options` JSON-RPC and applies the choice via + * `config.set`, so expensive-model confirmation can happen before switch. * * 2. Standalone mode (ModelsPage, Config settings) — pass a `loader` and * `onApply`. The picker fetches options via the REST endpoint and calls @@ -47,6 +47,23 @@ interface ModelOptionsResponse { providers?: ModelOptionProvider[]; } +interface ExpensiveModelConfirmResponse { + confirm_message?: string; + confirm_required?: boolean; + warning?: string; +} + +interface ConfigSetResponse extends ExpensiveModelConfirmResponse { + value?: string; +} + +interface PendingExpensiveConfirm { + message: string; + model: string; + persistGlobal: boolean; + provider: string; +} + interface Props { /** Chat-mode: when present, picker emits a slash command via onSubmit. */ gw?: GatewayClient; @@ -56,10 +73,14 @@ interface Props { /** Standalone-mode: when present (and onSubmit absent), picker calls onApply. */ loader?(): Promise; onApply?(args: { + confirmExpensiveModel?: boolean; provider: string; model: string; persistGlobal: boolean; - }): Promise | void; + }): + | Promise + | ExpensiveModelConfirmResponse + | void; onClose(): void; title?: string; @@ -90,6 +111,8 @@ export function ModelPickerDialog(props: Props) { const [query, setQuery] = useState(""); const [persistGlobal, setPersistGlobal] = useState(alwaysGlobal); const [applying, setApplying] = useState(false); + const [pendingConfirm, setPendingConfirm] = + useState(null); const closedRef = useRef(false); // Load providers + models on open. @@ -179,16 +202,65 @@ export function ModelPickerDialog(props: Props) { const canConfirm = !!selectedProvider && !!selectedModel && !applying; - const confirm = async () => { - if (!canConfirm || !selectedProvider) return; + const applySelection = async ( + confirmExpensiveModel = false, + forced?: PendingExpensiveConfirm, + ) => { + const providerSlug = forced?.provider ?? selectedProvider?.slug ?? ""; + const model = forced?.model ?? selectedModel; + const shouldPersistGlobal = forced?.persistGlobal ?? persistGlobal; + + if (!providerSlug || !model || applying) return; + if (standalone && onApply) { setApplying(true); try { - await onApply({ - provider: selectedProvider.slug, - model: selectedModel, - persistGlobal, + const result = await onApply({ + confirmExpensiveModel, + provider: providerSlug, + model, + persistGlobal: shouldPersistGlobal, }); + if (result?.confirm_required) { + setPendingConfirm({ + provider: providerSlug, + model, + persistGlobal: shouldPersistGlobal, + message: + result.confirm_message || + result.warning || + "This model has unusually high known pricing.", + }); + return; + } + onClose(); + } catch (e) { + setError(e instanceof Error ? e.message : String(e)); + } finally { + setApplying(false); + } + } else if (gw && sessionId) { + setApplying(true); + try { + const global = shouldPersistGlobal ? " --global" : ""; + const result = await gw.request("config.set", { + confirm_expensive_model: confirmExpensiveModel, + key: "model", + session_id: sessionId, + value: `${model} --provider ${providerSlug}${global}`, + }); + if (result?.confirm_required) { + setPendingConfirm({ + provider: providerSlug, + model, + persistGlobal: shouldPersistGlobal, + message: + result.confirm_message || + result.warning || + "This model has unusually high known pricing.", + }); + return; + } onClose(); } catch (e) { setError(e instanceof Error ? e.message : String(e)); @@ -196,14 +268,17 @@ export function ModelPickerDialog(props: Props) { setApplying(false); } } else if (onSubmit) { - const global = persistGlobal ? " --global" : ""; - onSubmit( - `/model ${selectedModel} --provider ${selectedProvider.slug}${global}`, - ); + const global = shouldPersistGlobal ? " --global" : ""; + onSubmit(`/model ${model} --provider ${providerSlug}${global}`); onClose(); } }; + const confirm = () => { + if (!canConfirm) return; + void applySelection(); + }; + // Portal to document.body: the main dashboard column in App.tsx is // `relative z-2`, which creates a stacking context that traps fixed // descendants below the app sidebar (z-50). Without the portal this @@ -280,8 +355,12 @@ export function ModelPickerDialog(props: Props) { onSelect={setSelectedModel} onConfirm={(m) => { setSelectedModel(m); - // Confirm on next tick so state settles. - window.setTimeout(confirm, 0); + void applySelection(false, { + provider: selectedProvider?.slug ?? "", + model: m, + persistGlobal, + message: "", + }); }} /> @@ -320,6 +399,22 @@ export function ModelPickerDialog(props: Props) { + setPendingConfirm(null)} + onConfirm={() => { + const pending = pendingConfirm; + if (!pending) return; + setPendingConfirm(null); + void applySelection(true, pending); + }} + /> , document.body, ); diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 980faf3d11f..c38a72bc40f 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -1763,9 +1763,12 @@ export interface AuxiliaryModelsResponse { } export interface ModelAssignmentRequest { + confirm_expensive_model?: boolean; scope: "main" | "auxiliary"; provider: string; model: string; + /** Optional OpenAI-compatible endpoint URL for custom/local main providers. */ + base_url?: string; /** For auxiliary: task slot name, "" for all, "__reset__" to reset all. */ task?: string; } @@ -1779,6 +1782,8 @@ export interface StaleAuxAssignment { } export interface ModelAssignmentResponse { + confirm_message?: string; + confirm_required?: boolean; ok: boolean; scope?: string; provider?: string; diff --git a/web/src/pages/ModelsPage.tsx b/web/src/pages/ModelsPage.tsx index 50cd695158f..80eec8bfb3a 100644 --- a/web/src/pages/ModelsPage.tsx +++ b/web/src/pages/ModelsPage.tsx @@ -26,7 +26,7 @@ import { Spinner } from "@nous-research/ui/ui/components/spinner"; import { Stats } from "@nous-research/ui/ui/components/stats"; import { Card, CardContent, CardHeader, CardTitle } from "@nous-research/ui/ui/components/card"; import { Badge } from "@nous-research/ui/ui/components/badge"; -import { ConfirmDialog } from "@nous-research/ui/ui/components/confirm-dialog"; +import { ConfirmDialog } from "@/components/ConfirmDialog"; import { useModalBehavior } from "@/hooks/useModalBehavior"; import { usePageHeader } from "@/contexts/usePageHeader"; import { useI18n } from "@/i18n"; @@ -209,10 +209,16 @@ function UseAsMenu({ const [open, setOpen] = useState(false); const [busy, setBusy] = useState(false); const [error, setError] = useState(null); + const [pendingConfirm, setPendingConfirm] = useState<{ + message: string; + scope: "main" | "auxiliary"; + task: string; + } | null>(null); const assign = async ( scope: "main" | "auxiliary", task: string, + confirmExpensiveModel = false, ) => { if (!provider || !model) { setError("Missing provider/model"); @@ -221,7 +227,23 @@ function UseAsMenu({ setBusy(true); setError(null); try { - await api.setModelAssignment({ scope, provider, model, task }); + const result = await api.setModelAssignment({ + confirm_expensive_model: confirmExpensiveModel, + scope, + provider, + model, + task, + }); + if (result.confirm_required) { + setPendingConfirm({ + scope, + task, + message: + result.confirm_message || + "This model has unusually high known pricing.", + }); + return; + } onAssigned(); setOpen(false); } catch (e) { @@ -310,6 +332,22 @@ function UseAsMenu({ )} )} + setPendingConfirm(null)} + onConfirm={() => { + const pending = pendingConfirm; + if (!pending) return; + setPendingConfirm(null); + void assign(pending.scope, pending.task, true); + }} + /> ); } @@ -619,14 +657,16 @@ function AuxiliaryTasksModal({ AUX_TASKS.find((t) => t.key === picker.task)?.label ?? picker.task }`} - onApply={async ({ provider, model }) => { - await api.setModelAssignment({ + onApply={async ({ provider, model, confirmExpensiveModel }) => { + const result = await api.setModelAssignment({ + confirm_expensive_model: confirmExpensiveModel, scope: "auxiliary", task: picker.task, provider, model, }); - onSaved(); + if (!result.confirm_required) onSaved(); + return result; }} onClose={() => setPicker(null)} /> @@ -666,14 +706,23 @@ function ModelSettingsPanel({ task, provider, model, + confirmExpensiveModel, }: { + confirmExpensiveModel?: boolean; scope: "main" | "auxiliary"; task: string; provider: string; model: string; }) => { - await api.setModelAssignment({ scope, task, provider, model }); - onSaved(); + const result = await api.setModelAssignment({ + confirm_expensive_model: confirmExpensiveModel, + scope, + task, + provider, + model, + }); + if (!result.confirm_required) onSaved(); + return result; }; // Count how many aux tasks have overrides @@ -749,14 +798,15 @@ function ModelSettingsPanel({ loader={api.getModelOptions} alwaysGlobal title="Set Main Model" - onApply={async ({ provider, model }) => { - await applyAssignment({ + onApply={({ provider, model, confirmExpensiveModel }) => + applyAssignment({ + confirmExpensiveModel, scope: "main", task: "", provider, model, - }); - }} + }) + } onClose={() => setPicker(null)} /> )}