diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 3d63287393..2180eb0311 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1792,6 +1792,53 @@ def _reset_aux_to_auto() -> int: return count +def _set_builtin_provider_config( + cfg: dict, + provider: str, + base_url: str = "", + api_mode: str = "", +) -> dict: + """Set the main model's provider routing fields and clear stale credentials. + + Normalizes ``cfg["model"]`` to a dict (converting bare-string models), + sets ``provider`` / ``base_url`` / ``api_mode``, and **always** pops + ``api_key``. Built-in providers resolve their keys from env vars or + the credential pool at runtime — a leftover ``api_key`` from a prior + provider causes credential drift and 401 errors. + + Does NOT call ``save_config()`` — the caller is responsible for saving + after setting any additional fields (e.g. bedrock region, reasoning + effort). + + Returns the normalized model dict so callers can inspect or modify it + further before saving. + + Mirrors auth.py ``set_provider_in_config`` (line ~2764). (#14134) + """ + model = cfg.get("model") + if not isinstance(model, dict): + model = {"default": model} if model else {} + cfg["model"] = model + + model["provider"] = provider + + if base_url: + model["base_url"] = base_url + else: + model.pop("base_url", None) + + if api_mode: + model["api_mode"] = api_mode + else: + model.pop("api_mode", None) + + # Always clear stale api_key. Built-in providers never store their + # key in config.yaml — they use env vars / credential pool. + model.pop("api_key", None) + + return model + + def _aux_config_menu() -> None: """Top-level auxiliary-model picker — choose a task to configure. @@ -2088,17 +2135,10 @@ def _model_flow_openrouter(config, current_model=""): if selected: _save_model_choice(selected) - # Update config provider and deactivate any OAuth provider from hermes_cli.config import load_config, save_config cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "openrouter" - model["base_url"] = OPENROUTER_BASE_URL - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, "openrouter", OPENROUTER_BASE_URL, "chat_completions") save_config(cfg) deactivate_provider() print(f"Default model set to: {selected} (via OpenRouter)") @@ -2149,13 +2189,7 @@ def _model_flow_ai_gateway(config, current_model=""): from hermes_cli.config import load_config, save_config cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "ai-gateway" - model["base_url"] = AI_GATEWAY_BASE_URL - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, "ai-gateway", AI_GATEWAY_BASE_URL, "chat_completions") save_config(cfg) deactivate_provider() print(f"Default model set to: {selected} (via Vercel AI Gateway)") @@ -2318,6 +2352,7 @@ def _model_flow_nous(config, current_model="", args=None): model_cfg = {} model_cfg["provider"] = "nous" model_cfg["default"] = selected + model_cfg.pop("api_key", None) # built-in provider — key comes from credential pool if inference_url and inference_url.strip(): model_cfg["base_url"] = inference_url.rstrip("/") else: @@ -3310,16 +3345,11 @@ def _model_flow_copilot(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model["api_mode"] = copilot_model_api_mode( - selected, - catalog=catalog, - api_key=api_key, + _set_builtin_provider_config( + cfg, + provider_id, + effective_base, + copilot_model_api_mode(selected, catalog=catalog, api_key=api_key), ) if selected_effort is not None: _set_reasoning_effort(cfg, selected_effort) @@ -3436,13 +3466,7 @@ def _model_flow_copilot_acp(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model["api_mode"] = "chat_completions" + _set_builtin_provider_config(cfg, provider_id, effective_base, "chat_completions") save_config(cfg) deactivate_provider() @@ -3544,13 +3568,7 @@ def _model_flow_kimi(config, current_model=""): # Update config with provider and base URL cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model.pop("api_mode", None) # let runtime auto-detect from URL + _set_builtin_provider_config(cfg, provider_id, effective_base) # no api_mode — runtime auto-detects from URL save_config(cfg) deactivate_provider() @@ -3678,17 +3696,11 @@ def _model_flow_stepfun(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - model.pop("api_mode", None) + _set_builtin_provider_config(cfg, provider_id, effective_base) # no api_mode — runtime auto-detects save_config(cfg) deactivate_provider() - config["model"] = dict(model) + config["model"] = cfg["model"] print(f"Default model set to: {selected} (via {pconfig.name})") else: print("No change.") @@ -3754,13 +3766,7 @@ def _model_flow_bedrock_api_key(config, region, current_model=""): # Save as custom provider pointing to bedrock-mantle cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "custom" - model["base_url"] = mantle_base_url - model.pop("api_mode", None) # chat_completions is the default + _set_builtin_provider_config(cfg, "custom", mantle_base_url) # no api_mode — chat_completions is default # Also save region in bedrock config for reference bedrock_cfg = cfg.get("bedrock", {}) @@ -3943,13 +3949,11 @@ def _model_flow_bedrock(config, current_model=""): _save_model_choice(selected) cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "bedrock" - model["base_url"] = f"https://bedrock-runtime.{region}.amazonaws.com" - model.pop("api_mode", None) # bedrock_converse is auto-detected + _set_builtin_provider_config( + cfg, + "bedrock", + f"https://bedrock-runtime.{region}.amazonaws.com", + ) # no api_mode — bedrock_converse is auto-detected bedrock_cfg = cfg.get("bedrock", {}) if not isinstance(bedrock_cfg, dict): @@ -4181,21 +4185,12 @@ def _model_flow_api_key_provider(config, provider_id, current_model=""): # Update config with provider, base URL, and provider-specific API mode cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = provider_id - model["base_url"] = effective_base - if provider_id in {"opencode-zen", "opencode-go"}: - model["api_mode"] = opencode_model_api_mode(provider_id, selected) - else: - model.pop("api_mode", None) - # Clear stale api_key from a previous provider. Built-in providers - # get their keys from env vars / credential pool — a leftover key - # from a prior provider causes credential drift (401 errors). - # Mirrors auth.py set_provider_in_config (line ~2764). (#14134) - model.pop("api_key", None) + effective_api_mode = ( + opencode_model_api_mode(provider_id, selected) + if provider_id in {"opencode-zen", "opencode-go"} + else "" + ) + _set_builtin_provider_config(cfg, provider_id, effective_base, effective_api_mode) save_config(cfg) deactivate_provider() @@ -4422,12 +4417,7 @@ def _model_flow_anthropic(config, current_model=""): # Leaving a stale base_url in config can contaminate other # providers if the user switches without running 'hermes model'. cfg = load_config() - model = cfg.get("model") - if not isinstance(model, dict): - model = {"default": model} if model else {} - cfg["model"] = model - model["provider"] = "anthropic" - model.pop("base_url", None) + _set_builtin_provider_config(cfg, "anthropic") save_config(cfg) deactivate_provider() diff --git a/tests/hermes_cli/test_api_key_drift_provider_switch.py b/tests/hermes_cli/test_api_key_drift_provider_switch.py index a682d2b996..51a9d8eedf 100644 --- a/tests/hermes_cli/test_api_key_drift_provider_switch.py +++ b/tests/hermes_cli/test_api_key_drift_provider_switch.py @@ -1,14 +1,13 @@ -"""Tests that switching providers via _model_flow_api_key_provider -clears stale api_key from the model config dict. +"""Tests that _set_builtin_provider_config clears stale api_key. -Regression test for #14134: when switching from one API-key provider -to another, the old provider's api_key was left in model.api_key, -causing credential drift — the new provider would try to use the -old provider's key and fail with 401. +Regression test for #14134: when switching from one provider to another, +the old provider's api_key was left in model.api_key, causing credential +drift — the new provider would try to use the old provider's key and get +401 errors. -The sister function in auth.py (set_provider_in_config) correctly -pops both api_key and api_mode on provider switch. This test ensures -_model_flow_api_key_provider does the same. +The helper _set_builtin_provider_config is the single source of truth +for built-in provider config updates. All model-flow functions that +set provider/base_url/api_mode should route through it. """ import os @@ -53,12 +52,140 @@ def _read_model_config(home): return config.get("model", {}) +# ── Helper-level tests ────────────────────────────────────────────── + + +class TestSetBuiltinProviderConfig: + """Direct tests for _set_builtin_provider_config.""" + + def test_api_key_cleared(self, config_home): + """api_key from a previous provider must be popped.""" + _write_config(config_home, { + "default": "old-model", + "provider": "custom", + "api_key": "sk-stale-key", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + model = cfg["model"] + assert model["provider"] == "openrouter" + assert "api_key" not in model, f"api_key should be popped, found: {model.get('api_key')}" + + def test_base_url_set_when_provided(self, config_home): + """base_url is set when explicitly provided.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + assert cfg["model"]["base_url"] == "https://openrouter.ai/v1" + + def test_base_url_cleared_when_empty(self, config_home): + """base_url is popped when empty string is passed (e.g. anthropic).""" + _write_config(config_home, { + "default": "m", + "provider": "custom", + "base_url": "https://stale.example.com/v1", + "api_key": "sk-old", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "anthropic") + + assert "base_url" not in cfg["model"], "base_url should be cleared for anthropic" + + def test_api_mode_set_when_provided(self, config_home): + """api_mode is set when explicitly provided.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "openrouter", "https://openrouter.ai/v1", "chat_completions") + + assert cfg["model"]["api_mode"] == "chat_completions" + + def test_api_mode_cleared_when_empty(self, config_home): + """api_mode is popped when empty string is passed (e.g. kimi, stepfun).""" + _write_config(config_home, { + "default": "m", + "provider": "custom", + "api_mode": "anthropic_messages", + }) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "stepfun", "https://api.stepfun.com/v1") + + assert "api_mode" not in cfg["model"], "api_mode should be cleared when empty" + + def test_string_model_normalized_to_dict(self, config_home): + """A bare-string model value is normalized to a dict.""" + _write_config(config_home, "old-model-string") + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "ai-gateway", "https://gateway.ai/v1", "chat_completions") + + model = cfg["model"] + assert isinstance(model, dict) + assert model.get("provider") == "ai-gateway" + + def test_no_api_key_no_error(self, config_home): + """Pop on a config without api_key should not raise.""" + _write_config(config_home, {"default": "m", "provider": "old"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + _set_builtin_provider_config(cfg, "nous", "https://api.nous.com/v1") + + assert cfg["model"]["provider"] == "nous" + assert "api_key" not in cfg["model"] + + def test_returns_model_dict(self, config_home): + """The helper returns the normalized model dict for further modification.""" + _write_config(config_home, {"default": "m"}) + + from hermes_cli.config import load_config + from hermes_cli.main import _set_builtin_provider_config + + cfg = load_config() + result = _set_builtin_provider_config(cfg, "bedrock", "https://bedrock.us-east-1.amazonaws.com") + + assert isinstance(result, dict) + assert result["provider"] == "bedrock" + # Can still modify before saving + result["bedrock_region"] = "us-east-1" + assert cfg["model"]["bedrock_region"] == "us-east-1" + + +# ── Integration: api_key_provider flow ────────────────────────────── + + class TestApiKeyDriftOnProviderSwitch: """Switching from one api-key provider to another must clear the stale api_key from the model config dict.""" def test_api_key_cleared_on_provider_switch(self, config_home, monkeypatch): - """Start with model.api_key = 'sk-old-key' from provider A, + """Start with model.api_key from provider A, switch to provider B — api_key must be popped.""" from hermes_cli.auth import PROVIDER_REGISTRY @@ -66,12 +193,11 @@ class TestApiKeyDriftOnProviderSwitch: if not pconfig: pytest.skip("zai not in PROVIDER_REGISTRY") - # Start with a config from a *previous* provider that had an api_key _write_config(config_home, { "default": "some-old-model", "provider": "ollama-cloud", "base_url": "https://api.ola.cloud/v1", - "api_key": "sk-old-provider-key-12345", + "api_key": "sk-stale", }) monkeypatch.setenv("GLM_API_KEY", "test-key") @@ -86,11 +212,9 @@ class TestApiKeyDriftOnProviderSwitch: model = _read_model_config(config_home) assert isinstance(model, dict) - assert model.get("provider") == "zai", ( - f"provider should be 'zai', got {model.get('provider')}" - ) + assert model.get("provider") == "zai" assert "api_key" not in model, ( - f"api_key should be cleared on provider switch, but found: {model.get('api_key')}" + f"api_key should be cleared on provider switch, found: {model.get('api_key')}" ) def test_api_mode_also_cleared_on_non_opencode_switch(self, config_home, monkeypatch): @@ -102,12 +226,11 @@ class TestApiKeyDriftOnProviderSwitch: if not pconfig: pytest.skip("zai not in PROVIDER_REGISTRY") - # Start with custom-provider config that had api_mode and api_key _write_config(config_home, { "default": "custom-model", "provider": "custom", "base_url": "https://custom.old/v1", - "api_key": "sk-stale-custom-key", + "api_key": "sk-stale", "api_mode": "anthropic_messages", }) @@ -123,9 +246,7 @@ class TestApiKeyDriftOnProviderSwitch: model = _read_model_config(config_home) assert isinstance(model, dict) - assert "api_key" not in model, ( - f"api_key should be cleared, got: {model.get('api_key')}" - ) + assert "api_key" not in model assert "api_mode" not in model, ( f"api_mode should be cleared for non-opencode, got: {model.get('api_mode')}" ) @@ -142,7 +263,7 @@ class TestApiKeyDriftOnProviderSwitch: _write_config(config_home, { "default": "old-model-from-previous-provider", "provider": "ollama-cloud", - "api_key": "sk-orphaned-key", + "api_key": "sk-stale", }) monkeypatch.setenv("GLM_API_KEY", "test-key") @@ -156,9 +277,7 @@ class TestApiKeyDriftOnProviderSwitch: _model_flow_api_key_provider(load_config(), "zai", "old-model-from-previous-provider") model = _read_model_config(config_home) - assert model.get("default") == "glm-5", ( - f"model.default should be 'glm-5', got {model.get('default')}" - ) + assert model.get("default") == "glm-5" assert "api_key" not in model def test_no_api_key_no_error(self, config_home, monkeypatch): @@ -169,7 +288,6 @@ class TestApiKeyDriftOnProviderSwitch: if not pconfig: pytest.skip("zai not in PROVIDER_REGISTRY") - # Clean config, no api_key _write_config(config_home, { "default": "old-model", "provider": "ollama-cloud",