mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-10 08:32:09 +00:00
fix(model): require confirmation for expensive model selections
Rebased onto current main and re-ported across the restructured surfaces: model flows now thread confirm_provider/base_url/api_key through hermes_cli/model_setup_flows.py, the Discord picker lives in plugins/platforms/discord/adapter.py, and the web dashboard picker applies chat-mode switches via config.set so the expensive-model confirmation can ride the response. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
parent
4eadef18a9
commit
af978ecb17
27 changed files with 1354 additions and 111 deletions
|
|
@ -13,6 +13,7 @@ DEFAULT_PRICING = {"input": 0.0, "output": 0.0}
|
|||
|
||||
_ZERO = Decimal("0")
|
||||
_ONE_MILLION = Decimal("1000000")
|
||||
_NOUS_DEFAULT_BASE_URL = "https://inference-api.nousresearch.com/v1"
|
||||
|
||||
CostStatus = Literal["actual", "estimated", "included", "unknown"]
|
||||
CostSource = Literal[
|
||||
|
|
@ -570,6 +571,8 @@ def resolve_billing_route(
|
|||
return BillingRoute(provider="openai-codex", model=model, base_url=base_url or "", billing_mode="subscription_included")
|
||||
if provider_name == "openrouter" or base_url_host_matches(base_url or "", "openrouter.ai"):
|
||||
return BillingRoute(provider="openrouter", model=model, base_url=base_url or "", billing_mode="official_models_api")
|
||||
if provider_name == "nous" or base_url_host_matches(base_url or "", "inference-api.nousresearch.com"):
|
||||
return BillingRoute(provider="nous", model=model, base_url=base_url or _NOUS_DEFAULT_BASE_URL, billing_mode="official_models_api")
|
||||
if provider_name == "anthropic":
|
||||
return BillingRoute(provider="anthropic", model=model.split("/")[-1], base_url=base_url or "", billing_mode="official_docs_snapshot")
|
||||
if provider_name == "openai":
|
||||
|
|
|
|||
54
cli.py
54
cli.py
|
|
@ -6516,6 +6516,47 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
|
|||
}
|
||||
self._invalidate(min_interval=0.0)
|
||||
|
||||
def _confirm_expensive_model_switch(self, result) -> bool:
|
||||
"""Ask for explicit confirmation before applying costly model switches."""
|
||||
if not getattr(result, "success", False):
|
||||
return True
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
result.new_model,
|
||||
provider=result.target_provider,
|
||||
base_url=result.base_url or self.base_url or "",
|
||||
api_key=result.api_key or self.api_key or "",
|
||||
model_info=result.model_info,
|
||||
)
|
||||
except Exception:
|
||||
warning = None
|
||||
if warning is None:
|
||||
return True
|
||||
|
||||
choices = [
|
||||
("once", "Switch anyway", "Use this model for the current Hermes session."),
|
||||
("cancel", "Cancel", "Keep the current model."),
|
||||
]
|
||||
raw = self._prompt_text_input_modal(
|
||||
title="!!! Expensive Model Warning !!!",
|
||||
detail=warning.message,
|
||||
choices=choices,
|
||||
timeout=120,
|
||||
)
|
||||
choice = self._normalize_slash_confirm_choice(raw, choices)
|
||||
return choice == "once"
|
||||
|
||||
def _confirm_and_apply_model_switch_result(self, result, persist_global: bool) -> None:
|
||||
try:
|
||||
if result.success and not self._confirm_expensive_model_switch(result):
|
||||
_cprint(" Model switch cancelled.")
|
||||
return
|
||||
self._apply_model_switch_result(result, persist_global)
|
||||
except Exception as exc:
|
||||
_cprint(f" ✗ Model selection failed: {exc}")
|
||||
|
||||
def _close_model_picker(self) -> None:
|
||||
self._model_picker_state = None
|
||||
self._restore_modal_input_snapshot()
|
||||
|
|
@ -6692,7 +6733,14 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
|
|||
custom_providers=state.get("custom_provs"),
|
||||
)
|
||||
self._close_model_picker()
|
||||
self._apply_model_switch_result(result, persist_global)
|
||||
if getattr(self, "_app", None):
|
||||
threading.Thread(
|
||||
target=self._confirm_and_apply_model_switch_result,
|
||||
args=(result, persist_global),
|
||||
daemon=True,
|
||||
).start()
|
||||
else:
|
||||
self._confirm_and_apply_model_switch_result(result, persist_global)
|
||||
return
|
||||
self._close_model_picker()
|
||||
|
||||
|
|
@ -6793,6 +6841,10 @@ class HermesCLI(CLIAgentSetupMixin, CLICommandsMixin):
|
|||
_cprint(f" ✗ {result.error_message}")
|
||||
return
|
||||
|
||||
if not self._confirm_expensive_model_switch(result):
|
||||
_cprint(" Model switch cancelled.")
|
||||
return
|
||||
|
||||
# Apply to CLI state.
|
||||
# Update requested_provider so _ensure_runtime_credentials() doesn't
|
||||
# overwrite the switch on the next turn (it re-resolves from this).
|
||||
|
|
|
|||
|
|
@ -3030,7 +3030,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
async def _handle_model_picker_callback(
|
||||
self, query, data: str, chat_id: str
|
||||
) -> None:
|
||||
"""Handle model picker inline keyboard callbacks (mp:/mm:/mb:/mx:/mg:)."""
|
||||
"""Handle model picker inline keyboard callbacks (mp:/mm:/mc:/mb:/mx:/mg:)."""
|
||||
state = self._model_picker_state.get(chat_id)
|
||||
if not state:
|
||||
await query.answer(text="Picker expired — use /model again.")
|
||||
|
|
@ -3115,6 +3115,51 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
)
|
||||
await query.answer()
|
||||
|
||||
elif data.startswith("mc:"):
|
||||
# --- Expensive model confirmed: perform the switch ---
|
||||
try:
|
||||
idx = int(data[3:])
|
||||
except ValueError:
|
||||
await query.answer(text="Invalid selection.")
|
||||
return
|
||||
|
||||
model_list = state.get("model_list", [])
|
||||
if idx < 0 or idx >= len(model_list):
|
||||
await query.answer(text="Invalid model index.")
|
||||
return
|
||||
|
||||
model_id = model_list[idx]
|
||||
provider_slug = state.get("selected_provider", "")
|
||||
callback = state.get("on_model_selected")
|
||||
|
||||
if not callback:
|
||||
await query.answer(text="Picker expired.")
|
||||
return
|
||||
|
||||
try:
|
||||
result_text = await callback(chat_id, model_id, provider_slug)
|
||||
except Exception as exc:
|
||||
logger.error("Model picker switch failed: %s", exc)
|
||||
result_text = f"Error switching model: {exc}"
|
||||
|
||||
try:
|
||||
await query.edit_message_text(
|
||||
text=self.format_message(result_text),
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_markup=None,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
await query.edit_message_text(
|
||||
text=result_text,
|
||||
parse_mode=None,
|
||||
reply_markup=None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
await query.answer(text="Model switched!")
|
||||
self._model_picker_state.pop(chat_id, None)
|
||||
|
||||
elif data.startswith("mm:"):
|
||||
# --- Model selected: perform the switch ---
|
||||
try:
|
||||
|
|
@ -3136,6 +3181,33 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
await query.answer(text="Picker expired.")
|
||||
return
|
||||
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
model_id,
|
||||
provider=provider_slug,
|
||||
)
|
||||
except Exception:
|
||||
warning = None
|
||||
if warning is not None:
|
||||
keyboard = InlineKeyboardMarkup([
|
||||
[InlineKeyboardButton("Switch anyway", callback_data=f"mc:{idx}")],
|
||||
[
|
||||
InlineKeyboardButton("◀ Back", callback_data="mb"),
|
||||
InlineKeyboardButton("✗ Cancel", callback_data="mx"),
|
||||
],
|
||||
])
|
||||
await query.edit_message_text(
|
||||
text=self.format_message(
|
||||
f"⚠ *Expensive Model Warning*\n\n{warning.message}"
|
||||
),
|
||||
parse_mode=ParseMode.MARKDOWN_V2,
|
||||
reply_markup=keyboard,
|
||||
)
|
||||
await query.answer(text="Confirm expensive model")
|
||||
return
|
||||
|
||||
try:
|
||||
result_text = await callback(chat_id, model_id, provider_slug)
|
||||
except Exception as exc:
|
||||
|
|
@ -3260,7 +3332,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
query_user_name = getattr(query.from_user, "first_name", None)
|
||||
|
||||
# --- Model picker callbacks ---
|
||||
if data.startswith(("mp:", "mpg:", "mm:", "mb", "mx", "mg:")):
|
||||
if data.startswith(("mp:", "mpg:", "mm:", "mc:", "mb", "mx", "mg:")):
|
||||
chat_id = str(query.message.chat_id) if query.message else None
|
||||
if chat_id:
|
||||
await self._handle_model_picker_callback(query, data, chat_id)
|
||||
|
|
|
|||
|
|
@ -6175,6 +6175,40 @@ def _reset_config_provider() -> Path:
|
|||
return config_path
|
||||
|
||||
|
||||
def _confirm_expensive_model_selection(
|
||||
model_id: str,
|
||||
*,
|
||||
provider: str = "",
|
||||
base_url: str = "",
|
||||
api_key: str = "",
|
||||
) -> bool:
|
||||
"""Prompt before saving a model whose known pricing exceeds guardrails."""
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
model_id,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception:
|
||||
warning = None
|
||||
if warning is None:
|
||||
return True
|
||||
|
||||
print()
|
||||
print("=" * 72)
|
||||
print(warning.message)
|
||||
print("=" * 72)
|
||||
try:
|
||||
response = input("Switch anyway? [y/N]: ").strip().lower()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return False
|
||||
return response in {"y", "yes"}
|
||||
|
||||
|
||||
def _prompt_model_selection(
|
||||
model_ids: List[str],
|
||||
current_model: str = "",
|
||||
|
|
@ -6182,6 +6216,9 @@ def _prompt_model_selection(
|
|||
unavailable_models: Optional[List[str]] = None,
|
||||
portal_url: str = "",
|
||||
unavailable_message: str = "",
|
||||
confirm_provider: str = "",
|
||||
confirm_base_url: str = "",
|
||||
confirm_api_key: str = "",
|
||||
) -> Optional[str]:
|
||||
"""Interactive model selection. Puts current_model first with a marker. Returns chosen model ID or None.
|
||||
|
||||
|
|
@ -6195,6 +6232,18 @@ def _prompt_model_selection(
|
|||
|
||||
_unavailable = unavailable_models or []
|
||||
|
||||
def _confirmed_selection(mid: str) -> Optional[str]:
|
||||
if not mid:
|
||||
return None
|
||||
if confirm_provider and not _confirm_expensive_model_selection(
|
||||
mid,
|
||||
provider=confirm_provider,
|
||||
base_url=confirm_base_url,
|
||||
api_key=confirm_api_key,
|
||||
):
|
||||
return None
|
||||
return mid
|
||||
|
||||
# Reorder: current model first, then the rest (deduplicated)
|
||||
ordered = []
|
||||
if current_model and current_model in model_ids:
|
||||
|
|
@ -6310,13 +6359,13 @@ def _prompt_model_selection(
|
|||
return None
|
||||
print()
|
||||
if idx < len(ordered):
|
||||
return ordered[idx]
|
||||
return _confirmed_selection(ordered[idx])
|
||||
elif idx == len(ordered):
|
||||
try:
|
||||
custom = input("Enter model name: ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
return None
|
||||
return custom if custom else None
|
||||
return _confirmed_selection(custom) if custom else None
|
||||
return None
|
||||
except (ImportError, NotImplementedError, OSError, subprocess.SubprocessError):
|
||||
pass
|
||||
|
|
@ -6348,10 +6397,10 @@ def _prompt_model_selection(
|
|||
return None
|
||||
idx = int(choice)
|
||||
if 1 <= idx <= n:
|
||||
return ordered[idx - 1]
|
||||
return _confirmed_selection(ordered[idx - 1])
|
||||
elif idx == n + 1:
|
||||
custom = input("Enter model name: ").strip()
|
||||
return custom if custom else None
|
||||
return _confirmed_selection(custom) if custom else None
|
||||
elif idx == n + 2:
|
||||
return None
|
||||
print(f"Please enter 1-{n + 2}")
|
||||
|
|
@ -7730,6 +7779,9 @@ def _login_nous(args, pconfig: ProviderConfig) -> None:
|
|||
unavailable_models=unavailable_models,
|
||||
portal_url=_portal,
|
||||
unavailable_message=unavailable_message,
|
||||
confirm_provider="nous",
|
||||
confirm_base_url=inference_base_url,
|
||||
confirm_api_key=runtime_key,
|
||||
)
|
||||
elif unavailable_models:
|
||||
_url = (_portal or DEFAULT_NOUS_PORTAL_URL).rstrip("/")
|
||||
|
|
|
|||
|
|
@ -3269,6 +3269,7 @@ def _aux_flow_provider_model(
|
|||
model_list,
|
||||
current_model=current_model,
|
||||
pricing=pricing,
|
||||
confirm_provider=provider_slug,
|
||||
)
|
||||
if selected is None:
|
||||
print("No change.")
|
||||
|
|
|
|||
134
hermes_cli/model_cost_guard.py
Normal file
134
hermes_cli/model_cost_guard.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Expensive-model confirmation helpers for model selection surfaces."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal, InvalidOperation
|
||||
from typing import Optional
|
||||
|
||||
from agent.models_dev import ModelInfo
|
||||
|
||||
|
||||
INPUT_COST_WARNING_THRESHOLD = Decimal("20")
|
||||
OUTPUT_COST_WARNING_THRESHOLD = Decimal("100")
|
||||
GPT55_PRO_OPENROUTER_ID = "openai/gpt-5.5-pro"
|
||||
GPT55_SUGGESTION = "did you mean to select openai/gpt-5.5?"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ExpensiveModelWarning:
|
||||
"""Confirmation payload for models above Hermes' cost guardrail."""
|
||||
|
||||
model: str
|
||||
provider: str
|
||||
input_cost_per_million: Optional[Decimal]
|
||||
output_cost_per_million: Optional[Decimal]
|
||||
source: str
|
||||
message: str
|
||||
|
||||
|
||||
def _to_decimal(value: object) -> Optional[Decimal]:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return Decimal(str(value))
|
||||
except (InvalidOperation, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _format_money(value: Optional[Decimal]) -> str:
|
||||
if value is None:
|
||||
return "unknown"
|
||||
return f"${value:.2f}/M"
|
||||
|
||||
|
||||
def _pricing_from_model_info(
|
||||
model_info: Optional[ModelInfo],
|
||||
) -> tuple[Optional[Decimal], Optional[Decimal], str]:
|
||||
if model_info is None or not model_info.has_cost_data():
|
||||
return None, None, ""
|
||||
return (
|
||||
_to_decimal(model_info.cost_input),
|
||||
_to_decimal(model_info.cost_output),
|
||||
"models.dev",
|
||||
)
|
||||
|
||||
|
||||
def expensive_model_warning(
|
||||
model_name: str,
|
||||
*,
|
||||
provider: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Optional[ExpensiveModelWarning]:
|
||||
"""Return a warning payload when known pricing exceeds safety thresholds.
|
||||
|
||||
The guard only triggers when pricing is known. Callers should use this after
|
||||
model resolution so aliases and provider-specific model IDs have settled.
|
||||
"""
|
||||
model = (model_name or "").strip()
|
||||
if not model:
|
||||
return None
|
||||
|
||||
input_cost, output_cost, source = _pricing_from_model_info(model_info)
|
||||
if input_cost is None and output_cost is None and provider:
|
||||
try:
|
||||
from agent.models_dev import get_model_info
|
||||
|
||||
input_cost, output_cost, source = _pricing_from_model_info(
|
||||
get_model_info(provider, model)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if input_cost is None and output_cost is None:
|
||||
try:
|
||||
from agent.usage_pricing import get_pricing_entry
|
||||
|
||||
entry = get_pricing_entry(
|
||||
model,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception:
|
||||
entry = None
|
||||
if entry is not None:
|
||||
input_cost = entry.input_cost_per_million
|
||||
output_cost = entry.output_cost_per_million
|
||||
source = entry.source
|
||||
|
||||
over_input = (
|
||||
input_cost is not None and input_cost > INPUT_COST_WARNING_THRESHOLD
|
||||
)
|
||||
over_output = (
|
||||
output_cost is not None and output_cost > OUTPUT_COST_WARNING_THRESHOLD
|
||||
)
|
||||
if not over_input and not over_output:
|
||||
return None
|
||||
|
||||
lines = [
|
||||
"!!! EXPENSIVE MODEL WARNING !!!",
|
||||
"",
|
||||
f"{model} has known pricing above Hermes' safety threshold.",
|
||||
f"Input tokens: {_format_money(input_cost)}",
|
||||
f"Output tokens: {_format_money(output_cost)}",
|
||||
(
|
||||
"Threshold: more than $20/M input tokens or more than "
|
||||
"$100/M output tokens."
|
||||
),
|
||||
]
|
||||
if source:
|
||||
lines.append(f"Pricing source: {source}.")
|
||||
if model.lower() == GPT55_PRO_OPENROUTER_ID:
|
||||
lines.append(GPT55_SUGGESTION)
|
||||
lines.append("Confirm only if you intend to use this model.")
|
||||
|
||||
return ExpensiveModelWarning(
|
||||
model=model,
|
||||
provider=(provider or "").strip(),
|
||||
input_cost_per_million=input_cost,
|
||||
output_cost_per_million=output_cost,
|
||||
source=source or "unknown",
|
||||
message="\n".join(lines),
|
||||
)
|
||||
|
|
@ -102,7 +102,12 @@ def _model_flow_openrouter(config, current_model=""):
|
|||
pricing = get_pricing_for_provider("openrouter", force_refresh=True)
|
||||
|
||||
selected = _prompt_model_selection(
|
||||
openrouter_models, current_model=current_model, pricing=pricing
|
||||
openrouter_models,
|
||||
current_model=current_model,
|
||||
pricing=pricing,
|
||||
confirm_provider="openrouter",
|
||||
confirm_base_url=OPENROUTER_BASE_URL,
|
||||
confirm_api_key=_resolved or existing_key,
|
||||
)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
|
|
@ -311,6 +316,9 @@ def _model_flow_nous(config, current_model="", args=None):
|
|||
unavailable_models=unavailable_models,
|
||||
portal_url=_nous_portal_url,
|
||||
unavailable_message=unavailable_message,
|
||||
confirm_provider="nous",
|
||||
confirm_base_url=creds.get("base_url", ""),
|
||||
confirm_api_key=creds.get("api_key", ""),
|
||||
)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
|
|
@ -416,7 +424,13 @@ def _model_flow_openai_codex(config, current_model=""):
|
|||
|
||||
codex_models = get_codex_model_ids(access_token=_codex_token)
|
||||
|
||||
selected = _prompt_model_selection(codex_models, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
codex_models,
|
||||
current_model=current_model,
|
||||
confirm_provider="openai-codex",
|
||||
confirm_base_url=DEFAULT_CODEX_BASE_URL,
|
||||
confirm_api_key=_codex_token or "",
|
||||
)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
_update_config_for_provider("openai-codex", DEFAULT_CODEX_BASE_URL)
|
||||
|
|
@ -546,7 +560,12 @@ def _model_flow_qwen_oauth(_config, current_model=""):
|
|||
models = list(_DEFAULT_QWEN_PORTAL_MODELS)
|
||||
|
||||
default = current_model or (models[0] if models else "qwen3-coder-plus")
|
||||
selected = _prompt_model_selection(models, current_model=default)
|
||||
selected = _prompt_model_selection(
|
||||
models,
|
||||
current_model=default,
|
||||
confirm_provider="qwen-oauth",
|
||||
confirm_base_url=DEFAULT_QWEN_BASE_URL,
|
||||
)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
_update_config_for_provider("qwen-oauth", DEFAULT_QWEN_BASE_URL)
|
||||
|
|
@ -595,7 +614,12 @@ def _model_flow_minimax_oauth(config, current_model="", args=None):
|
|||
from hermes_cli.models import _PROVIDER_MODELS
|
||||
|
||||
model_ids = _PROVIDER_MODELS.get("minimax-oauth", [])
|
||||
selected = _prompt_model_selection(model_ids, current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_ids,
|
||||
current_model,
|
||||
confirm_provider="minimax-oauth",
|
||||
confirm_base_url=creds["base_url"],
|
||||
)
|
||||
if not selected:
|
||||
return
|
||||
_save_model_choice(selected)
|
||||
|
|
@ -664,7 +688,12 @@ def _model_flow_google_gemini_cli(_config, current_model=""):
|
|||
|
||||
models = list(_PROVIDER_MODELS.get("google-gemini-cli") or [])
|
||||
default = current_model or (models[0] if models else "gemini-3-flash-preview")
|
||||
selected = _prompt_model_selection(models, current_model=default)
|
||||
selected = _prompt_model_selection(
|
||||
models,
|
||||
current_model=default,
|
||||
confirm_provider="google-gemini-cli",
|
||||
confirm_base_url=DEFAULT_GEMINI_CLOUDCODE_BASE_URL,
|
||||
)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
_update_config_for_provider(
|
||||
|
|
@ -1589,7 +1618,11 @@ def _model_flow_copilot(config, current_model=""):
|
|||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(
|
||||
model_list, current_model=normalized_current_model
|
||||
model_list,
|
||||
current_model=normalized_current_model,
|
||||
confirm_provider=provider_id,
|
||||
confirm_base_url=effective_base,
|
||||
confirm_api_key=api_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
|
@ -1727,6 +1760,9 @@ def _model_flow_copilot_acp(config, current_model=""):
|
|||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=normalized_current_model,
|
||||
confirm_provider=provider_id,
|
||||
confirm_base_url=effective_base,
|
||||
confirm_api_key=catalog_api_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
|
@ -1831,7 +1867,13 @@ def _model_flow_kimi(config, current_model=""):
|
|||
model_list = _PROVIDER_MODELS.get("moonshot", [])
|
||||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider=provider_id,
|
||||
confirm_base_url=effective_base,
|
||||
confirm_api_key=existing_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input("Enter model name: ").strip()
|
||||
|
|
@ -1939,7 +1981,13 @@ def _model_flow_stepfun(config, current_model=""):
|
|||
)
|
||||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider=provider_id,
|
||||
confirm_base_url=effective_base,
|
||||
confirm_api_key=existing_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name: ").strip()
|
||||
|
|
@ -2015,7 +2063,13 @@ def _model_flow_bedrock_api_key(config, region, current_model=""):
|
|||
print(f" Showing {len(model_list)} curated models")
|
||||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider="custom",
|
||||
confirm_base_url=mantle_base_url,
|
||||
confirm_api_key=existing_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input(" Model ID: ").strip()
|
||||
|
|
@ -2204,7 +2258,12 @@ def _model_flow_bedrock(config, current_model=""):
|
|||
|
||||
# 4. Model selection
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider="bedrock",
|
||||
confirm_base_url=f"https://bedrock-runtime.{region}.amazonaws.com",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input(" Model ID: ").strip()
|
||||
|
|
@ -2488,7 +2547,13 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""):
|
|||
model_list = list(dict.fromkeys(mid for mid in model_list if mid))
|
||||
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider=provider_id,
|
||||
confirm_base_url=effective_base,
|
||||
confirm_api_key=existing_key,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name: ").strip()
|
||||
|
|
@ -2638,7 +2703,11 @@ def _model_flow_anthropic(config, current_model=""):
|
|||
# Model selection
|
||||
model_list = _PROVIDER_MODELS.get("anthropic", [])
|
||||
if model_list:
|
||||
selected = _prompt_model_selection(model_list, current_model=current_model)
|
||||
selected = _prompt_model_selection(
|
||||
model_list,
|
||||
current_model=current_model,
|
||||
confirm_provider="anthropic",
|
||||
)
|
||||
else:
|
||||
try:
|
||||
selected = input("Model name (e.g., claude-sonnet-4-20250514): ").strip()
|
||||
|
|
|
|||
|
|
@ -657,21 +657,6 @@ class AudioTranscriptionRequest(BaseModel):
|
|||
mime_type: Optional[str] = None
|
||||
|
||||
|
||||
class ModelAssignment(BaseModel):
|
||||
"""Payload for POST /api/model/set — assign a provider/model to a slot.
|
||||
|
||||
scope="main" → writes model.provider + model.default
|
||||
scope="auxiliary" → writes auxiliary.<task>.provider + auxiliary.<task>.model
|
||||
scope="auxiliary" with task="" → applied to every auxiliary.* slot
|
||||
scope="auxiliary" with task="__reset__" → resets every slot to provider="auto"
|
||||
"""
|
||||
|
||||
scope: str
|
||||
provider: str
|
||||
model: str
|
||||
task: str = ""
|
||||
|
||||
|
||||
_AUDIO_MIME_EXTENSIONS: Dict[str, str] = {
|
||||
"audio/aac": ".aac",
|
||||
"audio/flac": ".flac",
|
||||
|
|
@ -713,6 +698,7 @@ class ModelAssignment(BaseModel):
|
|||
# reads model.base_url from config (it ignores OPENAI_BASE_URL), so this is
|
||||
# the path that actually wires a local endpoint into resolution.
|
||||
base_url: str = ""
|
||||
confirm_expensive_model: bool = False
|
||||
|
||||
|
||||
def _apply_main_model_assignment(
|
||||
|
|
@ -2470,6 +2456,27 @@ async def set_model_assignment(body: ModelAssignment):
|
|||
try:
|
||||
cfg = load_config()
|
||||
|
||||
if model and not body.confirm_expensive_model:
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
model,
|
||||
provider=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
warning = None
|
||||
if warning is not None:
|
||||
return {
|
||||
"ok": False,
|
||||
"scope": scope,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"confirm_required": True,
|
||||
"confirm_message": warning.message,
|
||||
}
|
||||
|
||||
if scope == "main":
|
||||
if not provider or not model:
|
||||
raise HTTPException(status_code=400, detail="provider and model required for main")
|
||||
|
|
|
|||
|
|
@ -5638,6 +5638,7 @@ def _define_discord_view_classes() -> None:
|
|||
self.allowed_role_ids = allowed_role_ids or set()
|
||||
self.resolved = False
|
||||
self._selected_provider: str = ""
|
||||
self._pending_expensive_model: str = ""
|
||||
|
||||
self._build_provider_select()
|
||||
|
||||
|
|
@ -5720,6 +5721,38 @@ def _define_discord_view_classes() -> None:
|
|||
cancel_btn.callback = self._on_cancel
|
||||
self.add_item(cancel_btn)
|
||||
|
||||
def _build_expensive_confirm(self, model_id: str):
|
||||
"""Build confirmation buttons for unusually expensive models."""
|
||||
self.clear_items()
|
||||
self._pending_expensive_model = model_id
|
||||
|
||||
confirm_btn = discord.ui.Button(
|
||||
label="Switch anyway",
|
||||
style=discord.ButtonStyle.red,
|
||||
custom_id="model_expensive_confirm",
|
||||
)
|
||||
confirm_btn.callback = self._on_expensive_confirm
|
||||
self.add_item(confirm_btn)
|
||||
|
||||
cancel_btn = discord.ui.Button(
|
||||
label="Cancel",
|
||||
style=discord.ButtonStyle.grey,
|
||||
custom_id="model_expensive_cancel",
|
||||
)
|
||||
cancel_btn.callback = self._on_cancel
|
||||
self.add_item(cancel_btn)
|
||||
|
||||
def _expensive_warning_for(self, model_id: str):
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
return expensive_model_warning(
|
||||
model_id,
|
||||
provider=self._selected_provider,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _on_provider_selected(self, interaction: discord.Interaction):
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
|
|
@ -5749,7 +5782,11 @@ def _define_discord_view_classes() -> None:
|
|||
view=self,
|
||||
)
|
||||
|
||||
async def _on_model_selected(self, interaction: discord.Interaction):
|
||||
async def _switch_selected_model(
|
||||
self,
|
||||
interaction: discord.Interaction,
|
||||
model_id: str,
|
||||
):
|
||||
if self.resolved:
|
||||
await interaction.response.send_message(
|
||||
"Already resolved~", ephemeral=True
|
||||
|
|
@ -5762,7 +5799,6 @@ def _define_discord_view_classes() -> None:
|
|||
return
|
||||
|
||||
self.resolved = True
|
||||
model_id = interaction.data["values"][0]
|
||||
self.clear_items()
|
||||
await interaction.response.edit_message(
|
||||
embed=discord.Embed(
|
||||
|
|
@ -5791,6 +5827,50 @@ def _define_discord_view_classes() -> None:
|
|||
view=None,
|
||||
)
|
||||
|
||||
async def _on_model_selected(self, interaction: discord.Interaction):
|
||||
if self.resolved:
|
||||
await interaction.response.send_message(
|
||||
"Already resolved~", ephemeral=True
|
||||
)
|
||||
return
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
"You're not authorized~", ephemeral=True
|
||||
)
|
||||
return
|
||||
|
||||
model_id = interaction.data["values"][0]
|
||||
warning = self._expensive_warning_for(model_id)
|
||||
if warning is not None:
|
||||
self._build_expensive_confirm(model_id)
|
||||
await interaction.response.edit_message(
|
||||
embed=discord.Embed(
|
||||
title="⚠ Expensive Model Warning",
|
||||
description=warning.message,
|
||||
color=discord.Color.red(),
|
||||
),
|
||||
view=self,
|
||||
)
|
||||
return
|
||||
|
||||
await self._switch_selected_model(interaction, model_id)
|
||||
|
||||
async def _on_expensive_confirm(self, interaction: discord.Interaction):
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
"You're not authorized~", ephemeral=True
|
||||
)
|
||||
return
|
||||
if not self._pending_expensive_model:
|
||||
await interaction.response.send_message(
|
||||
"Model selection expired.", ephemeral=True
|
||||
)
|
||||
return
|
||||
await self._switch_selected_model(
|
||||
interaction,
|
||||
self._pending_expensive_model,
|
||||
)
|
||||
|
||||
async def _on_back(self, interaction: discord.Interaction):
|
||||
if not self._check_auth(interaction):
|
||||
await interaction.response.send_message(
|
||||
|
|
|
|||
|
|
@ -192,6 +192,32 @@ def test_custom_endpoint_models_api_pricing_is_supported(monkeypatch):
|
|||
assert float(entry.output_cost_per_million) == 2.0
|
||||
|
||||
|
||||
def test_nous_portal_pricing_preserves_vendor_prefixed_model_ids(monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def _fake_fetch_endpoint_model_metadata(base_url, api_key=None):
|
||||
seen["base_url"] = base_url
|
||||
return {
|
||||
"openai/gpt-5.5-pro": {
|
||||
"pricing": {
|
||||
"prompt": "0.000025",
|
||||
"completion": "0.000125",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.fetch_endpoint_model_metadata",
|
||||
_fake_fetch_endpoint_model_metadata,
|
||||
)
|
||||
|
||||
entry = get_pricing_entry("openai/gpt-5.5-pro", provider="nous")
|
||||
|
||||
assert seen["base_url"] == "https://inference-api.nousresearch.com/v1"
|
||||
assert float(entry.input_cost_per_million) == 25.0
|
||||
assert float(entry.output_cost_per_million) == 125.0
|
||||
|
||||
|
||||
def test_deepseek_v4_pro_pricing_entry_exists():
|
||||
"""Regression test: deepseek-v4-pro must have a pricing entry.
|
||||
|
||||
|
|
|
|||
|
|
@ -80,3 +80,91 @@ async def test_model_picker_clears_controls_before_running_switch_callback():
|
|||
interaction.response.edit_message.assert_awaited_once()
|
||||
interaction.response.defer.assert_not_called()
|
||||
interaction.edit_original_response.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expensive_model_requires_confirmation(monkeypatch):
|
||||
events: list[object] = []
|
||||
|
||||
async def on_model_selected(chat_id: str, model_id: str, provider_slug: str) -> str:
|
||||
events.append(("switch", chat_id, model_id, provider_slug))
|
||||
return "Model switched"
|
||||
|
||||
async def edit_message(**kwargs):
|
||||
events.append(
|
||||
(
|
||||
"edit",
|
||||
kwargs["embed"].title,
|
||||
kwargs["embed"].description,
|
||||
kwargs["view"],
|
||||
)
|
||||
)
|
||||
|
||||
async def edit_original_response(**kwargs):
|
||||
events.append((
|
||||
"final-edit",
|
||||
kwargs["embed"].title,
|
||||
kwargs["embed"].description,
|
||||
kwargs["view"],
|
||||
))
|
||||
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?"
|
||||
),
|
||||
)
|
||||
|
||||
view = ModelPickerView(
|
||||
providers=[
|
||||
{
|
||||
"slug": "openrouter",
|
||||
"name": "OpenRouter",
|
||||
"models": ["openai/gpt-5.5-pro"],
|
||||
"total_models": 1,
|
||||
"is_current": True,
|
||||
}
|
||||
],
|
||||
current_model="openai/gpt-5.5",
|
||||
current_provider="openrouter",
|
||||
session_key="session-1",
|
||||
on_model_selected=on_model_selected,
|
||||
allowed_user_ids={"123"}, # matches the interaction user; empty = fail-closed
|
||||
)
|
||||
view._selected_provider = "openrouter"
|
||||
|
||||
interaction = SimpleNamespace(
|
||||
user=SimpleNamespace(id=123),
|
||||
channel_id=456,
|
||||
data={"values": ["openai/gpt-5.5-pro"]},
|
||||
response=SimpleNamespace(
|
||||
send_message=AsyncMock(),
|
||||
edit_message=AsyncMock(side_effect=edit_message),
|
||||
),
|
||||
edit_original_response=AsyncMock(side_effect=edit_original_response),
|
||||
)
|
||||
|
||||
await view._on_model_selected(interaction)
|
||||
|
||||
assert events == [
|
||||
(
|
||||
"edit",
|
||||
"⚠ Expensive Model Warning",
|
||||
"!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?",
|
||||
view,
|
||||
),
|
||||
]
|
||||
assert view.resolved is False
|
||||
|
||||
await view._on_expensive_confirm(interaction)
|
||||
|
||||
assert events[1:] == [
|
||||
(
|
||||
"edit",
|
||||
"⚙ Switching Model",
|
||||
"Switching to `openai/gpt-5.5-pro`...",
|
||||
None,
|
||||
),
|
||||
("switch", "456", "openai/gpt-5.5-pro", "openrouter"),
|
||||
("final-edit", "⚙ Model Switched", "Model switched", None),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -91,10 +91,6 @@ class TestTelegramModelPicker:
|
|||
query.answer = AsyncMock()
|
||||
query.edit_message_text = AsyncMock()
|
||||
|
||||
update = MagicMock()
|
||||
update.callback_query = query
|
||||
context = MagicMock()
|
||||
|
||||
await adapter._handle_model_picker_callback(query, "mb", "12345")
|
||||
|
||||
edit_kwargs = query.edit_message_text.call_args[1]
|
||||
|
|
@ -133,17 +129,11 @@ class TestTelegramModelPicker:
|
|||
|
||||
await adapter._handle_model_picker_callback(query, "mm:0", "12345")
|
||||
|
||||
# The callback was invoked with the selected model
|
||||
callback.assert_awaited_once()
|
||||
# edit_message_text MUST be called on the success path (this is the
|
||||
# regression we're guarding).
|
||||
query.edit_message_text.assert_awaited()
|
||||
edit_kwargs = query.edit_message_text.call_args[1]
|
||||
assert "MARKDOWN_V2" in repr(edit_kwargs["parse_mode"])
|
||||
# The dynamic result text was routed through format_message
|
||||
# (backtick code blocks survive escaping).
|
||||
assert "`gpt-5`" in edit_kwargs["text"]
|
||||
# State is cleaned up after a successful switch.
|
||||
assert "12345" not in adapter._model_picker_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -184,7 +174,7 @@ class TestTelegramModelPicker:
|
|||
providers = [
|
||||
{"slug": "minimax", "name": "MiniMax", "total_models": 2},
|
||||
{"slug": "minimax-cn", "name": "MiniMax (China)", "total_models": 3},
|
||||
{"slug": "xai", "name": "xAI", "total_models": 1}, # lone group member
|
||||
{"slug": "xai", "name": "xAI", "total_models": 1},
|
||||
]
|
||||
|
||||
await adapter.send_model_picker(
|
||||
|
|
@ -197,14 +187,11 @@ class TestTelegramModelPicker:
|
|||
metadata=None,
|
||||
)
|
||||
|
||||
# Top-level keyboard: MiniMax family folded into one group button;
|
||||
# xai (lone member) degraded to a direct provider button.
|
||||
assert "mpg:minimax" in built
|
||||
assert "mp:xai" in built
|
||||
assert "mp:minimax" not in built
|
||||
assert "mp:minimax-cn" not in built
|
||||
|
||||
# Drill into the MiniMax group → members appear as mp: buttons + back.
|
||||
built.clear()
|
||||
query = AsyncMock()
|
||||
query.message = MagicMock()
|
||||
|
|
@ -216,7 +203,49 @@ class TestTelegramModelPicker:
|
|||
|
||||
assert "mp:minimax" in built
|
||||
assert "mp:minimax-cn" in built
|
||||
assert "mb" in built # back-to-providers button present
|
||||
assert "mb" in built
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_expensive_model_requires_confirmation(self, monkeypatch):
|
||||
adapter = _make_adapter()
|
||||
callback = AsyncMock(return_value="Switched to `openai/gpt-5.5-pro`")
|
||||
adapter._model_picker_state["12345"] = {
|
||||
"providers": [
|
||||
{"slug": "openrouter", "name": "OpenRouter", "total_models": 1, "is_current": True}
|
||||
],
|
||||
"current_model": "model_1",
|
||||
"current_provider": "openrouter",
|
||||
"session_key": "s",
|
||||
"on_model_selected": callback,
|
||||
"selected_provider": "openrouter",
|
||||
"model_list": ["openai/gpt-5.5-pro"],
|
||||
"msg_id": 42,
|
||||
}
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(
|
||||
message="!!! EXPENSIVE MODEL WARNING !!!\ndid you mean to select openai/gpt-5.5?"
|
||||
),
|
||||
)
|
||||
|
||||
query = AsyncMock()
|
||||
query.message = MagicMock()
|
||||
query.message.chat_id = 12345
|
||||
query.answer = AsyncMock()
|
||||
query.edit_message_text = AsyncMock()
|
||||
|
||||
await adapter._handle_model_picker_callback(query, "mm:0", "12345")
|
||||
|
||||
callback.assert_not_awaited()
|
||||
assert "12345" in adapter._model_picker_state
|
||||
first_edit = query.edit_message_text.call_args[1]
|
||||
assert "EXPENSIVE MODEL WARNING" in first_edit["text"]
|
||||
assert first_edit["reply_markup"] is not None
|
||||
|
||||
await adapter._handle_model_picker_callback(query, "mc:0", "12345")
|
||||
|
||||
callback.assert_awaited_once_with("12345", "openai/gpt-5.5-pro", "openrouter")
|
||||
assert "12345" not in adapter._model_picker_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_without_thread_when_thread_not_found(self):
|
||||
|
|
|
|||
|
|
@ -133,7 +133,7 @@ def test_model_command_uses_runtime_access_token_for_codex_list(monkeypatch):
|
|||
captured["access_token"] = access_token
|
||||
return ["gpt-5.2-codex", "gpt-5.2"]
|
||||
|
||||
def _fake_prompt_model_selection(model_ids, current_model=""):
|
||||
def _fake_prompt_model_selection(model_ids, current_model="", **_kwargs):
|
||||
captured["model_ids"] = list(model_ids)
|
||||
captured["current_model"] = current_model
|
||||
return None
|
||||
|
|
@ -181,7 +181,7 @@ def test_model_command_prompts_to_reuse_or_reauthenticate_codex_session(monkeypa
|
|||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._prompt_model_selection",
|
||||
lambda model_ids, current_model="": None,
|
||||
lambda model_ids, current_model="", **_kwargs: None,
|
||||
)
|
||||
|
||||
_model_flow_openai_codex({}, current_model="gpt-5.4")
|
||||
|
|
@ -219,7 +219,7 @@ def test_model_command_uses_existing_codex_session_without_relogin(monkeypatch):
|
|||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._prompt_model_selection",
|
||||
lambda model_ids, current_model="": None,
|
||||
lambda model_ids, current_model="", **_kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.auth._login_openai_codex",
|
||||
|
|
|
|||
97
tests/hermes_cli/test_model_cost_guard.py
Normal file
97
tests/hermes_cli/test_model_cost_guard.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from agent.models_dev import ModelInfo
|
||||
from agent.usage_pricing import PricingEntry
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
|
||||
def test_no_warning_when_known_prices_are_at_threshold():
|
||||
info = ModelInfo(
|
||||
id="edge/model",
|
||||
name="edge/model",
|
||||
family="",
|
||||
provider_id="test",
|
||||
cost_input=20.0,
|
||||
cost_output=100.0,
|
||||
)
|
||||
|
||||
assert expensive_model_warning("edge/model", provider="test", model_info=info) is None
|
||||
|
||||
|
||||
def test_warns_when_models_dev_input_price_exceeds_threshold():
|
||||
info = ModelInfo(
|
||||
id="expensive/input",
|
||||
name="expensive/input",
|
||||
family="",
|
||||
provider_id="test",
|
||||
cost_input=20.01,
|
||||
cost_output=1.0,
|
||||
)
|
||||
|
||||
warning = expensive_model_warning(
|
||||
"expensive/input",
|
||||
provider="test",
|
||||
model_info=info,
|
||||
)
|
||||
|
||||
assert warning is not None
|
||||
assert warning.input_cost_per_million == Decimal("20.01")
|
||||
assert "EXPENSIVE MODEL WARNING" in warning.message
|
||||
assert "$20/M input" in warning.message
|
||||
|
||||
|
||||
def test_warns_when_pricing_entry_output_price_exceeds_threshold(monkeypatch):
|
||||
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.get_pricing_entry",
|
||||
lambda *_args, **_kwargs: PricingEntry(
|
||||
input_cost_per_million=Decimal("1.00"),
|
||||
output_cost_per_million=Decimal("100.01"),
|
||||
source="provider_models_api",
|
||||
),
|
||||
)
|
||||
|
||||
warning = expensive_model_warning("provider/expensive-output", provider="openrouter")
|
||||
|
||||
assert warning is not None
|
||||
assert warning.output_cost_per_million == Decimal("100.01")
|
||||
assert "$100.01/M" in warning.message
|
||||
|
||||
|
||||
def test_openai_gpt55_pro_adds_suggestion(monkeypatch):
|
||||
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.get_pricing_entry",
|
||||
lambda *_args, **_kwargs: PricingEntry(
|
||||
input_cost_per_million=Decimal("25"),
|
||||
output_cost_per_million=Decimal("125"),
|
||||
source="provider_models_api",
|
||||
),
|
||||
)
|
||||
|
||||
warning = expensive_model_warning("openai/gpt-5.5-pro", provider="openrouter")
|
||||
|
||||
assert warning is not None
|
||||
assert "did you mean to select openai/gpt-5.5?" in warning.message
|
||||
|
||||
|
||||
def test_openai_gpt55_pro_warns_for_nous_portal_pricing(monkeypatch):
|
||||
monkeypatch.setattr("agent.models_dev.get_model_info", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(
|
||||
"agent.usage_pricing.fetch_endpoint_model_metadata",
|
||||
lambda base_url, api_key="": {
|
||||
"openai/gpt-5.5-pro": {
|
||||
"pricing": {
|
||||
"prompt": "0.000025",
|
||||
"completion": "0.000125",
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
warning = expensive_model_warning("openai/gpt-5.5-pro", provider="nous")
|
||||
|
||||
assert warning is not None
|
||||
assert warning.input_cost_per_million == Decimal("25.000000")
|
||||
assert warning.output_cost_per_million == Decimal("125.000000")
|
||||
assert "did you mean to select openai/gpt-5.5?" in warning.message
|
||||
64
tests/hermes_cli/test_model_picker_expensive_confirm.py
Normal file
64
tests/hermes_cli/test_model_picker_expensive_confirm.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
from hermes_cli.model_switch import ModelSwitchResult
|
||||
|
||||
|
||||
def _bound(fn, instance):
|
||||
return fn.__get__(instance, type(instance))
|
||||
|
||||
|
||||
def test_prompt_toolkit_model_picker_defers_confirmation_off_key_handler(monkeypatch):
|
||||
import cli as cli_mod
|
||||
|
||||
result = ModelSwitchResult(
|
||||
success=True,
|
||||
new_model="openai/gpt-5.5-pro",
|
||||
target_provider="nous",
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_switch.switch_model",
|
||||
lambda **_kwargs: result,
|
||||
)
|
||||
|
||||
captured = {}
|
||||
|
||||
class _Thread:
|
||||
def __init__(self, *, target, args, daemon):
|
||||
captured["target"] = target
|
||||
captured["args"] = args
|
||||
captured["daemon"] = daemon
|
||||
|
||||
def start(self):
|
||||
captured["started"] = True
|
||||
|
||||
monkeypatch.setattr(cli_mod.threading, "Thread", _Thread)
|
||||
|
||||
self_ = SimpleNamespace(
|
||||
_app=object(),
|
||||
_model_picker_state={
|
||||
"stage": "model",
|
||||
"provider_data": {"slug": "nous"},
|
||||
"model_list": ["openai/gpt-5.5-pro"],
|
||||
"selected": 0,
|
||||
"user_provs": None,
|
||||
"custom_provs": None,
|
||||
},
|
||||
provider="nous",
|
||||
model="openai/gpt-5.5",
|
||||
base_url="",
|
||||
api_key="",
|
||||
_restore_modal_input_snapshot=lambda: None,
|
||||
_invalidate=lambda **_kwargs: None,
|
||||
)
|
||||
self_._close_model_picker = _bound(cli_mod.HermesCLI._close_model_picker, self_)
|
||||
self_._confirm_and_apply_model_switch_result = (
|
||||
lambda *_args: captured.setdefault("ran_inline", True)
|
||||
)
|
||||
|
||||
_bound(cli_mod.HermesCLI._handle_model_picker_selection, self_)()
|
||||
|
||||
assert self_._model_picker_state is None
|
||||
assert captured["started"] is True
|
||||
assert captured["daemon"] is True
|
||||
assert captured["args"] == (result, False)
|
||||
assert "ran_inline" not in captured
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
cannot initialize (e.g. non-TTY, curses unavailable, terminal error)."""
|
||||
|
||||
import subprocess
|
||||
from types import SimpleNamespace
|
||||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
|
|
@ -24,6 +25,46 @@ def test_prompt_model_selection_falls_back_on_menu_runtime_error(monkeypatch):
|
|||
assert selected == "model-b"
|
||||
|
||||
|
||||
def test_prompt_model_selection_requires_expensive_confirmation(monkeypatch, capsys):
|
||||
from hermes_cli.auth import _prompt_model_selection
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
|
||||
)
|
||||
responses = iter(["1", "n"])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses))
|
||||
|
||||
selected = _prompt_model_selection(
|
||||
["openai/gpt-5.5-pro"],
|
||||
confirm_provider="nous",
|
||||
)
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert selected is None
|
||||
assert "EXPENSIVE MODEL WARNING" in out
|
||||
|
||||
|
||||
def test_prompt_model_selection_allows_confirmed_expensive_model(monkeypatch):
|
||||
from hermes_cli.auth import _prompt_model_selection
|
||||
|
||||
monkeypatch.setattr("hermes_cli.curses_ui.curses_radiolist", _raise_menu)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
|
||||
)
|
||||
responses = iter(["1", "y"])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt="": next(responses))
|
||||
|
||||
selected = _prompt_model_selection(
|
||||
["openai/gpt-5.5-pro"],
|
||||
confirm_provider="nous",
|
||||
)
|
||||
|
||||
assert selected == "openai/gpt-5.5-pro"
|
||||
|
||||
|
||||
def test_prompt_reasoning_effort_falls_back_on_menu_runtime_error(monkeypatch):
|
||||
from hermes_cli.main import _prompt_reasoning_effort_selection
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import os
|
|||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -1069,6 +1070,41 @@ class TestWebServerEndpoints:
|
|||
assert "GATEWAY_PROXY_URL" not in managed
|
||||
assert "GATEWAY_PROXY_URL" in _MESSAGING_KEYS_PAGE_KEYS
|
||||
|
||||
def test_model_set_requires_confirmation_for_expensive_model(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_cost_guard.expensive_model_warning",
|
||||
lambda *_args, **_kwargs: SimpleNamespace(message="EXPENSIVE MODEL WARNING"),
|
||||
)
|
||||
|
||||
resp = self.client.post(
|
||||
"/api/model/set",
|
||||
json={
|
||||
"scope": "main",
|
||||
"provider": "nous",
|
||||
"model": "openai/gpt-5.5-pro",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is False
|
||||
assert data["confirm_required"] is True
|
||||
assert data["confirm_message"] == "EXPENSIVE MODEL WARNING"
|
||||
|
||||
confirmed = self.client.post(
|
||||
"/api/model/set",
|
||||
json={
|
||||
"scope": "main",
|
||||
"provider": "nous",
|
||||
"model": "openai/gpt-5.5-pro",
|
||||
"confirm_expensive_model": True,
|
||||
},
|
||||
)
|
||||
|
||||
assert confirmed.status_code == 200
|
||||
assert confirmed.json()["ok"] is True
|
||||
|
||||
|
||||
def test_reveal_env_var(self, tmp_path):
|
||||
"""POST /api/env/reveal should return the real unredacted value."""
|
||||
from hermes_cli.config import save_env_value
|
||||
|
|
|
|||
|
|
@ -2397,7 +2397,7 @@ def test_config_set_model_waits_for_lazy_agent_before_switch(monkeypatch):
|
|||
target["agent"] = agent
|
||||
agent_ready.set()
|
||||
|
||||
def fake_apply(sid, target, raw):
|
||||
def fake_apply(sid, target, raw, **kwargs):
|
||||
calls.append(("apply", sid, target.get("agent"), raw))
|
||||
if target.get("agent") is not agent:
|
||||
raise AssertionError("model switch ran before lazy agent was ready")
|
||||
|
|
@ -2424,7 +2424,7 @@ def test_config_set_model_uses_live_switch_path(monkeypatch):
|
|||
server._sessions["sid"] = _session()
|
||||
seen = {}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
def _fake_apply(sid, session, raw, **_kwargs):
|
||||
seen["args"] = (sid, session["session_key"], raw)
|
||||
return {"value": "new/model", "warning": "catalog unreachable"}
|
||||
|
||||
|
|
@ -2442,6 +2442,74 @@ def test_config_set_model_uses_live_switch_path(monkeypatch):
|
|||
assert seen["args"] == ("sid", "session-key", "new/model")
|
||||
|
||||
|
||||
def test_config_set_model_requires_confirmation_for_expensive_model(monkeypatch):
|
||||
class _Agent:
|
||||
provider = "openrouter"
|
||||
model = "old/model"
|
||||
base_url = ""
|
||||
api_key = "sk-or"
|
||||
switched = False
|
||||
|
||||
def switch_model(self, **_kwargs):
|
||||
self.switched = True
|
||||
|
||||
result = types.SimpleNamespace(
|
||||
success=True,
|
||||
new_model="openai/gpt-5.5-pro",
|
||||
target_provider="openrouter",
|
||||
api_key="sk-or",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_mode="chat_completions",
|
||||
warning_message="",
|
||||
model_info=types.SimpleNamespace(
|
||||
has_cost_data=lambda: True,
|
||||
cost_input=25.0,
|
||||
cost_output=125.0,
|
||||
),
|
||||
)
|
||||
|
||||
agent = _Agent()
|
||||
server._sessions["sid"] = _session(agent=agent)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.model_switch.switch_model", lambda **_kwargs: result
|
||||
)
|
||||
monkeypatch.setattr(server, "_restart_slash_worker", lambda sid, session: None)
|
||||
monkeypatch.setattr(server, "_emit", lambda *args, **kwargs: None)
|
||||
|
||||
resp = server.handle_request(
|
||||
{
|
||||
"id": "1",
|
||||
"method": "config.set",
|
||||
"params": {
|
||||
"session_id": "sid",
|
||||
"key": "model",
|
||||
"value": "openai/gpt-5.5-pro --provider openrouter",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert resp["result"]["confirm_required"] is True
|
||||
assert "did you mean to select openai/gpt-5.5?" in resp["result"]["confirm_message"]
|
||||
assert agent.switched is False
|
||||
|
||||
confirmed = server.handle_request(
|
||||
{
|
||||
"id": "2",
|
||||
"method": "config.set",
|
||||
"params": {
|
||||
"session_id": "sid",
|
||||
"key": "model",
|
||||
"value": "openai/gpt-5.5-pro --provider openrouter",
|
||||
"confirm_expensive_model": True,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
assert confirmed["result"]["confirm_required"] is False
|
||||
assert confirmed["result"]["value"] == "openai/gpt-5.5-pro"
|
||||
assert agent.switched is True
|
||||
|
||||
|
||||
def test_config_set_model_global_persists(monkeypatch):
|
||||
class _Agent:
|
||||
provider = "openrouter"
|
||||
|
|
@ -3944,7 +4012,7 @@ def test_config_set_model_rejects_while_running(monkeypatch):
|
|||
"""/model via config.set must reject during an in-flight turn."""
|
||||
seen = {"called": False}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
def _fake_apply(sid, session, raw, **_kwargs):
|
||||
seen["called"] = True
|
||||
return {"value": raw, "warning": ""}
|
||||
|
||||
|
|
@ -3978,7 +4046,7 @@ def test_config_set_model_allowed_when_idle(monkeypatch):
|
|||
"""Regression guard: idle sessions can still switch models."""
|
||||
seen = {"called": False}
|
||||
|
||||
def _fake_apply(sid, session, raw):
|
||||
def _fake_apply(sid, session, raw, **_kwargs):
|
||||
seen["called"] = True
|
||||
return {"value": "newmodel", "warning": ""}
|
||||
|
||||
|
|
|
|||
|
|
@ -1696,7 +1696,13 @@ def _persist_model_switch(result) -> None:
|
|||
save_config(cfg)
|
||||
|
||||
|
||||
def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
|
||||
def _apply_model_switch(
|
||||
sid: str,
|
||||
session: dict,
|
||||
raw_input: str,
|
||||
*,
|
||||
confirm_expensive_model: bool = False,
|
||||
) -> dict:
|
||||
from hermes_cli.model_switch import parse_model_flags, switch_model
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
|
|
@ -1753,6 +1759,27 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
|
|||
if not result.success:
|
||||
raise ValueError(result.error_message or "model switch failed")
|
||||
|
||||
if not confirm_expensive_model:
|
||||
try:
|
||||
from hermes_cli.model_cost_guard import expensive_model_warning
|
||||
|
||||
warning = expensive_model_warning(
|
||||
result.new_model,
|
||||
provider=result.target_provider,
|
||||
base_url=result.base_url or current_base_url,
|
||||
api_key=result.api_key or current_api_key,
|
||||
model_info=result.model_info,
|
||||
)
|
||||
except Exception:
|
||||
warning = None
|
||||
if warning is not None:
|
||||
return {
|
||||
"value": result.new_model,
|
||||
"warning": warning.message,
|
||||
"confirm_required": True,
|
||||
"confirm_message": warning.message,
|
||||
}
|
||||
|
||||
if agent:
|
||||
agent.switch_model(
|
||||
new_model=result.new_model,
|
||||
|
|
@ -1787,7 +1814,11 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
|
|||
}
|
||||
if persist_global:
|
||||
_persist_model_switch(result)
|
||||
return {"value": result.new_model, "warning": result.warning_message or ""}
|
||||
return {
|
||||
"value": result.new_model,
|
||||
"warning": result.warning_message or "",
|
||||
"confirm_required": False,
|
||||
}
|
||||
|
||||
|
||||
def _compress_session_history(
|
||||
|
|
@ -6196,13 +6227,31 @@ def _(rid, params: dict) -> dict:
|
|||
if session.get("agent") is None:
|
||||
return _err(rid, 5032, "agent initialization failed")
|
||||
result = _apply_model_switch(
|
||||
params.get("session_id", ""), session, value
|
||||
params.get("session_id", ""),
|
||||
session,
|
||||
value,
|
||||
confirm_expensive_model=bool(
|
||||
params.get("confirm_expensive_model", False)
|
||||
),
|
||||
)
|
||||
else:
|
||||
result = _apply_model_switch("", {"agent": None}, value)
|
||||
result = _apply_model_switch(
|
||||
"",
|
||||
{"agent": None},
|
||||
value,
|
||||
confirm_expensive_model=bool(
|
||||
params.get("confirm_expensive_model", False)
|
||||
),
|
||||
)
|
||||
return _ok(
|
||||
rid,
|
||||
{"key": key, "value": result["value"], "warning": result["warning"]},
|
||||
{
|
||||
"key": key,
|
||||
"value": result["value"],
|
||||
"warning": result["warning"],
|
||||
"confirm_required": result.get("confirm_required", False),
|
||||
"confirm_message": result.get("confirm_message", ""),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
return _err(rid, 5001, str(e))
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ describe('createSlashHandler', () => {
|
|||
|
||||
expect(createSlashHandler(ctx)('/model x-model')).toBe(true)
|
||||
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
|
||||
confirm_expensive_model: false,
|
||||
key: 'model',
|
||||
session_id: 'sid-abc',
|
||||
value: 'x-model'
|
||||
|
|
@ -128,6 +129,7 @@ describe('createSlashHandler', () => {
|
|||
createSlashHandler(ctx)(`/model anthropic/claude-sonnet-4.6 --provider openrouter ${TUI_SESSION_MODEL_FLAG}`)
|
||||
).toBe(true)
|
||||
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
|
||||
confirm_expensive_model: false,
|
||||
key: 'model',
|
||||
session_id: 'sid-abc',
|
||||
value: 'anthropic/claude-sonnet-4.6 --provider openrouter'
|
||||
|
|
@ -140,6 +142,7 @@ describe('createSlashHandler', () => {
|
|||
|
||||
createSlashHandler(ctx)('/model x-model --global')
|
||||
expect(ctx.gateway.rpc).toHaveBeenCalledWith('config.set', {
|
||||
confirm_expensive_model: false,
|
||||
key: 'model',
|
||||
session_id: 'sid-abc',
|
||||
value: 'x-model --global'
|
||||
|
|
|
|||
|
|
@ -72,10 +72,25 @@ export const sessionCommands: SlashCommand[] = [
|
|||
return patchOverlayState({ modelPicker: true })
|
||||
}
|
||||
|
||||
ctx.gateway
|
||||
.rpc<ConfigSetResponse>('config.set', { key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) })
|
||||
const switchModel = (confirmExpensiveModel = false) => ctx.gateway
|
||||
.rpc<ConfigSetResponse>('config.set', { confirm_expensive_model: confirmExpensiveModel, key: 'model', session_id: ctx.sid, value: modelValueForConfigSet(arg) })
|
||||
.then(
|
||||
ctx.guarded<ConfigSetResponse>(r => {
|
||||
if (r.confirm_required) {
|
||||
patchOverlayState({
|
||||
confirm: {
|
||||
cancelLabel: 'Cancel',
|
||||
confirmLabel: 'Switch anyway',
|
||||
danger: true,
|
||||
detail: r.confirm_message || r.warning || 'This model has unusually high known pricing.',
|
||||
onConfirm: () => switchModel(true),
|
||||
title: 'Expensive model selection'
|
||||
}
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if (!r.value) {
|
||||
return ctx.transcript.sys('error: invalid response: model switch')
|
||||
}
|
||||
|
|
@ -89,6 +104,8 @@ export const sessionCommands: SlashCommand[] = [
|
|||
}))
|
||||
})
|
||||
)
|
||||
|
||||
switchModel()
|
||||
}
|
||||
},
|
||||
|
||||
|
|
|
|||
|
|
@ -103,6 +103,8 @@ export interface ConfigGetValueResponse {
|
|||
}
|
||||
|
||||
export interface ConfigSetResponse {
|
||||
confirm_message?: string
|
||||
confirm_required?: boolean
|
||||
credential_warning?: string
|
||||
history_reset?: boolean
|
||||
info?: SessionInfo
|
||||
|
|
|
|||
|
|
@ -292,24 +292,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) {
|
|||
setVersion((v) => v + 1);
|
||||
}, []);
|
||||
|
||||
// Picker hands us a fully-formed slash command (e.g. "/model anthropic/...").
|
||||
// Fire-and-forget through `slash.exec`; the TUI pane will render the result
|
||||
// via PTY, so the sidebar doesn't need to surface output of its own.
|
||||
const onModelSubmit = useCallback(
|
||||
(slashCommand: string) => {
|
||||
if (!sessionId) {
|
||||
return;
|
||||
}
|
||||
|
||||
void gw.request("slash.exec", {
|
||||
session_id: sessionId,
|
||||
command: slashCommand,
|
||||
});
|
||||
setModelOpen(false);
|
||||
},
|
||||
[gw, sessionId],
|
||||
);
|
||||
|
||||
const canPickModel = state === "open" && !!sessionId;
|
||||
const modelLabel = (info.model ?? "—").split("/").slice(-1)[0] ?? "—";
|
||||
const banner = error ?? info.credential_warning ?? null;
|
||||
|
|
@ -390,7 +372,6 @@ export function ChatSidebar({ channel, className }: ChatSidebarProps) {
|
|||
gw={gw}
|
||||
sessionId={sessionId}
|
||||
onClose={() => setModelOpen(false)}
|
||||
onSubmit={onModelSubmit}
|
||||
/>
|
||||
)}
|
||||
</aside>
|
||||
|
|
|
|||
122
web/src/components/ConfirmDialog.tsx
Normal file
122
web/src/components/ConfirmDialog.tsx
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import { Button } from "@nous-research/ui/ui/components/button";
|
||||
import { AlertTriangle } from "lucide-react";
|
||||
import { useEffect, useRef } from "react";
|
||||
import { createPortal } from "react-dom";
|
||||
import { cn, themedBody } from "@/lib/utils";
|
||||
|
||||
interface ConfirmDialogProps {
|
||||
cancelLabel?: string;
|
||||
confirmLabel?: string;
|
||||
description?: string;
|
||||
destructive?: boolean;
|
||||
loading?: boolean;
|
||||
onCancel: () => void;
|
||||
onConfirm: () => void;
|
||||
open: boolean;
|
||||
title: string;
|
||||
}
|
||||
|
||||
export function ConfirmDialog({
|
||||
cancelLabel = "Cancel",
|
||||
confirmLabel = "Confirm",
|
||||
description,
|
||||
destructive = false,
|
||||
loading = false,
|
||||
onCancel,
|
||||
onConfirm,
|
||||
open,
|
||||
title,
|
||||
}: ConfirmDialogProps) {
|
||||
const dialogRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!open) return;
|
||||
|
||||
const prevActive = document.activeElement as HTMLElement | null;
|
||||
dialogRef.current
|
||||
?.querySelector<HTMLButtonElement>("[data-confirm]")
|
||||
?.focus();
|
||||
|
||||
const onKey = (e: KeyboardEvent) => {
|
||||
if (e.key === "Escape") {
|
||||
e.preventDefault();
|
||||
onCancel();
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("keydown", onKey);
|
||||
const prevOverflow = document.body.style.overflow;
|
||||
document.body.style.overflow = "hidden";
|
||||
|
||||
return () => {
|
||||
document.removeEventListener("keydown", onKey);
|
||||
document.body.style.overflow = prevOverflow;
|
||||
prevActive?.focus?.();
|
||||
};
|
||||
}, [open, onCancel]);
|
||||
|
||||
if (!open) return null;
|
||||
|
||||
return createPortal(
|
||||
<div
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
aria-labelledby="confirm-dialog-title"
|
||||
aria-describedby={description ? "confirm-dialog-desc" : undefined}
|
||||
onClick={(e) => {
|
||||
if (e.target === e.currentTarget) onCancel();
|
||||
}}
|
||||
className="fixed inset-0 z-[200] flex items-center justify-center bg-background/85 backdrop-blur-sm p-4"
|
||||
>
|
||||
<div
|
||||
ref={dialogRef}
|
||||
className={cn(
|
||||
themedBody,
|
||||
"relative w-full max-w-md border border-border bg-card shadow-2xl",
|
||||
)}
|
||||
>
|
||||
<div className="flex items-start gap-3 p-4 border-b border-border">
|
||||
{destructive && (
|
||||
<div aria-hidden className="mt-0.5 shrink-0 text-destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex-1 min-w-0 flex flex-col gap-1">
|
||||
<h2
|
||||
id="confirm-dialog-title"
|
||||
className="font-mondwest text-display text-base tracking-wider"
|
||||
>
|
||||
{title}
|
||||
</h2>
|
||||
|
||||
{description && (
|
||||
<p
|
||||
id="confirm-dialog-desc"
|
||||
className="text-xs text-muted-foreground leading-relaxed whitespace-pre-line"
|
||||
>
|
||||
{description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex items-center justify-end gap-2 p-3">
|
||||
<Button type="button" outlined onClick={onCancel} disabled={loading}>
|
||||
{cancelLabel}
|
||||
</Button>
|
||||
<Button
|
||||
data-confirm
|
||||
type="button"
|
||||
destructive={destructive}
|
||||
onClick={onConfirm}
|
||||
disabled={loading}
|
||||
>
|
||||
{loading ? "…" : confirmLabel}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</div>,
|
||||
document.body,
|
||||
);
|
||||
}
|
||||
|
|
@ -4,6 +4,7 @@ import { ListItem } from "@nous-research/ui/ui/components/list-item";
|
|||
import { Spinner } from "@nous-research/ui/ui/components/spinner";
|
||||
import { Input } from "@nous-research/ui/ui/components/input";
|
||||
import { Label } from "@nous-research/ui/ui/components/label";
|
||||
import { ConfirmDialog } from "@/components/ConfirmDialog";
|
||||
import type { GatewayClient } from "@/lib/gatewayClient";
|
||||
import { Check, Search, X } from "lucide-react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
|
@ -21,9 +22,8 @@ import { fuzzyRank } from "@/lib/fuzzy";
|
|||
* Two invocation modes:
|
||||
*
|
||||
* 1. Chat-session mode (ChatSidebar) — pass `gw` + `sessionId`. The picker
|
||||
* loads options via `model.options` JSON-RPC and emits the result as a
|
||||
* slash command string (`/model <model> --provider <slug> [--global]`)
|
||||
* through `onSubmit`, which the ChatPage pipes to `slashExec`.
|
||||
* loads options via `model.options` JSON-RPC and applies the choice via
|
||||
* `config.set`, so expensive-model confirmation can happen before switch.
|
||||
*
|
||||
* 2. Standalone mode (ModelsPage, Config settings) — pass a `loader` and
|
||||
* `onApply`. The picker fetches options via the REST endpoint and calls
|
||||
|
|
@ -47,6 +47,23 @@ interface ModelOptionsResponse {
|
|||
providers?: ModelOptionProvider[];
|
||||
}
|
||||
|
||||
interface ExpensiveModelConfirmResponse {
|
||||
confirm_message?: string;
|
||||
confirm_required?: boolean;
|
||||
warning?: string;
|
||||
}
|
||||
|
||||
interface ConfigSetResponse extends ExpensiveModelConfirmResponse {
|
||||
value?: string;
|
||||
}
|
||||
|
||||
interface PendingExpensiveConfirm {
|
||||
message: string;
|
||||
model: string;
|
||||
persistGlobal: boolean;
|
||||
provider: string;
|
||||
}
|
||||
|
||||
interface Props {
|
||||
/** Chat-mode: when present, picker emits a slash command via onSubmit. */
|
||||
gw?: GatewayClient;
|
||||
|
|
@ -56,10 +73,14 @@ interface Props {
|
|||
/** Standalone-mode: when present (and onSubmit absent), picker calls onApply. */
|
||||
loader?(): Promise<ModelOptionsResponse>;
|
||||
onApply?(args: {
|
||||
confirmExpensiveModel?: boolean;
|
||||
provider: string;
|
||||
model: string;
|
||||
persistGlobal: boolean;
|
||||
}): Promise<void> | void;
|
||||
}):
|
||||
| Promise<ExpensiveModelConfirmResponse | void>
|
||||
| ExpensiveModelConfirmResponse
|
||||
| void;
|
||||
|
||||
onClose(): void;
|
||||
title?: string;
|
||||
|
|
@ -90,6 +111,8 @@ export function ModelPickerDialog(props: Props) {
|
|||
const [query, setQuery] = useState("");
|
||||
const [persistGlobal, setPersistGlobal] = useState(alwaysGlobal);
|
||||
const [applying, setApplying] = useState(false);
|
||||
const [pendingConfirm, setPendingConfirm] =
|
||||
useState<PendingExpensiveConfirm | null>(null);
|
||||
const closedRef = useRef(false);
|
||||
|
||||
// Load providers + models on open.
|
||||
|
|
@ -179,16 +202,65 @@ export function ModelPickerDialog(props: Props) {
|
|||
|
||||
const canConfirm = !!selectedProvider && !!selectedModel && !applying;
|
||||
|
||||
const confirm = async () => {
|
||||
if (!canConfirm || !selectedProvider) return;
|
||||
const applySelection = async (
|
||||
confirmExpensiveModel = false,
|
||||
forced?: PendingExpensiveConfirm,
|
||||
) => {
|
||||
const providerSlug = forced?.provider ?? selectedProvider?.slug ?? "";
|
||||
const model = forced?.model ?? selectedModel;
|
||||
const shouldPersistGlobal = forced?.persistGlobal ?? persistGlobal;
|
||||
|
||||
if (!providerSlug || !model || applying) return;
|
||||
|
||||
if (standalone && onApply) {
|
||||
setApplying(true);
|
||||
try {
|
||||
await onApply({
|
||||
provider: selectedProvider.slug,
|
||||
model: selectedModel,
|
||||
persistGlobal,
|
||||
const result = await onApply({
|
||||
confirmExpensiveModel,
|
||||
provider: providerSlug,
|
||||
model,
|
||||
persistGlobal: shouldPersistGlobal,
|
||||
});
|
||||
if (result?.confirm_required) {
|
||||
setPendingConfirm({
|
||||
provider: providerSlug,
|
||||
model,
|
||||
persistGlobal: shouldPersistGlobal,
|
||||
message:
|
||||
result.confirm_message ||
|
||||
result.warning ||
|
||||
"This model has unusually high known pricing.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
onClose();
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : String(e));
|
||||
} finally {
|
||||
setApplying(false);
|
||||
}
|
||||
} else if (gw && sessionId) {
|
||||
setApplying(true);
|
||||
try {
|
||||
const global = shouldPersistGlobal ? " --global" : "";
|
||||
const result = await gw.request<ConfigSetResponse>("config.set", {
|
||||
confirm_expensive_model: confirmExpensiveModel,
|
||||
key: "model",
|
||||
session_id: sessionId,
|
||||
value: `${model} --provider ${providerSlug}${global}`,
|
||||
});
|
||||
if (result?.confirm_required) {
|
||||
setPendingConfirm({
|
||||
provider: providerSlug,
|
||||
model,
|
||||
persistGlobal: shouldPersistGlobal,
|
||||
message:
|
||||
result.confirm_message ||
|
||||
result.warning ||
|
||||
"This model has unusually high known pricing.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
onClose();
|
||||
} catch (e) {
|
||||
setError(e instanceof Error ? e.message : String(e));
|
||||
|
|
@ -196,14 +268,17 @@ export function ModelPickerDialog(props: Props) {
|
|||
setApplying(false);
|
||||
}
|
||||
} else if (onSubmit) {
|
||||
const global = persistGlobal ? " --global" : "";
|
||||
onSubmit(
|
||||
`/model ${selectedModel} --provider ${selectedProvider.slug}${global}`,
|
||||
);
|
||||
const global = shouldPersistGlobal ? " --global" : "";
|
||||
onSubmit(`/model ${model} --provider ${providerSlug}${global}`);
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
const confirm = () => {
|
||||
if (!canConfirm) return;
|
||||
void applySelection();
|
||||
};
|
||||
|
||||
// Portal to document.body: the main dashboard column in App.tsx is
|
||||
// `relative z-2`, which creates a stacking context that traps fixed
|
||||
// descendants below the app sidebar (z-50). Without the portal this
|
||||
|
|
@ -280,8 +355,12 @@ export function ModelPickerDialog(props: Props) {
|
|||
onSelect={setSelectedModel}
|
||||
onConfirm={(m) => {
|
||||
setSelectedModel(m);
|
||||
// Confirm on next tick so state settles.
|
||||
window.setTimeout(confirm, 0);
|
||||
void applySelection(false, {
|
||||
provider: selectedProvider?.slug ?? "",
|
||||
model: m,
|
||||
persistGlobal,
|
||||
message: "",
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
|
@ -320,6 +399,22 @@ export function ModelPickerDialog(props: Props) {
|
|||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
<ConfirmDialog
|
||||
open={!!pendingConfirm}
|
||||
title="Expensive Model Warning"
|
||||
description={pendingConfirm?.message}
|
||||
destructive
|
||||
confirmLabel="Switch anyway"
|
||||
cancelLabel="Cancel"
|
||||
loading={applying}
|
||||
onCancel={() => setPendingConfirm(null)}
|
||||
onConfirm={() => {
|
||||
const pending = pendingConfirm;
|
||||
if (!pending) return;
|
||||
setPendingConfirm(null);
|
||||
void applySelection(true, pending);
|
||||
}}
|
||||
/>
|
||||
</div>,
|
||||
document.body,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -1763,9 +1763,12 @@ export interface AuxiliaryModelsResponse {
|
|||
}
|
||||
|
||||
export interface ModelAssignmentRequest {
|
||||
confirm_expensive_model?: boolean;
|
||||
scope: "main" | "auxiliary";
|
||||
provider: string;
|
||||
model: string;
|
||||
/** Optional OpenAI-compatible endpoint URL for custom/local main providers. */
|
||||
base_url?: string;
|
||||
/** For auxiliary: task slot name, "" for all, "__reset__" to reset all. */
|
||||
task?: string;
|
||||
}
|
||||
|
|
@ -1779,6 +1782,8 @@ export interface StaleAuxAssignment {
|
|||
}
|
||||
|
||||
export interface ModelAssignmentResponse {
|
||||
confirm_message?: string;
|
||||
confirm_required?: boolean;
|
||||
ok: boolean;
|
||||
scope?: string;
|
||||
provider?: string;
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ import { Spinner } from "@nous-research/ui/ui/components/spinner";
|
|||
import { Stats } from "@nous-research/ui/ui/components/stats";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@nous-research/ui/ui/components/card";
|
||||
import { Badge } from "@nous-research/ui/ui/components/badge";
|
||||
import { ConfirmDialog } from "@nous-research/ui/ui/components/confirm-dialog";
|
||||
import { ConfirmDialog } from "@/components/ConfirmDialog";
|
||||
import { useModalBehavior } from "@/hooks/useModalBehavior";
|
||||
import { usePageHeader } from "@/contexts/usePageHeader";
|
||||
import { useI18n } from "@/i18n";
|
||||
|
|
@ -209,10 +209,16 @@ function UseAsMenu({
|
|||
const [open, setOpen] = useState(false);
|
||||
const [busy, setBusy] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [pendingConfirm, setPendingConfirm] = useState<{
|
||||
message: string;
|
||||
scope: "main" | "auxiliary";
|
||||
task: string;
|
||||
} | null>(null);
|
||||
|
||||
const assign = async (
|
||||
scope: "main" | "auxiliary",
|
||||
task: string,
|
||||
confirmExpensiveModel = false,
|
||||
) => {
|
||||
if (!provider || !model) {
|
||||
setError("Missing provider/model");
|
||||
|
|
@ -221,7 +227,23 @@ function UseAsMenu({
|
|||
setBusy(true);
|
||||
setError(null);
|
||||
try {
|
||||
await api.setModelAssignment({ scope, provider, model, task });
|
||||
const result = await api.setModelAssignment({
|
||||
confirm_expensive_model: confirmExpensiveModel,
|
||||
scope,
|
||||
provider,
|
||||
model,
|
||||
task,
|
||||
});
|
||||
if (result.confirm_required) {
|
||||
setPendingConfirm({
|
||||
scope,
|
||||
task,
|
||||
message:
|
||||
result.confirm_message ||
|
||||
"This model has unusually high known pricing.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
onAssigned();
|
||||
setOpen(false);
|
||||
} catch (e) {
|
||||
|
|
@ -310,6 +332,22 @@ function UseAsMenu({
|
|||
)}
|
||||
</div>
|
||||
)}
|
||||
<ConfirmDialog
|
||||
open={!!pendingConfirm}
|
||||
title="Expensive Model Warning"
|
||||
description={pendingConfirm?.message}
|
||||
destructive
|
||||
confirmLabel="Switch anyway"
|
||||
cancelLabel="Cancel"
|
||||
loading={busy}
|
||||
onCancel={() => setPendingConfirm(null)}
|
||||
onConfirm={() => {
|
||||
const pending = pendingConfirm;
|
||||
if (!pending) return;
|
||||
setPendingConfirm(null);
|
||||
void assign(pending.scope, pending.task, true);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -619,14 +657,16 @@ function AuxiliaryTasksModal({
|
|||
AUX_TASKS.find((t) => t.key === picker.task)?.label ??
|
||||
picker.task
|
||||
}`}
|
||||
onApply={async ({ provider, model }) => {
|
||||
await api.setModelAssignment({
|
||||
onApply={async ({ provider, model, confirmExpensiveModel }) => {
|
||||
const result = await api.setModelAssignment({
|
||||
confirm_expensive_model: confirmExpensiveModel,
|
||||
scope: "auxiliary",
|
||||
task: picker.task,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
onSaved();
|
||||
if (!result.confirm_required) onSaved();
|
||||
return result;
|
||||
}}
|
||||
onClose={() => setPicker(null)}
|
||||
/>
|
||||
|
|
@ -666,14 +706,23 @@ function ModelSettingsPanel({
|
|||
task,
|
||||
provider,
|
||||
model,
|
||||
confirmExpensiveModel,
|
||||
}: {
|
||||
confirmExpensiveModel?: boolean;
|
||||
scope: "main" | "auxiliary";
|
||||
task: string;
|
||||
provider: string;
|
||||
model: string;
|
||||
}) => {
|
||||
await api.setModelAssignment({ scope, task, provider, model });
|
||||
onSaved();
|
||||
const result = await api.setModelAssignment({
|
||||
confirm_expensive_model: confirmExpensiveModel,
|
||||
scope,
|
||||
task,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
if (!result.confirm_required) onSaved();
|
||||
return result;
|
||||
};
|
||||
|
||||
// Count how many aux tasks have overrides
|
||||
|
|
@ -749,14 +798,15 @@ function ModelSettingsPanel({
|
|||
loader={api.getModelOptions}
|
||||
alwaysGlobal
|
||||
title="Set Main Model"
|
||||
onApply={async ({ provider, model }) => {
|
||||
await applyAssignment({
|
||||
onApply={({ provider, model, confirmExpensiveModel }) =>
|
||||
applyAssignment({
|
||||
confirmExpensiveModel,
|
||||
scope: "main",
|
||||
task: "",
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
}}
|
||||
})
|
||||
}
|
||||
onClose={() => setPicker(null)}
|
||||
/>
|
||||
)}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue