diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 4f9c8dd868..1f23b2ac33 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1529,19 +1529,80 @@ def select_provider_and_model(args=None): def _named_custom_provider_map(cfg) -> dict[str, dict[str, str]]: from hermes_cli.config import read_raw_config - def _identity(entry): - return ( - str(entry.get("provider_key", "") or "").strip(), - str(entry.get("name", "") or "").strip(), - str(entry.get("base_url", "") or "").strip().rstrip("/"), - str(entry.get("model", "") or "").strip(), - ) + # Build a lookup of raw (un-expanded) api_key templates keyed by a + # stable identity. We intentionally bypass + # ``get_compatible_custom_providers(read_raw_config())`` here because + # its ``_normalize_custom_provider_entry`` step calls ``urlparse()`` + # on ``base_url`` and drops any entry whose ``base_url`` is itself an + # env-ref template (e.g. ``${NEURALWATT_API_BASE}``). Dropping those + # entries is exactly how env-ref preservation fails for the user + # config that motivated this fix. + raw_api_key_refs: dict[tuple, str] = {} + raw_cfg = read_raw_config() - raw_api_key_refs = {} - for raw_entry in get_compatible_custom_providers(read_raw_config()): - raw_api_key = str(raw_entry.get("api_key", "") or "").strip() - if "${" in raw_api_key: - raw_api_key_refs[_identity(raw_entry)] = raw_api_key + def _record_raw( + name: str, + provider_key: str, + model: str, + api_key: str, + ) -> None: + template = str(api_key or "").strip() + if "${" not in template: + return + name = str(name or "").strip() + provider_key = str(provider_key or "").strip() + model = str(model or "").strip() + # Index by every plausible identity the loaded (expanded) config + # might present: (name), (name, model), (provider_key), and + # (provider_key, model). Case-insensitive on name/provider_key so + # the loaded entry matches regardless of display casing. + if name: + raw_api_key_refs.setdefault((name.lower(),), template) + raw_api_key_refs.setdefault((name.lower(), model), template) + if provider_key: + raw_api_key_refs.setdefault((provider_key.lower(),), template) + raw_api_key_refs.setdefault( + (provider_key.lower(), model), template + ) + + raw_list = raw_cfg.get("custom_providers") + if isinstance(raw_list, list): + for raw_entry in raw_list: + if not isinstance(raw_entry, dict): + continue + _record_raw( + raw_entry.get("name", ""), + "", + raw_entry.get("model", "") + or raw_entry.get("default_model", ""), + raw_entry.get("api_key", ""), + ) + raw_providers = raw_cfg.get("providers") + if isinstance(raw_providers, dict): + for raw_key, raw_entry in raw_providers.items(): + if not isinstance(raw_entry, dict): + continue + _record_raw( + raw_entry.get("name", "") or raw_key, + raw_key, + raw_entry.get("model", "") + or raw_entry.get("default_model", ""), + raw_entry.get("api_key", ""), + ) + + def _lookup_ref(name: str, provider_key: str, model: str) -> str: + name_lc = str(name or "").strip().lower() + pkey_lc = str(provider_key or "").strip().lower() + model = str(model or "").strip() + for identity in ( + (pkey_lc, model), + (pkey_lc,), + (name_lc, model), + (name_lc,), + ): + if identity[0] and identity in raw_api_key_refs: + return raw_api_key_refs[identity] + return "" custom_provider_map = {} for entry in get_compatible_custom_providers(cfg): @@ -1566,7 +1627,9 @@ def select_provider_and_model(args=None): "model": entry.get("model", ""), "api_mode": entry.get("api_mode", ""), "provider_key": provider_key, - "api_key_ref": raw_api_key_refs.get(_identity(entry), ""), + "api_key_ref": _lookup_ref( + name, provider_key, entry.get("model", "") + ), } return custom_provider_map diff --git a/tests/hermes_cli/test_custom_provider_model_switch.py b/tests/hermes_cli/test_custom_provider_model_switch.py index 8235c93087..57706f2172 100644 --- a/tests/hermes_cli/test_custom_provider_model_switch.py +++ b/tests/hermes_cli/test_custom_provider_model_switch.py @@ -257,3 +257,68 @@ class TestCustomProviderModelSwitch: assert config["model"]["api_key"] == "${EXAMPLE_PROVIDER_API_KEY}" assert config["custom_providers"][0]["key_env"] == "EXAMPLE_PROVIDER_API_KEY" assert "sk-live-example-provider" not in config_path.read_text() + + def test_env_ref_base_url_preserves_api_key_ref_through_picker( + self, config_home, monkeypatch + ): + """Integration regression: when BOTH ``base_url`` and ``api_key`` use + ``${VAR}`` templates (the Discord-reported NeuralWatt case), the picker + must still preserve the env reference in ``model.api_key``. + + The earlier lookup went through ``get_compatible_custom_providers`` + which dropped entries whose ``base_url`` was an env-ref template + (``urlparse("${NEURALWATT_API_BASE}")`` has no scheme/netloc), causing + ``api_key_ref`` to stay empty and the resolved secret to be written to + ``config.yaml``. This test drives the real picker-callsite code path. + """ + import yaml + from hermes_cli.main import select_provider_and_model + + config_path = config_home / "config.yaml" + config_path.write_text( + "model:\n" + " default: old-model\n" + " provider: openrouter\n" + "custom_providers:\n" + "- name: NeuralWatt\n" + " base_url: ${NEURALWATT_API_BASE}\n" + " api_key: ${NEURALWATT_API_KEY}\n" + " model: qwen3.6-35b-fast\n" + " models: []\n" + ) + monkeypatch.setenv("NEURALWATT_API_BASE", "https://api.neuralwatt.com/v1") + monkeypatch.setenv("NEURALWATT_API_KEY", "sk-live-neuralwatt-secret") + + # Exercise the real picker: select "custom:neuralwatt" from the + # provider menu. ``select_provider_and_model`` prompts for a provider + # choice (returns an index), then hands off to + # ``_model_flow_named_custom`` with the provider_info built by + # ``_named_custom_provider_map``. + def _pick_neuralwatt(labels, default=0): + for i, label in enumerate(labels): + if "NeuralWatt" in label: + return i + raise AssertionError( + f"NeuralWatt entry missing from provider menu: {labels}" + ) + + with patch("hermes_cli.main._prompt_provider_choice", + side_effect=_pick_neuralwatt), \ + patch("hermes_cli.models.fetch_api_models", + return_value=["qwen3.6-35b-fast"]) as mock_fetch, \ + patch.dict("sys.modules", {"simple_term_menu": None}), \ + patch("builtins.input", return_value="1"), \ + patch("builtins.print"): + select_provider_and_model() + + # The live probe must still use the resolved secret. + mock_fetch.assert_called_once() + probe_args, probe_kwargs = mock_fetch.call_args + assert probe_args[0] == "sk-live-neuralwatt-secret" + + # But config.yaml must keep the env reference, not the plaintext secret. + saved = config_path.read_text() + config = yaml.safe_load(saved) or {} + assert config["model"]["api_key"] == "${NEURALWATT_API_KEY}" + assert config["custom_providers"][0]["api_key"] == "${NEURALWATT_API_KEY}" + assert "sk-live-neuralwatt-secret" not in saved