mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): web UI model picker now switches provider when selecting non-default provider model
This commit is contained in:
parent
00c3d848d8
commit
2305473cef
3 changed files with 358 additions and 13 deletions
|
|
@ -249,6 +249,11 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = {
|
|||
"description": "Context window override (0 = auto-detect from model metadata)",
|
||||
"category": "general",
|
||||
},
|
||||
"model_provider": {
|
||||
"type": "string",
|
||||
"description": "Provider for the model (e.g. openrouter, anthropic, ollama-local). Leave empty to auto-detect.",
|
||||
"category": "general",
|
||||
},
|
||||
"terminal.backend": {
|
||||
"type": "select",
|
||||
"description": "Terminal execution backend",
|
||||
|
|
@ -406,14 +411,16 @@ def _build_schema_from_config(
|
|||
CONFIG_SCHEMA = _build_schema_from_config(DEFAULT_CONFIG)
|
||||
|
||||
# Inject virtual fields that don't live in DEFAULT_CONFIG but are surfaced
|
||||
# by the normalize/denormalize cycle. Insert model_context_length right after
|
||||
# the "model" key so it renders adjacent in the frontend.
|
||||
# by the normalize/denormalize cycle. Insert model_context_length and
|
||||
# model_provider right after the "model" key so they render adjacent.
|
||||
_mcl_entry = _SCHEMA_OVERRIDES["model_context_length"]
|
||||
_mp_entry = _SCHEMA_OVERRIDES["model_provider"]
|
||||
_ordered_schema: Dict[str, Dict[str, Any]] = {}
|
||||
for _k, _v in CONFIG_SCHEMA.items():
|
||||
_ordered_schema[_k] = _v
|
||||
if _k == "model":
|
||||
_ordered_schema["model_context_length"] = _mcl_entry
|
||||
_ordered_schema["model_provider"] = _mp_entry
|
||||
CONFIG_SCHEMA = _ordered_schema
|
||||
|
||||
|
||||
|
|
@ -791,18 +798,21 @@ def _normalize_config_for_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
from DEFAULT_CONFIG where ``model`` is a string, but user configs often have the
|
||||
dict form. Normalize to the string form so the frontend schema matches.
|
||||
|
||||
Also surfaces ``model_context_length`` as a top-level field so the web UI can
|
||||
display and edit it. A value of 0 means "auto-detect".
|
||||
Also surfaces ``model_context_length`` and ``model_provider`` as top-level virtual
|
||||
fields so the web UI can display and edit them. ``model_context_length`` of 0 means
|
||||
"auto-detect"; ``model_provider`` empty means "use default resolution".
|
||||
"""
|
||||
config = dict(config) # shallow copy
|
||||
model_val = config.get("model")
|
||||
if isinstance(model_val, dict):
|
||||
# Extract context_length before flattening the dict
|
||||
# Extract context_length and provider before flattening the dict
|
||||
ctx_len = model_val.get("context_length", 0)
|
||||
config["model"] = model_val.get("default", model_val.get("name", ""))
|
||||
config["model_context_length"] = ctx_len if isinstance(ctx_len, int) else 0
|
||||
config["model_provider"] = model_val.get("provider", "")
|
||||
else:
|
||||
config["model_context_length"] = 0
|
||||
config["model_provider"] = ""
|
||||
return config
|
||||
|
||||
|
||||
|
|
@ -910,6 +920,38 @@ def get_model_info():
|
|||
return dict(_EMPTY_MODEL_INFO)
|
||||
|
||||
|
||||
def _infer_provider_for_model(model_name: str, disk_config: Dict[str, Any]) -> Optional[str]:
|
||||
"""Try to infer which configured provider serves a given model name.
|
||||
|
||||
Checks user-configured providers' explicit model lists first, then the
|
||||
built-in curated catalogs. Returns the provider slug on a deterministic
|
||||
match, or ``None`` when no match is found (caller should preserve the
|
||||
existing provider).
|
||||
"""
|
||||
# 1. User-configured providers (providers: {} dict) with explicit model lists
|
||||
user_providers = disk_config.get("providers", {})
|
||||
if isinstance(user_providers, dict):
|
||||
for slug, pdata in user_providers.items():
|
||||
if not isinstance(pdata, dict):
|
||||
continue
|
||||
models = pdata.get("models", [])
|
||||
if isinstance(models, list) and model_name in models:
|
||||
return str(slug)
|
||||
|
||||
# 2. Built-in curated catalogs: openrouter first (largest aggregator), then others
|
||||
try:
|
||||
from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS
|
||||
if model_name in {mid for mid, _ in OPENROUTER_MODELS}:
|
||||
return "openrouter"
|
||||
for provider_slug, models in _PROVIDER_MODELS.items():
|
||||
if model_name in models:
|
||||
return provider_slug
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Reverse _normalize_config_for_web before saving.
|
||||
|
||||
|
|
@ -921,13 +963,17 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
Also handles ``model_context_length`` — writes it back into the model dict
|
||||
as ``context_length``. A value of 0 or absent means "auto-detect" (omitted
|
||||
from the dict so get_model_context_length() uses its normal resolution).
|
||||
|
||||
``model_provider`` is a virtual field the frontend may send to explicitly
|
||||
override the provider. When absent or empty, the provider is inferred from
|
||||
the new model name if it changed, then falls back to the on-disk value.
|
||||
"""
|
||||
config = dict(config)
|
||||
# Remove any _model_meta that might have leaked in (shouldn't happen
|
||||
# with the stripped GET response, but be defensive)
|
||||
config.pop("_model_meta", None)
|
||||
|
||||
# Extract and remove model_context_length before processing model
|
||||
# Extract and remove virtual fields before processing model
|
||||
ctx_override = config.pop("model_context_length", 0)
|
||||
if not isinstance(ctx_override, int):
|
||||
try:
|
||||
|
|
@ -935,6 +981,9 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
except (TypeError, ValueError):
|
||||
ctx_override = 0
|
||||
|
||||
# model_provider: explicit provider override from the frontend (may be "")
|
||||
explicit_provider = str(config.pop("model_provider", "") or "").strip()
|
||||
|
||||
model_val = config.get("model")
|
||||
if isinstance(model_val, str) and model_val:
|
||||
# Read the current disk config to recover model subkeys
|
||||
|
|
@ -942,8 +991,26 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
disk_config = load_config()
|
||||
disk_model = disk_config.get("model")
|
||||
if isinstance(disk_model, dict):
|
||||
old_default = disk_model.get("default", "")
|
||||
# Preserve all subkeys, update default with the new value
|
||||
disk_model["default"] = model_val
|
||||
|
||||
# Determine the correct provider for the (possibly new) model.
|
||||
# Priority: explicit frontend override > inference > existing disk value.
|
||||
disk_provider = disk_model.get("provider", "")
|
||||
if explicit_provider and explicit_provider != disk_provider:
|
||||
# Frontend sent an explicit provider change
|
||||
disk_model["provider"] = explicit_provider
|
||||
disk_model.pop("base_url", None)
|
||||
disk_model.pop("api_mode", None)
|
||||
elif model_val != old_default:
|
||||
# Model name changed — try to infer the matching provider
|
||||
inferred = _infer_provider_for_model(model_val, disk_config)
|
||||
if inferred and inferred != disk_provider:
|
||||
disk_model["provider"] = inferred
|
||||
disk_model.pop("base_url", None)
|
||||
disk_model.pop("api_mode", None)
|
||||
|
||||
# Write context_length into the model dict (0 = remove/auto)
|
||||
if ctx_override > 0:
|
||||
disk_model["context_length"] = ctx_override
|
||||
|
|
@ -952,12 +1019,14 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]:
|
|||
config["model"] = disk_model
|
||||
else:
|
||||
# Model was previously a bare string — upgrade to dict if
|
||||
# user is setting a context_length override
|
||||
if ctx_override > 0:
|
||||
config["model"] = {
|
||||
"default": model_val,
|
||||
"context_length": ctx_override,
|
||||
}
|
||||
# user is setting a context_length override or explicit provider
|
||||
if ctx_override > 0 or explicit_provider:
|
||||
new_dict: Dict[str, Any] = {"default": model_val}
|
||||
if ctx_override > 0:
|
||||
new_dict["context_length"] = ctx_override
|
||||
if explicit_provider:
|
||||
new_dict["provider"] = explicit_provider
|
||||
config["model"] = new_dict
|
||||
except Exception:
|
||||
pass # can't read disk config — just use the string form
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -1925,3 +1925,270 @@ class TestPtyWebSocket:
|
|||
):
|
||||
pass
|
||||
assert exc.value.code == 4400
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# model_provider virtual field — normalize / denormalize / schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestModelProviderVirtualField:
|
||||
"""Tests for the model_provider virtual field surfaced by the config API.
|
||||
|
||||
Regression: when the user picked a model from a different provider in the
|
||||
Web UI, the provider was never updated — the stale on-disk provider was
|
||||
preserved because _denormalize_config_from_web only updated `default`.
|
||||
"""
|
||||
|
||||
# ── normalize ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_normalize_extracts_provider_from_dict(self):
|
||||
"""normalize should surface provider from model dict as model_provider."""
|
||||
from hermes_cli.web_server import _normalize_config_for_web
|
||||
|
||||
cfg = {"model": {"default": "google/gemini-3-flash", "provider": "openrouter"}}
|
||||
result = _normalize_config_for_web(cfg)
|
||||
assert result["model"] == "google/gemini-3-flash"
|
||||
assert result["model_provider"] == "openrouter"
|
||||
|
||||
def test_normalize_bare_string_yields_empty_provider(self):
|
||||
"""normalize should set model_provider='' for a bare string model."""
|
||||
from hermes_cli.web_server import _normalize_config_for_web
|
||||
|
||||
result = _normalize_config_for_web({"model": "anthropic/claude-sonnet-4"})
|
||||
assert result["model_provider"] == ""
|
||||
|
||||
def test_normalize_dict_without_provider_yields_empty(self):
|
||||
"""normalize should set model_provider='' when model dict has no provider."""
|
||||
from hermes_cli.web_server import _normalize_config_for_web
|
||||
|
||||
cfg = {"model": {"default": "test/model"}}
|
||||
result = _normalize_config_for_web(cfg)
|
||||
assert result["model_provider"] == ""
|
||||
|
||||
# ── denormalize: explicit model_provider override ──────────────────────
|
||||
|
||||
def test_denormalize_explicit_provider_override_switches_provider(self):
|
||||
"""Sending model_provider should update the disk provider."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
}
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "google/gemini-3-flash",
|
||||
"model_provider": "openrouter",
|
||||
})
|
||||
assert isinstance(result["model"], dict)
|
||||
assert result["model"]["default"] == "google/gemini-3-flash"
|
||||
assert result["model"]["provider"] == "openrouter"
|
||||
# provider-specific fields from old provider should be cleared
|
||||
assert "base_url" not in result["model"]
|
||||
assert "model_provider" not in result # virtual field consumed
|
||||
|
||||
def test_denormalize_explicit_provider_same_as_disk_no_op(self):
|
||||
"""Sending model_provider matching disk provider should be a no-op."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "anthropic/claude-opus-4.6",
|
||||
"provider": "openrouter",
|
||||
}
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "anthropic/claude-sonnet-4.6",
|
||||
"model_provider": "openrouter",
|
||||
})
|
||||
assert result["model"]["provider"] == "openrouter"
|
||||
|
||||
def test_denormalize_explicit_empty_provider_runs_inference(self):
|
||||
"""Empty model_provider should fall through to inference, not wipe provider."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
}
|
||||
})
|
||||
|
||||
# Same model — no change → provider preserved
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "llama3.2",
|
||||
"model_provider": "",
|
||||
})
|
||||
assert result["model"]["provider"] == "ollama-local"
|
||||
|
||||
# ── denormalize: inference when model changes ──────────────────────────
|
||||
|
||||
def test_denormalize_model_change_infers_provider_from_user_providers_list(self):
|
||||
"""When model changes to one in user providers' models list, switch provider."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
},
|
||||
"providers": {
|
||||
"openrouter": {
|
||||
"models": ["google/gemini-2.5-flash", "anthropic/claude-opus-4.7"],
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "google/gemini-2.5-flash",
|
||||
"model_provider": "",
|
||||
})
|
||||
assert result["model"]["provider"] == "openrouter"
|
||||
assert result["model"]["default"] == "google/gemini-2.5-flash"
|
||||
assert "base_url" not in result["model"]
|
||||
|
||||
def test_denormalize_model_change_infers_provider_from_openrouter_catalog(self):
|
||||
"""When model changes to one in OPENROUTER_MODELS, switch to openrouter."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
from hermes_cli.models import OPENROUTER_MODELS
|
||||
from unittest.mock import patch
|
||||
|
||||
# Use the first model in OPENROUTER_MODELS as a known good entry
|
||||
if not OPENROUTER_MODELS:
|
||||
pytest.skip("OPENROUTER_MODELS is empty")
|
||||
known_openrouter_model = OPENROUTER_MODELS[0][0]
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
},
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": known_openrouter_model,
|
||||
"model_provider": "",
|
||||
})
|
||||
assert result["model"]["provider"] == "openrouter"
|
||||
assert result["model"]["default"] == known_openrouter_model
|
||||
assert "base_url" not in result["model"]
|
||||
|
||||
def test_denormalize_same_model_no_provider_change(self):
|
||||
"""When model stays the same, provider should not change."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
}
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "llama3.2",
|
||||
"model_provider": "",
|
||||
})
|
||||
assert result["model"]["provider"] == "ollama-local"
|
||||
assert result["model"].get("base_url") == "http://localhost:11434/v1"
|
||||
|
||||
def test_denormalize_unknown_model_preserves_provider(self):
|
||||
"""When model changes to an unknown model, keep existing provider."""
|
||||
from hermes_cli.web_server import _denormalize_config_from_web
|
||||
from hermes_cli.config import save_config
|
||||
|
||||
save_config({
|
||||
"model": {
|
||||
"default": "llama3.2",
|
||||
"provider": "ollama-local",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
}
|
||||
})
|
||||
|
||||
result = _denormalize_config_from_web({
|
||||
"model": "totally-unknown-model:v999",
|
||||
"model_provider": "",
|
||||
})
|
||||
# Unknown model — provider should not change
|
||||
assert result["model"]["provider"] == "ollama-local"
|
||||
|
||||
# ── schema ─────────────────────────────────────────────────────────────
|
||||
|
||||
def test_schema_has_model_provider(self):
|
||||
"""CONFIG_SCHEMA should include model_provider virtual field."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
assert "model_provider" in CONFIG_SCHEMA
|
||||
|
||||
def test_schema_model_provider_after_model_context_length(self):
|
||||
"""model_provider should appear immediately after model_context_length."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
keys = list(CONFIG_SCHEMA.keys())
|
||||
mcl_idx = keys.index("model_context_length")
|
||||
assert keys[mcl_idx + 1] == "model_provider"
|
||||
|
||||
def test_schema_model_provider_is_string(self):
|
||||
"""model_provider schema entry should have type=string."""
|
||||
from hermes_cli.web_server import CONFIG_SCHEMA
|
||||
entry = CONFIG_SCHEMA["model_provider"]
|
||||
assert entry["type"] == "string"
|
||||
assert entry["category"] == "general"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _infer_provider_for_model unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestInferProviderForModel:
|
||||
"""Unit tests for _infer_provider_for_model."""
|
||||
|
||||
def test_finds_model_in_user_providers_list(self):
|
||||
from hermes_cli.web_server import _infer_provider_for_model
|
||||
|
||||
disk_cfg = {
|
||||
"providers": {
|
||||
"my-openrouter": {
|
||||
"models": ["vendor/model-a", "vendor/model-b"],
|
||||
}
|
||||
}
|
||||
}
|
||||
assert _infer_provider_for_model("vendor/model-a", disk_cfg) == "my-openrouter"
|
||||
|
||||
def test_user_providers_dict_is_not_list(self):
|
||||
"""providers: {} should be a dict; non-dict values are skipped."""
|
||||
from hermes_cli.web_server import _infer_provider_for_model
|
||||
|
||||
disk_cfg = {"providers": {"bad": "not-a-dict"}}
|
||||
# should not crash; returns None for bad entries
|
||||
result = _infer_provider_for_model("anything", disk_cfg)
|
||||
assert result is None
|
||||
|
||||
def test_no_providers_falls_through_to_catalog(self):
|
||||
from hermes_cli.web_server import _infer_provider_for_model
|
||||
from hermes_cli.models import OPENROUTER_MODELS
|
||||
|
||||
if not OPENROUTER_MODELS:
|
||||
pytest.skip("OPENROUTER_MODELS is empty")
|
||||
model_id = OPENROUTER_MODELS[0][0]
|
||||
result = _infer_provider_for_model(model_id, {})
|
||||
assert result == "openrouter"
|
||||
|
||||
def test_unknown_model_returns_none(self):
|
||||
from hermes_cli.web_server import _infer_provider_for_model
|
||||
|
||||
result = _infer_provider_for_model("nonexistent/model-xyzzy-9999", {})
|
||||
assert result is None
|
||||
|
|
|
|||
|
|
@ -303,7 +303,16 @@ export default function ConfigPage() {
|
|||
schemaKey={key}
|
||||
schema={s}
|
||||
value={getNestedValue(config, key)}
|
||||
onChange={(v) => setConfig(setNestedValue(config, key, v))}
|
||||
onChange={(v) => {
|
||||
let next = setNestedValue(config, key, v);
|
||||
// When the model name changes, clear model_provider so the
|
||||
// backend can infer the correct provider for the new model
|
||||
// rather than inheriting the stale one.
|
||||
if (key === "model") {
|
||||
next = setNestedValue(next, "model_provider", "");
|
||||
}
|
||||
setConfig(next);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue