diff --git a/hermes_cli/web_server.py b/hermes_cli/web_server.py index 8c33a383e5..5166505f24 100644 --- a/hermes_cli/web_server.py +++ b/hermes_cli/web_server.py @@ -249,6 +249,11 @@ _SCHEMA_OVERRIDES: Dict[str, Dict[str, Any]] = { "description": "Context window override (0 = auto-detect from model metadata)", "category": "general", }, + "model_provider": { + "type": "string", + "description": "Provider for the model (e.g. openrouter, anthropic, ollama-local). Leave empty to auto-detect.", + "category": "general", + }, "terminal.backend": { "type": "select", "description": "Terminal execution backend", @@ -406,14 +411,16 @@ def _build_schema_from_config( CONFIG_SCHEMA = _build_schema_from_config(DEFAULT_CONFIG) # Inject virtual fields that don't live in DEFAULT_CONFIG but are surfaced -# by the normalize/denormalize cycle. Insert model_context_length right after -# the "model" key so it renders adjacent in the frontend. +# by the normalize/denormalize cycle. Insert model_context_length and +# model_provider right after the "model" key so they render adjacent. _mcl_entry = _SCHEMA_OVERRIDES["model_context_length"] +_mp_entry = _SCHEMA_OVERRIDES["model_provider"] _ordered_schema: Dict[str, Dict[str, Any]] = {} for _k, _v in CONFIG_SCHEMA.items(): _ordered_schema[_k] = _v if _k == "model": _ordered_schema["model_context_length"] = _mcl_entry + _ordered_schema["model_provider"] = _mp_entry CONFIG_SCHEMA = _ordered_schema @@ -791,18 +798,21 @@ def _normalize_config_for_web(config: Dict[str, Any]) -> Dict[str, Any]: from DEFAULT_CONFIG where ``model`` is a string, but user configs often have the dict form. Normalize to the string form so the frontend schema matches. - Also surfaces ``model_context_length`` as a top-level field so the web UI can - display and edit it. A value of 0 means "auto-detect". + Also surfaces ``model_context_length`` and ``model_provider`` as top-level virtual + fields so the web UI can display and edit them. ``model_context_length`` of 0 means + "auto-detect"; ``model_provider`` empty means "use default resolution". """ config = dict(config) # shallow copy model_val = config.get("model") if isinstance(model_val, dict): - # Extract context_length before flattening the dict + # Extract context_length and provider before flattening the dict ctx_len = model_val.get("context_length", 0) config["model"] = model_val.get("default", model_val.get("name", "")) config["model_context_length"] = ctx_len if isinstance(ctx_len, int) else 0 + config["model_provider"] = model_val.get("provider", "") else: config["model_context_length"] = 0 + config["model_provider"] = "" return config @@ -910,6 +920,38 @@ def get_model_info(): return dict(_EMPTY_MODEL_INFO) +def _infer_provider_for_model(model_name: str, disk_config: Dict[str, Any]) -> Optional[str]: + """Try to infer which configured provider serves a given model name. + + Checks user-configured providers' explicit model lists first, then the + built-in curated catalogs. Returns the provider slug on a deterministic + match, or ``None`` when no match is found (caller should preserve the + existing provider). + """ + # 1. User-configured providers (providers: {} dict) with explicit model lists + user_providers = disk_config.get("providers", {}) + if isinstance(user_providers, dict): + for slug, pdata in user_providers.items(): + if not isinstance(pdata, dict): + continue + models = pdata.get("models", []) + if isinstance(models, list) and model_name in models: + return str(slug) + + # 2. Built-in curated catalogs: openrouter first (largest aggregator), then others + try: + from hermes_cli.models import OPENROUTER_MODELS, _PROVIDER_MODELS + if model_name in {mid for mid, _ in OPENROUTER_MODELS}: + return "openrouter" + for provider_slug, models in _PROVIDER_MODELS.items(): + if model_name in models: + return provider_slug + except Exception: + pass + + return None + + def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: """Reverse _normalize_config_for_web before saving. @@ -921,13 +963,17 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: Also handles ``model_context_length`` — writes it back into the model dict as ``context_length``. A value of 0 or absent means "auto-detect" (omitted from the dict so get_model_context_length() uses its normal resolution). + + ``model_provider`` is a virtual field the frontend may send to explicitly + override the provider. When absent or empty, the provider is inferred from + the new model name if it changed, then falls back to the on-disk value. """ config = dict(config) # Remove any _model_meta that might have leaked in (shouldn't happen # with the stripped GET response, but be defensive) config.pop("_model_meta", None) - # Extract and remove model_context_length before processing model + # Extract and remove virtual fields before processing model ctx_override = config.pop("model_context_length", 0) if not isinstance(ctx_override, int): try: @@ -935,6 +981,9 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: except (TypeError, ValueError): ctx_override = 0 + # model_provider: explicit provider override from the frontend (may be "") + explicit_provider = str(config.pop("model_provider", "") or "").strip() + model_val = config.get("model") if isinstance(model_val, str) and model_val: # Read the current disk config to recover model subkeys @@ -942,8 +991,26 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: disk_config = load_config() disk_model = disk_config.get("model") if isinstance(disk_model, dict): + old_default = disk_model.get("default", "") # Preserve all subkeys, update default with the new value disk_model["default"] = model_val + + # Determine the correct provider for the (possibly new) model. + # Priority: explicit frontend override > inference > existing disk value. + disk_provider = disk_model.get("provider", "") + if explicit_provider and explicit_provider != disk_provider: + # Frontend sent an explicit provider change + disk_model["provider"] = explicit_provider + disk_model.pop("base_url", None) + disk_model.pop("api_mode", None) + elif model_val != old_default: + # Model name changed — try to infer the matching provider + inferred = _infer_provider_for_model(model_val, disk_config) + if inferred and inferred != disk_provider: + disk_model["provider"] = inferred + disk_model.pop("base_url", None) + disk_model.pop("api_mode", None) + # Write context_length into the model dict (0 = remove/auto) if ctx_override > 0: disk_model["context_length"] = ctx_override @@ -952,12 +1019,14 @@ def _denormalize_config_from_web(config: Dict[str, Any]) -> Dict[str, Any]: config["model"] = disk_model else: # Model was previously a bare string — upgrade to dict if - # user is setting a context_length override - if ctx_override > 0: - config["model"] = { - "default": model_val, - "context_length": ctx_override, - } + # user is setting a context_length override or explicit provider + if ctx_override > 0 or explicit_provider: + new_dict: Dict[str, Any] = {"default": model_val} + if ctx_override > 0: + new_dict["context_length"] = ctx_override + if explicit_provider: + new_dict["provider"] = explicit_provider + config["model"] = new_dict except Exception: pass # can't read disk config — just use the string form return config diff --git a/tests/hermes_cli/test_web_server.py b/tests/hermes_cli/test_web_server.py index e83f5bdeb3..63a5c50820 100644 --- a/tests/hermes_cli/test_web_server.py +++ b/tests/hermes_cli/test_web_server.py @@ -1925,3 +1925,270 @@ class TestPtyWebSocket: ): pass assert exc.value.code == 4400 + + +# --------------------------------------------------------------------------- +# model_provider virtual field — normalize / denormalize / schema +# --------------------------------------------------------------------------- + + +class TestModelProviderVirtualField: + """Tests for the model_provider virtual field surfaced by the config API. + + Regression: when the user picked a model from a different provider in the + Web UI, the provider was never updated — the stale on-disk provider was + preserved because _denormalize_config_from_web only updated `default`. + """ + + # ── normalize ────────────────────────────────────────────────────────── + + def test_normalize_extracts_provider_from_dict(self): + """normalize should surface provider from model dict as model_provider.""" + from hermes_cli.web_server import _normalize_config_for_web + + cfg = {"model": {"default": "google/gemini-3-flash", "provider": "openrouter"}} + result = _normalize_config_for_web(cfg) + assert result["model"] == "google/gemini-3-flash" + assert result["model_provider"] == "openrouter" + + def test_normalize_bare_string_yields_empty_provider(self): + """normalize should set model_provider='' for a bare string model.""" + from hermes_cli.web_server import _normalize_config_for_web + + result = _normalize_config_for_web({"model": "anthropic/claude-sonnet-4"}) + assert result["model_provider"] == "" + + def test_normalize_dict_without_provider_yields_empty(self): + """normalize should set model_provider='' when model dict has no provider.""" + from hermes_cli.web_server import _normalize_config_for_web + + cfg = {"model": {"default": "test/model"}} + result = _normalize_config_for_web(cfg) + assert result["model_provider"] == "" + + # ── denormalize: explicit model_provider override ────────────────────── + + def test_denormalize_explicit_provider_override_switches_provider(self): + """Sending model_provider should update the disk provider.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + } + }) + + result = _denormalize_config_from_web({ + "model": "google/gemini-3-flash", + "model_provider": "openrouter", + }) + assert isinstance(result["model"], dict) + assert result["model"]["default"] == "google/gemini-3-flash" + assert result["model"]["provider"] == "openrouter" + # provider-specific fields from old provider should be cleared + assert "base_url" not in result["model"] + assert "model_provider" not in result # virtual field consumed + + def test_denormalize_explicit_provider_same_as_disk_no_op(self): + """Sending model_provider matching disk provider should be a no-op.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "anthropic/claude-opus-4.6", + "provider": "openrouter", + } + }) + + result = _denormalize_config_from_web({ + "model": "anthropic/claude-sonnet-4.6", + "model_provider": "openrouter", + }) + assert result["model"]["provider"] == "openrouter" + + def test_denormalize_explicit_empty_provider_runs_inference(self): + """Empty model_provider should fall through to inference, not wipe provider.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + } + }) + + # Same model — no change → provider preserved + result = _denormalize_config_from_web({ + "model": "llama3.2", + "model_provider": "", + }) + assert result["model"]["provider"] == "ollama-local" + + # ── denormalize: inference when model changes ────────────────────────── + + def test_denormalize_model_change_infers_provider_from_user_providers_list(self): + """When model changes to one in user providers' models list, switch provider.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + }, + "providers": { + "openrouter": { + "models": ["google/gemini-2.5-flash", "anthropic/claude-opus-4.7"], + } + }, + }) + + result = _denormalize_config_from_web({ + "model": "google/gemini-2.5-flash", + "model_provider": "", + }) + assert result["model"]["provider"] == "openrouter" + assert result["model"]["default"] == "google/gemini-2.5-flash" + assert "base_url" not in result["model"] + + def test_denormalize_model_change_infers_provider_from_openrouter_catalog(self): + """When model changes to one in OPENROUTER_MODELS, switch to openrouter.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + from hermes_cli.models import OPENROUTER_MODELS + from unittest.mock import patch + + # Use the first model in OPENROUTER_MODELS as a known good entry + if not OPENROUTER_MODELS: + pytest.skip("OPENROUTER_MODELS is empty") + known_openrouter_model = OPENROUTER_MODELS[0][0] + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + }, + }) + + result = _denormalize_config_from_web({ + "model": known_openrouter_model, + "model_provider": "", + }) + assert result["model"]["provider"] == "openrouter" + assert result["model"]["default"] == known_openrouter_model + assert "base_url" not in result["model"] + + def test_denormalize_same_model_no_provider_change(self): + """When model stays the same, provider should not change.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + } + }) + + result = _denormalize_config_from_web({ + "model": "llama3.2", + "model_provider": "", + }) + assert result["model"]["provider"] == "ollama-local" + assert result["model"].get("base_url") == "http://localhost:11434/v1" + + def test_denormalize_unknown_model_preserves_provider(self): + """When model changes to an unknown model, keep existing provider.""" + from hermes_cli.web_server import _denormalize_config_from_web + from hermes_cli.config import save_config + + save_config({ + "model": { + "default": "llama3.2", + "provider": "ollama-local", + "base_url": "http://localhost:11434/v1", + } + }) + + result = _denormalize_config_from_web({ + "model": "totally-unknown-model:v999", + "model_provider": "", + }) + # Unknown model — provider should not change + assert result["model"]["provider"] == "ollama-local" + + # ── schema ───────────────────────────────────────────────────────────── + + def test_schema_has_model_provider(self): + """CONFIG_SCHEMA should include model_provider virtual field.""" + from hermes_cli.web_server import CONFIG_SCHEMA + assert "model_provider" in CONFIG_SCHEMA + + def test_schema_model_provider_after_model_context_length(self): + """model_provider should appear immediately after model_context_length.""" + from hermes_cli.web_server import CONFIG_SCHEMA + keys = list(CONFIG_SCHEMA.keys()) + mcl_idx = keys.index("model_context_length") + assert keys[mcl_idx + 1] == "model_provider" + + def test_schema_model_provider_is_string(self): + """model_provider schema entry should have type=string.""" + from hermes_cli.web_server import CONFIG_SCHEMA + entry = CONFIG_SCHEMA["model_provider"] + assert entry["type"] == "string" + assert entry["category"] == "general" + + +# --------------------------------------------------------------------------- +# _infer_provider_for_model unit tests +# --------------------------------------------------------------------------- + + +class TestInferProviderForModel: + """Unit tests for _infer_provider_for_model.""" + + def test_finds_model_in_user_providers_list(self): + from hermes_cli.web_server import _infer_provider_for_model + + disk_cfg = { + "providers": { + "my-openrouter": { + "models": ["vendor/model-a", "vendor/model-b"], + } + } + } + assert _infer_provider_for_model("vendor/model-a", disk_cfg) == "my-openrouter" + + def test_user_providers_dict_is_not_list(self): + """providers: {} should be a dict; non-dict values are skipped.""" + from hermes_cli.web_server import _infer_provider_for_model + + disk_cfg = {"providers": {"bad": "not-a-dict"}} + # should not crash; returns None for bad entries + result = _infer_provider_for_model("anything", disk_cfg) + assert result is None + + def test_no_providers_falls_through_to_catalog(self): + from hermes_cli.web_server import _infer_provider_for_model + from hermes_cli.models import OPENROUTER_MODELS + + if not OPENROUTER_MODELS: + pytest.skip("OPENROUTER_MODELS is empty") + model_id = OPENROUTER_MODELS[0][0] + result = _infer_provider_for_model(model_id, {}) + assert result == "openrouter" + + def test_unknown_model_returns_none(self): + from hermes_cli.web_server import _infer_provider_for_model + + result = _infer_provider_for_model("nonexistent/model-xyzzy-9999", {}) + assert result is None diff --git a/web/src/pages/ConfigPage.tsx b/web/src/pages/ConfigPage.tsx index 80cef29e4c..29e82cafe3 100644 --- a/web/src/pages/ConfigPage.tsx +++ b/web/src/pages/ConfigPage.tsx @@ -303,7 +303,16 @@ export default function ConfigPage() { schemaKey={key} schema={s} value={getNestedValue(config, key)} - onChange={(v) => setConfig(setNestedValue(config, key, v))} + onChange={(v) => { + let next = setNestedValue(config, key, v); + // When the model name changes, clear model_provider so the + // backend can infer the correct provider for the new model + // rather than inheriting the stale one. + if (key === "model") { + next = setNestedValue(next, "model_provider", ""); + } + setConfig(next); + }} />