fix(model): preserve custom endpoint credentials during /model switch

Named custom providers (ollama-launch, lmstudio, etc.) resolve to
provider="custom" via _resolve_named_custom_runtime(). When switch_model()
re-resolves credentials for "custom", _get_named_custom_provider("custom")
rejects the bare name and the fallback chain reaches OpenRouter — silently
changing the endpoint.

Detect custom/local endpoints and skip credential re-resolution when the
provider hasn't changed. The existing api_key and base_url from the live
agent are already correct.

Reproduced: ollama launch hermes → /model gemma3:4b → probed OpenRouter
instead of localhost:11434. After fix: model switch stays on local Ollama.
This commit is contained in:
kshitijk4poor 2026-04-24 16:39:12 +05:30
parent 3a86f70969
commit ca8df75626
2 changed files with 82 additions and 4 deletions

View file

@ -770,6 +770,7 @@ def switch_model(
api_mode = "" api_mode = ""
if provider_changed or explicit_provider: if provider_changed or explicit_provider:
# Switching providers — resolve fresh credentials for the target.
try: try:
runtime = resolve_runtime_provider(requested=target_provider) runtime = resolve_runtime_provider(requested=target_provider)
api_key = runtime.get("api_key", "") api_key = runtime.get("api_key", "")
@ -787,13 +788,41 @@ def switch_model(
), ),
) )
else: else:
# Same provider, different model. Re-resolve to pick up credential
# rotation and provider-specific base_url adjustments (e.g. OpenCode
# strips /v1 for Anthropic models, restores it for chat_completions).
# However, re-resolution fails for custom/local endpoints: the
# runtime resolver maps named providers (e.g. "ollama-launch") to the
# generic "custom" label, which cannot round-trip and falls through to
# OpenRouter. Detect this and keep the caller's existing credentials.
_keep_existing = False
try: try:
runtime = resolve_runtime_provider(requested=current_provider) runtime = resolve_runtime_provider(requested=current_provider)
api_key = runtime.get("api_key", "") _resolved_base = (runtime.get("base_url") or "").rstrip("/")
base_url = runtime.get("base_url", "") _current_base = (current_base_url or "").rstrip("/")
api_mode = runtime.get("api_mode", "") # Guard: for custom/local providers, if re-resolution produced a
# different base_url it means the resolver couldn't find the
# original endpoint and fell through (e.g. "custom" → OpenRouter).
# Known providers (opencode, openrouter, etc.) may legitimately
# return a different base_url, so only guard "custom"/"local".
if (
current_provider in ("custom", "local")
and _current_base
and _resolved_base != _current_base
):
_keep_existing = True
else:
api_key = runtime.get("api_key", "") or api_key
base_url = runtime.get("base_url", "") or base_url
api_mode = runtime.get("api_mode", "")
except Exception: except Exception:
pass _keep_existing = True
if _keep_existing:
# Resolution failed or fell through — keep the caller's
# known-good credentials and detect api_mode from the URL.
from hermes_cli.runtime_provider import _detect_api_mode_for_url
api_mode = _detect_api_mode_for_url(base_url) or "chat_completions"
# --- Direct alias override: use exact base_url from the alias if set --- # --- Direct alias override: use exact base_url from the alias if set ---
if resolved_alias: if resolved_alias:
@ -824,6 +853,39 @@ def switch_model(
"message": f"Could not validate `{new_model}`: {e}", "message": f"Could not validate `{new_model}`: {e}",
} }
# Fallback: if the /v1/models probe rejected the model but the config's
# provider entry lists it, accept anyway. This covers models that the
# endpoint serves but doesn't advertise (e.g. Ollama cloud models like
# "kimi-k2.5:cloud" work via cloud routing but don't appear in the local
# /v1/models response).
if not validation.get("accepted") and target_provider in ("custom", "local"):
_config_models: set = set()
if user_providers and isinstance(user_providers, dict):
for _ep_cfg in user_providers.values():
if isinstance(_ep_cfg, dict):
_dm = _ep_cfg.get("default_model", "")
if _dm:
_config_models.add(_dm)
_ml = _ep_cfg.get("models", [])
if isinstance(_ml, (list, dict)):
_config_models.update(m for m in _ml if m)
if custom_providers and isinstance(custom_providers, list):
for _cp in custom_providers:
if isinstance(_cp, dict):
_dm = _cp.get("model", "")
if _dm:
_config_models.add(_dm)
_ml = _cp.get("models", [])
if isinstance(_ml, (list, dict)):
_config_models.update(m for m in _ml if m)
if new_model in _config_models:
validation = {
"accepted": True,
"persist": True,
"recognized": True,
"message": None,
}
if not validation.get("accepted"): if not validation.get("accepted"):
msg = validation.get("message", "Invalid model") msg = validation.get("message", "Invalid model")
return ModelSwitchResult( return ModelSwitchResult(

View file

@ -667,6 +667,20 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
current_base_url = str(runtime.get("base_url", "") or "") current_base_url = str(runtime.get("base_url", "") or "")
current_api_key = str(runtime.get("api_key", "") or "") current_api_key = str(runtime.get("api_key", "") or "")
# Load user-defined providers from config so switch_model can resolve
# named custom endpoints (e.g. "ollama-launch") and use the config's
# model list as a validation fallback for models the /v1/models probe
# doesn't advertise (e.g. Ollama cloud models). Parity with cli.py's
# _handle_model_switch which passes the same two fields.
cfg = _load_cfg()
user_provs = cfg.get("providers") if isinstance(cfg.get("providers"), dict) else None
custom_provs = None
try:
from hermes_cli.config import get_compatible_custom_providers
custom_provs = get_compatible_custom_providers(cfg)
except Exception:
pass
result = switch_model( result = switch_model(
raw_input=model_input, raw_input=model_input,
current_provider=current_provider, current_provider=current_provider,
@ -675,6 +689,8 @@ def _apply_model_switch(sid: str, session: dict, raw_input: str) -> dict:
current_api_key=current_api_key, current_api_key=current_api_key,
is_global=persist_global, is_global=persist_global,
explicit_provider=explicit_provider, explicit_provider=explicit_provider,
user_providers=user_provs,
custom_providers=custom_provs,
) )
if not result.success: if not result.success:
raise ValueError(result.error_message or "model switch failed") raise ValueError(result.error_message or "model switch failed")