diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 7de68d2cb4..2180eb0311 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1792,6 +1792,53 @@ def _reset_aux_to_auto() -> int: return count +def _set_builtin_provider_config( + cfg: dict, + provider: str, + base_url: str = "", + api_mode: str = "", +) -> dict: + """Set the main model's provider routing fields and clear stale credentials. + + Normalizes ``cfg["model"]`` to a dict (converting bare-string models), + sets ``provider`` / ``base_url`` / ``api_mode``, and **always** pops + ``api_key``. Built-in providers resolve their keys from env vars or + the credential pool at runtime — a leftover ``api_key`` from a prior + provider causes credential drift and 401 errors. + + Does NOT call ``save_config()`` — the caller is responsible for saving + after setting any additional fields (e.g. bedrock region, reasoning + effort). + + Returns the normalized model dict so callers can inspect or modify it + further before saving. + + Mirrors auth.py ``set_provider_in_config`` (line ~2764). (#14134) + """ + model = cfg.get("model") + if not isinstance(model, dict): + model = {"default": model} if model else {} + cfg["model"] = model + + model["provider"] = provider + + if base_url: + model["base_url"] = base_url + else: + model.pop("base_url", None) + + if api_mode: + model["api_mode"] = api_mode + else: + model.pop("api_mode", None) + + # Always clear stale api_key. Built-in providers never store their + # key in config.yaml — they use env vars / credential pool. + model.pop("api_key", None) + + return model + + def _aux_config_menu() -> None: """Top-level auxiliary-model picker — choose a task to configure. @@ -2088,17 +2135,10 @@ def _model_flow_openrouter(config, current_model=""): if selected: _save_model_choice(selected) - # Update config provider and deactivate any OAuth provider from hermes_cli.config import load_config, save_config cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "openrouter" - model["base_url"] = OPENROUTER_BASE_URL - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, "openrouter", OPENROUTER_BASE_URL, "chat_completions") save_config(cfg) deactivate_provider() print(f"Default model set to: {selected} (via OpenRouter)") @@ -2149,13 +2189,7 @@ def _model_flow_ai_gateway(config, current_model=""): from hermes_cli.config import load_config, save_config cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "ai-gateway" - model["base_url"] = AI_GATEWAY_BASE_URL - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, "ai-gateway", AI_GATEWAY_BASE_URL, "chat_completions") save_config(cfg) deactivate_provider() print(f"Default model set to: {selected} (via Vercel AI Gateway)") @@ -2318,6 +2352,7 @@ def _model_flow_nous(config, current_model="", args=None): model_cfg = {} model_cfg["provider"] = "nous" model_cfg["default"] = selected + model_cfg.pop("api_key", None) # built-in provider — key comes from credential pool if inference_url and inference_url.strip(): model_cfg["base_url"] = inference_url.rstrip("/") else: @@ -3310,16 +3345,11 @@ def _model_flow_copilot(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model["api_mode"] = copilot_model_api_mode( - selected, - catalog=catalog, - api_key=api_key, + _set_builtin_provider_config( + cfg, + provider_id, + effective_base, + copilot_model_api_mode(selected, catalog=catalog, api_key=api_key), ) if selected_effort is not None: _set_reasoning_effort(cfg, selected_effort) @@ -3436,13 +3466,7 @@ def _model_flow_copilot_acp(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, provider_id, effective_base, "chat_completions") save_config(cfg) deactivate_provider() @@ -3544,13 +3568,7 @@ def _model_flow_kimi(config, current_model=""): # Update config with provider and base URL cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model.pop("api_mode", None) # let runtime auto-detect from URL + _set_builtin_provider_config(cfg, provider_id, effective_base) # no api_mode — runtime auto-detects from URL save_config(cfg) deactivate_provider() @@ -3678,17 +3696,11 @@ def _model_flow_stepfun(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model.pop("api_mode", None) + _set_builtin_provider_config(cfg, provider_id, effective_base) # no api_mode — runtime auto-detects save_config(cfg) deactivate_provider() - config["model"] = dict(model) + config["model"] = cfg["model"] print(f"Default model set to: {selected} (via {pconfig.name})") else: print("No change.") @@ -3754,13 +3766,7 @@ def _model_flow_bedrock_api_key(config, region, current_model=""): # Save as custom provider pointing to bedrock-mantle cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "custom" - model["base_url"] = mantle_base_url - model.pop("api_mode", None) # chat_completions is the default + _set_builtin_provider_config(cfg, "custom", mantle_base_url) # no api_mode — chat_completions is default # Also save region in bedrock config for reference bedrock_cfg = cfg.get("bedrock", {}) @@ -3943,13 +3949,11 @@ def _model_flow_bedrock(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "bedrock" - model["base_url"] = f"https://bedrock-runtime.{region}.amazonaws.com" - model.pop("api_mode", None) # bedrock_converse is auto-detected + _set_builtin_provider_config( + cfg, + "bedrock", + f"https://bedrock-runtime.{region}.amazonaws.com", + ) # no api_mode — bedrock_converse is auto-detected bedrock_cfg = cfg.get("bedrock", {}) if not isinstance(bedrock_cfg, dict): @@ -4181,16 +4185,12 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): # Update config with provider, base URL, and provider-specific API mode cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - if provider_id in {"opencode-zen", "opencode-go"}: - model["api_mode"] = opencode_model_api_mode(provider_id, selected) - else: - model.pop("api_mode", None) + effective_api_mode = ( + opencode_model_api_mode(provider_id, selected) + if provider_id in {"opencode-zen", "opencode-go"} + else "" + ) + _set_builtin_provider_config(cfg, provider_id, effective_base, effective_api_mode) save_config(cfg) deactivate_provider() @@ -4417,12 +4417,7 @@ def _model_flow_anthropic(config, current_model=""): # Leaving a stale base_url in config can contaminate other # providers if the user switches without running 'hermes model'. cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "anthropic" - model.pop("base_url", None) + _set_builtin_provider_config(cfg, "anthropic") save_config(cfg) deactivate_provider() diff --git a/tests/hermes_cli/test_api_key_drift_provider_switch.py b/tests/hermes_cli/test_api_key_drift_provider_switch.py new file mode 100644 index 0000000000..51a9d8eedf --- /dev/null +++ b/tests/hermes_cli/test_api_key_drift_provider_switch.py @@ -0,0 +1,309 @@ +"""Tests that _set_builtin_provider_config clears stale api_key. + +Regression test for #14134: when switching from one provider to another, +the old provider's api_key was left in model.api_key, causing credential +drift — the new provider would try to use the old provider's key and get +401 errors. + +The helper _set_builtin_provider_config is the single source of truth +for built-in provider config updates. All model-flow functions that +set provider/base_url/api_mode should route through it. +""" + +import os +from unittest.mock import patch + +import pytest + + +@pytest.fixture +def config_home(tmp_path, monkeypatch): + """Isolated HERMES_HOME with config that has a stale api_key.""" + home = tmp_path / "hermes" + home.mkdir() + env_file = home / ".env" + env_file.write_text("") + monkeypatch.setenv("HERMES_HOME", str(home)) + # Clear env vars that could interfere + monkeypatch.delenv("HERMES_MODEL", raising=False) + monkeypatch.delenv("LLM_MODEL", raising=False) + monkeypatch.delenv("HERMES_INFERENCE_PROVIDER", raising=False) + monkeypatch.delenv("GITHUB_TOKEN", raising=False) + monkeypatch.delenv("GH_TOKEN", raising=False) + monkeypatch.delenv("OPENAI_BASE_URL", raising=False) + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + monkeypatch.delenv("STEPFUN_API_KEY", raising=False) + monkeypatch.delenv("STEPFUN_BASE_URL", raising=False) + return home + + +def _write_config(home, model_dict): + """Write a config.yaml with the given model dict.""" + import yaml + config_yaml = home / "config.yaml" + config_yaml.write_text(yaml.dump({"model": model_dict})) + + +def _read_model_config(home): + """Read the model section from config.yaml.""" + import yaml + config = yaml.safe_load((home / "config.yaml").read_text()) or {} + return config.get("model", {}) + + +# ── Helper-level tests ────────────────────────────────────────────── + + +class TestSetBuiltinProviderConfig: + """Direct tests for _set_builtin_provider_config.""" + + def test_api_key_cleared(self, config_home): + """api_key from a previous provider must be popped.""" + _write_config(config_home, { + "default": "old-model", + "provider": "custom", + "api_key": "sk-stale-key", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + model = cfg["model"] + assert model["provider"] == "openrouter" + assert "api_key" not in model, f"api_key should be popped, found: {model.get('api_key')}" + + def test_base_url_set_when_provided(self, config_home): + """base_url is set when explicitly provided.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + assert cfg["model"]["base_url"] == "https://openrouter.ai/v1" + + def test_base_url_cleared_when_empty(self, config_home): + """base_url is popped when empty string is passed (e.g. anthropic).""" + _write_config(config_home, { + "default": "m", + "provider": "custom", + "base_url": "https://stale.example.com/v1", + "api_key": "sk-old", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "anthropic") + + assert "base_url" not in cfg["model"], "base_url should be cleared for anthropic" + + def test_api_mode_set_when_provided(self, config_home): + """api_mode is set when explicitly provided.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + assert cfg["model"]["api_mode"] == "chat_completions" + + def test_api_mode_cleared_when_empty(self, config_home): + """api_mode is popped when empty string is passed (e.g. kimi, stepfun).""" + _write_config(config_home, { + "default": "m", + "provider": "custom", + "api_mode": "anthropic_messages", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "stepfun", "https://api.stepfun.com/v1") + + assert "api_mode" not in cfg["model"], "api_mode should be cleared when empty" + + def test_string_model_normalized_to_dict(self, config_home): + """A bare-string model value is normalized to a dict.""" + _write_config(config_home, "old-model-string") + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "ai-gateway", "https://gateway.ai/v1", "chat_completions") + + model = cfg["model"] + assert isinstance(model, dict) + assert model.get("provider") == "ai-gateway" + + def test_no_api_key_no_error(self, config_home): + """Pop on a config without api_key should not raise.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "nous", "https://api.nous.com/v1") + + assert cfg["model"]["provider"] == "nous" + assert "api_key" not in cfg["model"] + + def test_returns_model_dict(self, config_home): + """The helper returns the normalized model dict for further modification.""" + _write_config(config_home, {"default": "m"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + result = _set_builtin_provider_config(cfg, "bedrock", "https://bedrock.us-east-1.amazonaws.com") + + assert isinstance(result, dict) + assert result["provider"] == "bedrock" + # Can still modify before saving + result["bedrock_region"] = "us-east-1" + assert cfg["model"]["bedrock_region"] == "us-east-1" + + +# ── Integration: api_key_provider flow ────────────────────────────── + + +class TestApiKeyDriftOnProviderSwitch: + """Switching from one api-key provider to another must clear the + stale api_key from the model config dict.""" + + def test_api_key_cleared_on_provider_switch(self, config_home, monkeypatch): + """Start with model.api_key from provider A, + switch to provider B — api_key must be popped.""" + from hermes_cli.auth import PROVIDER_REGISTRY + + pconfig = PROVIDER_REGISTRY.get("zai") + if not pconfig: + pytest.skip("zai not in PROVIDER_REGISTRY") + + _write_config(config_home, { + "default": "some-old-model", + "provider": "ollama-cloud", + "base_url": "https://api.ola.cloud/v1", + "api_key": "sk-stale", + }) + + monkeypatch.setenv("GLM_API_KEY", "test-key") + + from hermes_cli.main import _model_flow_api_key_provider + from hermes_cli.config import load_config + + with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \ + patch("hermes_cli.auth.deactivate_provider"), \ + patch("builtins.input", return_value=""): + _model_flow_api_key_provider(load_config(), "zai", "some-old-model") + + model = _read_model_config(config_home) + assert isinstance(model, dict) + assert model.get("provider") == "zai" + assert "api_key" not in model, ( + f"api_key should be cleared on provider switch, found: {model.get('api_key')}" + ) + + def test_api_mode_also_cleared_on_non_opencode_switch(self, config_home, monkeypatch): + """A stale api_mode from a previous custom provider must also + be cleared when switching to a non-opencode provider.""" + from hermes_cli.auth import PROVIDER_REGISTRY + + pconfig = PROVIDER_REGISTRY.get("zai") + if not pconfig: + pytest.skip("zai not in PROVIDER_REGISTRY") + + _write_config(config_home, { + "default": "custom-model", + "provider": "custom", + "base_url": "https://custom.old/v1", + "api_key": "sk-stale", + "api_mode": "anthropic_messages", + }) + + monkeypatch.setenv("GLM_API_KEY", "test-key") + + from hermes_cli.main import _model_flow_api_key_provider + from hermes_cli.config import load_config + + with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \ + patch("hermes_cli.auth.deactivate_provider"), \ + patch("builtins.input", return_value=""): + _model_flow_api_key_provider(load_config(), "zai", "custom-model") + + model = _read_model_config(config_home) + assert isinstance(model, dict) + assert "api_key" not in model + assert "api_mode" not in model, ( + f"api_mode should be cleared for non-opencode, got: {model.get('api_mode')}" + ) + + def test_switch_preserves_default_model(self, config_home, monkeypatch): + """The model.default should be updated to the new selection even + when there was a stale api_key.""" + from hermes_cli.auth import PROVIDER_REGISTRY + + pconfig = PROVIDER_REGISTRY.get("zai") + if not pconfig: + pytest.skip("zai not in PROVIDER_REGISTRY") + + _write_config(config_home, { + "default": "old-model-from-previous-provider", + "provider": "ollama-cloud", + "api_key": "sk-stale", + }) + + monkeypatch.setenv("GLM_API_KEY", "test-key") + + from hermes_cli.main import _model_flow_api_key_provider + from hermes_cli.config import load_config + + with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \ + patch("hermes_cli.auth.deactivate_provider"), \ + patch("builtins.input", return_value=""): + _model_flow_api_key_provider(load_config(), "zai", "old-model-from-previous-provider") + + model = _read_model_config(config_home) + assert model.get("default") == "glm-5" + assert "api_key" not in model + + def test_no_api_key_no_error(self, config_home, monkeypatch): + """If config has no stale api_key, switching should still work fine.""" + from hermes_cli.auth import PROVIDER_REGISTRY + + pconfig = PROVIDER_REGISTRY.get("zai") + if not pconfig: + pytest.skip("zai not in PROVIDER_REGISTRY") + + _write_config(config_home, { + "default": "old-model", + "provider": "ollama-cloud", + }) + + monkeypatch.setenv("GLM_API_KEY", "test-key") + + from hermes_cli.main import _model_flow_api_key_provider + from hermes_cli.config import load_config + + with patch("hermes_cli.auth._prompt_model_selection", return_value="glm-5"), \ + patch("hermes_cli.auth.deactivate_provider"), \ + patch("builtins.input", return_value=""): + _model_flow_api_key_provider(load_config(), "zai", "old-model") + + model = _read_model_config(config_home) + assert model.get("provider") == "zai" + assert model.get("default") == "glm-5" + assert "api_key" not in model \ No newline at end of file