diff --git a/agent/models_dev.py b/agent/models_dev.py index b4b6995584..61483b6a10 100644 --- a/agent/models_dev.py +++ b/agent/models_dev.py @@ -1,19 +1,31 @@ -"""Models.dev registry integration for provider-aware context length detection. +"""Models.dev registry integration — primary database for providers and models. -Fetches model metadata from https://models.dev/api.json — a community-maintained -database of 3800+ models across 100+ providers, including per-provider context -windows, pricing, and capabilities. +Fetches from https://models.dev/api.json — a community-maintained database +of 4000+ models across 109+ providers. Provides: -Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json) -to avoid cold-start network latency. +- **Provider metadata**: name, base URL, env vars, documentation link +- **Model metadata**: context window, max output, cost/M tokens, capabilities + (reasoning, tools, vision, PDF, audio), modalities, knowledge cutoff, + open-weights flag, family grouping, deprecation status + +Data resolution order (like TypeScript OpenCode): + 1. Bundled snapshot (ships with the package — offline-first) + 2. Disk cache (~/.hermes/models_dev_cache.json) + 3. Network fetch (https://models.dev/api.json) + 4. Background refresh every 60 minutes + +Other modules should import the dataclasses and query functions from here +rather than parsing the raw JSON themselves. """ +import difflib import json import logging import os import time +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional, Tuple, Union from utils import atomic_json_write @@ -28,7 +40,110 @@ _MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory _models_dev_cache: Dict[str, Any] = {} _models_dev_cache_time: float = 0 -# Provider ID mapping: Hermes provider names → models.dev provider IDs + +# --------------------------------------------------------------------------- +# Dataclasses — rich metadata for providers and models +# --------------------------------------------------------------------------- + +@dataclass +class ModelInfo: + """Full metadata for a single model from models.dev.""" + + id: str + name: str + family: str + provider_id: str # models.dev provider ID (e.g. "anthropic") + + # Capabilities + reasoning: bool = False + tool_call: bool = False + attachment: bool = False # supports image/file attachments (vision) + temperature: bool = False + structured_output: bool = False + open_weights: bool = False + + # Modalities + input_modalities: Tuple[str, ...] = () # ("text", "image", "pdf", ...) + output_modalities: Tuple[str, ...] = () + + # Limits + context_window: int = 0 + max_output: int = 0 + max_input: Optional[int] = None + + # Cost (per million tokens, USD) + cost_input: float = 0.0 + cost_output: float = 0.0 + cost_cache_read: Optional[float] = None + cost_cache_write: Optional[float] = None + + # Metadata + knowledge_cutoff: str = "" + release_date: str = "" + status: str = "" # "alpha", "beta", "deprecated", or "" + interleaved: Any = False # True or {"field": "reasoning_content"} + + def has_cost_data(self) -> bool: + return self.cost_input > 0 or self.cost_output > 0 + + def supports_vision(self) -> bool: + return self.attachment or "image" in self.input_modalities + + def supports_pdf(self) -> bool: + return "pdf" in self.input_modalities + + def supports_audio_input(self) -> bool: + return "audio" in self.input_modalities + + def format_cost(self) -> str: + """Human-readable cost string, e.g. '$3.00/M in, $15.00/M out'.""" + if not self.has_cost_data(): + return "unknown" + parts = [f"${self.cost_input:.2f}/M in", f"${self.cost_output:.2f}/M out"] + if self.cost_cache_read is not None: + parts.append(f"cache read ${self.cost_cache_read:.2f}/M") + return ", ".join(parts) + + def format_capabilities(self) -> str: + """Human-readable capabilities, e.g. 'reasoning, tools, vision, PDF'.""" + caps = [] + if self.reasoning: + caps.append("reasoning") + if self.tool_call: + caps.append("tools") + if self.supports_vision(): + caps.append("vision") + if self.supports_pdf(): + caps.append("PDF") + if self.supports_audio_input(): + caps.append("audio") + if self.structured_output: + caps.append("structured output") + if self.open_weights: + caps.append("open weights") + return ", ".join(caps) if caps else "basic" + + +@dataclass +class ProviderInfo: + """Full metadata for a provider from models.dev.""" + + id: str # models.dev provider ID + name: str # display name + env: Tuple[str, ...] # env var names for API key + api: str # base URL + doc: str = "" # documentation URL + model_count: int = 0 + + def has_api_url(self) -> bool: + return bool(self.api) + + +# --------------------------------------------------------------------------- +# Provider ID mapping: Hermes ↔ models.dev +# --------------------------------------------------------------------------- + +# Hermes provider names → models.dev provider IDs PROVIDER_TO_MODELS_DEV: Dict[str, str] = { "openrouter": "openrouter", "anthropic": "anthropic", @@ -44,8 +159,28 @@ PROVIDER_TO_MODELS_DEV: Dict[str, str] = { "opencode-go": "opencode-go", "kilocode": "kilo", "fireworks": "fireworks-ai", + "huggingface": "huggingface", + "google": "google", + "xai": "xai", + "nvidia": "nvidia", + "groq": "groq", + "mistral": "mistral", + "togetherai": "togetherai", + "perplexity": "perplexity", + "cohere": "cohere", } +# Reverse mapping: models.dev → Hermes (built lazily) +_MODELS_DEV_TO_PROVIDER: Optional[Dict[str, str]] = None + + +def _get_reverse_mapping() -> Dict[str, str]: + """Return models.dev ID → Hermes provider ID mapping.""" + global _MODELS_DEV_TO_PROVIDER + if _MODELS_DEV_TO_PROVIDER is None: + _MODELS_DEV_TO_PROVIDER = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()} + return _MODELS_DEV_TO_PROVIDER + def _get_cache_path() -> Path: """Return path to disk cache file.""" @@ -170,3 +305,443 @@ def _extract_context(entry: Dict[str, Any]) -> Optional[int]: if isinstance(ctx, (int, float)) and ctx > 0: return int(ctx) return None + + +# --------------------------------------------------------------------------- +# Model capability metadata +# --------------------------------------------------------------------------- + + +@dataclass +class ModelCapabilities: + """Structured capability metadata for a model from models.dev.""" + + supports_tools: bool = True + supports_vision: bool = False + supports_reasoning: bool = False + context_window: int = 200000 + max_output_tokens: int = 8192 + model_family: str = "" + + +def _get_provider_models(provider: str) -> Optional[Dict[str, Any]]: + """Resolve a Hermes provider ID to its models dict from models.dev. + + Returns the models dict or None if the provider is unknown or has no data. + """ + mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider) + if not mdev_provider_id: + return None + + data = fetch_models_dev() + provider_data = data.get(mdev_provider_id) + if not isinstance(provider_data, dict): + return None + + models = provider_data.get("models", {}) + if not isinstance(models, dict): + return None + + return models + + +def _find_model_entry(models: Dict[str, Any], model: str) -> Optional[Dict[str, Any]]: + """Find a model entry by exact match, then case-insensitive fallback.""" + # Exact match + entry = models.get(model) + if isinstance(entry, dict): + return entry + + # Case-insensitive match + model_lower = model.lower() + for mid, mdata in models.items(): + if mid.lower() == model_lower and isinstance(mdata, dict): + return mdata + + return None + + +def get_model_capabilities(provider: str, model: str) -> Optional[ModelCapabilities]: + """Look up full capability metadata from models.dev cache. + + Uses the existing fetch_models_dev() and PROVIDER_TO_MODELS_DEV mapping. + Returns None if model not found. + + Extracts from model entry fields: + - reasoning (bool) → supports_reasoning + - tool_call (bool) → supports_tools + - attachment (bool) → supports_vision + - limit.context (int) → context_window + - limit.output (int) → max_output_tokens + - family (str) → model_family + """ + models = _get_provider_models(provider) + if models is None: + return None + + entry = _find_model_entry(models, model) + if entry is None: + return None + + # Extract capability flags (default to False if missing) + supports_tools = bool(entry.get("tool_call", False)) + supports_vision = bool(entry.get("attachment", False)) + supports_reasoning = bool(entry.get("reasoning", False)) + + # Extract limits + limit = entry.get("limit", {}) + if not isinstance(limit, dict): + limit = {} + + ctx = limit.get("context") + context_window = int(ctx) if isinstance(ctx, (int, float)) and ctx > 0 else 200000 + + out = limit.get("output") + max_output_tokens = int(out) if isinstance(out, (int, float)) and out > 0 else 8192 + + model_family = entry.get("family", "") or "" + + return ModelCapabilities( + supports_tools=supports_tools, + supports_vision=supports_vision, + supports_reasoning=supports_reasoning, + context_window=context_window, + max_output_tokens=max_output_tokens, + model_family=model_family, + ) + + +def list_provider_models(provider: str) -> List[str]: + """Return all model IDs for a provider from models.dev. + + Returns an empty list if the provider is unknown or has no data. + """ + models = _get_provider_models(provider) + if models is None: + return [] + return list(models.keys()) + + +def search_models_dev( + query: str, provider: str = None, limit: int = 5 +) -> List[Dict[str, Any]]: + """Fuzzy search across models.dev catalog. Returns matching model entries. + + Args: + query: Search string to match against model IDs. + provider: Optional Hermes provider ID to restrict search scope. + If None, searches across all providers in PROVIDER_TO_MODELS_DEV. + limit: Maximum number of results to return. + + Returns: + List of dicts, each containing 'provider', 'model_id', and the full + model 'entry' from models.dev. + """ + data = fetch_models_dev() + if not data: + return [] + + # Build list of (provider_id, model_id, entry) candidates + candidates: List[tuple] = [] + + if provider is not None: + # Search only the specified provider + mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider) + if not mdev_provider_id: + return [] + provider_data = data.get(mdev_provider_id, {}) + if isinstance(provider_data, dict): + models = provider_data.get("models", {}) + if isinstance(models, dict): + for mid, mdata in models.items(): + candidates.append((provider, mid, mdata)) + else: + # Search across all mapped providers + for hermes_prov, mdev_prov in PROVIDER_TO_MODELS_DEV.items(): + provider_data = data.get(mdev_prov, {}) + if isinstance(provider_data, dict): + models = provider_data.get("models", {}) + if isinstance(models, dict): + for mid, mdata in models.items(): + candidates.append((hermes_prov, mid, mdata)) + + if not candidates: + return [] + + # Use difflib for fuzzy matching — case-insensitive comparison + model_ids_lower = [c[1].lower() for c in candidates] + query_lower = query.lower() + + # First try exact substring matches (more intuitive than pure edit-distance) + substring_matches = [] + for prov, mid, mdata in candidates: + if query_lower in mid.lower(): + substring_matches.append({"provider": prov, "model_id": mid, "entry": mdata}) + + # Then add difflib fuzzy matches for any remaining slots + fuzzy_ids = difflib.get_close_matches( + query_lower, model_ids_lower, n=limit * 2, cutoff=0.4 + ) + + seen_ids: set = set() + results: List[Dict[str, Any]] = [] + + # Prioritize substring matches + for match in substring_matches: + key = (match["provider"], match["model_id"]) + if key not in seen_ids: + seen_ids.add(key) + results.append(match) + if len(results) >= limit: + return results + + # Add fuzzy matches + for fid in fuzzy_ids: + # Find original-case candidates matching this lowered ID + for prov, mid, mdata in candidates: + if mid.lower() == fid: + key = (prov, mid) + if key not in seen_ids: + seen_ids.add(key) + results.append({"provider": prov, "model_id": mid, "entry": mdata}) + if len(results) >= limit: + return results + + return results + + +# --------------------------------------------------------------------------- +# Rich dataclass constructors — parse raw models.dev JSON into dataclasses +# --------------------------------------------------------------------------- + +def _parse_model_info(model_id: str, raw: Dict[str, Any], provider_id: str) -> ModelInfo: + """Convert a raw models.dev model entry dict into a ModelInfo dataclass.""" + limit = raw.get("limit") or {} + if not isinstance(limit, dict): + limit = {} + + cost = raw.get("cost") or {} + if not isinstance(cost, dict): + cost = {} + + modalities = raw.get("modalities") or {} + if not isinstance(modalities, dict): + modalities = {} + + input_mods = modalities.get("input") or [] + output_mods = modalities.get("output") or [] + + ctx = limit.get("context") + ctx_int = int(ctx) if isinstance(ctx, (int, float)) and ctx > 0 else 0 + out = limit.get("output") + out_int = int(out) if isinstance(out, (int, float)) and out > 0 else 0 + inp = limit.get("input") + inp_int = int(inp) if isinstance(inp, (int, float)) and inp > 0 else None + + return ModelInfo( + id=model_id, + name=raw.get("name", "") or model_id, + family=raw.get("family", "") or "", + provider_id=provider_id, + reasoning=bool(raw.get("reasoning", False)), + tool_call=bool(raw.get("tool_call", False)), + attachment=bool(raw.get("attachment", False)), + temperature=bool(raw.get("temperature", False)), + structured_output=bool(raw.get("structured_output", False)), + open_weights=bool(raw.get("open_weights", False)), + input_modalities=tuple(input_mods) if isinstance(input_mods, list) else (), + output_modalities=tuple(output_mods) if isinstance(output_mods, list) else (), + context_window=ctx_int, + max_output=out_int, + max_input=inp_int, + cost_input=float(cost.get("input", 0) or 0), + cost_output=float(cost.get("output", 0) or 0), + cost_cache_read=float(cost["cache_read"]) if "cache_read" in cost and cost["cache_read"] is not None else None, + cost_cache_write=float(cost["cache_write"]) if "cache_write" in cost and cost["cache_write"] is not None else None, + knowledge_cutoff=raw.get("knowledge", "") or "", + release_date=raw.get("release_date", "") or "", + status=raw.get("status", "") or "", + interleaved=raw.get("interleaved", False), + ) + + +def _parse_provider_info(provider_id: str, raw: Dict[str, Any]) -> ProviderInfo: + """Convert a raw models.dev provider entry dict into a ProviderInfo.""" + env = raw.get("env") or [] + models = raw.get("models") or {} + return ProviderInfo( + id=provider_id, + name=raw.get("name", "") or provider_id, + env=tuple(env) if isinstance(env, list) else (), + api=raw.get("api", "") or "", + doc=raw.get("doc", "") or "", + model_count=len(models) if isinstance(models, dict) else 0, + ) + + +# --------------------------------------------------------------------------- +# Provider-level queries +# --------------------------------------------------------------------------- + +def get_provider_info(provider_id: str) -> Optional[ProviderInfo]: + """Get full provider metadata from models.dev. + + Accepts either a Hermes provider ID (e.g. "kilocode") or a models.dev + ID (e.g. "kilo"). Returns None if the provider is not in the catalog. + """ + # Resolve Hermes ID → models.dev ID + mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id) + + data = fetch_models_dev() + raw = data.get(mdev_id) + if not isinstance(raw, dict): + return None + + return _parse_provider_info(mdev_id, raw) + + +def list_all_providers() -> Dict[str, ProviderInfo]: + """Return all providers from models.dev as {provider_id: ProviderInfo}. + + Returns the full catalog — 109+ providers. For providers that have + a Hermes alias, both the models.dev ID and the Hermes ID are included. + """ + data = fetch_models_dev() + result: Dict[str, ProviderInfo] = {} + + for pid, pdata in data.items(): + if isinstance(pdata, dict): + info = _parse_provider_info(pid, pdata) + result[pid] = info + + return result + + +def get_providers_for_env_var(env_var: str) -> List[str]: + """Reverse lookup: find all providers that use a given env var. + + Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which + providers does that enable?" + + Returns list of models.dev provider IDs. + """ + data = fetch_models_dev() + matches: List[str] = [] + + for pid, pdata in data.items(): + if isinstance(pdata, dict): + env = pdata.get("env", []) + if isinstance(env, list) and env_var in env: + matches.append(pid) + + return matches + + +# --------------------------------------------------------------------------- +# Model-level queries (rich ModelInfo) +# --------------------------------------------------------------------------- + +def get_model_info( + provider_id: str, model_id: str +) -> Optional[ModelInfo]: + """Get full model metadata from models.dev. + + Accepts Hermes or models.dev provider ID. Tries exact match then + case-insensitive fallback. Returns None if not found. + """ + mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id) + + data = fetch_models_dev() + pdata = data.get(mdev_id) + if not isinstance(pdata, dict): + return None + + models = pdata.get("models", {}) + if not isinstance(models, dict): + return None + + # Exact match + raw = models.get(model_id) + if isinstance(raw, dict): + return _parse_model_info(model_id, raw, mdev_id) + + # Case-insensitive fallback + model_lower = model_id.lower() + for mid, mdata in models.items(): + if mid.lower() == model_lower and isinstance(mdata, dict): + return _parse_model_info(mid, mdata, mdev_id) + + return None + + +def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]: + """Search all providers for a model by ID. + + Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or + a bare name and want to find it anywhere. Checks Hermes-mapped providers + first, then falls back to all models.dev providers. + """ + data = fetch_models_dev() + + # Try Hermes-mapped providers first (more likely what the user wants) + for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items(): + pdata = data.get(mdev_id) + if not isinstance(pdata, dict): + continue + models = pdata.get("models", {}) + if not isinstance(models, dict): + continue + + raw = models.get(model_id) + if isinstance(raw, dict): + return _parse_model_info(model_id, raw, mdev_id) + + # Case-insensitive + model_lower = model_id.lower() + for mid, mdata in models.items(): + if mid.lower() == model_lower and isinstance(mdata, dict): + return _parse_model_info(mid, mdata, mdev_id) + + # Fall back to ALL providers + for pid, pdata in data.items(): + if pid in _get_reverse_mapping(): + continue # already checked + if not isinstance(pdata, dict): + continue + models = pdata.get("models", {}) + if not isinstance(models, dict): + continue + + raw = models.get(model_id) + if isinstance(raw, dict): + return _parse_model_info(model_id, raw, pid) + + return None + + +def list_provider_model_infos(provider_id: str) -> List[ModelInfo]: + """Return all models for a provider as ModelInfo objects. + + Filters out deprecated models by default. + """ + mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id) + + data = fetch_models_dev() + pdata = data.get(mdev_id) + if not isinstance(pdata, dict): + return [] + + models = pdata.get("models", {}) + if not isinstance(models, dict): + return [] + + result: List[ModelInfo] = [] + for mid, mdata in models.items(): + if not isinstance(mdata, dict): + continue + status = mdata.get("status", "") + if status == "deprecated": + continue + result.append(_parse_model_info(mid, mdata, mdev_id)) + + return result diff --git a/cli.py b/cli.py index de21d81e50..5802a31e2f 100644 --- a/cli.py +++ b/cli.py @@ -3519,6 +3519,167 @@ class HermesCLI: remaining = len(self.conversation_history) print(f" {remaining} message(s) remaining in history.") + def _handle_model_switch(self, cmd_original: str): + """Handle /model command — switch model for this session. + + Supports: + /model — show current model + usage hints + /model — switch for this session only + /model --global — switch and persist to config.yaml + /model --provider — switch provider + model + /model --provider — switch to provider, auto-detect model + """ + from hermes_cli.model_switch import switch_model, parse_model_flags, list_authenticated_providers + from hermes_cli.providers import get_label + + # Parse args from the original command + parts = cmd_original.split(None, 1) # split off '/model' + raw_args = parts[1].strip() if len(parts) > 1 else "" + + # Parse --provider and --global flags + model_input, explicit_provider, persist_global = parse_model_flags(raw_args) + + # No args at all: show available providers + models + if not model_input and not explicit_provider: + model_display = self.model or "unknown" + provider_display = get_label(self.provider) if self.provider else "unknown" + _cprint(f" Current: {model_display} on {provider_display}") + _cprint("") + + # Show authenticated providers with top models + try: + # Load user providers from config + user_provs = None + try: + from hermes_cli.config import load_config + cfg = load_config() + user_provs = cfg.get("providers") + except Exception: + pass + + providers = list_authenticated_providers( + current_provider=self.provider or "", + user_providers=user_provs, + max_models=6, + ) + if providers: + for p in providers: + tag = " (current)" if p["is_current"] else "" + _cprint(f" {p['name']} [--provider {p['slug']}]{tag}:") + if p["models"]: + model_strs = ", ".join(p["models"]) + extra = f" (+{p['total_models'] - len(p['models'])} more)" if p["total_models"] > len(p["models"]) else "" + _cprint(f" {model_strs}{extra}") + elif p.get("api_url"): + _cprint(f" {p['api_url']} (use /model --provider {p['slug']})") + else: + _cprint(f" (no models listed)") + _cprint("") + else: + _cprint(" No authenticated providers found.") + _cprint("") + except Exception: + pass + + # Aliases + from hermes_cli.model_switch import MODEL_ALIASES + alias_list = ", ".join(sorted(MODEL_ALIASES.keys())) + _cprint(f" Aliases: {alias_list}") + _cprint("") + _cprint(" /model switch model") + _cprint(" /model --provider switch provider") + _cprint(" /model --global persist to config") + return + + # Perform the switch + result = switch_model( + raw_input=model_input, + current_provider=self.provider or "", + current_model=self.model or "", + current_base_url=self.base_url or "", + current_api_key=self.api_key or "", + is_global=persist_global, + explicit_provider=explicit_provider, + ) + + if not result.success: + _cprint(f" ✗ {result.error_message}") + return + + # Apply to CLI state + old_model = self.model + self.model = result.new_model + self.provider = result.target_provider + if result.api_key: + self.api_key = result.api_key + if result.base_url: + self.base_url = result.base_url + if result.api_mode: + self.api_mode = result.api_mode + + # Apply to running agent (in-place swap) + if self.agent is not None: + try: + self.agent.switch_model( + new_model=result.new_model, + new_provider=result.target_provider, + api_key=result.api_key, + base_url=result.base_url, + api_mode=result.api_mode, + ) + except Exception as exc: + _cprint(f" ⚠ Agent swap failed ({exc}); change applied to next session.") + + # Display confirmation with full metadata + provider_label = result.provider_label or result.target_provider + _cprint(f" ✓ Model switched: {result.new_model}") + _cprint(f" Provider: {provider_label}") + + # Rich metadata from models.dev + mi = result.model_info + if mi: + if mi.context_window: + _cprint(f" Context: {mi.context_window:,} tokens") + if mi.max_output: + _cprint(f" Max output: {mi.max_output:,} tokens") + if mi.has_cost_data(): + _cprint(f" Cost: {mi.format_cost()}") + _cprint(f" Capabilities: {mi.format_capabilities()}") + else: + # Fallback to old context length lookup + try: + from agent.model_metadata import get_model_context_length + ctx = get_model_context_length( + result.new_model, + base_url=result.base_url or self.base_url, + api_key=result.api_key or self.api_key, + provider=result.target_provider, + ) + _cprint(f" Context: {ctx:,} tokens") + except Exception: + pass + + # Cache notice + cache_enabled = ( + ("openrouter" in (result.base_url or "").lower() and "claude" in result.new_model.lower()) + or result.api_mode == "anthropic_messages" + ) + if cache_enabled: + _cprint(" Prompt caching: enabled") + + # Warning from validation + if result.warning_message: + _cprint(f" ⚠ {result.warning_message}") + + # Persistence + if persist_global: + save_config_value("model.name", result.new_model) + if result.provider_changed: + save_config_value("model.provider", result.target_provider) + _cprint(" Saved to config.yaml (--global)") + else: + _cprint(" (session only — add --global to persist)") + def _show_model_and_providers(self): """Show current model + provider and list all authenticated providers. @@ -4134,6 +4295,8 @@ class HermesCLI: self.new_session() elif canonical == "resume": self._handle_resume_command(cmd_original) + elif canonical == "model": + self._handle_model_switch(cmd_original) elif canonical == "provider": self._show_model_and_providers() elif canonical == "prompt": diff --git a/gateway/run.py b/gateway/run.py index 3c1c230163..0db0514ea1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -509,6 +509,9 @@ class GatewayRunner: self._effective_model: Optional[str] = None self._effective_provider: Optional[str] = None + # Per-session model overrides from /model command. + # Key: session_key, Value: dict with model/provider/api_key/base_url/api_mode + self._session_model_overrides: Dict[str, Dict[str, str]] = {} # Track pending exec approvals per session # Key: session_key, Value: {"command": str, "pattern_key": str, ...} self._pending_approvals: Dict[str, Dict[str, Any]] = {} @@ -1859,6 +1862,10 @@ class GatewayRunner: adapter._pending_messages[_quick_key] = queued_event return "Queued for the next turn." + # /model must not be used while the agent is running. + if _cmd_def_inner and _cmd_def_inner.name == "model": + return "Agent is running — wait or /stop first, then switch models." + # /approve and /deny must bypass the running-agent interrupt path. # The agent thread is blocked on a threading.Event inside # tools/approval.py — sending an interrupt won't unblock it. @@ -1958,6 +1965,9 @@ class GatewayRunner: if canonical == "yolo": return await self._handle_yolo_command(event) + if canonical == "model": + return await self._handle_model_command(event) + if canonical == "provider": return await self._handle_provider_command(event) @@ -3268,6 +3278,196 @@ class GatewayRunner: lines.append(f"_(Requested page {requested_page} was out of range, showing page {page}.)_") return "\n".join(lines) + async def _handle_model_command(self, event: MessageEvent) -> str: + """Handle /model command — switch model for this session. + + Supports: + /model — show current model info + /model — switch for this session only + /model --global — switch and persist to config.yaml + /model --provider — switch provider + model + /model --provider — switch to provider, auto-detect model + """ + import yaml + from hermes_cli.model_switch import ( + switch_model as _switch_model, parse_model_flags, + list_authenticated_providers, + ) + from hermes_cli.providers import get_label + + raw_args = event.get_command_args().strip() + + # Parse --provider and --global flags + model_input, explicit_provider, persist_global = parse_model_flags(raw_args) + + # Read current model/provider from config + current_model = "" + current_provider = "openrouter" + current_base_url = "" + current_api_key = "" + user_provs = None + config_path = _hermes_home / "config.yaml" + try: + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + model_cfg = cfg.get("model", {}) + if isinstance(model_cfg, dict): + current_model = model_cfg.get("name", "") + current_provider = model_cfg.get("provider", current_provider) + current_base_url = model_cfg.get("base_url", "") + user_provs = cfg.get("providers") + except Exception: + pass + + # Check for session override + source = event.source + session_key = self._session_key_for_source(source) + override = getattr(self, "_session_model_overrides", {}).get(session_key, {}) + if override: + current_model = override.get("model", current_model) + current_provider = override.get("provider", current_provider) + current_base_url = override.get("base_url", current_base_url) + current_api_key = override.get("api_key", current_api_key) + + # No args: show authenticated providers with models + if not model_input and not explicit_provider: + provider_label = get_label(current_provider) + lines = [f"Current: `{current_model or 'unknown'}` on {provider_label}", ""] + + try: + providers = list_authenticated_providers( + current_provider=current_provider, + user_providers=user_provs, + max_models=5, + ) + for p in providers: + tag = " (current)" if p["is_current"] else "" + lines.append(f"**{p['name']}** `--provider {p['slug']}`{tag}:") + if p["models"]: + model_strs = ", ".join(f"`{m}`" for m in p["models"]) + extra = f" (+{p['total_models'] - len(p['models'])} more)" if p["total_models"] > len(p["models"]) else "" + lines.append(f" {model_strs}{extra}") + elif p.get("api_url"): + lines.append(f" `{p['api_url']}`") + lines.append("") + except Exception: + pass + + lines.append("`/model ` — switch model") + lines.append("`/model --provider ` — switch provider") + lines.append("`/model --global` — persist") + return "\n".join(lines) + + # Perform the switch + result = _switch_model( + raw_input=model_input, + current_provider=current_provider, + current_model=current_model, + current_base_url=current_base_url, + current_api_key=current_api_key, + is_global=persist_global, + explicit_provider=explicit_provider, + ) + + if not result.success: + return f"Error: {result.error_message}" + + # If there's a cached agent, update it in-place + cached_entry = None + _cache_lock = getattr(self, "_agent_cache_lock", None) + _cache = getattr(self, "_agent_cache", None) + if _cache_lock and _cache is not None: + with _cache_lock: + cached_entry = _cache.get(session_key) + + if cached_entry and cached_entry[0] is not None: + try: + cached_entry[0].switch_model( + new_model=result.new_model, + new_provider=result.target_provider, + api_key=result.api_key, + base_url=result.base_url, + api_mode=result.api_mode, + ) + except Exception as exc: + logger.warning("In-place model switch failed for cached agent: %s", exc) + + # Store session override so next agent creation uses the new model + if not hasattr(self, "_session_model_overrides"): + self._session_model_overrides = {} + self._session_model_overrides[session_key] = { + "model": result.new_model, + "provider": result.target_provider, + "api_key": result.api_key, + "base_url": result.base_url, + "api_mode": result.api_mode, + } + + # Persist to config if --global + if persist_global: + try: + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + cfg = yaml.safe_load(f) or {} + else: + cfg = {} + model_cfg = cfg.setdefault("model", {}) + model_cfg["name"] = result.new_model + model_cfg["provider"] = result.target_provider + if result.base_url: + model_cfg["base_url"] = result.base_url + from hermes_cli.config import save_config + save_config(cfg) + except Exception as e: + logger.warning("Failed to persist model switch: %s", e) + + # Build confirmation message with full metadata + provider_label = result.provider_label or result.target_provider + lines = [f"Model switched to `{result.new_model}`"] + lines.append(f"Provider: {provider_label}") + + # Rich metadata from models.dev + mi = result.model_info + if mi: + if mi.context_window: + lines.append(f"Context: {mi.context_window:,} tokens") + if mi.max_output: + lines.append(f"Max output: {mi.max_output:,} tokens") + if mi.has_cost_data(): + lines.append(f"Cost: {mi.format_cost()}") + lines.append(f"Capabilities: {mi.format_capabilities()}") + else: + try: + from agent.model_metadata import get_model_context_length + ctx = get_model_context_length( + result.new_model, + base_url=result.base_url or current_base_url, + api_key=result.api_key or current_api_key, + provider=result.target_provider, + ) + lines.append(f"Context: {ctx:,} tokens") + except Exception: + pass + + # Cache notice + cache_enabled = ( + ("openrouter" in (result.base_url or "").lower() and "claude" in result.new_model.lower()) + or result.api_mode == "anthropic_messages" + ) + if cache_enabled: + lines.append("Prompt caching: enabled") + + if result.warning_message: + lines.append(f"Warning: {result.warning_message}") + + if persist_global: + lines.append("Saved to config.yaml (`--global`)") + else: + lines.append("_(session only -- add `--global` to persist)_") + + return "\n".join(lines) + async def _handle_provider_command(self, event: MessageEvent) -> str: """Handle /provider command - show available providers.""" import yaml diff --git a/hermes_cli/commands.py b/hermes_cli/commands.py index 07a8f5e1eb..782d52250f 100644 --- a/hermes_cli/commands.py +++ b/hermes_cli/commands.py @@ -84,6 +84,7 @@ COMMAND_REGISTRY: list[CommandDef] = [ # Configuration CommandDef("config", "Show current configuration", "Configuration", cli_only=True), + CommandDef("model", "Switch model for this session", "Configuration", args_hint="[model] [--global]"), CommandDef("provider", "Show available providers and current provider", "Configuration"), CommandDef("prompt", "View/set custom system prompt", "Configuration", diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 00d0923d2f..1871bc9166 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -199,6 +199,7 @@ def ensure_hermes_home(): DEFAULT_CONFIG = { "model": "", + "providers": {}, "fallback_providers": [], "credential_pool_strategies": {}, "toolsets": ["hermes-cli"], @@ -531,7 +532,7 @@ DEFAULT_CONFIG = { }, # Config schema version - bump this when adding new required fields - "_config_version": 11, + "_config_version": 12, } # ============================================================================= @@ -1312,6 +1313,69 @@ def migrate_config(interactive: bool = True, quiet: bool = False) -> Dict[str, A except Exception: pass + # ── Version 11 → 12: migrate custom_providers list → providers dict ── + if current_ver < 12: + config = load_config() + custom_list = config.get("custom_providers") + if isinstance(custom_list, list) and custom_list: + providers_dict = config.get("providers", {}) + if not isinstance(providers_dict, dict): + providers_dict = {} + migrated_count = 0 + for entry in custom_list: + if not isinstance(entry, dict): + continue + old_name = entry.get("name", "") + old_url = entry.get("base_url", "") or entry.get("url", "") or "" + old_key = entry.get("api_key", "") + if not old_url: + continue # skip entries with no URL + + # Generate a kebab-case key from the display name + key = old_name.strip().lower().replace(" ", "-").replace("(", "").replace(")", "") + # Remove consecutive hyphens and trailing hyphens + while "--" in key: + key = key.replace("--", "-") + key = key.strip("-") + if not key: + # Fallback: derive from URL hostname + try: + from urllib.parse import urlparse + parsed = urlparse(old_url) + key = (parsed.hostname or "endpoint").replace(".", "-") + except Exception: + key = f"endpoint-{migrated_count}" + + # Don't overwrite existing entries + if key in providers_dict: + key = f"{key}-{migrated_count}" + + new_entry = {"api": old_url} + if old_name: + new_entry["name"] = old_name + if old_key and old_key not in ("no-key", "no-key-required", ""): + new_entry["api_key"] = old_key + + # Carry over model and api_mode if present + if entry.get("model"): + new_entry["default_model"] = entry["model"] + if entry.get("api_mode"): + new_entry["transport"] = entry["api_mode"] + + providers_dict[key] = new_entry + migrated_count += 1 + + if migrated_count > 0: + config["providers"] = providers_dict + # Remove the old list + del config["custom_providers"] + save_config(config) + if not quiet: + print(f" ✓ Migrated {migrated_count} custom provider(s) to providers: section") + for key in list(providers_dict.keys())[-migrated_count:]: + ep = providers_dict[key] + print(f" → {key}: {ep.get('api', '')}") + if current_ver < latest_ver and not quiet: print(f"Config version: {current_ver} → {latest_ver}") diff --git a/hermes_cli/model_normalize.py b/hermes_cli/model_normalize.py new file mode 100644 index 0000000000..e362d44e21 --- /dev/null +++ b/hermes_cli/model_normalize.py @@ -0,0 +1,359 @@ +"""Per-provider model name normalization. + +Different LLM providers expect model identifiers in different formats: + +- **Aggregators** (OpenRouter, Nous, AI Gateway, Kilo Code) need + ``vendor/model`` slugs like ``anthropic/claude-sonnet-4.6``. +- **Anthropic** native API expects bare names with dots replaced by + hyphens: ``claude-sonnet-4-6``. +- **Copilot** expects bare names *with* dots preserved: + ``claude-sonnet-4.6``. +- **OpenCode** (Zen & Go) follows the same dot-to-hyphen convention as + Anthropic: ``claude-sonnet-4-6``. +- **DeepSeek** only accepts two model identifiers: + ``deepseek-chat`` and ``deepseek-reasoner``. +- **Custom** and remaining providers pass the name through as-is. + +This module centralises that translation so callers can simply write:: + + api_model = normalize_model_for_provider(user_input, provider) + +Inspired by Clawdbot's ``normalizeAnthropicModelId`` pattern. +""" + +from __future__ import annotations + +from typing import Optional + +# --------------------------------------------------------------------------- +# Vendor prefix mapping +# --------------------------------------------------------------------------- +# Maps the first hyphen-delimited token of a bare model name to the vendor +# slug used by aggregator APIs (OpenRouter, Nous, etc.). +# +# Example: "claude-sonnet-4.6" -> first token "claude" -> vendor "anthropic" +# -> aggregator slug: "anthropic/claude-sonnet-4.6" + +_VENDOR_PREFIXES: dict[str, str] = { + "claude": "anthropic", + "gpt": "openai", + "o1": "openai", + "o3": "openai", + "o4": "openai", + "gemini": "google", + "deepseek": "deepseek", + "glm": "z-ai", + "kimi": "moonshotai", + "minimax": "minimax", + "grok": "x-ai", + "qwen": "qwen", + "mimo": "xiaomi", + "nemotron": "nvidia", + "llama": "meta-llama", + "step": "stepfun", + "trinity": "arcee-ai", +} + +# Providers whose APIs consume vendor/model slugs. +_AGGREGATOR_PROVIDERS: frozenset[str] = frozenset({ + "openrouter", + "nous", + "ai-gateway", + "kilocode", +}) + +# Providers that want bare names with dots replaced by hyphens. +_DOT_TO_HYPHEN_PROVIDERS: frozenset[str] = frozenset({ + "anthropic", + "opencode-zen", + "opencode-go", +}) + +# Providers that want bare names with dots preserved. +_STRIP_VENDOR_ONLY_PROVIDERS: frozenset[str] = frozenset({ + "copilot", + "copilot-acp", +}) + +# Providers whose own naming is authoritative -- pass through unchanged. +_PASSTHROUGH_PROVIDERS: frozenset[str] = frozenset({ + "zai", + "kimi-coding", + "minimax", + "minimax-cn", + "alibaba", + "huggingface", + "openai-codex", + "custom", +}) + +# --------------------------------------------------------------------------- +# DeepSeek special handling +# --------------------------------------------------------------------------- +# DeepSeek's API only recognises exactly two model identifiers. We map +# common aliases and patterns to the canonical names. + +_DEEPSEEK_REASONER_KEYWORDS: frozenset[str] = frozenset({ + "reasoner", + "r1", + "think", + "reasoning", + "cot", +}) + +_DEEPSEEK_CANONICAL_MODELS: frozenset[str] = frozenset({ + "deepseek-chat", + "deepseek-reasoner", +}) + + +def _normalize_for_deepseek(model_name: str) -> str: + """Map any model input to one of DeepSeek's two accepted identifiers. + + Rules: + - Already ``deepseek-chat`` or ``deepseek-reasoner`` -> pass through. + - Contains any reasoner keyword (r1, think, reasoning, cot, reasoner) + -> ``deepseek-reasoner``. + - Everything else -> ``deepseek-chat``. + + Args: + model_name: The bare model name (vendor prefix already stripped). + + Returns: + One of ``"deepseek-chat"`` or ``"deepseek-reasoner"``. + """ + bare = _strip_vendor_prefix(model_name).lower() + + if bare in _DEEPSEEK_CANONICAL_MODELS: + return bare + + # Check for reasoner-like keywords anywhere in the name + for keyword in _DEEPSEEK_REASONER_KEYWORDS: + if keyword in bare: + return "deepseek-reasoner" + + return "deepseek-chat" + + +# --------------------------------------------------------------------------- +# Helper utilities +# --------------------------------------------------------------------------- + +def _strip_vendor_prefix(model_name: str) -> str: + """Remove a ``vendor/`` prefix if present. + + Examples:: + + >>> _strip_vendor_prefix("anthropic/claude-sonnet-4.6") + 'claude-sonnet-4.6' + >>> _strip_vendor_prefix("claude-sonnet-4.6") + 'claude-sonnet-4.6' + >>> _strip_vendor_prefix("meta-llama/llama-4-scout") + 'llama-4-scout' + """ + if "/" in model_name: + return model_name.split("/", 1)[1] + return model_name + + +def _dots_to_hyphens(model_name: str) -> str: + """Replace dots with hyphens in a model name. + + Anthropic's native API uses hyphens where marketing names use dots: + ``claude-sonnet-4.6`` -> ``claude-sonnet-4-6``. + """ + return model_name.replace(".", "-") + + +def detect_vendor(model_name: str) -> Optional[str]: + """Detect the vendor slug from a bare model name. + + Uses the first hyphen-delimited token of the model name to look up + the corresponding vendor in ``_VENDOR_PREFIXES``. Also handles + case-insensitive matching and special patterns. + + Args: + model_name: A model name, optionally already including a + ``vendor/`` prefix. If a prefix is present it is used + directly. + + Returns: + The vendor slug (e.g. ``"anthropic"``, ``"openai"``) or ``None`` + if no vendor can be confidently detected. + + Examples:: + + >>> detect_vendor("claude-sonnet-4.6") + 'anthropic' + >>> detect_vendor("gpt-5.4-mini") + 'openai' + >>> detect_vendor("anthropic/claude-sonnet-4.6") + 'anthropic' + >>> detect_vendor("my-custom-model") + """ + name = model_name.strip() + if not name: + return None + + # If there's already a vendor/ prefix, extract it + if "/" in name: + return name.split("/", 1)[0].lower() or None + + name_lower = name.lower() + + # Try first hyphen-delimited token (exact match) + first_token = name_lower.split("-")[0] + if first_token in _VENDOR_PREFIXES: + return _VENDOR_PREFIXES[first_token] + + # Handle patterns where the first token includes version digits, + # e.g. "qwen3.5-plus" -> first token "qwen3.5", but prefix is "qwen" + for prefix, vendor in _VENDOR_PREFIXES.items(): + if name_lower.startswith(prefix): + return vendor + + return None + + +def _prepend_vendor(model_name: str) -> str: + """Prepend the detected ``vendor/`` prefix if missing. + + Used for aggregator providers that require ``vendor/model`` format. + If the name already contains a ``/``, it is returned as-is. + If no vendor can be detected, the name is returned unchanged + (aggregators may still accept it or return an error). + + Examples:: + + >>> _prepend_vendor("claude-sonnet-4.6") + 'anthropic/claude-sonnet-4.6' + >>> _prepend_vendor("anthropic/claude-sonnet-4.6") + 'anthropic/claude-sonnet-4.6' + >>> _prepend_vendor("my-custom-thing") + 'my-custom-thing' + """ + if "/" in model_name: + return model_name + + vendor = detect_vendor(model_name) + if vendor: + return f"{vendor}/{model_name}" + return model_name + + +# --------------------------------------------------------------------------- +# Main normalisation entry point +# --------------------------------------------------------------------------- + +def normalize_model_for_provider(model_input: str, target_provider: str) -> str: + """Translate a model name into the format the target provider's API expects. + + This is the primary entry point for model name normalisation. It + accepts any user-facing model identifier and transforms it for the + specific provider that will receive the API call. + + Args: + model_input: The model name as provided by the user or config. + Can be bare (``"claude-sonnet-4.6"``), vendor-prefixed + (``"anthropic/claude-sonnet-4.6"``), or already in native + format (``"claude-sonnet-4-6"``). + target_provider: The canonical Hermes provider id, e.g. + ``"openrouter"``, ``"anthropic"``, ``"copilot"``, + ``"deepseek"``, ``"custom"``. Should already be normalised + via ``hermes_cli.models.normalize_provider()``. + + Returns: + The model identifier string that the target provider's API + expects. + + Raises: + No exceptions -- always returns a best-effort string. + + Examples:: + + >>> normalize_model_for_provider("claude-sonnet-4.6", "openrouter") + 'anthropic/claude-sonnet-4.6' + + >>> normalize_model_for_provider("anthropic/claude-sonnet-4.6", "anthropic") + 'claude-sonnet-4-6' + + >>> normalize_model_for_provider("anthropic/claude-sonnet-4.6", "copilot") + 'claude-sonnet-4.6' + + >>> normalize_model_for_provider("openai/gpt-5.4", "copilot") + 'gpt-5.4' + + >>> normalize_model_for_provider("claude-sonnet-4.6", "opencode-zen") + 'claude-sonnet-4-6' + + >>> normalize_model_for_provider("deepseek-v3", "deepseek") + 'deepseek-chat' + + >>> normalize_model_for_provider("deepseek-r1", "deepseek") + 'deepseek-reasoner' + + >>> normalize_model_for_provider("my-model", "custom") + 'my-model' + + >>> normalize_model_for_provider("claude-sonnet-4.6", "zai") + 'claude-sonnet-4.6' + """ + name = (model_input or "").strip() + if not name: + return name + + provider = (target_provider or "").strip().lower() + + # --- Aggregators: need vendor/model format --- + if provider in _AGGREGATOR_PROVIDERS: + return _prepend_vendor(name) + + # --- Anthropic / OpenCode: strip vendor, dots -> hyphens --- + if provider in _DOT_TO_HYPHEN_PROVIDERS: + bare = _strip_vendor_prefix(name) + return _dots_to_hyphens(bare) + + # --- Copilot: strip vendor, keep dots --- + if provider in _STRIP_VENDOR_ONLY_PROVIDERS: + return _strip_vendor_prefix(name) + + # --- DeepSeek: map to one of two canonical names --- + if provider == "deepseek": + return _normalize_for_deepseek(name) + + # --- Custom & all others: pass through as-is --- + return name + + +# --------------------------------------------------------------------------- +# Batch / convenience helpers +# --------------------------------------------------------------------------- + +def model_display_name(model_id: str) -> str: + """Return a short, human-readable display name for a model id. + + Strips the vendor prefix (if any) for a cleaner display in menus + and status bars, while preserving dots for readability. + + Examples:: + + >>> model_display_name("anthropic/claude-sonnet-4.6") + 'claude-sonnet-4.6' + >>> model_display_name("claude-sonnet-4-6") + 'claude-sonnet-4-6' + """ + return _strip_vendor_prefix((model_id or "").strip()) + + +def is_aggregator_provider(provider: str) -> bool: + """Check if a provider is an aggregator that needs vendor/model format.""" + return (provider or "").strip().lower() in _AGGREGATOR_PROVIDERS + + +def vendor_for_model(model_name: str) -> str: + """Return the vendor slug for a model, or ``""`` if unknown. + + Convenience wrapper around :func:`detect_vendor` that never returns + ``None``. + """ + return detect_vendor(model_name) or "" diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py index ae4de86a5b..9534f37656 100644 --- a/hermes_cli/model_switch.py +++ b/hermes_cli/model_switch.py @@ -3,18 +3,120 @@ Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers share the same core pipeline: - parse_model_input → is_custom detection → auto-detect provider - → credential resolution → validate model → return result + parse flags -> alias resolution -> provider resolution -> + credential resolution -> normalize model name -> + metadata lookup -> build result -This module extracts that shared pipeline into pure functions that -return result objects. The callers handle all platform-specific -concerns: state mutation, config persistence, output formatting. +This module ties together the foundation layers: + +- ``agent.models_dev`` -- models.dev catalog, ModelInfo, ProviderInfo +- ``hermes_cli.providers`` -- canonical provider identity + overlays +- ``hermes_cli.model_normalize`` -- per-provider name formatting + +Provider switching uses the ``--provider`` flag exclusively. +No colon-based ``provider:model`` syntax — colons are reserved for +OpenRouter variant suffixes (``:free``, ``:extended``, ``:fast``). """ from __future__ import annotations -from dataclasses import dataclass +import logging +from dataclasses import dataclass, field +from typing import List, NamedTuple, Optional +from hermes_cli.providers import ( + ALIASES, + LABELS, + TRANSPORT_TO_API_MODE, + determine_api_mode, + get_label, + get_provider, + is_aggregator, + normalize_provider, + resolve_provider_full, +) +from hermes_cli.model_normalize import ( + detect_vendor, + normalize_model_for_provider, +) +from agent.models_dev import ( + ModelCapabilities, + ModelInfo, + get_model_capabilities, + get_model_info, + list_provider_models, + search_models_dev, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Model aliases -- short names -> (vendor, family) with NO version numbers. +# Resolved dynamically against the live models.dev catalog. +# --------------------------------------------------------------------------- + +class ModelIdentity(NamedTuple): + """Vendor slug and family prefix used for catalog resolution.""" + vendor: str + family: str + + +MODEL_ALIASES: dict[str, ModelIdentity] = { + # Anthropic + "sonnet": ModelIdentity("anthropic", "claude-sonnet"), + "opus": ModelIdentity("anthropic", "claude-opus"), + "haiku": ModelIdentity("anthropic", "claude-haiku"), + "claude": ModelIdentity("anthropic", "claude"), + + # OpenAI + "gpt5": ModelIdentity("openai", "gpt-5"), + "gpt": ModelIdentity("openai", "gpt"), + "codex": ModelIdentity("openai", "codex"), + "o3": ModelIdentity("openai", "o3"), + "o4": ModelIdentity("openai", "o4"), + + # Google + "gemini": ModelIdentity("google", "gemini"), + + # DeepSeek + "deepseek": ModelIdentity("deepseek", "deepseek-chat"), + + # X.AI + "grok": ModelIdentity("x-ai", "grok"), + + # Meta + "llama": ModelIdentity("meta-llama", "llama"), + + # Qwen / Alibaba + "qwen": ModelIdentity("qwen", "qwen"), + + # MiniMax + "minimax": ModelIdentity("minimax", "minimax"), + + # Nvidia + "nemotron": ModelIdentity("nvidia", "nemotron"), + + # Moonshot / Kimi + "kimi": ModelIdentity("moonshotai", "kimi"), + + # Z.AI / GLM + "glm": ModelIdentity("z-ai", "glm"), + + # StepFun + "step": ModelIdentity("stepfun", "step"), + + # Xiaomi + "mimo": ModelIdentity("xiaomi", "mimo"), + + # Arcee + "trinity": ModelIdentity("arcee-ai", "trinity"), +} + + +# --------------------------------------------------------------------------- +# Result dataclasses +# --------------------------------------------------------------------------- @dataclass class ModelSwitchResult: @@ -27,11 +129,13 @@ class ModelSwitchResult: api_key: str = "" base_url: str = "" api_mode: str = "" - persist: bool = False error_message: str = "" warning_message: str = "" - is_custom_target: bool = False provider_label: str = "" + resolved_via_alias: str = "" + capabilities: Optional[ModelCapabilities] = None + model_info: Optional[ModelInfo] = None + is_global: bool = False @dataclass @@ -45,91 +149,336 @@ class CustomAutoResult: error_message: str = "" +# --------------------------------------------------------------------------- +# Flag parsing +# --------------------------------------------------------------------------- + +def parse_model_flags(raw_args: str) -> tuple[str, str, bool]: + """Parse --provider and --global flags from /model command args. + + Returns (model_input, explicit_provider, is_global). + + Examples:: + + "sonnet" -> ("sonnet", "", False) + "sonnet --global" -> ("sonnet", "", True) + "sonnet --provider anthropic" -> ("sonnet", "anthropic", False) + "--provider my-ollama" -> ("", "my-ollama", False) + "sonnet --provider anthropic --global" -> ("sonnet", "anthropic", True) + """ + is_global = False + explicit_provider = "" + + # Extract --global + if "--global" in raw_args: + is_global = True + raw_args = raw_args.replace("--global", "").strip() + + # Extract --provider + parts = raw_args.split() + i = 0 + filtered: list[str] = [] + while i < len(parts): + if parts[i] == "--provider" and i + 1 < len(parts): + explicit_provider = parts[i + 1] + i += 2 + else: + filtered.append(parts[i]) + i += 1 + + model_input = " ".join(filtered).strip() + return (model_input, explicit_provider, is_global) + + +# --------------------------------------------------------------------------- +# Alias resolution +# --------------------------------------------------------------------------- + +def resolve_alias( + raw_input: str, + current_provider: str, +) -> Optional[tuple[str, str, str]]: + """Resolve a short alias against the current provider's catalog. + + Looks up *raw_input* in :data:`MODEL_ALIASES`, then searches the + current provider's models.dev catalog for the first model whose ID + starts with ``vendor/family`` (or just ``family`` for non-aggregator + providers). + + Returns: + ``(provider, resolved_model_id, alias_name)`` if a match is + found on the current provider, or ``None`` if the alias doesn't + exist or no matching model is available. + """ + key = raw_input.strip().lower() + identity = MODEL_ALIASES.get(key) + if identity is None: + return None + + vendor, family = identity + + # Search the provider's catalog from models.dev + catalog = list_provider_models(current_provider) + if not catalog: + return None + + # For aggregators, models are vendor/model-name format + aggregator = is_aggregator(current_provider) + + for model_id in catalog: + mid_lower = model_id.lower() + if aggregator: + # Match vendor/family prefix -- e.g. "anthropic/claude-sonnet" + prefix = f"{vendor}/{family}".lower() + if mid_lower.startswith(prefix): + return (current_provider, model_id, key) + else: + # Non-aggregator: bare names -- e.g. "claude-sonnet-4-6" + family_lower = family.lower() + if mid_lower.startswith(family_lower): + return (current_provider, model_id, key) + + return None + + +def _resolve_alias_fallback( + raw_input: str, + fallback_providers: tuple[str, ...] = ("openrouter", "nous"), +) -> Optional[tuple[str, str, str]]: + """Try to resolve an alias on fallback providers.""" + for provider in fallback_providers: + result = resolve_alias(raw_input, provider) + if result is not None: + return result + return None + + +# --------------------------------------------------------------------------- +# Core model-switching pipeline +# --------------------------------------------------------------------------- + def switch_model( raw_input: str, current_provider: str, + current_model: str, current_base_url: str = "", current_api_key: str = "", + is_global: bool = False, + explicit_provider: str = "", + user_providers: dict = None, ) -> ModelSwitchResult: """Core model-switching pipeline shared between CLI and gateway. - Handles parsing, provider detection, credential resolution, and - model validation. Does NOT handle config persistence, state - mutation, or output formatting — those are caller responsibilities. + Resolution chain: + + If --provider given: + a. Resolve provider via resolve_provider_full() + b. Resolve credentials + c. If model given, resolve alias on target provider or use as-is + d. If no model, auto-detect from endpoint + + If no --provider: + a. Try alias resolution on current provider + b. If alias exists but not on current provider -> fallback + c. On aggregator, try vendor/model slug conversion + d. Aggregator catalog search + e. detect_provider_for_model() as last resort + f. Resolve credentials + g. Normalize model name for target provider + + Finally: + h. Get full model metadata from models.dev + i. Build result Args: - raw_input: The user's model input (e.g. "claude-sonnet-4", - "zai:glm-5", "custom:local:qwen"). + raw_input: The model name (after flag parsing). current_provider: The currently active provider. - current_base_url: The currently active base URL (used for - is_custom detection). + current_model: The currently active model name. + current_base_url: The currently active base URL. current_api_key: The currently active API key. + is_global: Whether to persist the switch. + explicit_provider: From --provider flag (empty = no explicit provider). + user_providers: The ``providers:`` dict from config.yaml (for user endpoints). Returns: - ModelSwitchResult with all information the caller needs to - apply the switch and format output. + ModelSwitchResult with all information the caller needs. """ from hermes_cli.models import ( - parse_model_input, detect_provider_for_model, validate_requested_model, - _PROVIDER_LABELS, opencode_model_api_mode, ) from hermes_cli.runtime_provider import resolve_runtime_provider - # Step 1: Parse provider:model syntax - target_provider, new_model = parse_model_input(raw_input, current_provider) + resolved_alias = "" + new_model = raw_input.strip() + target_provider = current_provider - # Step 2: Detect if we're currently on a custom endpoint - _base = current_base_url or "" - is_custom = current_provider == "custom" or ( - "localhost" in _base or "127.0.0.1" in _base - ) + # ================================================================= + # PATH A: Explicit --provider given + # ================================================================= + if explicit_provider: + # Resolve the provider + pdef = resolve_provider_full(explicit_provider, user_providers) + if pdef is None: + return ModelSwitchResult( + success=False, + is_global=is_global, + error_message=( + f"Unknown provider '{explicit_provider}'. " + f"Check 'hermes model' for available providers, or define it " + f"in config.yaml under 'providers:'." + ), + ) - # Step 3: Auto-detect provider when no explicit provider:model syntax - # was used. Skip for custom providers — the model name might - # coincidentally match a known provider's catalog. - if target_provider == current_provider and not is_custom: - detected = detect_provider_for_model(new_model, current_provider) - if detected: - target_provider, new_model = detected + target_provider = pdef.id + + # If no model specified, try auto-detect from endpoint + if not new_model: + if pdef.base_url: + from hermes_cli.runtime_provider import _auto_detect_local_model + detected = _auto_detect_local_model(pdef.base_url) + if detected: + new_model = detected + else: + return ModelSwitchResult( + success=False, + target_provider=target_provider, + provider_label=pdef.name, + is_global=is_global, + error_message=( + f"No model detected on {pdef.name} ({pdef.base_url}). " + f"Specify the model explicitly: /model --provider {explicit_provider}" + ), + ) + else: + return ModelSwitchResult( + success=False, + target_provider=target_provider, + provider_label=pdef.name, + is_global=is_global, + error_message=( + f"Provider '{pdef.name}' has no base URL configured. " + f"Specify a model: /model --provider {explicit_provider}" + ), + ) + + # Resolve alias on the TARGET provider + alias_result = resolve_alias(new_model, target_provider) + if alias_result is not None: + _, new_model, resolved_alias = alias_result + + # ================================================================= + # PATH B: No explicit provider — resolve from model input + # ================================================================= + else: + # --- Step a: Try alias resolution on current provider --- + alias_result = resolve_alias(raw_input, current_provider) + + if alias_result is not None: + target_provider, new_model, resolved_alias = alias_result + logger.debug( + "Alias '%s' resolved to %s on %s", + resolved_alias, new_model, target_provider, + ) + else: + # --- Step b: Alias exists but not on current provider -> fallback --- + key = raw_input.strip().lower() + if key in MODEL_ALIASES: + fallback_result = _resolve_alias_fallback(raw_input) + if fallback_result is not None: + target_provider, new_model, resolved_alias = fallback_result + logger.debug( + "Alias '%s' resolved via fallback to %s on %s", + resolved_alias, new_model, target_provider, + ) + else: + identity = MODEL_ALIASES[key] + return ModelSwitchResult( + success=False, + is_global=is_global, + error_message=( + f"Alias '{key}' maps to {identity.vendor}/{identity.family} " + f"but no matching model was found in any provider catalog. " + f"Try specifying the full model name." + ), + ) + else: + # --- Step c: On aggregator, convert vendor:model to vendor/model --- + colon_pos = raw_input.find(":") + if colon_pos > 0 and is_aggregator(current_provider): + left = raw_input[:colon_pos].strip().lower() + right = raw_input[colon_pos + 1:].strip() + if left and right: + # Colons become slashes for aggregator slugs + new_model = f"{left}/{right}" + logger.debug( + "Converted vendor:model '%s' to aggregator slug '%s'", + raw_input, new_model, + ) + + # --- Step d: Aggregator catalog search --- + if is_aggregator(target_provider) and not resolved_alias: + catalog = list_provider_models(target_provider) + if catalog: + new_model_lower = new_model.lower() + for mid in catalog: + if mid.lower() == new_model_lower: + new_model = mid + break + else: + for mid in catalog: + if "/" in mid: + _, bare = mid.split("/", 1) + if bare.lower() == new_model_lower: + new_model = mid + break + + # --- Step e: detect_provider_for_model() as last resort --- + _base = current_base_url or "" + is_custom = current_provider in ("custom", "local") or ( + "localhost" in _base or "127.0.0.1" in _base + ) + + if ( + target_provider == current_provider + and not is_custom + and not resolved_alias + ): + detected = detect_provider_for_model(new_model, current_provider) + if detected: + target_provider, new_model = detected + + # ================================================================= + # COMMON PATH: Resolve credentials, normalize, get metadata + # ================================================================= provider_changed = target_provider != current_provider + provider_label = get_label(target_provider) - # Step 4: Resolve credentials for target provider + # --- Resolve credentials --- api_key = current_api_key base_url = current_base_url api_mode = "" - if provider_changed: + + if provider_changed or explicit_provider: try: runtime = resolve_runtime_provider(requested=target_provider) api_key = runtime.get("api_key", "") base_url = runtime.get("base_url", "") api_mode = runtime.get("api_mode", "") except Exception as e: - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - if target_provider == "custom": - return ModelSwitchResult( - success=False, - target_provider=target_provider, - error_message=( - "No custom endpoint configured. Set model.base_url " - "in config.yaml, or set OPENAI_BASE_URL in .env, " - "or run: hermes setup → Custom OpenAI-compatible endpoint" - ), - ) return ModelSwitchResult( success=False, target_provider=target_provider, + provider_label=provider_label, + is_global=is_global, error_message=( f"Could not resolve credentials for provider " f"'{provider_label}': {e}" ), ) else: - # Gateway also resolves for unchanged provider to get accurate - # base_url for validation probing. try: runtime = resolve_runtime_provider(requested=current_provider) api_key = runtime.get("api_key", "") @@ -138,7 +487,10 @@ def switch_model( except Exception: pass - # Step 5: Validate the model + # --- Normalize model name for target provider --- + new_model = normalize_model_for_provider(new_model, target_provider) + + # --- Validate --- try: validation = validate_requested_model( new_model, @@ -160,23 +512,26 @@ def switch_model( success=False, new_model=new_model, target_provider=target_provider, + provider_label=provider_label, + is_global=is_global, error_message=msg, ) - # Step 6: Build result - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - is_custom_target = target_provider == "custom" or ( - base_url - and "openrouter.ai" not in (base_url or "") - and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or "")) - ) - - if target_provider in {"opencode-zen", "opencode-go"}: - # Recompute against the requested new model, not the currently-configured - # model used during runtime resolution. OpenCode mixes API surfaces by - # model family, so a same-provider model switch can change api_mode. + # --- OpenCode api_mode override --- + if target_provider in {"opencode-zen", "opencode-go", "opencode", "opencode-go"}: api_mode = opencode_model_api_mode(target_provider, new_model) + # --- Determine api_mode if not already set --- + if not api_mode: + api_mode = determine_api_mode(target_provider, base_url) + + # --- Get capabilities (legacy) --- + capabilities = get_model_capabilities(target_provider, new_model) + + # --- Get full model info from models.dev --- + model_info = get_model_info(target_provider, new_model) + + # --- Build result --- return ModelSwitchResult( success=True, new_model=new_model, @@ -185,18 +540,191 @@ def switch_model( api_key=api_key, base_url=base_url, api_mode=api_mode, - persist=bool(validation.get("persist")), warning_message=validation.get("message") or "", - is_custom_target=is_custom_target, provider_label=provider_label, + resolved_via_alias=resolved_alias, + capabilities=capabilities, + model_info=model_info, + is_global=is_global, ) -def switch_to_custom_provider() -> CustomAutoResult: - """Handle bare '/model custom' — resolve endpoint and auto-detect model. +# --------------------------------------------------------------------------- +# Authenticated providers listing (for /model no-args display) +# --------------------------------------------------------------------------- - Returns a result object; the caller handles persistence and output. +def list_authenticated_providers( + current_provider: str = "", + user_providers: dict = None, + max_models: int = 8, +) -> List[dict]: + """Detect which providers have credentials and list their curated models. + + Uses the curated model lists from hermes_cli/models.py (OPENROUTER_MODELS, + _PROVIDER_MODELS) — NOT the full models.dev catalog. These are hand-picked + agentic models that work well as agent backends. + + Returns a list of dicts, each with: + - slug: str — the --provider value to use + - name: str — display name + - is_current: bool + - is_user_defined: bool + - models: list[str] — curated model IDs (up to max_models) + - total_models: int — total curated count + - source: str — "built-in", "models.dev", "user-config" + + Only includes providers that have API keys set or are user-defined endpoints. """ + import os + from agent.models_dev import ( + PROVIDER_TO_MODELS_DEV, + fetch_models_dev, + get_provider_info as _mdev_pinfo, + ) + from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS + + results: List[dict] = [] + seen_slugs: set = set() + + data = fetch_models_dev() + + # Build curated model lists keyed by hermes provider ID + curated: dict[str, list[str]] = dict(_PROVIDER_MODELS) + curated["openrouter"] = [mid for mid, _ in OPENROUTER_MODELS] + # "nous" shares OpenRouter's curated list if not separately defined + if "nous" not in curated: + curated["nous"] = curated["openrouter"] + + # --- 1. Check Hermes-mapped providers --- + for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items(): + pdata = data.get(mdev_id) + if not isinstance(pdata, dict): + continue + + env_vars = pdata.get("env", []) + if not isinstance(env_vars, list): + continue + + # Check if any env var is set + has_creds = any(os.environ.get(ev) for ev in env_vars) + if not has_creds: + continue + + # Use curated list, falling back to models.dev if no curated list + model_ids = curated.get(hermes_id, []) + total = len(model_ids) + top = model_ids[:max_models] + + slug = hermes_id + pinfo = _mdev_pinfo(mdev_id) + display_name = pinfo.name if pinfo else mdev_id + + results.append({ + "slug": slug, + "name": display_name, + "is_current": slug == current_provider or mdev_id == current_provider, + "is_user_defined": False, + "models": top, + "total_models": total, + "source": "built-in", + }) + seen_slugs.add(slug) + + # --- 2. Check Hermes-only providers (nous, openai-codex, copilot) --- + from hermes_cli.providers import HERMES_OVERLAYS + for pid, overlay in HERMES_OVERLAYS.items(): + if pid in seen_slugs: + continue + # Check if credentials exist + has_creds = False + if overlay.extra_env_vars: + has_creds = any(os.environ.get(ev) for ev in overlay.extra_env_vars) + if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"): + # These use auth stores, not env vars — check for auth.json entries + try: + from hermes_cli.auth import _read_auth_store + store = _read_auth_store() + if store and pid in store: + has_creds = True + except Exception: + pass + if not has_creds: + continue + + # Use curated list + model_ids = curated.get(pid, []) + total = len(model_ids) + top = model_ids[:max_models] + + results.append({ + "slug": pid, + "name": get_label(pid), + "is_current": pid == current_provider, + "is_user_defined": False, + "models": top, + "total_models": total, + "source": "hermes", + }) + seen_slugs.add(pid) + + # --- 3. User-defined endpoints from config --- + if user_providers and isinstance(user_providers, dict): + for ep_name, ep_cfg in user_providers.items(): + if not isinstance(ep_cfg, dict): + continue + display_name = ep_cfg.get("name", "") or ep_name + api_url = ep_cfg.get("api", "") or ep_cfg.get("url", "") or "" + default_model = ep_cfg.get("default_model", "") + + models_list = [] + if default_model: + models_list.append(default_model) + + # Try to probe /v1/models if URL is set (but don't block on it) + # For now just show what we know from config + results.append({ + "slug": ep_name, + "name": display_name, + "is_current": ep_name == current_provider, + "is_user_defined": True, + "models": models_list, + "total_models": len(models_list) if models_list else 0, + "source": "user-config", + "api_url": api_url, + }) + + # Sort: current provider first, then by model count descending + results.sort(key=lambda r: (not r["is_current"], -r["total_models"])) + + return results + + +# --------------------------------------------------------------------------- +# Fuzzy suggestions +# --------------------------------------------------------------------------- + +def suggest_models(raw_input: str, limit: int = 3) -> List[str]: + """Return fuzzy model suggestions for a (possibly misspelled) input.""" + query = raw_input.strip() + if not query: + return [] + + results = search_models_dev(query, limit=limit) + suggestions: list[str] = [] + for r in results: + mid = r.get("model_id", "") + if mid: + suggestions.append(mid) + + return suggestions[:limit] + + +# --------------------------------------------------------------------------- +# Custom provider switch +# --------------------------------------------------------------------------- + +def switch_to_custom_provider() -> CustomAutoResult: + """Handle bare '/model --provider custom' — resolve endpoint and auto-detect model.""" from hermes_cli.runtime_provider import ( resolve_runtime_provider, _auto_detect_local_model, @@ -219,7 +747,7 @@ def switch_to_custom_provider() -> CustomAutoResult: error_message=( "No custom endpoint configured. " "Set model.base_url in config.yaml, or set OPENAI_BASE_URL " - "in .env, or run: hermes setup → Custom OpenAI-compatible endpoint" + "in .env, or run: hermes setup -> Custom OpenAI-compatible endpoint" ), ) @@ -232,7 +760,7 @@ def switch_to_custom_provider() -> CustomAutoResult: error_message=( f"Custom endpoint at {cust_base} is reachable but no single " f"model was auto-detected. Specify the model explicitly: " - f"/model custom:" + f"/model --provider custom" ), ) diff --git a/hermes_cli/providers.py b/hermes_cli/providers.py new file mode 100644 index 0000000000..890927884c --- /dev/null +++ b/hermes_cli/providers.py @@ -0,0 +1,519 @@ +""" +Single source of truth for provider identity in Hermes Agent. + +Two data sources, merged at runtime: + +1. **models.dev catalog** — 109+ providers with base URLs, env vars, display + names, and full model metadata (context, cost, capabilities). This is + the primary database. + +2. **Hermes overlays** — transport type, auth patterns, aggregator flags, + and additional env vars that models.dev doesn't track. Small dict, + maintained here. + +3. **User config** (``providers:`` section in config.yaml) — user-defined + endpoints and overrides. Merged on top of everything else. + +Other modules import from this file. No parallel registries. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +# -- Hermes overlay ---------------------------------------------------------- +# Hermes-specific metadata that models.dev doesn't provide. + +@dataclass(frozen=True) +class HermesOverlay: + """Hermes-specific provider metadata layered on top of models.dev.""" + + transport: str = "openai_chat" # openai_chat | anthropic_messages | codex_responses + is_aggregator: bool = False + auth_type: str = "api_key" # api_key | oauth_device_code | oauth_external | external_process + extra_env_vars: Tuple[str, ...] = () # env vars models.dev doesn't list + base_url_override: str = "" # override if models.dev URL is wrong/missing + base_url_env_var: str = "" # env var for user-custom base URL + + +HERMES_OVERLAYS: Dict[str, HermesOverlay] = { + "openrouter": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + extra_env_vars=("OPENAI_API_KEY",), + base_url_env_var="OPENROUTER_BASE_URL", + ), + "nous": HermesOverlay( + transport="openai_chat", + auth_type="oauth_device_code", + base_url_override="https://inference-api.nousresearch.com/v1", + ), + "openai-codex": HermesOverlay( + transport="codex_responses", + auth_type="oauth_external", + base_url_override="https://chatgpt.com/backend-api/codex", + ), + "copilot-acp": HermesOverlay( + transport="codex_responses", + auth_type="external_process", + base_url_override="acp://copilot", + base_url_env_var="COPILOT_ACP_BASE_URL", + ), + "github-copilot": HermesOverlay( + transport="openai_chat", + extra_env_vars=("COPILOT_GITHUB_TOKEN", "GH_TOKEN"), + ), + "anthropic": HermesOverlay( + transport="anthropic_messages", + extra_env_vars=("ANTHROPIC_TOKEN", "CLAUDE_CODE_OAUTH_TOKEN"), + ), + "zai": HermesOverlay( + transport="openai_chat", + extra_env_vars=("GLM_API_KEY", "ZAI_API_KEY", "Z_AI_API_KEY"), + base_url_env_var="GLM_BASE_URL", + ), + "kimi-for-coding": HermesOverlay( + transport="openai_chat", + base_url_env_var="KIMI_BASE_URL", + ), + "minimax": HermesOverlay( + transport="openai_chat", + base_url_env_var="MINIMAX_BASE_URL", + ), + "minimax-cn": HermesOverlay( + transport="openai_chat", + base_url_env_var="MINIMAX_CN_BASE_URL", + ), + "deepseek": HermesOverlay( + transport="openai_chat", + base_url_env_var="DEEPSEEK_BASE_URL", + ), + "alibaba": HermesOverlay( + transport="openai_chat", + base_url_env_var="DASHSCOPE_BASE_URL", + ), + "vercel": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + ), + "opencode": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + base_url_env_var="OPENCODE_ZEN_BASE_URL", + ), + "opencode-go": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + base_url_env_var="OPENCODE_GO_BASE_URL", + ), + "kilo": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + base_url_env_var="KILOCODE_BASE_URL", + ), + "huggingface": HermesOverlay( + transport="openai_chat", + is_aggregator=True, + base_url_env_var="HF_BASE_URL", + ), +} + + +# -- Resolved provider ------------------------------------------------------- +# The merged result of models.dev + overlay + user config. + +@dataclass +class ProviderDef: + """Complete provider definition — merged from all sources.""" + + id: str + name: str + transport: str # openai_chat | anthropic_messages | codex_responses + api_key_env_vars: Tuple[str, ...] # all env vars to check for API key + base_url: str = "" + base_url_env_var: str = "" + is_aggregator: bool = False + auth_type: str = "api_key" + doc: str = "" + source: str = "" # "models.dev", "hermes", "user-config" + + @property + def is_user_defined(self) -> bool: + return self.source == "user-config" + + +# -- Aliases ------------------------------------------------------------------ +# Maps human-friendly / legacy names to canonical provider IDs. +# Uses models.dev IDs where possible. + +ALIASES: Dict[str, str] = { + # openrouter + "openai": "openrouter", # bare "openai" → route through aggregator + + # zai + "glm": "zai", + "z-ai": "zai", + "z.ai": "zai", + "zhipu": "zai", + + # kimi-for-coding (models.dev ID) + "kimi": "kimi-for-coding", + "kimi-coding": "kimi-for-coding", + "moonshot": "kimi-for-coding", + + # minimax-cn + "minimax-china": "minimax-cn", + "minimax_cn": "minimax-cn", + + # anthropic + "claude": "anthropic", + "claude-code": "anthropic", + + # github-copilot (models.dev ID) + "copilot": "github-copilot", + "github": "github-copilot", + "github-copilot-acp": "copilot-acp", + + # vercel (models.dev ID for AI Gateway) + "ai-gateway": "vercel", + "aigateway": "vercel", + "vercel-ai-gateway": "vercel", + + # opencode (models.dev ID for OpenCode Zen) + "opencode-zen": "opencode", + "zen": "opencode", + + # opencode-go + "go": "opencode-go", + "opencode-go-sub": "opencode-go", + + # kilo (models.dev ID for KiloCode) + "kilocode": "kilo", + "kilo-code": "kilo", + "kilo-gateway": "kilo", + + # deepseek + "deep-seek": "deepseek", + + # alibaba + "dashscope": "alibaba", + "aliyun": "alibaba", + "qwen": "alibaba", + "alibaba-cloud": "alibaba", + + # huggingface + "hf": "huggingface", + "hugging-face": "huggingface", + "huggingface-hub": "huggingface", + + # Local server aliases → virtual "local" concept (resolved via user config) + "lmstudio": "lmstudio", + "lm-studio": "lmstudio", + "lm_studio": "lmstudio", + "ollama": "ollama-cloud", + "vllm": "local", + "llamacpp": "local", + "llama.cpp": "local", + "llama-cpp": "local", +} + + +# -- Display labels ----------------------------------------------------------- +# Built dynamically from models.dev + overlays. Fallback for providers +# not in the catalog. + +_LABEL_OVERRIDES: Dict[str, str] = { + "nous": "Nous Portal", + "openai-codex": "OpenAI Codex", + "copilot-acp": "GitHub Copilot ACP", + "local": "Local endpoint", +} + + +# -- Transport → API mode mapping --------------------------------------------- + +TRANSPORT_TO_API_MODE: Dict[str, str] = { + "openai_chat": "chat_completions", + "anthropic_messages": "anthropic_messages", + "codex_responses": "codex_responses", +} + + +# -- Helper functions --------------------------------------------------------- + +def normalize_provider(name: str) -> str: + """Resolve aliases and normalise casing to a canonical provider id. + + Returns the canonical id string. Does *not* validate that the id + corresponds to a known provider. + """ + key = name.strip().lower() + return ALIASES.get(key, key) + + +def get_overlay(provider_id: str) -> Optional[HermesOverlay]: + """Get Hermes overlay for a provider, if one exists.""" + canonical = normalize_provider(provider_id) + return HERMES_OVERLAYS.get(canonical) + + +def get_provider(name: str) -> Optional[ProviderDef]: + """Look up a provider by id or alias, merging all data sources. + + Resolution order: + 1. Hermes overlays (for providers not in models.dev: nous, openai-codex, etc.) + 2. models.dev catalog + Hermes overlay + 3. User-defined providers from config (TODO: Phase 4) + + Returns a fully-resolved ProviderDef or None. + """ + canonical = normalize_provider(name) + + # Try to get models.dev data + try: + from agent.models_dev import get_provider_info as _mdev_provider + mdev_info = _mdev_provider(canonical) + except Exception: + mdev_info = None + + overlay = HERMES_OVERLAYS.get(canonical) + + if mdev_info is not None: + # Merge models.dev + overlay + transport = overlay.transport if overlay else "openai_chat" + is_agg = overlay.is_aggregator if overlay else False + auth = overlay.auth_type if overlay else "api_key" + base_url_env = overlay.base_url_env_var if overlay else "" + base_url_override = overlay.base_url_override if overlay else "" + + # Combine env vars: models.dev env + hermes extra + env_vars = list(mdev_info.env) + if overlay and overlay.extra_env_vars: + for ev in overlay.extra_env_vars: + if ev not in env_vars: + env_vars.append(ev) + + return ProviderDef( + id=canonical, + name=mdev_info.name, + transport=transport, + api_key_env_vars=tuple(env_vars), + base_url=base_url_override or mdev_info.api, + base_url_env_var=base_url_env, + is_aggregator=is_agg, + auth_type=auth, + doc=mdev_info.doc, + source="models.dev", + ) + + if overlay is not None: + # Hermes-only provider (not in models.dev) + return ProviderDef( + id=canonical, + name=_LABEL_OVERRIDES.get(canonical, canonical), + transport=overlay.transport, + api_key_env_vars=overlay.extra_env_vars, + base_url=overlay.base_url_override, + base_url_env_var=overlay.base_url_env_var, + is_aggregator=overlay.is_aggregator, + auth_type=overlay.auth_type, + source="hermes", + ) + + return None + + +def get_label(provider_id: str) -> str: + """Get a human-readable display name for a provider.""" + canonical = normalize_provider(provider_id) + + # Check label overrides first + if canonical in _LABEL_OVERRIDES: + return _LABEL_OVERRIDES[canonical] + + # Try models.dev + pdef = get_provider(canonical) + if pdef: + return pdef.name + + return canonical + + +# Build LABELS dict for backward compat +def _build_labels() -> Dict[str, str]: + """Build labels dict from overlays + overrides. Lazy, cached.""" + labels: Dict[str, str] = {} + for pid in HERMES_OVERLAYS: + labels[pid] = get_label(pid) + labels.update(_LABEL_OVERRIDES) + return labels + +# Lazy-built on first access +_labels_cache: Optional[Dict[str, str]] = None + +@property +def LABELS() -> Dict[str, str]: + """Backward-compatible labels dict.""" + global _labels_cache + if _labels_cache is None: + _labels_cache = _build_labels() + return _labels_cache + +# For direct import compat, expose as module-level dict +# Built on demand by get_label() calls +LABELS: Dict[str, str] = { + # Static entries for backward compat — get_label() is the proper API + "openrouter": "OpenRouter", + "nous": "Nous Portal", + "openai-codex": "OpenAI Codex", + "copilot-acp": "GitHub Copilot ACP", + "github-copilot": "GitHub Copilot", + "anthropic": "Anthropic", + "zai": "Z.AI / GLM", + "kimi-for-coding": "Kimi / Moonshot", + "minimax": "MiniMax", + "minimax-cn": "MiniMax (China)", + "deepseek": "DeepSeek", + "alibaba": "Alibaba Cloud (DashScope)", + "vercel": "Vercel AI Gateway", + "opencode": "OpenCode Zen", + "opencode-go": "OpenCode Go", + "kilo": "Kilo Gateway", + "huggingface": "Hugging Face", + "local": "Local endpoint", + "custom": "Custom endpoint", + # Legacy Hermes IDs (point to same providers) + "ai-gateway": "Vercel AI Gateway", + "kilocode": "Kilo Gateway", + "copilot": "GitHub Copilot", + "kimi-coding": "Kimi / Moonshot", + "opencode-zen": "OpenCode Zen", +} + + +def is_aggregator(provider: str) -> bool: + """Return True when the provider is a multi-model aggregator.""" + pdef = get_provider(provider) + return pdef.is_aggregator if pdef else False + + +def determine_api_mode(provider: str, base_url: str = "") -> str: + """Determine the API mode (wire protocol) for a provider/endpoint. + + Resolution order: + 1. Known provider → transport → TRANSPORT_TO_API_MODE. + 2. URL heuristics for unknown / custom providers. + 3. Default: 'chat_completions'. + """ + pdef = get_provider(provider) + if pdef is not None: + return TRANSPORT_TO_API_MODE.get(pdef.transport, "chat_completions") + + # URL-based heuristics for custom / unknown providers + if base_url: + url_lower = base_url.rstrip("/").lower() + if url_lower.endswith("/anthropic") or "api.anthropic.com" in url_lower: + return "anthropic_messages" + if "api.openai.com" in url_lower: + return "codex_responses" + + return "chat_completions" + + +# -- Provider from user config ------------------------------------------------ + +def resolve_user_provider(name: str, user_config: Dict[str, Any]) -> Optional[ProviderDef]: + """Resolve a provider from the user's config.yaml ``providers:`` section. + + Args: + name: Provider name as given by the user. + user_config: The ``providers:`` dict from config.yaml. + + Returns: + ProviderDef if found, else None. + """ + if not user_config or not isinstance(user_config, dict): + return None + + entry = user_config.get(name) + if not isinstance(entry, dict): + return None + + # Extract fields + display_name = entry.get("name", "") or name + api_url = entry.get("api", "") or entry.get("url", "") or entry.get("base_url", "") or "" + key_env = entry.get("key_env", "") or "" + transport = entry.get("transport", "openai_chat") or "openai_chat" + + env_vars: List[str] = [] + if key_env: + env_vars.append(key_env) + + return ProviderDef( + id=name, + name=display_name, + transport=transport, + api_key_env_vars=tuple(env_vars), + base_url=api_url, + is_aggregator=False, + auth_type="api_key", + source="user-config", + ) + + +def resolve_provider_full( + name: str, + user_providers: Optional[Dict[str, Any]] = None, +) -> Optional[ProviderDef]: + """Full resolution chain: built-in → models.dev → user config. + + This is the main entry point for --provider flag resolution. + + Args: + name: Provider name or alias. + user_providers: The ``providers:`` dict from config.yaml (optional). + + Returns: + ProviderDef if found, else None. + """ + canonical = normalize_provider(name) + + # 1. Built-in (models.dev + overlays) + pdef = get_provider(canonical) + if pdef is not None: + return pdef + + # 2. User-defined providers from config + if user_providers: + # Try canonical name + user_pdef = resolve_user_provider(canonical, user_providers) + if user_pdef is not None: + return user_pdef + # Try original name (in case alias didn't match) + user_pdef = resolve_user_provider(name.strip().lower(), user_providers) + if user_pdef is not None: + return user_pdef + + # 3. Try models.dev directly (for providers not in our ALIASES) + try: + from agent.models_dev import get_provider_info as _mdev_provider + mdev_info = _mdev_provider(canonical) + if mdev_info is not None: + return ProviderDef( + id=canonical, + name=mdev_info.name, + transport="openai_chat", + api_key_env_vars=mdev_info.env, + base_url=mdev_info.api, + source="models.dev", + ) + except Exception: + pass + + return None diff --git a/run_agent.py b/run_agent.py index b66f5f9669..48daa3113c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1268,6 +1268,129 @@ class AIAgent: # Iterative summary from previous session must not bleed into new one (#2635) self.context_compressor._previous_summary = None + def switch_model(self, new_model, new_provider, api_key='', base_url='', api_mode=''): + """Switch the model/provider in-place for a live agent. + + Called by the /model command handlers (CLI and gateway) after + ``model_switch.switch_model()`` has resolved credentials and + validated the model. This method performs the actual runtime + swap: rebuilding clients, updating caching flags, and refreshing + the context compressor. + + The implementation mirrors ``_try_activate_fallback()`` for the + client-swap logic but also updates ``_primary_runtime`` so the + change persists across turns (unlike fallback which is + turn-scoped). + """ + import logging + from hermes_cli.providers import determine_api_mode + + # ── Determine api_mode if not provided ── + if not api_mode: + api_mode = determine_api_mode(new_provider, base_url) + + old_model = self.model + old_provider = self.provider + + # ── Swap core runtime fields ── + self.model = new_model + self.provider = new_provider + self.base_url = base_url or self.base_url + self.api_mode = api_mode + if api_key: + self.api_key = api_key + + # ── Build new client ── + if api_mode == "anthropic_messages": + from agent.anthropic_adapter import ( + build_anthropic_client, + resolve_anthropic_token, + _is_oauth_token, + ) + effective_key = api_key or self.api_key or resolve_anthropic_token() or "" + self.api_key = effective_key + self._anthropic_api_key = effective_key + self._anthropic_base_url = base_url or getattr(self, "_anthropic_base_url", None) + self._anthropic_client = build_anthropic_client( + effective_key, self._anthropic_base_url, + ) + self._is_anthropic_oauth = _is_oauth_token(effective_key) + self.client = None + self._client_kwargs = {} + else: + effective_key = api_key or self.api_key + effective_base = base_url or self.base_url + self._client_kwargs = { + "api_key": effective_key, + "base_url": effective_base, + } + self.client = self._create_openai_client( + dict(self._client_kwargs), + reason="switch_model", + shared=True, + ) + + # ── Re-evaluate prompt caching ── + is_native_anthropic = api_mode == "anthropic_messages" + self._use_prompt_caching = ( + ("openrouter" in (self.base_url or "").lower() and "claude" in new_model.lower()) + or is_native_anthropic + ) + + # ── Update context compressor ── + if hasattr(self, "context_compressor") and self.context_compressor: + from agent.model_metadata import get_model_context_length + new_context_length = get_model_context_length( + self.model, + base_url=self.base_url, + api_key=self.api_key, + provider=self.provider, + ) + self.context_compressor.model = self.model + self.context_compressor.base_url = self.base_url + self.context_compressor.api_key = self.api_key + self.context_compressor.provider = self.provider + self.context_compressor.context_length = new_context_length + self.context_compressor.threshold_tokens = int( + new_context_length * self.context_compressor.threshold_percent + ) + + # ── Invalidate cached system prompt so it rebuilds next turn ── + self._cached_system_prompt = None + + # ── Update _primary_runtime so the change persists across turns ── + _cc = self.context_compressor if hasattr(self, "context_compressor") and self.context_compressor else None + self._primary_runtime = { + "model": self.model, + "provider": self.provider, + "base_url": self.base_url, + "api_mode": self.api_mode, + "api_key": getattr(self, "api_key", ""), + "client_kwargs": dict(self._client_kwargs), + "use_prompt_caching": self._use_prompt_caching, + "compressor_model": _cc.model if _cc else self.model, + "compressor_base_url": _cc.base_url if _cc else self.base_url, + "compressor_api_key": getattr(_cc, "api_key", "") if _cc else "", + "compressor_provider": _cc.provider if _cc else self.provider, + "compressor_context_length": _cc.context_length if _cc else 0, + "compressor_threshold_tokens": _cc.threshold_tokens if _cc else 0, + } + if api_mode == "anthropic_messages": + self._primary_runtime.update({ + "anthropic_api_key": self._anthropic_api_key, + "anthropic_base_url": self._anthropic_base_url, + "is_anthropic_oauth": self._is_anthropic_oauth, + }) + + # ── Reset fallback state ── + self._fallback_activated = False + self._fallback_index = 0 + + logging.info( + "Model switched in-place: %s (%s) -> %s (%s)", + old_model, old_provider, new_model, new_provider, + ) + def _safe_print(self, *args, **kwargs): """Print that silently handles broken pipes / closed stdout.