fix(model): require confirmation for expensive model selections

Rebased onto current main and re-ported across the restructured
surfaces: model flows now thread confirm_provider/base_url/api_key
through hermes_cli/model_setup_flows.py, the Discord picker lives in
plugins/platforms/discord/adapter.py, and the web dashboard picker
applies chat-mode switches via config.set so the expensive-model
confirmation can ride the response.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
Robin Fernandes 2026-05-15 10:07:45 +10:00 committed by Teknium
parent 4eadef18a9
commit af978ecb17
27 changed files with 1354 additions and 111 deletions

View file

@ -13,6 +13,7 @@ DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
_ZERO = Decimal("0")
_ONE_MILLION = Decimal("1000000")
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
CostStatus = Literal["actual", "estimated", "included", "unknown"]
CostSource = Literal[
@ -570,6 +571,8 @@ def resolve_billing_route(
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
if provider_name == "openrouter" or base_url_host_matches(base_url or "", "openrouter.ai"):
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
if provider_name == "nous" or base_url_host_matches(base_url or "", "inference-api.nousresearch.com"):
return BillingRoute(provider="nous", model=model, base_url=base_url or _NOUS_DEFAULT_BASE_URL, billing_mode="official_models_api")
if provider_name == "anthropic":
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
if provider_name == "openai":

54
cli.py
View file

@ -6516,6 +6516,47 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
}
self._invalidate(min_interval=0.0)
def _confirm_expensive_model_switch(self, result) -> bool:
"""Ask for explicit confirmation before applying costly model switches."""
if not getattr(result, "success", False):
return True
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
result.new_model,
provider=result.target_provider,
base_url=result.base_url or self.base_url or "",
api_key=result.api_key or self.api_key or "",
model_info=result.model_info,
)
except Exception:
warning = None
if warning is None:
return True
choices = [
("once", "Switch anyway", "Use this model for the current Hermes session."),
("cancel", "Cancel", "Keep the current model."),
]
raw = self._prompt_text_input_modal(
title="!!! Expensive Model Warning !!!",
detail=warning.message,
choices=choices,
timeout=120,
)
choice = self._normalize_slash_confirm_choice(raw, choices)
return choice == "once"
def _confirm_and_apply_model_switch_result(self, result, persist_global: bool) -> None:
try:
if result.success and not self._confirm_expensive_model_switch(result):
_cprint(" Model switch cancelled.")
return
self._apply_model_switch_result(result, persist_global)
except Exception as exc:
_cprint(f" ✗ Model selection failed: {exc}")
def _close_model_picker(self) -> None:
self._model_picker_state = None
self._restore_modal_input_snapshot()
@ -6692,7 +6733,14 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
custom_providers=state.get("custom_provs"),
)
self._close_model_picker()
self._apply_model_switch_result(result, persist_global)
if getattr(self, "_app", None):
threading.Thread(
target=self._confirm_and_apply_model_switch_result,
args=(result, persist_global),
daemon=True,
).start()
else:
self._confirm_and_apply_model_switch_result(result, persist_global)
return
self._close_model_picker()
@ -6793,6 +6841,10 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
_cprint(f"{result.error_message}")
return
if not self._confirm_expensive_model_switch(result):
_cprint(" Model switch cancelled.")
return
# Apply to CLI state.
# Update requested_provider so _ensure_runtime_credentials() doesn't
# overwrite the switch on the next turn (it re-resolves from this).

View file

@ -3030,7 +3030,7 @@ class TelegramAdapter(BasePlatformAdapter):
async def _handle_model_picker_callback(
self, query, data: str, chat_id: str
) -> None:
"""Handle model picker inline keyboard callbacks (mp:/mm:/mb:/mx:/mg:)."""
"""Handle model picker inline keyboard callbacks (mp:/mm:/mc:/mb:/mx:/mg:)."""
state = self._model_picker_state.get(chat_id)
if not state:
await query.answer(text="Picker expired — use /model again.")
@ -3115,6 +3115,51 @@ class TelegramAdapter(BasePlatformAdapter):
)
await query.answer()
elif data.startswith("mc:"):
# --- Expensive model confirmed: perform the switch ---
try:
idx = int(data[3:])
except ValueError:
await query.answer(text="Invalid selection.")
return
model_list = state.get("model_list", [])
if idx < 0 or idx >= len(model_list):
await query.answer(text="Invalid model index.")
return
model_id = model_list[idx]
provider_slug = state.get("selected_provider", "")
callback = state.get("on_model_selected")
if not callback:
await query.answer(text="Picker expired.")
return
try:
result_text = await callback(chat_id, model_id, provider_slug)
except Exception as exc:
logger.error("Model picker switch failed: %s", exc)
result_text = f"Error switching model: {exc}"
try:
await query.edit_message_text(
text=self.format_message(result_text),
parse_mode=ParseMode.MARKDOWN_V2,
reply_markup=None,
)
except Exception:
try:
await query.edit_message_text(
text=result_text,
parse_mode=None,
reply_markup=None,
)
except Exception:
pass
await query.answer(text="Model switched!")
self._model_picker_state.pop(chat_id, None)
elif data.startswith("mm:"):
# --- Model selected: perform the switch ---
try:
@ -3136,6 +3181,33 @@ class TelegramAdapter(BasePlatformAdapter):
await query.answer(text="Picker expired.")
return
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
model_id,
provider=provider_slug,
)
except Exception:
warning = None
if warning is not None:
keyboard = InlineKeyboardMarkup([
[InlineKeyboardButton("Switch anyway", callback_data=f"mc:{idx}")],
[
InlineKeyboardButton("◀ Back", callback_data="mb"),
InlineKeyboardButton("✗ Cancel", callback_data="mx"),
],
])
await query.edit_message_text(
text=self.format_message(
f"⚠ *Expensive Model Warning*\n\n{warning.message}"
),
parse_mode=ParseMode.MARKDOWN_V2,
reply_markup=keyboard,
)
await query.answer(text="Confirm expensive model")
return
try:
result_text = await callback(chat_id, model_id, provider_slug)
except Exception as exc:
@ -3260,7 +3332,7 @@ class TelegramAdapter(BasePlatformAdapter):
query_user_name = getattr(query.from_user, "first_name", None)
# --- Model picker callbacks ---
if data.startswith(("mp:", "mpg:", "mm:", "mb", "mx", "mg:")):
if data.startswith(("mp:", "mpg:", "mm:", "mc:", "mb", "mx", "mg:")):
chat_id = str(query.message.chat_id) if query.message else None
if chat_id:
await self._handle_model_picker_callback(query, data, chat_id)

View file

@ -6175,6 +6175,40 @@ def _reset_config_provider() -> Path:
return config_path
def _confirm_expensive_model_selection(
model_id: str,
*,
provider: str = "",
base_url: str = "",
api_key: str = "",
) -> bool:
"""Prompt before saving a model whose known pricing exceeds guardrails."""
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
model_id,
provider=provider,
base_url=base_url,
api_key=api_key,
)
except Exception:
warning = None
if warning is None:
return True
print()
print("=" * 72)
print(warning.message)
print("=" * 72)
try:
response = input("Switch anyway? [y/N]: ").strip().lower()
except (KeyboardInterrupt, EOFError):
print()
return False
return response in {"y", "yes"}
def _prompt_model_selection(
model_ids: List[str],
current_model: str = "",
@ -6182,6 +6216,9 @@ def _prompt_model_selection(
unavailable_models: Optional[List[str]] = None,
portal_url: str = "",
unavailable_message: str = "",
confirm_provider: str = "",
confirm_base_url: str = "",
confirm_api_key: str = "",
) -> Optional[str]:
"""Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None.
@ -6195,6 +6232,18 @@ def _prompt_model_selection(
_unavailable = unavailable_models or []
def _confirmed_selection(mid: str) -> Optional[str]:
if not mid:
return None
if confirm_provider and not _confirm_expensive_model_selection(
mid,
provider=confirm_provider,
base_url=confirm_base_url,
api_key=confirm_api_key,
):
return None
return mid
# Reorder: current model first, then the rest (deduplicated)
ordered = []
if current_model and current_model in model_ids:
@ -6310,13 +6359,13 @@ def _prompt_model_selection(
return None
print()
if idx < len(ordered):
return ordered[idx]
return _confirmed_selection(ordered[idx])
elif idx == len(ordered):
try:
custom = input("Enter model name: ").strip()
except (EOFError, KeyboardInterrupt):
return None
return custom if custom else None
return _confirmed_selection(custom) if custom else None
return None
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
pass
@ -6348,10 +6397,10 @@ def _prompt_model_selection(
return None
idx = int(choice)
if 1 <= idx <= n:
return ordered[idx - 1]
return _confirmed_selection(ordered[idx - 1])
elif idx == n + 1:
custom = input("Enter model name: ").strip()
return custom if custom else None
return _confirmed_selection(custom) if custom else None
elif idx == n + 2:
return None
print(f"Please enter 1-{n + 2}")
@ -7730,6 +7779,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
unavailable_models=unavailable_models,
portal_url=_portal,
unavailable_message=unavailable_message,
confirm_provider="nous",
confirm_base_url=inference_base_url,
confirm_api_key=runtime_key,
)
elif unavailable_models:
_url = (_portal or DEFAULT_NOUS_PORTAL_URL).rstrip("/")

View file

@ -3269,6 +3269,7 @@ def _aux_flow_provider_model(
model_list,
current_model=current_model,
pricing=pricing,
confirm_provider=provider_slug,
)
if selected is None:
print("No change.")

View file

@ -0,0 +1,134 @@
"""Expensive-model confirmation helpers for model selection surfaces."""
from __future__ import annotations
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
from typing import Optional
from agent.models_dev import ModelInfo
INPUT_COST_WARNING_THRESHOLD = Decimal("20")
OUTPUT_COST_WARNING_THRESHOLD = Decimal("100")
GPT55_PRO_OPENROUTER_ID = "openai/gpt-5.5-pro"
GPT55_SUGGESTION = "did you mean to select openai/gpt-5.5?"
@dataclass(frozen=True)
class ExpensiveModelWarning:
"""Confirmation payload for models above Hermes' cost guardrail."""
model: str
provider: str
input_cost_per_million: Optional[Decimal]
output_cost_per_million: Optional[Decimal]
source: str
message: str
def _to_decimal(value: object) -> Optional[Decimal]:
if value is None:
return None
try:
return Decimal(str(value))
except (InvalidOperation, ValueError):
return None
def _format_money(value: Optional[Decimal]) -> str:
if value is None:
return "unknown"
return f"${value:.2f}/M"
def _pricing_from_model_info(
model_info: Optional[ModelInfo],
) -> tuple[Optional[Decimal], Optional[Decimal], str]:
if model_info is None or not model_info.has_cost_data():
return None, None, ""
return (
_to_decimal(model_info.cost_input),
_to_decimal(model_info.cost_output),
"models.dev",
)
def expensive_model_warning(
model_name: str,
*,
provider: Optional[str] = None,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
model_info: Optional[ModelInfo] = None,
) -> Optional[ExpensiveModelWarning]:
"""Return a warning payload when known pricing exceeds safety thresholds.
The guard only triggers when pricing is known. Callers should use this after
model resolution so aliases and provider-specific model IDs have settled.
"""
model = (model_name or "").strip()
if not model:
return None
input_cost, output_cost, source = _pricing_from_model_info(model_info)
if input_cost is None and output_cost is None and provider:
try:
from agent.models_dev import get_model_info
input_cost, output_cost, source = _pricing_from_model_info(
get_model_info(provider, model)
)
except Exception:
pass
if input_cost is None and output_cost is None:
try:
from agent.usage_pricing import get_pricing_entry
entry = get_pricing_entry(
model,
provider=provider,
base_url=base_url,
api_key=api_key,
)
except Exception:
entry = None
if entry is not None:
input_cost = entry.input_cost_per_million
output_cost = entry.output_cost_per_million
source = entry.source
over_input = (
input_cost is not None and input_cost > INPUT_COST_WARNING_THRESHOLD
)
over_output = (
output_cost is not None and output_cost > OUTPUT_COST_WARNING_THRESHOLD
)
if not over_input and not over_output:
return None
lines = [
"!!! EXPENSIVE MODEL WARNING !!!",
"",
f"{model} has known pricing above Hermes' safety threshold.",
f"Input tokens: {_format_money(input_cost)}",
f"Output tokens: {_format_money(output_cost)}",
(
"Threshold: more than $20/M input tokens or more than "
"$100/M output tokens."
),
]
if source:
lines.append(f"Pricing source: {source}.")
if model.lower() == GPT55_PRO_OPENROUTER_ID:
lines.append(GPT55_SUGGESTION)
lines.append("Confirm only if you intend to use this model.")
return ExpensiveModelWarning(
model=model,
provider=(provider or "").strip(),
input_cost_per_million=input_cost,
output_cost_per_million=output_cost,
source=source or "unknown",
message="\n".join(lines),
)

View file

@ -102,7 +102,12 @@ def _model_flow_openrouter(config, current_model=""):
pricing = get_pricing_for_provider("openrouter", force_refresh=True)
selected = _prompt_model_selection(
openrouter_models, current_model=current_model, pricing=pricing
openrouter_models,
current_model=current_model,
pricing=pricing,
confirm_provider="openrouter",
confirm_base_url=OPENROUTER_BASE_URL,
confirm_api_key=_resolved or existing_key,
)
if selected:
_save_model_choice(selected)
@ -311,6 +316,9 @@ def _model_flow_nous(config, current_model="", args=None):
unavailable_models=unavailable_models,
portal_url=_nous_portal_url,
unavailable_message=unavailable_message,
confirm_provider="nous",
confirm_base_url=creds.get("base_url", ""),
confirm_api_key=creds.get("api_key", ""),
)
if selected:
_save_model_choice(selected)
@ -416,7 +424,13 @@ def _model_flow_openai_codex(config, current_model=""):
codex_models = get_codex_model_ids(access_token=_codex_token)
selected = _prompt_model_selection(codex_models, current_model=current_model)
selected = _prompt_model_selection(
codex_models,
current_model=current_model,
confirm_provider="openai-codex",
confirm_base_url=DEFAULT_CODEX_BASE_URL,
confirm_api_key=_codex_token or "",
)
if selected:
_save_model_choice(selected)
_update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL)
@ -546,7 +560,12 @@ def _model_flow_qwen_oauth(_config, current_model=""):
models = list(_DEFAULT_QWEN_PORTAL_MODELS)
default = current_model or (models[0] if models else "qwen3-coder-plus")
selected = _prompt_model_selection(models, current_model=default)
selected = _prompt_model_selection(
models,
current_model=default,
confirm_provider="qwen-oauth",
confirm_base_url=DEFAULT_QWEN_BASE_URL,
)
if selected:
_save_model_choice(selected)
_update_config_for_provider("qwen-oauth", DEFAULT_QWEN_BASE_URL)
@ -595,7 +614,12 @@ def _model_flow_minimax_oauth(config, current_model="", args=None):
from hermes_cli.models import _PROVIDER_MODELS
model_ids = _PROVIDER_MODELS.get("minimax-oauth", [])
selected = _prompt_model_selection(model_ids, current_model)
selected = _prompt_model_selection(
model_ids,
current_model,
confirm_provider="minimax-oauth",
confirm_base_url=creds["base_url"],
)
if not selected:
return
_save_model_choice(selected)
@ -664,7 +688,12 @@ def _model_flow_google_gemini_cli(_config, current_model=""):
models = list(_PROVIDER_MODELS.get("google-gemini-cli") or [])
default = current_model or (models[0] if models else "gemini-3-flash-preview")
selected = _prompt_model_selection(models, current_model=default)
selected = _prompt_model_selection(
models,
current_model=default,
confirm_provider="google-gemini-cli",
confirm_base_url=DEFAULT_GEMINI_CLOUDCODE_BASE_URL,
)
if selected:
_save_model_choice(selected)
_update_config_for_provider(
@ -1589,7 +1618,11 @@ def _model_flow_copilot(config, current_model=""):
if model_list:
selected = _prompt_model_selection(
model_list, current_model=normalized_current_model
model_list,
current_model=normalized_current_model,
confirm_provider=provider_id,
confirm_base_url=effective_base,
confirm_api_key=api_key,
)
else:
try:
@ -1727,6 +1760,9 @@ def _model_flow_copilot_acp(config, current_model=""):
selected = _prompt_model_selection(
model_list,
current_model=normalized_current_model,
confirm_provider=provider_id,
confirm_base_url=effective_base,
confirm_api_key=catalog_api_key,
)
else:
try:
@ -1831,7 +1867,13 @@ def _model_flow_kimi(config, current_model=""):
model_list = _PROVIDER_MODELS.get("moonshot", [])
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider=provider_id,
confirm_base_url=effective_base,
confirm_api_key=existing_key,
)
else:
try:
selected = input("Enter model name: ").strip()
@ -1939,7 +1981,13 @@ def _model_flow_stepfun(config, current_model=""):
)
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider=provider_id,
confirm_base_url=effective_base,
confirm_api_key=existing_key,
)
else:
try:
selected = input("Model name: ").strip()
@ -2015,7 +2063,13 @@ def _model_flow_bedrock_api_key(config, region, current_model=""):
print(f" Showing {len(model_list)} curated models")
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider="custom",
confirm_base_url=mantle_base_url,
confirm_api_key=existing_key,
)
else:
try:
selected = input(" Model ID: ").strip()
@ -2204,7 +2258,12 @@ def _model_flow_bedrock(config, current_model=""):
# 4. Model selection
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider="bedrock",
confirm_base_url=f"https://bedrock-runtime.{region}.amazonaws.com",
)
else:
try:
selected = input(" Model ID: ").strip()
@ -2488,7 +2547,13 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
model_list = list(dict.fromkeys(mid for mid in model_list if mid))
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider=provider_id,
confirm_base_url=effective_base,
confirm_api_key=existing_key,
)
else:
try:
selected = input("Model name: ").strip()
@ -2638,7 +2703,11 @@ def _model_flow_anthropic(config, current_model=""):
# Model selection
model_list = _PROVIDER_MODELS.get("anthropic", [])
if model_list:
selected = _prompt_model_selection(model_list, current_model=current_model)
selected = _prompt_model_selection(
model_list,
current_model=current_model,
confirm_provider="anthropic",
)
else:
try:
selected = input("Model name (e.g., claude-sonnet-4-20250514): ").strip()

View file

@ -657,21 +657,6 @@ class AudioTranscriptionRequest(BaseModel):
mime_type: Optional[str] = None
class ModelAssignment(BaseModel):
"""Payload for POST /api/model/set — assign a provider/model to a slot.
scope="main" writes model.provider + model.default
scope="auxiliary" writes auxiliary.<task>.provider + auxiliary.<task>.model
scope="auxiliary" with task="" applied to every auxiliary.* slot
scope="auxiliary" with task="__reset__" resets every slot to provider="auto"
"""
scope: str
provider: str
model: str
task: str = ""
_AUDIO_MIME_EXTENSIONS: Dict[str, str] = {
"audio/aac": ".aac",
"audio/flac": ".flac",
@ -713,6 +698,7 @@ class ModelAssignment(BaseModel):
# reads model.base_url from config (it ignores OPENAI_BASE_URL), so this is
# the path that actually wires a local endpoint into resolution.
base_url: str = ""
confirm_expensive_model: bool = False
def _apply_main_model_assignment(
@ -2470,6 +2456,27 @@ async def set_model_assignment(body: ModelAssignment):
try:
cfg = load_config()
if model and not body.confirm_expensive_model:
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
model,
provider=provider,
base_url=base_url,
)
except Exception:
warning = None
if warning is not None:
return {
"ok": False,
"scope": scope,
"provider": provider,
"model": model,
"confirm_required": True,
"confirm_message": warning.message,
}
if scope == "main":
if not provider or not model:
raise HTTPException(status_code=400, detail="provider and model required for main")

View file

@ -5638,6 +5638,7 @@ def _define_discord_view_classes() -> None:
self.allowed_role_ids = allowed_role_ids or set()
self.resolved = False
self._selected_provider: str = ""
self._pending_expensive_model: str = ""
self._build_provider_select()
@ -5720,6 +5721,38 @@ def _define_discord_view_classes() -> None:
cancel_btn.callback = self._on_cancel
self.add_item(cancel_btn)
def _build_expensive_confirm(self, model_id: str):
"""Build confirmation buttons for unusually expensive models."""
self.clear_items()
self._pending_expensive_model = model_id
confirm_btn = discord.ui.Button(
label="Switch anyway",
style=discord.ButtonStyle.red,
custom_id="model_expensive_confirm",
)
confirm_btn.callback = self._on_expensive_confirm
self.add_item(confirm_btn)
cancel_btn = discord.ui.Button(
label="Cancel",
style=discord.ButtonStyle.grey,
custom_id="model_expensive_cancel",
)
cancel_btn.callback = self._on_cancel
self.add_item(cancel_btn)
def _expensive_warning_for(self, model_id: str):
try:
from hermes_cli.model_cost_guard import expensive_model_warning
return expensive_model_warning(
model_id,
provider=self._selected_provider,
)
except Exception:
return None
async def _on_provider_selected(self, interaction: discord.Interaction):
if not self._check_auth(interaction):
await interaction.response.send_message(
@ -5749,7 +5782,11 @@ def _define_discord_view_classes() -> None:
view=self,
)
async def _on_model_selected(self, interaction: discord.Interaction):
async def _switch_selected_model(
self,
interaction: discord.Interaction,
model_id: str,
):
if self.resolved:
await interaction.response.send_message(
"Already resolved~", ephemeral=True
@ -5762,7 +5799,6 @@ def _define_discord_view_classes() -> None:
return
self.resolved = True
model_id = interaction.data["values"][0]
self.clear_items()
await interaction.response.edit_message(
embed=discord.Embed(
@ -5791,6 +5827,50 @@ def _define_discord_view_classes() -> None:
view=None,
)
async def _on_model_selected(self, interaction: discord.Interaction):
if self.resolved:
await interaction.response.send_message(
"Already resolved~", ephemeral=True
)
return
if not self._check_auth(interaction):
await interaction.response.send_message(
"You're not authorized~", ephemeral=True
)
return
model_id = interaction.data["values"][0]
warning = self._expensive_warning_for(model_id)
if warning is not None:
self._build_expensive_confirm(model_id)
await interaction.response.edit_message(
embed=discord.Embed(
title="⚠ Expensive Model Warning",
description=warning.message,
color=discord.Color.red(),
),
view=self,
)
return
await self._switch_selected_model(interaction, model_id)
async def _on_expensive_confirm(self, interaction: discord.Interaction):
if not self._check_auth(interaction):
await interaction.response.send_message(
"You're not authorized~", ephemeral=True
)
return
if not self._pending_expensive_model:
await interaction.response.send_message(
"Model selection expired.", ephemeral=True
)
return
await self._switch_selected_model(
interaction,
self._pending_expensive_model,
)
async def _on_back(self, interaction: discord.Interaction):
if not self._check_auth(interaction):
await interaction.response.send_message(

View file

@ -192,6 +192,32 @@ def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch):
assert float(entry.output_cost_per_million) == 2.0
def test_nous_portal_pricing_preserves_vendor_prefixed_model_ids(monkeypatch):
seen = {}
def _fake_fetch_endpoint_model_metadata(base_url, api_key=None):
seen["base_url"] = base_url
return {
"openai/gpt-5.5-pro": {
"pricing": {
"prompt": "0.000025",
"completion": "0.000125",
}
}
}
monkeypatch.setattr(
"agent.usage_pricing.fetch_endpoint_model_metadata",
_fake_fetch_endpoint_model_metadata,
)
entry = get_pricing_entry("openai/gpt-5.5-pro", provider="nous")
assert seen["base_url"] == "https://inference-api.nousresearch.com/v1"
assert float(entry.input_cost_per_million) == 25.0
assert float(entry.output_cost_per_million) == 125.0
def test_deepseek_v4_pro_pricing_entry_exists():
"""Regression test: deepseek-v4-pro must have a pricing entry.

View file

@ -80,3 +80,91 @@ async def test_model_picker_clears_controls_before_running_switch_callback():
interaction.response.edit_message.assert_awaited_once()
interaction.response.defer.assert_not_called()
interaction.edit_original_response.assert_awaited_once()
@pytest.mark.asyncio
async def test_expensive_model_requires_confirmation(monkeypatch):
events: list[object] = []
async def on_model_selected(chat_id: str, model_id: str, provider_slug: str) -> str:
events.append(("switch", chat_id, model_id, provider_slug))
return "Model switched"
async def edit_message(**kwargs):
events.append(
(
"edit",
kwargs["embed"].title,
kwargs["embed"].description,
kwargs["view"],
)
)
async def edit_original_response(**kwargs):
events.append((
"final-edit",
kwargs["embed"].title,
kwargs["embed"].description,
kwargs["view"],
))
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
lambda *_args, **_kwargs: SimpleNamespace(
message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?"
),
)
view = ModelPickerView(
providers=[
{
"slug": "openrouter",
"name": "OpenRouter",
"models": ["openai/gpt-5.5-pro"],
"total_models": 1,
"is_current": True,
}
],
current_model="openai/gpt-5.5",
current_provider="openrouter",
session_key="session-1",
on_model_selected=on_model_selected,
allowed_user_ids={"123"}, # matches the interaction user; empty = fail-closed
)
view._selected_provider = "openrouter"
interaction = SimpleNamespace(
user=SimpleNamespace(id=123),
channel_id=456,
data={"values": ["openai/gpt-5.5-pro"]},
response=SimpleNamespace(
send_message=AsyncMock(),
edit_message=AsyncMock(side_effect=edit_message),
),
edit_original_response=AsyncMock(side_effect=edit_original_response),
)
await view._on_model_selected(interaction)
assert events == [
(
"edit",
"⚠ Expensive Model Warning",
"!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?",
view,
),
]
assert view.resolved is False
await view._on_expensive_confirm(interaction)
assert events[1:] == [
(
"edit",
"⚙ Switching Model",
"Switching to `openai/gpt-5.5-pro`...",
None,
),
("switch", "456", "openai/gpt-5.5-pro", "openrouter"),
("final-edit", "⚙ Model Switched", "Model switched", None),
]

View file

@ -91,10 +91,6 @@ class TestTelegramModelPicker:
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
update = MagicMock()
update.callback_query = query
context = MagicMock()
await adapter._handle_model_picker_callback(query, "mb", "12345")
edit_kwargs = query.edit_message_text.call_args[1]
@ -133,17 +129,11 @@ class TestTelegramModelPicker:
await adapter._handle_model_picker_callback(query, "mm:0", "12345")
# The callback was invoked with the selected model
callback.assert_awaited_once()
# edit_message_text MUST be called on the success path (this is the
# regression we're guarding).
query.edit_message_text.assert_awaited()
edit_kwargs = query.edit_message_text.call_args[1]
assert "MARKDOWN_V2" in repr(edit_kwargs["parse_mode"])
# The dynamic result text was routed through format_message
# (backtick code blocks survive escaping).
assert "`gpt-5`" in edit_kwargs["text"]
# State is cleaned up after a successful switch.
assert "12345" not in adapter._model_picker_state
@pytest.mark.asyncio
@ -184,7 +174,7 @@ class TestTelegramModelPicker:
providers = [
{"slug": "minimax", "name": "MiniMax", "total_models": 2},
{"slug": "minimax-cn", "name": "MiniMax (China)", "total_models": 3},
{"slug": "xai", "name": "xAI", "total_models": 1}, # lone group member
{"slug": "xai", "name": "xAI", "total_models": 1},
]
await adapter.send_model_picker(
@ -197,14 +187,11 @@ class TestTelegramModelPicker:
metadata=None,
)
# Top-level keyboard: MiniMax family folded into one group button;
# xai (lone member) degraded to a direct provider button.
assert "mpg:minimax" in built
assert "mp:xai" in built
assert "mp:minimax" not in built
assert "mp:minimax-cn" not in built
# Drill into the MiniMax group → members appear as mp: buttons + back.
built.clear()
query = AsyncMock()
query.message = MagicMock()
@ -216,7 +203,49 @@ class TestTelegramModelPicker:
assert "mp:minimax" in built
assert "mp:minimax-cn" in built
assert "mb" in built # back-to-providers button present
assert "mb" in built
@pytest.mark.asyncio
async def test_expensive_model_requires_confirmation(self, monkeypatch):
adapter = _make_adapter()
callback = AsyncMock(return_value="Switched to `openai/gpt-5.5-pro`")
adapter._model_picker_state["12345"] = {
"providers": [
{"slug": "openrouter", "name": "OpenRouter", "total_models": 1, "is_current": True}
],
"current_model": "model_1",
"current_provider": "openrouter",
"session_key": "s",
"on_model_selected": callback,
"selected_provider": "openrouter",
"model_list": ["openai/gpt-5.5-pro"],
"msg_id": 42,
}
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
lambda *_args, **_kwargs: SimpleNamespace(
message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?"
),
)
query = AsyncMock()
query.message = MagicMock()
query.message.chat_id = 12345
query.answer = AsyncMock()
query.edit_message_text = AsyncMock()
await adapter._handle_model_picker_callback(query, "mm:0", "12345")
callback.assert_not_awaited()
assert "12345" in adapter._model_picker_state
first_edit = query.edit_message_text.call_args[1]
assert "EXPENSIVE MODEL WARNING" in first_edit["text"]
assert first_edit["reply_markup"] is not None
await adapter._handle_model_picker_callback(query, "mc:0", "12345")
callback.assert_awaited_once_with("12345", "openai/gpt-5.5-pro", "openrouter")
assert "12345" not in adapter._model_picker_state
@pytest.mark.asyncio
async def test_retries_without_thread_when_thread_not_found(self):

View file

@ -133,7 +133,7 @@ def test_model_command_uses_runtime_access_token_for_codex_list(monkeypatch):
captured["access_token"] = access_token
return ["gpt-5.2-codex", "gpt-5.2"]
def _fake_prompt_model_selection(model_ids, current_model=""):
def _fake_prompt_model_selection(model_ids, current_model="", **_kwargs):
captured["model_ids"] = list(model_ids)
captured["current_model"] = current_model
return None
@ -181,7 +181,7 @@ def test_model_command_prompts_to_reuse_or_reauthenticate_codex_session(monkeypa
)
monkeypatch.setattr(
"hermes_cli.auth._prompt_model_selection",
lambda model_ids, current_model="": None,
lambda model_ids, current_model="", **_kwargs: None,
)
_model_flow_openai_codex({}, current_model="gpt-5.4")
@ -219,7 +219,7 @@ def test_model_command_uses_existing_codex_session_without_relogin(monkeypatch):
)
monkeypatch.setattr(
"hermes_cli.auth._prompt_model_selection",
lambda model_ids, current_model="": None,
lambda model_ids, current_model="", **_kwargs: None,
)
monkeypatch.setattr(
"hermes_cli.auth._login_openai_codex",

View file

@ -0,0 +1,97 @@
from decimal import Decimal
from agent.models_dev import ModelInfo
from agent.usage_pricing import PricingEntry
from hermes_cli.model_cost_guard import expensive_model_warning
def test_no_warning_when_known_prices_are_at_threshold():
info = ModelInfo(
id="edge/model",
name="edge/model",
family="",
provider_id="test",
cost_input=20.0,
cost_output=100.0,
)
assert expensive_model_warning("edge/model", provider="test", model_info=info) is None
def test_warns_when_models_dev_input_price_exceeds_threshold():
info = ModelInfo(
id="expensive/input",
name="expensive/input",
family="",
provider_id="test",
cost_input=20.01,
cost_output=1.0,
)
warning = expensive_model_warning(
"expensive/input",
provider="test",
model_info=info,
)
assert warning is not None
assert warning.input_cost_per_million == Decimal("20.01")
assert "EXPENSIVE MODEL WARNING" in warning.message
assert "$20/M input" in warning.message
def test_warns_when_pricing_entry_output_price_exceeds_threshold(monkeypatch):
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
"agent.usage_pricing.get_pricing_entry",
lambda *_args, **_kwargs: PricingEntry(
input_cost_per_million=Decimal("1.00"),
output_cost_per_million=Decimal("100.01"),
source="provider_models_api",
),
)
warning = expensive_model_warning("provider/expensive-output", provider="openrouter")
assert warning is not None
assert warning.output_cost_per_million == Decimal("100.01")
assert "$100.01/M" in warning.message
def test_openai_gpt55_pro_adds_suggestion(monkeypatch):
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
"agent.usage_pricing.get_pricing_entry",
lambda *_args, **_kwargs: PricingEntry(
input_cost_per_million=Decimal("25"),
output_cost_per_million=Decimal("125"),
source="provider_models_api",
),
)
warning = expensive_model_warning("openai/gpt-5.5-pro", provider="openrouter")
assert warning is not None
assert "did you mean to select openai/gpt-5.5?" in warning.message
def test_openai_gpt55_pro_warns_for_nous_portal_pricing(monkeypatch):
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(
"agent.usage_pricing.fetch_endpoint_model_metadata",
lambda base_url, api_key="": {
"openai/gpt-5.5-pro": {
"pricing": {
"prompt": "0.000025",
"completion": "0.000125",
}
}
},
)
warning = expensive_model_warning("openai/gpt-5.5-pro", provider="nous")
assert warning is not None
assert warning.input_cost_per_million == Decimal("25.000000")
assert warning.output_cost_per_million == Decimal("125.000000")
assert "did you mean to select openai/gpt-5.5?" in warning.message

View file

@ -0,0 +1,64 @@
from types import SimpleNamespace
from hermes_cli.model_switch import ModelSwitchResult
def _bound(fn, instance):
return fn.__get__(instance, type(instance))
def test_prompt_toolkit_model_picker_defers_confirmation_off_key_handler(monkeypatch):
import cli as cli_mod
result = ModelSwitchResult(
success=True,
new_model="openai/gpt-5.5-pro",
target_provider="nous",
)
monkeypatch.setattr(
"hermes_cli.model_switch.switch_model",
lambda **_kwargs: result,
)
captured = {}
class _Thread:
def __init__(self, *, target, args, daemon):
captured["target"] = target
captured["args"] = args
captured["daemon"] = daemon
def start(self):
captured["started"] = True
monkeypatch.setattr(cli_mod.threading, "Thread", _Thread)
self_ = SimpleNamespace(
_app=object(),
_model_picker_state={
"stage": "model",
"provider_data": {"slug": "nous"},
"model_list": ["openai/gpt-5.5-pro"],
"selected": 0,
"user_provs": None,
"custom_provs": None,
},
provider="nous",
model="openai/gpt-5.5",
base_url="",
api_key="",
_restore_modal_input_snapshot=lambda: None,
_invalidate=lambda **_kwargs: None,
)
self_._close_model_picker = _bound(cli_mod.HermesCLI._close_model_picker, self_)
self_._confirm_and_apply_model_switch_result = (
lambda *_args: captured.setdefault("ran_inline", True)
)
_bound(cli_mod.HermesCLI._handle_model_picker_selection, self_)()
assert self_._model_picker_state is None
assert captured["started"] is True
assert captured["daemon"] is True
assert captured["args"] == (result, False)
assert "ran_inline" not in captured

View file

@ -2,6 +2,7 @@
cannot initialize (e.g. non-TTY, curses unavailable, terminal error)."""
import subprocess
from types import SimpleNamespace
from hermes_cli.config import load_config, save_config
@ -24,6 +25,46 @@ def test_prompt_model_selection_falls_back_on_menu_runtime_error(monkeypatch):
assert selected == "model-b"
def test_prompt_model_selection_requires_expensive_confirmation(monkeypatch, capsys):
from hermes_cli.auth import _prompt_model_selection
monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu)
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
)
responses = iter(["1", "n"])
monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses))
selected = _prompt_model_selection(
["openai/gpt-5.5-pro"],
confirm_provider="nous",
)
out = capsys.readouterr().out
assert selected is None
assert "EXPENSIVE MODEL WARNING" in out
def test_prompt_model_selection_allows_confirmed_expensive_model(monkeypatch):
from hermes_cli.auth import _prompt_model_selection
monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu)
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
)
responses = iter(["1", "y"])
monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses))
selected = _prompt_model_selection(
["openai/gpt-5.5-pro"],
confirm_provider="nous",
)
assert selected == "openai/gpt-5.5-pro"
def test_prompt_reasoning_effort_falls_back_on_menu_runtime_error(monkeypatch):
from hermes_cli.main import _prompt_reasoning_effort_selection

View file

@ -4,6 +4,7 @@ import os
import json
import shutil
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch, MagicMock
import pytest
@ -1069,6 +1070,41 @@ class TestWebServerEndpoints:
assert "GATEWAY_PROXY_URL" not in managed
assert "GATEWAY_PROXY_URL" in _MESSAGING_KEYS_PAGE_KEYS
def test_model_set_requires_confirmation_for_expensive_model(self, monkeypatch):
monkeypatch.setattr(
"hermes_cli.model_cost_guard.expensive_model_warning",
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
)
resp = self.client.post(
"/api/model/set",
json={
"scope": "main",
"provider": "nous",
"model": "openai/gpt-5.5-pro",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is False
assert data["confirm_required"] is True
assert data["confirm_message"] == "EXPENSIVE MODEL WARNING"
confirmed = self.client.post(
"/api/model/set",
json={
"scope": "main",
"provider": "nous",
"model": "openai/gpt-5.5-pro",
"confirm_expensive_model": True,
},
)
assert confirmed.status_code == 200
assert confirmed.json()["ok"] is True
def test_reveal_env_var(self, tmp_path):
"""POST /api/env/reveal should return the real unredacted value."""
from hermes_cli.config import save_env_value

View file

@ -2397,7 +2397,7 @@ def test_config_set_model_waits_for_lazy_agent_before_switch(monkeypatch):
target["agent"] = agent
agent_ready.set()
def fake_apply(sid, target, raw):
def fake_apply(sid, target, raw, **kwargs):
calls.append(("apply", sid, target.get("agent"), raw))
if target.get("agent") is not agent:
raise AssertionError("model switch ran before lazy agent was ready")
@ -2424,7 +2424,7 @@ def test_config_set_model_uses_live_switch_path(monkeypatch):
server._sessions["sid"] = _session()
seen = {}
def _fake_apply(sid, session, raw):
def _fake_apply(sid, session, raw, **_kwargs):
seen["args"] = (sid, session["session_key"], raw)
return {"value": "new/model", "warning": "catalog unreachable"}
@ -2442,6 +2442,74 @@ def test_config_set_model_uses_live_switch_path(monkeypatch):
assert seen["args"] == ("sid", "session-key", "new/model")
def test_config_set_model_requires_confirmation_for_expensive_model(monkeypatch):
class _Agent:
provider = "openrouter"
model = "old/model"
base_url = ""
api_key = "sk-or"
switched = False
def switch_model(self, **_kwargs):
self.switched = True
result = types.SimpleNamespace(
success=True,
new_model="openai/gpt-5.5-pro",
target_provider="openrouter",
api_key="sk-or",
base_url="https://openrouter.ai/api/v1",
api_mode="chat_completions",
warning_message="",
model_info=types.SimpleNamespace(
has_cost_data=lambda: True,
cost_input=25.0,
cost_output=125.0,
),
)
agent = _Agent()
server._sessions["sid"] = _session(agent=agent)
monkeypatch.setattr(
"hermes_cli.model_switch.switch_model", lambda **_kwargs: result
)
monkeypatch.setattr(server, "_restart_slash_worker", lambda sid, session: None)
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
resp = server.handle_request(
{
"id": "1",
"method": "config.set",
"params": {
"session_id": "sid",
"key": "model",
"value": "openai/gpt-5.5-pro --provider openrouter",
},
}
)
assert resp["result"]["confirm_required"] is True
assert "did you mean to select openai/gpt-5.5?" in resp["result"]["confirm_message"]
assert agent.switched is False
confirmed = server.handle_request(
{
"id": "2",
"method": "config.set",
"params": {
"session_id": "sid",
"key": "model",
"value": "openai/gpt-5.5-pro --provider openrouter",
"confirm_expensive_model": True,
},
}
)
assert confirmed["result"]["confirm_required"] is False
assert confirmed["result"]["value"] == "openai/gpt-5.5-pro"
assert agent.switched is True
def test_config_set_model_global_persists(monkeypatch):
class _Agent:
provider = "openrouter"
@ -3944,7 +4012,7 @@ def test_config_set_model_rejects_while_running(monkeypatch):
"""/model via config.set must reject during an in-flight turn."""
seen = {"called": False}
def _fake_apply(sid, session, raw):
def _fake_apply(sid, session, raw, **_kwargs):
seen["called"] = True
return {"value": raw, "warning": ""}
@ -3978,7 +4046,7 @@ def test_config_set_model_allowed_when_idle(monkeypatch):
"""Regression guard: idle sessions can still switch models."""
seen = {"called": False}
def _fake_apply(sid, session, raw):
def _fake_apply(sid, session, raw, **_kwargs):
seen["called"] = True
return {"value": "newmodel", "warning": ""}

View file

@ -1696,7 +1696,13 @@ def _persist_model_switch(result) -> None:
save_config(cfg)
def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
def _apply_model_switch(
sid: str,
session: dict,
raw_input: str,
*,
confirm_expensive_model: bool = False,
) -> dict:
from hermes_cli.model_switch import parse_model_flags, switch_model
from hermes_cli.runtime_provider import resolve_runtime_provider
@ -1753,6 +1759,27 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
if not result.success:
raise ValueError(result.error_message or "model switch failed")
if not confirm_expensive_model:
try:
from hermes_cli.model_cost_guard import expensive_model_warning
warning = expensive_model_warning(
result.new_model,
provider=result.target_provider,
base_url=result.base_url or current_base_url,
api_key=result.api_key or current_api_key,
model_info=result.model_info,
)
except Exception:
warning = None
if warning is not None:
return {
"value": result.new_model,
"warning": warning.message,
"confirm_required": True,
"confirm_message": warning.message,
}
if agent:
agent.switch_model(
new_model=result.new_model,
@ -1787,7 +1814,11 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
}
if persist_global:
_persist_model_switch(result)
return {"value": result.new_model, "warning": result.warning_message or ""}
return {
"value": result.new_model,
"warning": result.warning_message or "",
"confirm_required": False,
}
def _compress_session_history(
@ -6196,13 +6227,31 @@ def _(rid, params: dict) -> dict:
if session.get("agent") is None:
return _err(rid, 5032, "agent initialization failed")
result = _apply_model_switch(
params.get("session_id", ""), session, value
params.get("session_id", ""),
session,
value,
confirm_expensive_model=bool(
params.get("confirm_expensive_model", False)
),
)
else:
result = _apply_model_switch("", {"agent": None}, value)
result = _apply_model_switch(
"",
{"agent": None},
value,
confirm_expensive_model=bool(
params.get("confirm_expensive_model", False)
),
)
return _ok(
rid,
{"key": key, "value": result["value"], "warning": result["warning"]},
{
"key": key,
"value": result["value"],
"warning": result["warning"],
"confirm_required": result.get("confirm_required", False),
"confirm_message": result.get("confirm_message", ""),
},
)
except Exception as e:
return _err(rid, 5001, str(e))

View file

@ -108,6 +108,7 @@ describe('createSlashHandler', () => {
expect(createSlashHandler(ctx)('/model x-model')).toBe(true)
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
confirm_expensive_model: false,
key: 'model',
session_id: 'sid-abc',
value: 'x-model'
@ -128,6 +129,7 @@ describe('createSlashHandler', () => {
createSlashHandler(ctx)(`/model anthropic/claude-sonnet-4.6 --provider openrouter ${TUI_SESSION_MODEL_FLAG}`)
).toBe(true)
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
confirm_expensive_model: false,
key: 'model',
session_id: 'sid-abc',
value: 'anthropic/claude-sonnet-4.6 --provider openrouter'
@ -140,6 +142,7 @@ describe('createSlashHandler', () => {
createSlashHandler(ctx)('/model x-model --global')
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
confirm_expensive_model: false,
key: 'model',
session_id: 'sid-abc',
value: 'x-model --global'

View file

@ -72,10 +72,25 @@ export const sessionCommands: SlashCommand[] = [
return patchOverlayState({ modelPicker: true })
}
ctx.gateway
.rpc<ConfigSetResponse>('config.set', { key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) })
const switchModel = (confirmExpensiveModel = false) => ctx.gateway
.rpc<ConfigSetResponse>('config.set', { confirm_expensive_model: confirmExpensiveModel, key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) })
.then(
ctx.guarded<ConfigSetResponse>(r => {
if (r.confirm_required) {
patchOverlayState({
confirm: {
cancelLabel: 'Cancel',
confirmLabel: 'Switch anyway',
danger: true,
detail: r.confirm_message || r.warning || 'This model has unusually high known pricing.',
onConfirm: () => switchModel(true),
title: 'Expensive model selection'
}
})
return
}
if (!r.value) {
return ctx.transcript.sys('error: invalid response: model switch')
}
@ -89,6 +104,8 @@ export const sessionCommands: SlashCommand[] = [
}))
})
)
switchModel()
}
},

View file

@ -103,6 +103,8 @@ export interface ConfigGetValueResponse {
}
export interface ConfigSetResponse {
confirm_message?: string
confirm_required?: boolean
credential_warning?: string
history_reset?: boolean
info?: SessionInfo

View file

@ -292,24 +292,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) {
setVersion((v) => v + 1);
}, []);
// Picker hands us a fully-formed slash command (e.g. "/model anthropic/...").
// Fire-and-forget through `slash.exec`; the TUI pane will render the result
// via PTY, so the sidebar doesn't need to surface output of its own.
const onModelSubmit = useCallback(
(slashCommand: string) => {
if (!sessionId) {
return;
}
void gw.request("slash.exec", {
session_id: sessionId,
command: slashCommand,
});
setModelOpen(false);
},
[gw, sessionId],
);
const canPickModel = state === "open" && !!sessionId;
const modelLabel = (info.model ?? "—").split("/").slice(-1)[0] ?? "—";
const banner = error ?? info.credential_warning ?? null;
@ -390,7 +372,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) {
gw={gw}
sessionId={sessionId}
onClose={() => setModelOpen(false)}
onSubmit={onModelSubmit}
/>
)}
</aside>

View file

@ -0,0 +1,122 @@
import { Button } from "@nous-research/ui/ui/components/button";
import { AlertTriangle } from "lucide-react";
import { useEffect, useRef } from "react";
import { createPortal } from "react-dom";
import { cn, themedBody } from "@/lib/utils";
interface ConfirmDialogProps {
cancelLabel?: string;
confirmLabel?: string;
description?: string;
destructive?: boolean;
loading?: boolean;
onCancel: () => void;
onConfirm: () => void;
open: boolean;
title: string;
}
export function ConfirmDialog({
cancelLabel = "Cancel",
confirmLabel = "Confirm",
description,
destructive = false,
loading = false,
onCancel,
onConfirm,
open,
title,
}: ConfirmDialogProps) {
const dialogRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (!open) return;
const prevActive = document.activeElement as HTMLElement | null;
dialogRef.current
?.querySelector<HTMLButtonElement>("[data-confirm]")
?.focus();
const onKey = (e: KeyboardEvent) => {
if (e.key === "Escape") {
e.preventDefault();
onCancel();
}
};
document.addEventListener("keydown", onKey);
const prevOverflow = document.body.style.overflow;
document.body.style.overflow = "hidden";
return () => {
document.removeEventListener("keydown", onKey);
document.body.style.overflow = prevOverflow;
prevActive?.focus?.();
};
}, [open, onCancel]);
if (!open) return null;
return createPortal(
<div
role="dialog"
aria-modal="true"
aria-labelledby="confirm-dialog-title"
aria-describedby={description ? "confirm-dialog-desc" : undefined}
onClick={(e) => {
if (e.target === e.currentTarget) onCancel();
}}
className="fixed inset-0 z-[200] flex items-center justify-center bg-background/85 backdrop-blur-sm p-4"
>
<div
ref={dialogRef}
className={cn(
themedBody,
"relative w-full max-w-md border border-border bg-card shadow-2xl",
)}
>
<div className="flex items-start gap-3 p-4 border-b border-border">
{destructive && (
<div aria-hidden className="mt-0.5 shrink-0 text-destructive">
<AlertTriangle className="h-4 w-4" />
</div>
)}
<div className="flex-1 min-w-0 flex flex-col gap-1">
<h2
id="confirm-dialog-title"
className="font-mondwest text-display text-base tracking-wider"
>
{title}
</h2>
{description && (
<p
id="confirm-dialog-desc"
className="text-xs text-muted-foreground leading-relaxed whitespace-pre-line"
>
{description}
</p>
)}
</div>
</div>
<div className="flex items-center justify-end gap-2 p-3">
<Button type="button" outlined onClick={onCancel} disabled={loading}>
{cancelLabel}
</Button>
<Button
data-confirm
type="button"
destructive={destructive}
onClick={onConfirm}
disabled={loading}
>
{loading ? "…" : confirmLabel}
</Button>
</div>
</div>
</div>,
document.body,
);
}

View file

@ -4,6 +4,7 @@ import { ListItem } from "@nous-research/ui/ui/components/list-item";
import { Spinner } from "@nous-research/ui/ui/components/spinner";
import { Input } from "@nous-research/ui/ui/components/input";
import { Label } from "@nous-research/ui/ui/components/label";
import { ConfirmDialog } from "@/components/ConfirmDialog";
import type { GatewayClient } from "@/lib/gatewayClient";
import { Check, Search, X } from "lucide-react";
import { useEffect, useMemo, useRef, useState } from "react";
@ -21,9 +22,8 @@ import { fuzzyRank } from "@/lib/fuzzy";
* Two invocation modes:
*
* 1. Chat-session mode (ChatSidebar) pass `gw` + `sessionId`. The picker
* loads options via `model.options` JSON-RPC and emits the result as a
* slash command string (`/model <model> --provider <slug> [--global]`)
* through `onSubmit`, which the ChatPage pipes to `slashExec`.
* loads options via `model.options` JSON-RPC and applies the choice via
* `config.set`, so expensive-model confirmation can happen before switch.
*
* 2. Standalone mode (ModelsPage, Config settings) pass a `loader` and
* `onApply`. The picker fetches options via the REST endpoint and calls
@ -47,6 +47,23 @@ interface ModelOptionsResponse {
providers?: ModelOptionProvider[];
}
interface ExpensiveModelConfirmResponse {
confirm_message?: string;
confirm_required?: boolean;
warning?: string;
}
interface ConfigSetResponse extends ExpensiveModelConfirmResponse {
value?: string;
}
interface PendingExpensiveConfirm {
message: string;
model: string;
persistGlobal: boolean;
provider: string;
}
interface Props {
/** Chat-mode: when present, picker emits a slash command via onSubmit. */
gw?: GatewayClient;
@ -56,10 +73,14 @@ interface Props {
/** Standalone-mode: when present (and onSubmit absent), picker calls onApply. */
loader?(): Promise<ModelOptionsResponse>;
onApply?(args: {
confirmExpensiveModel?: boolean;
provider: string;
model: string;
persistGlobal: boolean;
}): Promise<void> | void;
}):
| Promise<ExpensiveModelConfirmResponse | void>
| ExpensiveModelConfirmResponse
| void;
onClose(): void;
title?: string;
@ -90,6 +111,8 @@ export function ModelPickerDialog(props: Props) {
const [query, setQuery] = useState("");
const [persistGlobal, setPersistGlobal] = useState(alwaysGlobal);
const [applying, setApplying] = useState(false);
const [pendingConfirm, setPendingConfirm] =
useState<PendingExpensiveConfirm | null>(null);
const closedRef = useRef(false);
// Load providers + models on open.
@ -179,16 +202,65 @@ export function ModelPickerDialog(props: Props) {
const canConfirm = !!selectedProvider && !!selectedModel && !applying;
const confirm = async () => {
if (!canConfirm || !selectedProvider) return;
const applySelection = async (
confirmExpensiveModel = false,
forced?: PendingExpensiveConfirm,
) => {
const providerSlug = forced?.provider ?? selectedProvider?.slug ?? "";
const model = forced?.model ?? selectedModel;
const shouldPersistGlobal = forced?.persistGlobal ?? persistGlobal;
if (!providerSlug || !model || applying) return;
if (standalone && onApply) {
setApplying(true);
try {
await onApply({
provider: selectedProvider.slug,
model: selectedModel,
persistGlobal,
const result = await onApply({
confirmExpensiveModel,
provider: providerSlug,
model,
persistGlobal: shouldPersistGlobal,
});
if (result?.confirm_required) {
setPendingConfirm({
provider: providerSlug,
model,
persistGlobal: shouldPersistGlobal,
message:
result.confirm_message ||
result.warning ||
"This model has unusually high known pricing.",
});
return;
}
onClose();
} catch (e) {
setError(e instanceof Error ? e.message : String(e));
} finally {
setApplying(false);
}
} else if (gw && sessionId) {
setApplying(true);
try {
const global = shouldPersistGlobal ? " --global" : "";
const result = await gw.request<ConfigSetResponse>("config.set", {
confirm_expensive_model: confirmExpensiveModel,
key: "model",
session_id: sessionId,
value: `${model} --provider ${providerSlug}${global}`,
});
if (result?.confirm_required) {
setPendingConfirm({
provider: providerSlug,
model,
persistGlobal: shouldPersistGlobal,
message:
result.confirm_message ||
result.warning ||
"This model has unusually high known pricing.",
});
return;
}
onClose();
} catch (e) {
setError(e instanceof Error ? e.message : String(e));
@ -196,14 +268,17 @@ export function ModelPickerDialog(props: Props) {
setApplying(false);
}
} else if (onSubmit) {
const global = persistGlobal ? " --global" : "";
onSubmit(
`/model ${selectedModel} --provider ${selectedProvider.slug}${global}`,
);
const global = shouldPersistGlobal ? " --global" : "";
onSubmit(`/model ${model} --provider ${providerSlug}${global}`);
onClose();
}
};
const confirm = () => {
if (!canConfirm) return;
void applySelection();
};
// Portal to document.body: the main dashboard column in App.tsx is
// `relative z-2`, which creates a stacking context that traps fixed
// descendants below the app sidebar (z-50). Without the portal this
@ -280,8 +355,12 @@ export function ModelPickerDialog(props: Props) {
onSelect={setSelectedModel}
onConfirm={(m) => {
setSelectedModel(m);
// Confirm on next tick so state settles.
window.setTimeout(confirm, 0);
void applySelection(false, {
provider: selectedProvider?.slug ?? "",
model: m,
persistGlobal,
message: "",
});
}}
/>
</div>
@ -320,6 +399,22 @@ export function ModelPickerDialog(props: Props) {
</div>
</footer>
</div>
<ConfirmDialog
open={!!pendingConfirm}
title="Expensive Model Warning"
description={pendingConfirm?.message}
destructive
confirmLabel="Switch anyway"
cancelLabel="Cancel"
loading={applying}
onCancel={() => setPendingConfirm(null)}
onConfirm={() => {
const pending = pendingConfirm;
if (!pending) return;
setPendingConfirm(null);
void applySelection(true, pending);
}}
/>
</div>,
document.body,
);

View file

@ -1763,9 +1763,12 @@ export interface AuxiliaryModelsResponse {
}
export interface ModelAssignmentRequest {
confirm_expensive_model?: boolean;
scope: "main" | "auxiliary";
provider: string;
model: string;
/** Optional OpenAI-compatible endpoint URL for custom/local main providers. */
base_url?: string;
/** For auxiliary: task slot name, "" for all, "__reset__" to reset all. */
task?: string;
}
@ -1779,6 +1782,8 @@ export interface StaleAuxAssignment {
}
export interface ModelAssignmentResponse {
confirm_message?: string;
confirm_required?: boolean;
ok: boolean;
scope?: string;
provider?: string;

View file

@ -26,7 +26,7 @@ import { Spinner } from "@nous-research/ui/ui/components/spinner";
import { Stats } from "@nous-research/ui/ui/components/stats";
import { Card, CardContent, CardHeader, CardTitle } from "@nous-research/ui/ui/components/card";
import { Badge } from "@nous-research/ui/ui/components/badge";
import { ConfirmDialog } from "@nous-research/ui/ui/components/confirm-dialog";
import { ConfirmDialog } from "@/components/ConfirmDialog";
import { useModalBehavior } from "@/hooks/useModalBehavior";
import { usePageHeader } from "@/contexts/usePageHeader";
import { useI18n } from "@/i18n";
@ -209,10 +209,16 @@ function UseAsMenu({
const [open, setOpen] = useState(false);
const [busy, setBusy] = useState(false);
const [error, setError] = useState<string | null>(null);
const [pendingConfirm, setPendingConfirm] = useState<{
message: string;
scope: "main" | "auxiliary";
task: string;
} | null>(null);
const assign = async (
scope: "main" | "auxiliary",
task: string,
confirmExpensiveModel = false,
) => {
if (!provider || !model) {
setError("Missing provider/model");
@ -221,7 +227,23 @@ function UseAsMenu({
setBusy(true);
setError(null);
try {
await api.setModelAssignment({ scope, provider, model, task });
const result = await api.setModelAssignment({
confirm_expensive_model: confirmExpensiveModel,
scope,
provider,
model,
task,
});
if (result.confirm_required) {
setPendingConfirm({
scope,
task,
message:
result.confirm_message ||
"This model has unusually high known pricing.",
});
return;
}
onAssigned();
setOpen(false);
} catch (e) {
@ -310,6 +332,22 @@ function UseAsMenu({
)}
</div>
)}
<ConfirmDialog
open={!!pendingConfirm}
title="Expensive Model Warning"
description={pendingConfirm?.message}
destructive
confirmLabel="Switch anyway"
cancelLabel="Cancel"
loading={busy}
onCancel={() => setPendingConfirm(null)}
onConfirm={() => {
const pending = pendingConfirm;
if (!pending) return;
setPendingConfirm(null);
void assign(pending.scope, pending.task, true);
}}
/>
</div>
);
}
@ -619,14 +657,16 @@ function AuxiliaryTasksModal({
AUX_TASKS.find((t) => t.key === picker.task)?.label ??
picker.task
}`}
onApply={async ({ provider, model }) => {
await api.setModelAssignment({
onApply={async ({ provider, model, confirmExpensiveModel }) => {
const result = await api.setModelAssignment({
confirm_expensive_model: confirmExpensiveModel,
scope: "auxiliary",
task: picker.task,
provider,
model,
});
onSaved();
if (!result.confirm_required) onSaved();
return result;
}}
onClose={() => setPicker(null)}
/>
@ -666,14 +706,23 @@ function ModelSettingsPanel({
task,
provider,
model,
confirmExpensiveModel,
}: {
confirmExpensiveModel?: boolean;
scope: "main" | "auxiliary";
task: string;
provider: string;
model: string;
}) => {
await api.setModelAssignment({ scope, task, provider, model });
onSaved();
const result = await api.setModelAssignment({
confirm_expensive_model: confirmExpensiveModel,
scope,
task,
provider,
model,
});
if (!result.confirm_required) onSaved();
return result;
};
// Count how many aux tasks have overrides
@ -749,14 +798,15 @@ function ModelSettingsPanel({
loader={api.getModelOptions}
alwaysGlobal
title="Set Main Model"
onApply={async ({ provider, model }) => {
await applyAssignment({
onApply={({ provider, model, confirmExpensiveModel }) =>
applyAssignment({
confirmExpensiveModel,
scope: "main",
task: "",
provider,
model,
});
}}
})
}
onClose={() => setPicker(null)}
/>
)}