diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 2712a01eab..46a7e2c5f9 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1618,6 +1618,10 @@ def _model_flow_custom(config): model_name = input("Model name (e.g. gpt-4, llama-3-70b): ").strip() context_length_str = input("Context length in tokens [leave blank for auto-detect]: ").strip() + + # Prompt for a display name — shown in the provider menu on future runs + default_name = _auto_provider_name(effective_url) + display_name = input(f"Display name [{default_name}]: ").strip() or default_name except (KeyboardInterrupt, EOFError): print("\nCancelled.") return @@ -1673,15 +1677,37 @@ def _model_flow_custom(config): print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.") # Auto-save to custom_providers so it appears in the menu next time - _save_custom_provider(effective_url, effective_key, model_name or "", context_length=context_length) + _save_custom_provider(effective_url, effective_key, model_name or "", + context_length=context_length, name=display_name) -def _save_custom_provider(base_url, api_key="", model="", context_length=None): +def _auto_provider_name(base_url: str) -> str: + """Generate a display name from a custom endpoint URL. + + Returns a human-friendly label like "Local (localhost:11434)" or + "RunPod (xyz.runpod.io)". Used as the default when prompting the + user for a display name during custom endpoint setup. + """ + import re + clean = base_url.replace("https://", "").replace("http://", "").rstrip("/") + clean = re.sub(r"/v1/?$", "", clean) + name = clean.split("/")[0] + if "localhost" in name or "127.0.0.1" in name: + name = f"Local ({name})" + elif "runpod" in name.lower(): + name = f"RunPod ({name})" + else: + name = name.capitalize() + return name + + +def _save_custom_provider(base_url, api_key="", model="", context_length=None, + name=None): """Save a custom endpoint to custom_providers in config.yaml. Deduplicates by base_url — if the URL already exists, updates the model name and context_length but doesn't add a duplicate entry. - Auto-generates a display name from the URL hostname. + Uses *name* when provided, otherwise auto-generates from the URL. """ from hermes_cli.config import load_config, save_config @@ -1709,20 +1735,9 @@ def _save_custom_provider(base_url, api_key="", model="", context_length=None): save_config(cfg) return # already saved, updated if needed - # Auto-generate a name from the URL - import re - clean = base_url.replace("https://", "").replace("http://", "").rstrip("/") - # Remove /v1 suffix for cleaner names - clean = re.sub(r"/v1/?$", "", clean) - # Use hostname:port as the name - name = clean.split("/")[0] - # Capitalize for readability - if "localhost" in name or "127.0.0.1" in name: - name = f"Local ({name})" - elif "runpod" in name.lower(): - name = f"RunPod ({name})" - else: - name = name.capitalize() + # Use provided name or auto-generate from URL + if not name: + name = _auto_provider_name(base_url) entry = {"name": name, "base_url": base_url} if api_key: diff --git a/tests/cli/test_cli_provider_resolution.py b/tests/cli/test_cli_provider_resolution.py index 353b3234eb..9c5bf0cca4 100644 --- a/tests/cli/test_cli_provider_resolution.py +++ b/tests/cli/test_cli_provider_resolution.py @@ -576,8 +576,9 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys): monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None) # After the probe detects a single model ("llm"), the flow asks - # "Use this model? [Y/n]:" — confirm with Enter, then context length. - answers = iter(["http://localhost:8000", "local-key", "", ""]) + # "Use this model? [Y/n]:" — confirm with Enter, then context length, + # then display name. + answers = iter(["http://localhost:8000", "local-key", "", "", ""]) monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) monkeypatch.setattr("getpass.getpass", lambda _prompt="": next(answers)) @@ -641,3 +642,46 @@ def test_cmd_model_forwards_nous_login_tls_options(monkeypatch): "ca_bundle": "/tmp/local-ca.pem", "insecure": True, } + + +# --------------------------------------------------------------------------- +# _auto_provider_name — unit tests +# --------------------------------------------------------------------------- + +def test_auto_provider_name_localhost(): + from hermes_cli.main import _auto_provider_name + assert _auto_provider_name("http://localhost:11434/v1") == "Local (localhost:11434)" + assert _auto_provider_name("http://127.0.0.1:1234/v1") == "Local (127.0.0.1:1234)" + + +def test_auto_provider_name_runpod(): + from hermes_cli.main import _auto_provider_name + assert "RunPod" in _auto_provider_name("https://xyz.runpod.io/v1") + + +def test_auto_provider_name_remote(): + from hermes_cli.main import _auto_provider_name + result = _auto_provider_name("https://api.together.xyz/v1") + assert result == "Api.together.xyz" + + +def test_save_custom_provider_uses_provided_name(monkeypatch, tmp_path): + """When a display name is passed, it should appear in the saved entry.""" + import yaml + from hermes_cli.main import _save_custom_provider + + cfg_path = tmp_path / "config.yaml" + cfg_path.write_text(yaml.dump({})) + + monkeypatch.setattr( + "hermes_cli.config.load_config", lambda: yaml.safe_load(cfg_path.read_text()) or {}, + ) + saved = {} + def _save(cfg): + saved.update(cfg) + monkeypatch.setattr("hermes_cli.config.save_config", _save) + + _save_custom_provider("http://localhost:11434/v1", name="Ollama") + entries = saved.get("custom_providers", []) + assert len(entries) == 1 + assert entries[0]["name"] == "Ollama"