fix(gateway): web UI model picker now switches provider when selecting non-default provider model

This commit is contained in:
konsisumer 2026-04-25 00:41:18 +02:00
parent 00c3d848d8
commit 2305473cef
3 changed files with 358 additions and 13 deletions

View file

@ -249,6 +249,11 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = {
"description": "Context window override (0 = auto-detect from model metadata)",
"category": "general",
},
"model_provider": {
"type": "string",
"description": "Provider for the model (e.g. openrouter, anthropic, ollama-local). Leave empty to auto-detect.",
"category": "general",
},
"terminal.backend": {
"type": "select",
"description": "Terminal execution backend",
@ -406,14 +411,16 @@ def _build_schema_from_config(
CONFIG_SCHEMA = _build_schema_from_config(DEFAULT_CONFIG)
# Inject virtual fields that don't live in DEFAULT_CONFIG but are surfaced
# by the normalize/denormalize cycle. Insert model_context_length right after
# the "model" key so it renders adjacent in the frontend.
# by the normalize/denormalize cycle. Insert model_context_length and
# model_provider right after the "model" key so they render adjacent.
_mcl_entry = _SCHEMA_OVERRIDES["model_context_length"]
_mp_entry = _SCHEMA_OVERRIDES["model_provider"]
_ordered_schema: Dict[str, Dict[str, Any]] = {}
for _k, _v in CONFIG_SCHEMA.items():
_ordered_schema[_k] = _v
if _k == "model":
_ordered_schema["model_context_length"] = _mcl_entry
_ordered_schema["model_provider"] = _mp_entry
CONFIG_SCHEMA = _ordered_schema
@ -791,18 +798,21 @@ def _normalize_config_for_web(config: Dict[str, Any]) -> Dict[str, Any]:
from DEFAULT_CONFIG where ``model`` is a string, but user configs often have the
dict form. Normalize to the string form so the frontend schema matches.
Also surfaces ``model_context_length`` as a top-level field so the web UI can
display and edit it. A value of 0 means "auto-detect".
Also surfaces ``model_context_length`` and ``model_provider`` as top-level virtual
fields so the web UI can display and edit them. ``model_context_length`` of 0 means
"auto-detect"; ``model_provider`` empty means "use default resolution".
"""
config = dict(config) # shallow copy
model_val = config.get("model")
if isinstance(model_val, dict):
# Extract context_length before flattening the dict
# Extract context_length and provider before flattening the dict
ctx_len = model_val.get("context_length", 0)
config["model"] = model_val.get("default", model_val.get("name", ""))
config["model_context_length"] = ctx_len if isinstance(ctx_len, int) else 0
config["model_provider"] = model_val.get("provider", "")
else:
config["model_context_length"] = 0
config["model_provider"] = ""
return config
@ -910,6 +920,38 @@ def get_model_info():
return dict(_EMPTY_MODEL_INFO)
def _infer_provider_for_model(model_name: str, disk_config: Dict[str, Any]) -> Optional[str]:
"""Try to infer which configured provider serves a given model name.
Checks user-configured providers' explicit model lists first, then the
built-in curated catalogs. Returns the provider slug on a deterministic
match, or ``None`` when no match is found (caller should preserve the
existing provider).
"""
# 1. User-configured providers (providers: {} dict) with explicit model lists
user_providers = disk_config.get("providers", {})
if isinstance(user_providers, dict):
for slug, pdata in user_providers.items():
if not isinstance(pdata, dict):
continue
models = pdata.get("models", [])
if isinstance(models, list) and model_name in models:
return str(slug)
# 2. Built-in curated catalogs: openrouter first (largest aggregator), then others
try:
from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS
if model_name in {mid for mid, _ in OPENROUTER_MODELS}:
return "openrouter"
for provider_slug, models in _PROVIDER_MODELS.items():
if model_name in models:
return provider_slug
except Exception:
pass
return None
def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
"""Reverse _normalize_config_for_web before saving.
@ -921,13 +963,17 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
Also handles ``model_context_length`` writes it back into the model dict
as ``context_length``. A value of 0 or absent means "auto-detect" (omitted
from the dict so get_model_context_length() uses its normal resolution).
``model_provider`` is a virtual field the frontend may send to explicitly
override the provider. When absent or empty, the provider is inferred from
the new model name if it changed, then falls back to the on-disk value.
"""
config = dict(config)
# Remove any _model_meta that might have leaked in (shouldn't happen
# with the stripped GET response, but be defensive)
config.pop("_model_meta", None)
# Extract and remove model_context_length before processing model
# Extract and remove virtual fields before processing model
ctx_override = config.pop("model_context_length", 0)
if not isinstance(ctx_override, int):
try:
@ -935,6 +981,9 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
except (TypeError, ValueError):
ctx_override = 0
# model_provider: explicit provider override from the frontend (may be "")
explicit_provider = str(config.pop("model_provider", "") or "").strip()
model_val = config.get("model")
if isinstance(model_val, str) and model_val:
# Read the current disk config to recover model subkeys
@ -942,8 +991,26 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
disk_config = load_config()
disk_model = disk_config.get("model")
if isinstance(disk_model, dict):
old_default = disk_model.get("default", "")
# Preserve all subkeys, update default with the new value
disk_model["default"] = model_val
# Determine the correct provider for the (possibly new) model.
# Priority: explicit frontend override > inference > existing disk value.
disk_provider = disk_model.get("provider", "")
if explicit_provider and explicit_provider != disk_provider:
# Frontend sent an explicit provider change
disk_model["provider"] = explicit_provider
disk_model.pop("base_url", None)
disk_model.pop("api_mode", None)
elif model_val != old_default:
# Model name changed — try to infer the matching provider
inferred = _infer_provider_for_model(model_val, disk_config)
if inferred and inferred != disk_provider:
disk_model["provider"] = inferred
disk_model.pop("base_url", None)
disk_model.pop("api_mode", None)
# Write context_length into the model dict (0 = remove/auto)
if ctx_override > 0:
disk_model["context_length"] = ctx_override
@ -952,12 +1019,14 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
config["model"] = disk_model
else:
# Model was previously a bare string — upgrade to dict if
# user is setting a context_length override
if ctx_override > 0:
config["model"] = {
"default": model_val,
"context_length": ctx_override,
}
# user is setting a context_length override or explicit provider
if ctx_override > 0 or explicit_provider:
new_dict: Dict[str, Any] = {"default": model_val}
if ctx_override > 0:
new_dict["context_length"] = ctx_override
if explicit_provider:
new_dict["provider"] = explicit_provider
config["model"] = new_dict
except Exception:
pass # can't read disk config — just use the string form
return config