fix(gateway): web UI model picker now switches provider when selecting non-default provider model

This commit is contained in:
konsisumer 2026-04-25 00:41:18 +02:00
parent 00c3d848d8
commit 2305473cef
3 changed files with 358 additions and 13 deletions

View file

@ -249,6 +249,11 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = {
"description": "Context window override (0 = auto-detect from model metadata)",
"category": "general",
},
"model_provider": {
"type": "string",
"description": "Provider for the model (e.g. openrouter, anthropic, ollama-local). Leave empty to auto-detect.",
"category": "general",
},
"terminal.backend": {
"type": "select",
"description": "Terminal execution backend",
@ -406,14 +411,16 @@ def _build_schema_from_config(
CONFIG_SCHEMA = _build_schema_from_config(DEFAULT_CONFIG)
# Inject virtual fields that don't live in DEFAULT_CONFIG but are surfaced
# by the normalize/denormalize cycle. Insert model_context_length right after
# the "model" key so it renders adjacent in the frontend.
# by the normalize/denormalize cycle. Insert model_context_length and
# model_provider right after the "model" key so they render adjacent.
_mcl_entry = _SCHEMA_OVERRIDES["model_context_length"]
_mp_entry = _SCHEMA_OVERRIDES["model_provider"]
_ordered_schema: Dict[str, Dict[str, Any]] = {}
for _k, _v in CONFIG_SCHEMA.items():
_ordered_schema[_k] = _v
if _k == "model":
_ordered_schema["model_context_length"] = _mcl_entry
_ordered_schema["model_provider"] = _mp_entry
CONFIG_SCHEMA = _ordered_schema
@ -791,18 +798,21 @@ def _normalize_config_for_web(config: Dict[str, Any]) -> Dict[str, Any]:
from DEFAULT_CONFIG where ``model`` is a string, but user configs often have the
dict form. Normalize to the string form so the frontend schema matches.
Also surfaces ``model_context_length`` as a top-level field so the web UI can
display and edit it. A value of 0 means "auto-detect".
Also surfaces ``model_context_length`` and ``model_provider`` as top-level virtual
fields so the web UI can display and edit them. ``model_context_length`` of 0 means
"auto-detect"; ``model_provider`` empty means "use default resolution".
"""
config = dict(config) # shallow copy
model_val = config.get("model")
if isinstance(model_val, dict):
# Extract context_length before flattening the dict
# Extract context_length and provider before flattening the dict
ctx_len = model_val.get("context_length", 0)
config["model"] = model_val.get("default", model_val.get("name", ""))
config["model_context_length"] = ctx_len if isinstance(ctx_len, int) else 0
config["model_provider"] = model_val.get("provider", "")
else:
config["model_context_length"] = 0
config["model_provider"] = ""
return config
@ -910,6 +920,38 @@ def get_model_info():
return dict(_EMPTY_MODEL_INFO)
def _infer_provider_for_model(model_name: str, disk_config: Dict[str, Any]) -> Optional[str]:
"""Try to infer which configured provider serves a given model name.
Checks user-configured providers' explicit model lists first, then the
built-in curated catalogs. Returns the provider slug on a deterministic
match, or ``None`` when no match is found (caller should preserve the
existing provider).
"""
# 1. User-configured providers (providers: {} dict) with explicit model lists
user_providers = disk_config.get("providers", {})
if isinstance(user_providers, dict):
for slug, pdata in user_providers.items():
if not isinstance(pdata, dict):
continue
models = pdata.get("models", [])
if isinstance(models, list) and model_name in models:
return str(slug)
# 2. Built-in curated catalogs: openrouter first (largest aggregator), then others
try:
from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS
if model_name in {mid for mid, _ in OPENROUTER_MODELS}:
return "openrouter"
for provider_slug, models in _PROVIDER_MODELS.items():
if model_name in models:
return provider_slug
except Exception:
pass
return None
def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
"""Reverse _normalize_config_for_web before saving.
@ -921,13 +963,17 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
Also handles ``model_context_length`` writes it back into the model dict
as ``context_length``. A value of 0 or absent means "auto-detect" (omitted
from the dict so get_model_context_length() uses its normal resolution).
``model_provider`` is a virtual field the frontend may send to explicitly
override the provider. When absent or empty, the provider is inferred from
the new model name if it changed, then falls back to the on-disk value.
"""
config = dict(config)
# Remove any _model_meta that might have leaked in (shouldn't happen
# with the stripped GET response, but be defensive)
config.pop("_model_meta", None)
# Extract and remove model_context_length before processing model
# Extract and remove virtual fields before processing model
ctx_override = config.pop("model_context_length", 0)
if not isinstance(ctx_override, int):
try:
@ -935,6 +981,9 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
except (TypeError, ValueError):
ctx_override = 0
# model_provider: explicit provider override from the frontend (may be "")
explicit_provider = str(config.pop("model_provider", "") or "").strip()
model_val = config.get("model")
if isinstance(model_val, str) and model_val:
# Read the current disk config to recover model subkeys
@ -942,8 +991,26 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
disk_config = load_config()
disk_model = disk_config.get("model")
if isinstance(disk_model, dict):
old_default = disk_model.get("default", "")
# Preserve all subkeys, update default with the new value
disk_model["default"] = model_val
# Determine the correct provider for the (possibly new) model.
# Priority: explicit frontend override > inference > existing disk value.
disk_provider = disk_model.get("provider", "")
if explicit_provider and explicit_provider != disk_provider:
# Frontend sent an explicit provider change
disk_model["provider"] = explicit_provider
disk_model.pop("base_url", None)
disk_model.pop("api_mode", None)
elif model_val != old_default:
# Model name changed — try to infer the matching provider
inferred = _infer_provider_for_model(model_val, disk_config)
if inferred and inferred != disk_provider:
disk_model["provider"] = inferred
disk_model.pop("base_url", None)
disk_model.pop("api_mode", None)
# Write context_length into the model dict (0 = remove/auto)
if ctx_override > 0:
disk_model["context_length"] = ctx_override
@ -952,12 +1019,14 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
config["model"] = disk_model
else:
# Model was previously a bare string — upgrade to dict if
# user is setting a context_length override
if ctx_override > 0:
config["model"] = {
"default": model_val,
"context_length": ctx_override,
}
# user is setting a context_length override or explicit provider
if ctx_override > 0 or explicit_provider:
new_dict: Dict[str, Any] = {"default": model_val}
if ctx_override > 0:
new_dict["context_length"] = ctx_override
if explicit_provider:
new_dict["provider"] = explicit_provider
config["model"] = new_dict
except Exception:
pass # can't read disk config — just use the string form
return config

View file

@ -1925,3 +1925,270 @@ class TestPtyWebSocket:
):
pass
assert exc.value.code == 4400
# ---------------------------------------------------------------------------
# model_provider virtual field — normalize / denormalize / schema
# ---------------------------------------------------------------------------
class TestModelProviderVirtualField:
"""Tests for the model_provider virtual field surfaced by the config API.
Regression: when the user picked a model from a different provider in the
Web UI, the provider was never updated the stale on-disk provider was
preserved because _denormalize_config_from_web only updated `default`.
"""
# ── normalize ──────────────────────────────────────────────────────────
def test_normalize_extracts_provider_from_dict(self):
"""normalize should surface provider from model dict as model_provider."""
from hermes_cli.web_server import _normalize_config_for_web
cfg = {"model": {"default": "google/gemini-3-flash", "provider": "openrouter"}}
result = _normalize_config_for_web(cfg)
assert result["model"] == "google/gemini-3-flash"
assert result["model_provider"] == "openrouter"
def test_normalize_bare_string_yields_empty_provider(self):
"""normalize should set model_provider='' for a bare string model."""
from hermes_cli.web_server import _normalize_config_for_web
result = _normalize_config_for_web({"model": "anthropic/claude-sonnet-4"})
assert result["model_provider"] == ""
def test_normalize_dict_without_provider_yields_empty(self):
"""normalize should set model_provider='' when model dict has no provider."""
from hermes_cli.web_server import _normalize_config_for_web
cfg = {"model": {"default": "test/model"}}
result = _normalize_config_for_web(cfg)
assert result["model_provider"] == ""
# ── denormalize: explicit model_provider override ──────────────────────
def test_denormalize_explicit_provider_override_switches_provider(self):
"""Sending model_provider should update the disk provider."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
}
})
result = _denormalize_config_from_web({
"model": "google/gemini-3-flash",
"model_provider": "openrouter",
})
assert isinstance(result["model"], dict)
assert result["model"]["default"] == "google/gemini-3-flash"
assert result["model"]["provider"] == "openrouter"
# provider-specific fields from old provider should be cleared
assert "base_url" not in result["model"]
assert "model_provider" not in result # virtual field consumed
def test_denormalize_explicit_provider_same_as_disk_no_op(self):
"""Sending model_provider matching disk provider should be a no-op."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "anthropic/claude-opus-4.6",
"provider": "openrouter",
}
})
result = _denormalize_config_from_web({
"model": "anthropic/claude-sonnet-4.6",
"model_provider": "openrouter",
})
assert result["model"]["provider"] == "openrouter"
def test_denormalize_explicit_empty_provider_runs_inference(self):
"""Empty model_provider should fall through to inference, not wipe provider."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
}
})
# Same model — no change → provider preserved
result = _denormalize_config_from_web({
"model": "llama3.2",
"model_provider": "",
})
assert result["model"]["provider"] == "ollama-local"
# ── denormalize: inference when model changes ──────────────────────────
def test_denormalize_model_change_infers_provider_from_user_providers_list(self):
"""When model changes to one in user providers' models list, switch provider."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
},
"providers": {
"openrouter": {
"models": ["google/gemini-2.5-flash", "anthropic/claude-opus-4.7"],
}
},
})
result = _denormalize_config_from_web({
"model": "google/gemini-2.5-flash",
"model_provider": "",
})
assert result["model"]["provider"] == "openrouter"
assert result["model"]["default"] == "google/gemini-2.5-flash"
assert "base_url" not in result["model"]
def test_denormalize_model_change_infers_provider_from_openrouter_catalog(self):
"""When model changes to one in OPENROUTER_MODELS, switch to openrouter."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
from hermes_cli.models import OPENROUTER_MODELS
from unittest.mock import patch
# Use the first model in OPENROUTER_MODELS as a known good entry
if not OPENROUTER_MODELS:
pytest.skip("OPENROUTER_MODELS is empty")
known_openrouter_model = OPENROUTER_MODELS[0][0]
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
},
})
result = _denormalize_config_from_web({
"model": known_openrouter_model,
"model_provider": "",
})
assert result["model"]["provider"] == "openrouter"
assert result["model"]["default"] == known_openrouter_model
assert "base_url" not in result["model"]
def test_denormalize_same_model_no_provider_change(self):
"""When model stays the same, provider should not change."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
}
})
result = _denormalize_config_from_web({
"model": "llama3.2",
"model_provider": "",
})
assert result["model"]["provider"] == "ollama-local"
assert result["model"].get("base_url") == "http://localhost:11434/v1"
def test_denormalize_unknown_model_preserves_provider(self):
"""When model changes to an unknown model, keep existing provider."""
from hermes_cli.web_server import _denormalize_config_from_web
from hermes_cli.config import save_config
save_config({
"model": {
"default": "llama3.2",
"provider": "ollama-local",
"base_url": "http://localhost:11434/v1",
}
})
result = _denormalize_config_from_web({
"model": "totally-unknown-model:v999",
"model_provider": "",
})
# Unknown model — provider should not change
assert result["model"]["provider"] == "ollama-local"
# ── schema ─────────────────────────────────────────────────────────────
def test_schema_has_model_provider(self):
"""CONFIG_SCHEMA should include model_provider virtual field."""
from hermes_cli.web_server import CONFIG_SCHEMA
assert "model_provider" in CONFIG_SCHEMA
def test_schema_model_provider_after_model_context_length(self):
"""model_provider should appear immediately after model_context_length."""
from hermes_cli.web_server import CONFIG_SCHEMA
keys = list(CONFIG_SCHEMA.keys())
mcl_idx = keys.index("model_context_length")
assert keys[mcl_idx + 1] == "model_provider"
def test_schema_model_provider_is_string(self):
"""model_provider schema entry should have type=string."""
from hermes_cli.web_server import CONFIG_SCHEMA
entry = CONFIG_SCHEMA["model_provider"]
assert entry["type"] == "string"
assert entry["category"] == "general"
# ---------------------------------------------------------------------------
# _infer_provider_for_model unit tests
# ---------------------------------------------------------------------------
class TestInferProviderForModel:
"""Unit tests for _infer_provider_for_model."""
def test_finds_model_in_user_providers_list(self):
from hermes_cli.web_server import _infer_provider_for_model
disk_cfg = {
"providers": {
"my-openrouter": {
"models": ["vendor/model-a", "vendor/model-b"],
}
}
}
assert _infer_provider_for_model("vendor/model-a", disk_cfg) == "my-openrouter"
def test_user_providers_dict_is_not_list(self):
"""providers: {} should be a dict; non-dict values are skipped."""
from hermes_cli.web_server import _infer_provider_for_model
disk_cfg = {"providers": {"bad": "not-a-dict"}}
# should not crash; returns None for bad entries
result = _infer_provider_for_model("anything", disk_cfg)
assert result is None
def test_no_providers_falls_through_to_catalog(self):
from hermes_cli.web_server import _infer_provider_for_model
from hermes_cli.models import OPENROUTER_MODELS
if not OPENROUTER_MODELS:
pytest.skip("OPENROUTER_MODELS is empty")
model_id = OPENROUTER_MODELS[0][0]
result = _infer_provider_for_model(model_id, {})
assert result == "openrouter"
def test_unknown_model_returns_none(self):
from hermes_cli.web_server import _infer_provider_for_model
result = _infer_provider_for_model("nonexistent/model-xyzzy-9999", {})
assert result is None

View file

@ -303,7 +303,16 @@ export default function ConfigPage() {
schemaKey={key}
schema={s}
value={getNestedValue(config, key)}
onChange={(v) => setConfig(setNestedValue(config, key, v))}
onChange={(v) => {
let next = setNestedValue(config, key, v);
// When the model name changes, clear model_provider so the
// backend can infer the correct provider for the new model
// rather than inheriting the stale one.
if (key === "model") {
next = setNestedValue(next, "model_provider", "");
}
setConfig(next);
}}
/>
</div>
</div>