mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
refactor(model): extract shared switch_model() from CLI and gateway handlers
Phase 4 of the /model command overhaul. Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers had ~50 lines of duplicated core logic: parsing, provider detection, credential resolution, and model validation. This extracts that pipeline into hermes_cli/model_switch.py. New module exports: - ModelSwitchResult: dataclass with all fields both handlers need - CustomAutoResult: dataclass for bare '/model custom' results - switch_model(): core pipeline — parse → detect → resolve → validate - switch_to_custom_provider(): resolve endpoint + auto-detect model The shared functions are pure (no I/O side effects). Each caller handles its own platform-specific concerns: - CLI: sets self.model/provider/etc, calls save_config_value(), prints - Gateway: writes config.yaml directly, sets env vars, returns markdown Net result: -244 lines from handlers, +234 lines in shared module. The handlers are now ~80 lines each (down from ~150+) and can't drift apart on core logic.
This commit is contained in:
parent
ce39f9cc44
commit
2e524272b1
3 changed files with 359 additions and 258 deletions
172
cli.py
172
cli.py
|
|
@ -3562,151 +3562,83 @@ class HermesCLI:
|
|||
# Use original case so model names like "Anthropic/Claude-Opus-4" are preserved
|
||||
parts = cmd_original.split(maxsplit=1)
|
||||
if len(parts) > 1:
|
||||
from hermes_cli.auth import resolve_provider
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
validate_requested_model,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
from hermes_cli.model_switch import switch_model, switch_to_custom_provider
|
||||
|
||||
raw_input = parts[1].strip()
|
||||
|
||||
# Handle bare "/model custom" — switch to custom provider
|
||||
# and auto-detect the model from the endpoint.
|
||||
if raw_input.strip().lower() == "custom":
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
cust_base = runtime.get("base_url", "")
|
||||
cust_key = runtime.get("api_key", "")
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
print("(>_<) No custom endpoint configured.")
|
||||
print(" Set model.base_url in config.yaml, or set OPENAI_BASE_URL in .env,")
|
||||
print(" or run: hermes setup → Custom OpenAI-compatible endpoint")
|
||||
return True
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if detected_model:
|
||||
self.model = detected_model
|
||||
self.requested_provider = "custom"
|
||||
self.provider = "custom"
|
||||
self.api_key = cust_key
|
||||
self.base_url = cust_base
|
||||
self.agent = None
|
||||
save_config_value("model.default", detected_model)
|
||||
save_config_value("model.provider", "custom")
|
||||
save_config_value("model.base_url", cust_base)
|
||||
print(f"(^_^)b Model changed to: {detected_model} [provider: Custom]")
|
||||
print(f" Endpoint: {cust_base}")
|
||||
print(f" Status: connected (model auto-detected)")
|
||||
else:
|
||||
print(f"(>_<) Custom endpoint at {cust_base} is reachable but no single model was auto-detected.")
|
||||
print(f" Specify the model explicitly: /model custom:<model-name>")
|
||||
except Exception as e:
|
||||
print(f"(>_<) Could not resolve custom endpoint: {e}")
|
||||
result = switch_to_custom_provider()
|
||||
if result.success:
|
||||
self.model = result.model
|
||||
self.requested_provider = "custom"
|
||||
self.provider = "custom"
|
||||
self.api_key = result.api_key
|
||||
self.base_url = result.base_url
|
||||
self.agent = None
|
||||
save_config_value("model.default", result.model)
|
||||
save_config_value("model.provider", "custom")
|
||||
save_config_value("model.base_url", result.base_url)
|
||||
print(f"(^_^)b Model changed to: {result.model} [provider: Custom]")
|
||||
print(f" Endpoint: {result.base_url}")
|
||||
print(f" Status: connected (model auto-detected)")
|
||||
else:
|
||||
print(f"(>_<) {result.error_message}")
|
||||
return True
|
||||
|
||||
# Parse provider:model syntax (e.g. "openrouter:anthropic/claude-sonnet-4.5")
|
||||
# Core model-switching pipeline (shared with gateway)
|
||||
current_provider = self.provider or self.requested_provider or "openrouter"
|
||||
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
||||
# Auto-detect provider when no explicit provider:model syntax was used.
|
||||
# Skip auto-detection for custom providers — the model name might
|
||||
# coincidentally match a known provider's catalog, but the user
|
||||
# intends to use it on their custom endpoint. Require explicit
|
||||
# provider:model syntax (e.g. /model openai-codex:gpt-5.2-codex)
|
||||
# to switch away from a custom endpoint.
|
||||
_base = self.base_url or ""
|
||||
is_custom = current_provider == "custom" or (
|
||||
"localhost" in _base or "127.0.0.1" in _base
|
||||
result = switch_model(
|
||||
raw_input,
|
||||
current_provider,
|
||||
current_base_url=self.base_url or "",
|
||||
current_api_key=self.api_key or "",
|
||||
)
|
||||
if target_provider == current_provider and not is_custom:
|
||||
from hermes_cli.models import detect_provider_for_model
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# If provider is changing, re-resolve credentials for the new provider
|
||||
api_key_for_probe = self.api_key
|
||||
base_url_for_probe = self.base_url
|
||||
if provider_changed:
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key_for_probe = runtime.get("api_key", "")
|
||||
base_url_for_probe = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
if target_provider == "custom":
|
||||
print(f"(>_<) Custom endpoint not configured. Set OPENAI_BASE_URL and OPENAI_API_KEY,")
|
||||
print(f" or run: hermes setup → Custom OpenAI-compatible endpoint")
|
||||
else:
|
||||
print(f"(>_<) Could not resolve credentials for provider '{provider_label}': {e}")
|
||||
print(f"(^_^) Current model unchanged: {self.model}")
|
||||
return True
|
||||
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key_for_probe,
|
||||
base_url=base_url_for_probe,
|
||||
)
|
||||
except Exception:
|
||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
print(f"(>_<) {validation.get('message')}")
|
||||
print(f" Model unchanged: {self.model}")
|
||||
if "Did you mean" not in (validation.get("message") or ""):
|
||||
print(" Tip: Use /model to see available models, /provider to see providers")
|
||||
if not result.success:
|
||||
print(f"(>_<) {result.error_message}")
|
||||
if "Did you mean" not in result.error_message:
|
||||
print(f" Model unchanged: {self.model}")
|
||||
if "credentials" not in result.error_message.lower():
|
||||
print(" Tip: Use /model to see available models, /provider to see providers")
|
||||
else:
|
||||
self.model = new_model
|
||||
self.model = result.new_model
|
||||
self.agent = None # Force re-init
|
||||
|
||||
if provider_changed:
|
||||
self.requested_provider = target_provider
|
||||
self.provider = target_provider
|
||||
self.api_key = api_key_for_probe
|
||||
self.base_url = base_url_for_probe
|
||||
if result.provider_changed:
|
||||
self.requested_provider = result.target_provider
|
||||
self.provider = result.target_provider
|
||||
self.api_key = result.api_key
|
||||
self.base_url = result.base_url
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
provider_note = f" [provider: {provider_label}]" if provider_changed else ""
|
||||
provider_note = f" [provider: {result.provider_label}]" if result.provider_changed else ""
|
||||
|
||||
if validation.get("persist"):
|
||||
saved_model = save_config_value("model.default", new_model)
|
||||
if provider_changed:
|
||||
save_config_value("model.provider", target_provider)
|
||||
# Persist base_url for custom endpoints so it
|
||||
# survives restart; clear it when switching away
|
||||
# from custom to prevent stale URLs leaking into
|
||||
# the new provider's resolution (#2562 Phase 2).
|
||||
if base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or ""):
|
||||
save_config_value("model.base_url", base_url_for_probe)
|
||||
if result.persist:
|
||||
saved_model = save_config_value("model.default", result.new_model)
|
||||
if result.provider_changed:
|
||||
save_config_value("model.provider", result.target_provider)
|
||||
# Persist base_url for custom endpoints; clear
|
||||
# when switching away from custom (#2562 Phase 2).
|
||||
if result.base_url and "openrouter.ai" not in (result.base_url or ""):
|
||||
save_config_value("model.base_url", result.base_url)
|
||||
else:
|
||||
save_config_value("model.base_url", None)
|
||||
if saved_model:
|
||||
print(f"(^_^)b Model changed to: {new_model}{provider_note} (saved to config)")
|
||||
print(f"(^_^)b Model changed to: {result.new_model}{provider_note} (saved to config)")
|
||||
else:
|
||||
print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)")
|
||||
print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)")
|
||||
else:
|
||||
message = validation.get("message") or ""
|
||||
print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)")
|
||||
if message:
|
||||
print(f" Reason: {message}")
|
||||
print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)")
|
||||
if result.warning_message:
|
||||
print(f" Reason: {result.warning_message}")
|
||||
print(" Note: Model will revert on restart. Use a verified model to save to config.")
|
||||
|
||||
# Show endpoint info for custom providers
|
||||
_target_is_custom = target_provider == "custom" or (
|
||||
base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or "")
|
||||
and ("localhost" in (base_url_for_probe or "") or "127.0.0.1" in (base_url_for_probe or ""))
|
||||
)
|
||||
if _target_is_custom or (is_custom and not provider_changed):
|
||||
endpoint = base_url_for_probe or self.base_url or "custom endpoint"
|
||||
if result.is_custom_target:
|
||||
endpoint = result.base_url or self.base_url or "custom endpoint"
|
||||
print(f" Endpoint: {endpoint}")
|
||||
if not provider_changed:
|
||||
if not result.provider_changed:
|
||||
print(f" Tip: To switch providers, use /model provider:model")
|
||||
print(f" e.g. /model openai-codex:gpt-5.2-codex")
|
||||
else:
|
||||
|
|
|
|||
211
gateway/run.py
211
gateway/run.py
|
|
@ -2854,117 +2854,10 @@ class GatewayRunner:
|
|||
# Handle bare "/model custom" — switch to custom provider
|
||||
# and auto-detect the model from the endpoint.
|
||||
if args.strip().lower() == "custom":
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider as _rtp_custom,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
try:
|
||||
runtime = _rtp_custom(requested="custom")
|
||||
cust_base = runtime.get("base_url", "")
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
return (
|
||||
"⚠️ No custom endpoint configured.\n"
|
||||
"Set `model.base_url` in config.yaml, or `OPENAI_BASE_URL` in .env,\n"
|
||||
"or run: `hermes setup` → Custom OpenAI-compatible endpoint"
|
||||
)
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if detected_model:
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = detected_model
|
||||
user_config["model"]["provider"] = "custom"
|
||||
user_config["model"]["base_url"] = cust_base
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
os.environ["HERMES_MODEL"] = detected_model
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
return (
|
||||
f"🤖 Model changed to `{detected_model}` (saved to config)\n"
|
||||
f"**Provider:** Custom\n"
|
||||
f"**Endpoint:** `{cust_base}`\n"
|
||||
f"_Model auto-detected from endpoint. Takes effect on next message._"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"⚠️ Custom endpoint at `{cust_base}` is reachable but no single model was auto-detected.\n"
|
||||
f"Specify the model explicitly: `/model custom:<model-name>`"
|
||||
)
|
||||
except Exception as e:
|
||||
return f"⚠️ Could not resolve custom endpoint: {e}"
|
||||
|
||||
# Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(args, current_provider)
|
||||
|
||||
# Detect custom/local provider — skip auto-detection to prevent
|
||||
# silently accepting an OpenRouter model name on a localhost endpoint.
|
||||
# Users must use explicit provider:model syntax to switch away.
|
||||
_resolved_base = ""
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
|
||||
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
is_custom = current_provider == "custom" or (
|
||||
"localhost" in _resolved_base or "127.0.0.1" in _resolved_base
|
||||
)
|
||||
|
||||
# Auto-detect provider when no explicit provider:model syntax was used
|
||||
if target_provider == current_provider and not is_custom:
|
||||
from hermes_cli.models import detect_provider_for_model
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Resolve credentials for the target provider (for API probe)
|
||||
api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or ""
|
||||
base_url = "https://openrouter.ai/api/v1"
|
||||
if provider_changed:
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}"
|
||||
else:
|
||||
# Use current provider's base_url from config or registry
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Validate the model against the live API
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {"accepted": True, "persist": True, "recognized": False, "message": None}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||
return f"⚠️ {msg}{tip}"
|
||||
|
||||
# Persist to config only if validation approves
|
||||
if validation.get("persist"):
|
||||
from hermes_cli.model_switch import switch_to_custom_provider
|
||||
cust_result = switch_to_custom_provider()
|
||||
if not cust_result.success:
|
||||
return f"⚠️ {cust_result.error_message}"
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
|
|
@ -2972,14 +2865,63 @@ class GatewayRunner:
|
|||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = new_model
|
||||
if provider_changed:
|
||||
user_config["model"]["provider"] = target_provider
|
||||
# Persist base_url for custom endpoints so it survives
|
||||
# restart; clear it when switching away from custom to
|
||||
# prevent stale URLs leaking (#2562 Phase 2).
|
||||
if base_url and "openrouter.ai" not in (base_url or ""):
|
||||
user_config["model"]["base_url"] = base_url
|
||||
user_config["model"]["default"] = cust_result.model
|
||||
user_config["model"]["provider"] = "custom"
|
||||
user_config["model"]["base_url"] = cust_result.base_url
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
yaml.dump(user_config, f, default_flow_style=False, sort_keys=False)
|
||||
except Exception as e:
|
||||
return f"⚠️ Failed to save model change: {e}"
|
||||
os.environ["HERMES_MODEL"] = cust_result.model
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = "custom"
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
return (
|
||||
f"🤖 Model changed to `{cust_result.model}` (saved to config)\n"
|
||||
f"**Provider:** Custom\n"
|
||||
f"**Endpoint:** `{cust_result.base_url}`\n"
|
||||
f"_Model auto-detected from endpoint. Takes effect on next message._"
|
||||
)
|
||||
|
||||
# Core model-switching pipeline (shared with CLI)
|
||||
from hermes_cli.model_switch import switch_model
|
||||
|
||||
# Resolve current base_url for is_custom detection
|
||||
_resolved_base = ""
|
||||
try:
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp
|
||||
_resolved_base = _rtp(requested=current_provider).get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
result = switch_model(
|
||||
args,
|
||||
current_provider,
|
||||
current_base_url=_resolved_base,
|
||||
current_api_key=os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "",
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
msg = result.error_message
|
||||
tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else ""
|
||||
return f"⚠️ {msg}{tip}"
|
||||
|
||||
# Persist to config only if validation approves
|
||||
if result.persist:
|
||||
try:
|
||||
user_config = {}
|
||||
if config_path.exists():
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
user_config = yaml.safe_load(f) or {}
|
||||
if "model" not in user_config or not isinstance(user_config["model"], dict):
|
||||
user_config["model"] = {}
|
||||
user_config["model"]["default"] = result.new_model
|
||||
if result.provider_changed:
|
||||
user_config["model"]["provider"] = result.target_provider
|
||||
# Persist base_url for custom endpoints; clear when
|
||||
# switching away from custom (#2562 Phase 2).
|
||||
if result.base_url and "openrouter.ai" not in (result.base_url or ""):
|
||||
user_config["model"]["base_url"] = result.base_url
|
||||
else:
|
||||
user_config["model"].pop("base_url", None)
|
||||
with open(config_path, 'w', encoding="utf-8") as f:
|
||||
|
|
@ -2988,41 +2930,34 @@ class GatewayRunner:
|
|||
return f"⚠️ Failed to save model change: {e}"
|
||||
|
||||
# Set env vars so the next agent run picks up the change
|
||||
os.environ["HERMES_MODEL"] = new_model
|
||||
if provider_changed:
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider
|
||||
os.environ["HERMES_MODEL"] = result.new_model
|
||||
if result.provider_changed:
|
||||
os.environ["HERMES_INFERENCE_PROVIDER"] = result.target_provider
|
||||
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
provider_note = f"\n**Provider:** {provider_label}" if provider_changed else ""
|
||||
provider_note = f"\n**Provider:** {result.provider_label}" if result.provider_changed else ""
|
||||
|
||||
warning = ""
|
||||
if validation.get("message"):
|
||||
warning = f"\n⚠️ {validation['message']}"
|
||||
if result.warning_message:
|
||||
warning = f"\n⚠️ {result.warning_message}"
|
||||
|
||||
persist_note = "saved to config" if result.persist else "this session only — will revert on restart"
|
||||
|
||||
if validation.get("persist"):
|
||||
persist_note = "saved to config"
|
||||
else:
|
||||
persist_note = "this session only — will revert on restart"
|
||||
# Clear fallback state since user explicitly chose a model
|
||||
self._effective_model = None
|
||||
self._effective_provider = None
|
||||
|
||||
# Show endpoint info for custom providers
|
||||
_target_is_custom = target_provider == "custom" or (
|
||||
base_url and "openrouter.ai" not in (base_url or "")
|
||||
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
|
||||
)
|
||||
custom_hint = ""
|
||||
if _target_is_custom or (is_custom and not provider_changed):
|
||||
endpoint = base_url or _resolved_base or "custom endpoint"
|
||||
if result.is_custom_target:
|
||||
endpoint = result.base_url or _resolved_base or "custom endpoint"
|
||||
custom_hint = f"\n**Endpoint:** `{endpoint}`"
|
||||
if not provider_changed:
|
||||
if not result.provider_changed:
|
||||
custom_hint += (
|
||||
"\n_To switch providers, use_ `/model provider:model`"
|
||||
"\n_e.g._ `/model openrouter:anthropic/claude-sonnet-4`"
|
||||
)
|
||||
|
||||
return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
|
||||
return f"🤖 Model changed to `{result.new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_"
|
||||
|
||||
async def _handle_provider_command(self, event: MessageEvent) -> str:
|
||||
"""Handle /provider command - show available providers."""
|
||||
|
|
|
|||
234
hermes_cli/model_switch.py
Normal file
234
hermes_cli/model_switch.py
Normal file
|
|
@ -0,0 +1,234 @@
|
|||
"""Shared model-switching logic for CLI and gateway /model commands.
|
||||
|
||||
Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers
|
||||
share the same core pipeline:
|
||||
|
||||
parse_model_input → is_custom detection → auto-detect provider
|
||||
→ credential resolution → validate model → return result
|
||||
|
||||
This module extracts that shared pipeline into pure functions that
|
||||
return result objects. The callers handle all platform-specific
|
||||
concerns: state mutation, config persistence, output formatting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSwitchResult:
|
||||
"""Result of a model switch attempt."""
|
||||
|
||||
success: bool
|
||||
new_model: str = ""
|
||||
target_provider: str = ""
|
||||
provider_changed: bool = False
|
||||
api_key: str = ""
|
||||
base_url: str = ""
|
||||
persist: bool = False
|
||||
error_message: str = ""
|
||||
warning_message: str = ""
|
||||
is_custom_target: bool = False
|
||||
provider_label: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomAutoResult:
|
||||
"""Result of switching to bare 'custom' provider with auto-detect."""
|
||||
|
||||
success: bool
|
||||
model: str = ""
|
||||
base_url: str = ""
|
||||
api_key: str = ""
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
def switch_model(
|
||||
raw_input: str,
|
||||
current_provider: str,
|
||||
current_base_url: str = "",
|
||||
current_api_key: str = "",
|
||||
) -> ModelSwitchResult:
|
||||
"""Core model-switching pipeline shared between CLI and gateway.
|
||||
|
||||
Handles parsing, provider detection, credential resolution, and
|
||||
model validation. Does NOT handle config persistence, state
|
||||
mutation, or output formatting — those are caller responsibilities.
|
||||
|
||||
Args:
|
||||
raw_input: The user's model input (e.g. "claude-sonnet-4",
|
||||
"zai:glm-5", "custom:local:qwen").
|
||||
current_provider: The currently active provider.
|
||||
current_base_url: The currently active base URL (used for
|
||||
is_custom detection).
|
||||
current_api_key: The currently active API key.
|
||||
|
||||
Returns:
|
||||
ModelSwitchResult with all information the caller needs to
|
||||
apply the switch and format output.
|
||||
"""
|
||||
from hermes_cli.models import (
|
||||
parse_model_input,
|
||||
detect_provider_for_model,
|
||||
validate_requested_model,
|
||||
_PROVIDER_LABELS,
|
||||
)
|
||||
from hermes_cli.runtime_provider import resolve_runtime_provider
|
||||
|
||||
# Step 1: Parse provider:model syntax
|
||||
target_provider, new_model = parse_model_input(raw_input, current_provider)
|
||||
|
||||
# Step 2: Detect if we're currently on a custom endpoint
|
||||
_base = current_base_url or ""
|
||||
is_custom = current_provider == "custom" or (
|
||||
"localhost" in _base or "127.0.0.1" in _base
|
||||
)
|
||||
|
||||
# Step 3: Auto-detect provider when no explicit provider:model syntax
|
||||
# was used. Skip for custom providers — the model name might
|
||||
# coincidentally match a known provider's catalog.
|
||||
if target_provider == current_provider and not is_custom:
|
||||
detected = detect_provider_for_model(new_model, current_provider)
|
||||
if detected:
|
||||
target_provider, new_model = detected
|
||||
|
||||
provider_changed = target_provider != current_provider
|
||||
|
||||
# Step 4: Resolve credentials for target provider
|
||||
api_key = current_api_key
|
||||
base_url = current_base_url
|
||||
if provider_changed:
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=target_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception as e:
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
if target_provider == "custom":
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
target_provider=target_provider,
|
||||
error_message=(
|
||||
"No custom endpoint configured. Set model.base_url "
|
||||
"in config.yaml, or set OPENAI_BASE_URL in .env, "
|
||||
"or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
target_provider=target_provider,
|
||||
error_message=(
|
||||
f"Could not resolve credentials for provider "
|
||||
f"'{provider_label}': {e}"
|
||||
),
|
||||
)
|
||||
else:
|
||||
# Gateway also resolves for unchanged provider to get accurate
|
||||
# base_url for validation probing.
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested=current_provider)
|
||||
api_key = runtime.get("api_key", "")
|
||||
base_url = runtime.get("base_url", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Step 5: Validate the model
|
||||
try:
|
||||
validation = validate_requested_model(
|
||||
new_model,
|
||||
target_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception:
|
||||
validation = {
|
||||
"accepted": True,
|
||||
"persist": True,
|
||||
"recognized": False,
|
||||
"message": None,
|
||||
}
|
||||
|
||||
if not validation.get("accepted"):
|
||||
msg = validation.get("message", "Invalid model")
|
||||
return ModelSwitchResult(
|
||||
success=False,
|
||||
new_model=new_model,
|
||||
target_provider=target_provider,
|
||||
error_message=msg,
|
||||
)
|
||||
|
||||
# Step 6: Build result
|
||||
provider_label = _PROVIDER_LABELS.get(target_provider, target_provider)
|
||||
is_custom_target = target_provider == "custom" or (
|
||||
base_url
|
||||
and "openrouter.ai" not in (base_url or "")
|
||||
and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or ""))
|
||||
)
|
||||
|
||||
return ModelSwitchResult(
|
||||
success=True,
|
||||
new_model=new_model,
|
||||
target_provider=target_provider,
|
||||
provider_changed=provider_changed,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
persist=bool(validation.get("persist")),
|
||||
warning_message=validation.get("message") or "",
|
||||
is_custom_target=is_custom_target,
|
||||
provider_label=provider_label,
|
||||
)
|
||||
|
||||
|
||||
def switch_to_custom_provider() -> CustomAutoResult:
|
||||
"""Handle bare '/model custom' — resolve endpoint and auto-detect model.
|
||||
|
||||
Returns a result object; the caller handles persistence and output.
|
||||
"""
|
||||
from hermes_cli.runtime_provider import (
|
||||
resolve_runtime_provider,
|
||||
_auto_detect_local_model,
|
||||
)
|
||||
|
||||
try:
|
||||
runtime = resolve_runtime_provider(requested="custom")
|
||||
except Exception as e:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=f"Could not resolve custom endpoint: {e}",
|
||||
)
|
||||
|
||||
cust_base = runtime.get("base_url", "")
|
||||
cust_key = runtime.get("api_key", "")
|
||||
|
||||
if not cust_base or "openrouter.ai" in cust_base:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
error_message=(
|
||||
"No custom endpoint configured. "
|
||||
"Set model.base_url in config.yaml, or set OPENAI_BASE_URL "
|
||||
"in .env, or run: hermes setup → Custom OpenAI-compatible endpoint"
|
||||
),
|
||||
)
|
||||
|
||||
detected_model = _auto_detect_local_model(cust_base)
|
||||
if not detected_model:
|
||||
return CustomAutoResult(
|
||||
success=False,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
error_message=(
|
||||
f"Custom endpoint at {cust_base} is reachable but no single "
|
||||
f"model was auto-detected. Specify the model explicitly: "
|
||||
f"/model custom:<model-name>"
|
||||
),
|
||||
)
|
||||
|
||||
return CustomAutoResult(
|
||||
success=True,
|
||||
model=detected_model,
|
||||
base_url=cust_base,
|
||||
api_key=cust_key,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue