diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 4066b59910f..1a14a1e0fe9 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1832,52 +1832,10 @@ def select_provider_and_model(args=None): config_provider or os.getenv("HERMES_INFERENCE_PROVIDER") or "auto" ) compatible_custom_providers = get_compatible_custom_providers(config) - active = None - if effective_provider != "auto": - active_def = resolve_provider_full( - effective_provider, - config.get("providers"), - compatible_custom_providers, - ) - if active_def is not None: - active = active_def.id - else: - warning = ( - f"Unknown provider '{effective_provider}'. Check 'hermes model' for " - "available providers, or run 'hermes doctor' to diagnose config " - "issues." - ) - print(f"Warning: {warning} Falling back to auto provider detection.") - if active is None: - try: - active = resolve_provider("auto") - except AuthError as exc: - if effective_provider == "auto": - warning = format_auth_error(exc) - print(f"Warning: {warning} Falling back to auto provider detection.") - active = None # no provider yet; default to first in list - - # Detect custom endpoint - if active == "openrouter" and get_env_value("OPENAI_BASE_URL"): - active = "custom" - - from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS - - provider_labels = dict(_PROVIDER_LABELS) # derive from canonical list - active_label = provider_labels.get(active, active) if active else "none" - - print() - print(f" Current model: {current_model}") - print(f" Active provider: {active_label}") - print() - - # Step 1: Provider selection — flat list from CANONICAL_PROVIDERS - all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS] - def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]: from hermes_cli.config import read_raw_config - # Build a lookup of raw (un-expanded) api_key templates keyed by a + # Build lookups of raw (un-expanded) templates keyed by a # stable identity. We intentionally bypass # ``get_compatible_custom_providers(read_raw_config())`` here because # its ``_normalize_custom_provider_entry`` step calls ``urlparse()`` @@ -1886,6 +1844,7 @@ def select_provider_and_model(args=None): # entries is exactly how env-ref preservation fails for the user # config that motivated this fix. raw_api_key_refs: dict[tuple, str] = {} + raw_base_url_refs: dict[tuple, str] = {} raw_cfg = read_raw_config() def _record_raw( @@ -1893,10 +1852,10 @@ def select_provider_and_model(args=None): provider_key: str, model: str, api_key: str, + base_url: str, ) -> None: template = str(api_key or "").strip() - if "${" not in template: - return + base_template = str(base_url or "").strip() name = str(name or "").strip() provider_key = str(provider_key or "").strip() model = str(model or "").strip() @@ -1904,12 +1863,19 @@ def select_provider_and_model(args=None): # might present: (name), (name, model), (provider_key), and # (provider_key, model). Case-insensitive on name/provider_key so # the loaded entry matches regardless of display casing. + identities = [] if name: - raw_api_key_refs.setdefault((name.lower(),), template) - raw_api_key_refs.setdefault((name.lower(), model), template) + identities.extend(((name.lower(),), (name.lower(), model))) if provider_key: - raw_api_key_refs.setdefault((provider_key.lower(),), template) - raw_api_key_refs.setdefault((provider_key.lower(), model), template) + identities.extend( + ((provider_key.lower(),), (provider_key.lower(), model)) + ) + if "${" in template: + for identity in identities: + raw_api_key_refs.setdefault(identity, template) + if "${" in base_template: + for identity in identities: + raw_base_url_refs.setdefault(identity, base_template) raw_list = raw_cfg.get("custom_providers") if isinstance(raw_list, list): @@ -1921,6 +1887,9 @@ def select_provider_and_model(args=None): "", raw_entry.get("model", "") or raw_entry.get("default_model", ""), raw_entry.get("api_key", ""), + raw_entry.get("base_url", "") + or raw_entry.get("url", "") + or raw_entry.get("api", ""), ) raw_providers = raw_cfg.get("providers") if isinstance(raw_providers, dict): @@ -1932,9 +1901,17 @@ def select_provider_and_model(args=None): raw_key, raw_entry.get("model", "") or raw_entry.get("default_model", ""), raw_entry.get("api_key", ""), + raw_entry.get("base_url", "") + or raw_entry.get("url", "") + or raw_entry.get("api", ""), ) - def _lookup_ref(name: str, provider_key: str, model: str) -> str: + def _lookup_ref( + refs: dict[tuple, str], + name: str, + provider_key: str, + model: str, + ) -> str: name_lc = str(name or "").strip().lower() pkey_lc = str(provider_key or "").strip().lower() model = str(model or "").strip() @@ -1944,8 +1921,8 @@ def select_provider_and_model(args=None): (name_lc, model), (name_lc,), ): - if identity[0] and identity in raw_api_key_refs: - return raw_api_key_refs[identity] + if identity[0] and identity in refs: + return refs[identity] return "" custom_provider_map = {} @@ -1971,14 +1948,81 @@ def select_provider_and_model(args=None): "model": entry.get("model", ""), "api_mode": entry.get("api_mode", ""), "provider_key": provider_key, - "api_key_ref": _lookup_ref(name, provider_key, entry.get("model", "")), + "api_key_ref": _lookup_ref( + raw_api_key_refs, name, provider_key, entry.get("model", "") + ), + "base_url_ref": _lookup_ref( + raw_base_url_refs, name, provider_key, entry.get("model", "") + ), } return custom_provider_map + def _norm_base_url(url: str) -> str: + return str(url or "").strip().rstrip("/").lower() + # Add user-defined custom providers from config.yaml _custom_provider_map = _named_custom_provider_map( config ) # key → {name, base_url, api_key} + + def _active_custom_key_from_base_url() -> str: + if effective_provider != "custom" or not isinstance(model_cfg, dict): + return "" + current_base = _norm_base_url(model_cfg.get("base_url", "")) + if not current_base: + return "" + for key, provider_info in _custom_provider_map.items(): + if _norm_base_url(provider_info.get("base_url", "")) == current_base: + return key + return "" + + active = _active_custom_key_from_base_url() + if active is None: + active = "" + if not active and effective_provider != "auto": + active_def = resolve_provider_full( + effective_provider, + config.get("providers"), + compatible_custom_providers, + ) + if active_def is not None: + active = active_def.id + else: + warning = ( + f"Unknown provider '{effective_provider}'. Check 'hermes model' for " + "available providers, or run 'hermes doctor' to diagnose config " + "issues." + ) + print(f"Warning: {warning} Falling back to auto provider detection.") + if not active: + try: + active = resolve_provider("auto") + except AuthError as exc: + if effective_provider == "auto": + warning = format_auth_error(exc) + print(f"Warning: {warning} Falling back to auto provider detection.") + active = None # no provider yet; default to first in list + + # Detect custom endpoint + if active == "openrouter" and get_env_value("OPENAI_BASE_URL"): + active = "custom" + + from hermes_cli.models import CANONICAL_PROVIDERS, _PROVIDER_LABELS + + provider_labels = dict(_PROVIDER_LABELS) # derive from canonical list + if active and active in _custom_provider_map: + active_label = _custom_provider_map[active]["name"] + else: + active_label = provider_labels.get(active, active) if active else "none" + + print() + print(f" Current model: {current_model}") + print(f" Active provider: {active_label}") + print() + + # Step 1: Provider selection — flat list from CANONICAL_PROVIDERS + all_providers = [(p.slug, p.tui_desc) for p in CANONICAL_PROVIDERS] + for key, provider_info in _custom_provider_map.items(): name = provider_info["name"] base_url = provider_info["base_url"] @@ -3501,6 +3545,14 @@ def _custom_provider_api_key_config_value(provider_info, resolved_api_key=""): return str(resolved_api_key or "").strip() +def _custom_provider_base_url_config_value(provider_info, resolved_base_url=""): + """Return the value that should be persisted for a custom provider URL.""" + base_url_ref = str(provider_info.get("base_url_ref", "") or "").strip() + if base_url_ref: + return base_url_ref + return str(resolved_base_url or "").strip() + + def _save_custom_provider( base_url, api_key="", model="", context_length=None, name=None, api_mode=None ): @@ -4114,7 +4166,9 @@ def _model_flow_named_custom(config, provider_info): model.pop("api_key", None) else: model["provider"] = "custom" - model["base_url"] = base_url + model["base_url"] = _custom_provider_base_url_config_value( + provider_info, base_url + ) if config_api_key: model["api_key"] = config_api_key # Apply api_mode from custom_providers entry, or clear stale value diff --git a/tests/hermes_cli/test_custom_provider_model_switch.py b/tests/hermes_cli/test_custom_provider_model_switch.py index d123120ed83..1c14b848439 100644 --- a/tests/hermes_cli/test_custom_provider_model_switch.py +++ b/tests/hermes_cli/test_custom_provider_model_switch.py @@ -327,6 +327,118 @@ class TestCustomProviderModelSwitch: assert config["custom_providers"][0]["api_key"] == "${NEURALWATT_API_KEY}" assert "sk-live-neuralwatt-secret" not in saved + def test_bare_custom_current_provider_matches_env_base_url_before_first_fallback( + self, config_home, monkeypatch + ): + """`hermes model` must mark the custom provider matching model.base_url + as current instead of falling back to the first saved custom provider. + + Regression: with ``model.provider: custom`` and multiple + ``custom_providers`` entries, the CLI resolved bare ``custom`` through + ``resolve_custom_provider()``, whose compatibility fallback returns the + first entry. A config with Cerebras first and NeuralWatt active then + showed Cerebras as current. + """ + from hermes_cli.main import select_provider_and_model + + config_path = config_home / "config.yaml" + config_path.write_text( + "model:\n" + " default: kimi-k2.6-fast\n" + " provider: custom\n" + " base_url: ${NEURALWATT_API_BASE}\n" + " api_key: ${NEURALWATT_API_KEY}\n" + "providers: {}\n" + "custom_providers:\n" + "- name: Cerebras.ai\n" + " base_url: ${CEREBRAS_API_BASE}\n" + " api_key: ${CEREBRAS_API_KEY}\n" + " model: qwen-3-235b-a22b-instruct-2507\n" + " models: []\n" + "- name: NeuralWatt\n" + " base_url: ${NEURALWATT_API_BASE}\n" + " api_key: ${NEURALWATT_API_KEY}\n" + " model: kimi-k2.6-fast\n" + " models: []\n" + ) + monkeypatch.setenv("CEREBRAS_API_BASE", "https://api.cerebras.ai/v1") + monkeypatch.setenv("CEREBRAS_API_KEY", "sk-live-cerebras-secret") + monkeypatch.setenv("NEURALWATT_API_BASE", "https://api.neuralwatt.com/v1") + monkeypatch.setenv("NEURALWATT_API_KEY", "sk-live-neuralwatt-secret") + + captured: dict = {} + + def _capture_and_cancel(labels, default=0): + captured["labels"] = labels + captured["default"] = default + return len(labels) - 1 # Leave unchanged + + with patch("hermes_cli.main._prompt_provider_choice", + side_effect=_capture_and_cancel), \ + patch("builtins.print"): + select_provider_and_model() + + labels = captured["labels"] + default_label = labels[captured["default"]] + assert "NeuralWatt" in default_label + assert "currently active" in default_label + assert "Cerebras.ai" not in default_label + assert not any( + "Cerebras.ai" in label and "currently active" in label + for label in labels + ) + + def test_named_custom_provider_selection_preserves_base_url_env_ref( + self, config_home, monkeypatch + ): + """Selecting an env-backed custom provider should not expand its + ``base_url`` template into ``model.base_url`` on disk.""" + import yaml + from hermes_cli.main import select_provider_and_model + + config_path = config_home / "config.yaml" + config_path.write_text( + "model:\n" + " default: old-model\n" + " provider: openrouter\n" + "custom_providers:\n" + "- name: NeuralWatt\n" + " base_url: ${NEURALWATT_API_BASE}\n" + " api_key: ${NEURALWATT_API_KEY}\n" + " model: qwen3.6-35b-fast\n" + " models: []\n" + ) + monkeypatch.setenv("NEURALWATT_API_BASE", "https://api.neuralwatt.com/v1") + monkeypatch.setenv("NEURALWATT_API_KEY", "sk-live-neuralwatt-secret") + + def _pick_neuralwatt(labels, default=0): + for i, label in enumerate(labels): + if "NeuralWatt" in label: + return i + raise AssertionError( + f"NeuralWatt entry missing from provider menu: {labels}" + ) + + with patch("hermes_cli.main._prompt_provider_choice", + side_effect=_pick_neuralwatt), \ + patch("hermes_cli.models.fetch_api_models", + return_value=["qwen3.6-35b-fast"]) as mock_fetch, \ + patch.dict("sys.modules", {"simple_term_menu": None}), \ + patch("builtins.input", return_value="1"), \ + patch("builtins.print"): + select_provider_and_model() + + mock_fetch.assert_called_once() + probe_args, _ = mock_fetch.call_args + assert probe_args[1] == "https://api.neuralwatt.com/v1" + + saved = config_path.read_text() + config = yaml.safe_load(saved) or {} + assert config["model"]["base_url"] == "${NEURALWATT_API_BASE}" + assert config["model"]["api_key"] == "${NEURALWATT_API_KEY}" + assert "https://api.neuralwatt.com/v1" not in saved + assert "sk-live-neuralwatt-secret" not in saved + def test_key_env_providers_dict_entry_does_not_add_api_key( self, config_home, monkeypatch ):