diff --git a/hermes_cli/main.py b/hermes_cli/main.py index e8aa10bf1..b835efb0f 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1112,8 +1112,32 @@ def _model_flow_custom(config): effective_key = api_key or current_key + from hermes_cli.models import probe_api_models + + probe = probe_api_models(effective_key, effective_url) + if probe.get("used_fallback") and probe.get("resolved_base_url"): + print( + f"Warning: endpoint verification worked at {probe['resolved_base_url']}/models, " + f"not the exact URL you entered. Saving the working base URL instead." + ) + effective_url = probe["resolved_base_url"] + if base_url: + base_url = effective_url + elif probe.get("models") is not None: + print( + f"Verified endpoint via {probe.get('probed_url')} " + f"({len(probe.get('models') or [])} model(s) visible)" + ) + else: + print( + f"Warning: could not verify this endpoint via {probe.get('probed_url')}. " + f"Hermes will still save it." + ) + if probe.get("suggested_base_url"): + print(f" If this server expects /v1, try base URL: {probe['suggested_base_url']}") + if base_url: - save_env_value("OPENAI_BASE_URL", base_url) + save_env_value("OPENAI_BASE_URL", effective_url) if api_key: save_env_value("OPENAI_API_KEY", api_key) diff --git a/hermes_cli/models.py b/hermes_cli/models.py index 7b5826f72..c4a95a021 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -308,6 +308,62 @@ def _fetch_anthropic_models(timeout: float = 5.0) -> Optional[list[str]]: return None +def probe_api_models( + api_key: Optional[str], + base_url: Optional[str], + timeout: float = 5.0, +) -> dict[str, Any]: + """Probe an OpenAI-compatible ``/models`` endpoint with light URL heuristics.""" + normalized = (base_url or "").strip().rstrip("/") + if not normalized: + return { + "models": None, + "probed_url": None, + "resolved_base_url": "", + "suggested_base_url": None, + "used_fallback": False, + } + + if normalized.endswith("/v1"): + alternate_base = normalized[:-3].rstrip("/") + else: + alternate_base = normalized + "/v1" + + candidates: list[tuple[str, bool]] = [(normalized, False)] + if alternate_base and alternate_base != normalized: + candidates.append((alternate_base, True)) + + tried: list[str] = [] + headers: dict[str, str] = {} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + for candidate_base, is_fallback in candidates: + url = candidate_base.rstrip("/") + "/models" + tried.append(url) + req = urllib.request.Request(url, headers=headers) + try: + with urllib.request.urlopen(req, timeout=timeout) as resp: + data = json.loads(resp.read().decode()) + return { + "models": [m.get("id", "") for m in data.get("data", [])], + "probed_url": url, + "resolved_base_url": candidate_base.rstrip("/"), + "suggested_base_url": alternate_base if alternate_base != candidate_base else normalized, + "used_fallback": is_fallback, + } + except Exception: + continue + + return { + "models": None, + "probed_url": tried[-1] if tried else normalized.rstrip("/") + "/models", + "resolved_base_url": normalized, + "suggested_base_url": alternate_base if alternate_base != normalized else None, + "used_fallback": False, + } + + def fetch_api_models( api_key: Optional[str], base_url: Optional[str], @@ -318,22 +374,7 @@ def fetch_api_models( Returns a list of model ID strings, or ``None`` if the endpoint could not be reached (network error, timeout, auth failure, etc.). """ - if not base_url: - return None - - url = base_url.rstrip("/") + "/models" - headers: dict[str, str] = {} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - req = urllib.request.Request(url, headers=headers) - try: - with urllib.request.urlopen(req, timeout=timeout) as resp: - data = json.loads(resp.read().decode()) - # Standard OpenAI format: {"data": [{"id": "model-name", ...}, ...]} - return [m.get("id", "") for m in data.get("data", [])] - except Exception: - return None + return probe_api_models(api_key, base_url, timeout=timeout).get("models") def validate_requested_model( @@ -376,13 +417,53 @@ def validate_requested_model( "message": "Model names cannot contain spaces.", } - # Custom endpoints can serve any model — skip validation if normalized == "custom": + probe = probe_api_models(api_key, base_url) + api_models = probe.get("models") + if api_models is not None: + if requested in set(api_models): + return { + "accepted": True, + "persist": True, + "recognized": True, + "message": None, + } + + suggestions = get_close_matches(requested, api_models, n=3, cutoff=0.5) + suggestion_text = "" + if suggestions: + suggestion_text = "\n Similar models: " + ", ".join(f"`{s}`" for s in suggestions) + + message = ( + f"Note: `{requested}` was not found in this custom endpoint's model listing " + f"({probe.get('probed_url')}). It may still work if the server supports hidden or aliased models." + f"{suggestion_text}" + ) + if probe.get("used_fallback"): + message += ( + f"\n Endpoint verification succeeded after trying `{probe.get('resolved_base_url')}`. " + f"Consider saving that as your base URL." + ) + + return { + "accepted": True, + "persist": True, + "recognized": False, + "message": message, + } + + message = ( + f"Note: could not reach this custom endpoint's model listing at `{probe.get('probed_url')}`. " + f"Hermes will still save `{requested}`, but the endpoint should expose `/models` for verification." + ) + if probe.get("suggested_base_url"): + message += f"\n If this server expects `/v1`, try base URL: `{probe.get('suggested_base_url')}`" + return { "accepted": True, "persist": True, "recognized": False, - "message": None, + "message": message, } # Probe the live API to check if the model actually exists diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index 4c795438f..7e077d95f 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -933,11 +933,35 @@ def setup_model_provider(config: dict): base_url = prompt( " API base URL (e.g., https://api.example.com/v1)", current_url - ) + ).strip() api_key = prompt(" API key", password=True) model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model) if base_url: + from hermes_cli.models import probe_api_models + + probe = probe_api_models(api_key, base_url) + if probe.get("used_fallback") and probe.get("resolved_base_url"): + print_warning( + f"Endpoint verification worked at {probe['resolved_base_url']}/models, " + f"not the exact URL you entered. Saving the working base URL instead." + ) + base_url = probe["resolved_base_url"] + elif probe.get("models") is not None: + print_success( + f"Verified endpoint via {probe.get('probed_url')} " + f"({len(probe.get('models') or [])} model(s) visible)" + ) + else: + print_warning( + f"Could not verify this endpoint via {probe.get('probed_url')}. " + f"Hermes will still save it." + ) + if probe.get("suggested_base_url"): + print_info( + f" If this server expects /v1, try base URL: {probe['suggested_base_url']}" + ) + save_env_value("OPENAI_BASE_URL", base_url) if api_key: save_env_value("OPENAI_API_KEY", api_key) diff --git a/tests/hermes_cli/test_model_validation.py b/tests/hermes_cli/test_model_validation.py index f8ce868e2..59574c743 100644 --- a/tests/hermes_cli/test_model_validation.py +++ b/tests/hermes_cli/test_model_validation.py @@ -7,6 +7,7 @@ from hermes_cli.models import ( fetch_api_models, normalize_provider, parse_model_input, + probe_api_models, provider_label, provider_model_ids, validate_requested_model, @@ -26,7 +27,15 @@ FAKE_API_MODELS = [ def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw): """Shortcut: call validate_requested_model with mocked API.""" - with patch("hermes_cli.models.fetch_api_models", return_value=api_models): + probe_payload = { + "models": api_models, + "probed_url": "http://localhost:11434/v1/models", + "resolved_base_url": kw.get("base_url", "") or "http://localhost:11434/v1", + "suggested_base_url": None, + "used_fallback": False, + } + with patch("hermes_cli.models.fetch_api_models", return_value=api_models), \ + patch("hermes_cli.models.probe_api_models", return_value=probe_payload): return validate_requested_model(model, provider, **kw) @@ -147,6 +156,33 @@ class TestFetchApiModels: with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")): assert fetch_api_models("key", "https://example.com/v1") is None + def test_probe_api_models_tries_v1_fallback(self): + class _Resp: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def read(self): + return b'{"data": [{"id": "local-model"}]}' + + calls = [] + + def _fake_urlopen(req, timeout=5.0): + calls.append(req.full_url) + if req.full_url.endswith("/v1/models"): + return _Resp() + raise Exception("404") + + with patch("hermes_cli.models.urllib.request.urlopen", side_effect=_fake_urlopen): + probe = probe_api_models("key", "http://localhost:8000") + + assert calls == ["http://localhost:8000/models", "http://localhost:8000/v1/models"] + assert probe["models"] == ["local-model"] + assert probe["resolved_base_url"] == "http://localhost:8000/v1" + assert probe["used_fallback"] is True + # -- validate — format checks ----------------------------------------------- @@ -191,6 +227,7 @@ class TestValidateApiFound: ) assert result["accepted"] is True assert result["persist"] is True + assert result["recognized"] is True # -- validate — API not found ------------------------------------------------ @@ -232,3 +269,26 @@ class TestValidateApiFallback: result = _validate("some-model", provider="totally-unknown", api_models=None) assert result["accepted"] is True assert result["persist"] is True + + def test_custom_endpoint_warns_with_probed_url_and_v1_hint(self): + with patch( + "hermes_cli.models.probe_api_models", + return_value={ + "models": None, + "probed_url": "http://localhost:8000/v1/models", + "resolved_base_url": "http://localhost:8000", + "suggested_base_url": "http://localhost:8000/v1", + "used_fallback": False, + }, + ): + result = validate_requested_model( + "qwen3", + "custom", + api_key="local-key", + base_url="http://localhost:8000", + ) + + assert result["accepted"] is True + assert result["persist"] is True + assert "http://localhost:8000/v1/models" in result["message"] + assert "http://localhost:8000/v1" in result["message"] diff --git a/tests/hermes_cli/test_setup_model_provider.py b/tests/hermes_cli/test_setup_model_provider.py index 34b491066..daf0ce680 100644 --- a/tests/hermes_cli/test_setup_model_provider.py +++ b/tests/hermes_cli/test_setup_model_provider.py @@ -75,6 +75,58 @@ def test_setup_keep_current_custom_from_config_does_not_fall_through(tmp_path, m assert calls["count"] == 1 +def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + _clear_provider_env(monkeypatch) + + config = load_config() + + def fake_prompt_choice(question, choices, default=0): + if question == "Select your inference provider:": + return 3 # Custom endpoint + if question == "Configure vision:": + return len(choices) - 1 # Skip + raise AssertionError(f"Unexpected prompt_choice call: {question}") + + def fake_prompt(message, current=None, **kwargs): + if "API base URL" in message: + return "http://localhost:8000" + if "API key" in message: + return "local-key" + if "Model name" in message: + return "llm" + return "" + + monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) + monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt) + monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) + monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None) + monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) + monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: []) + monkeypatch.setattr( + "hermes_cli.models.probe_api_models", + lambda api_key, base_url: { + "models": ["llm"], + "probed_url": "http://localhost:8000/v1/models", + "resolved_base_url": "http://localhost:8000/v1", + "suggested_base_url": "http://localhost:8000/v1", + "used_fallback": True, + }, + ) + + setup_model_provider(config) + save_config(config) + + env = _read_env(tmp_path) + reloaded = load_config() + + assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1" + assert env.get("OPENAI_API_KEY") == "local-key" + assert reloaded["model"]["provider"] == "custom" + assert reloaded["model"]["base_url"] == "http://localhost:8000/v1" + assert reloaded["model"]["default"] == "llm" + + def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch): """Keep-current should respect config-backed providers, not fall back to OpenRouter.""" monkeypatch.setenv("HERMES_HOME", str(tmp_path)) diff --git a/tests/test_cli_provider_resolution.py b/tests/test_cli_provider_resolution.py index ffc5752ff..3144bed80 100644 --- a/tests/test_cli_provider_resolution.py +++ b/tests/test_cli_provider_resolution.py @@ -336,4 +336,42 @@ def test_cmd_model_falls_back_to_auto_on_invalid_provider(monkeypatch, capsys): assert "Warning:" in output assert "falling back to auto provider detection" in output.lower() - assert "No change." in output \ No newline at end of file + assert "No change." in output + + +def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys): + monkeypatch.setattr( + "hermes_cli.config.get_env_value", + lambda key: "" if key in {"OPENAI_BASE_URL", "OPENAI_API_KEY"} else "", + ) + saved_env = {} + monkeypatch.setattr("hermes_cli.config.save_env_value", lambda key, value: saved_env.__setitem__(key, value)) + monkeypatch.setattr("hermes_cli.auth._save_model_choice", lambda model: saved_env.__setitem__("MODEL", model)) + monkeypatch.setattr("hermes_cli.auth.deactivate_provider", lambda: None) + monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None) + monkeypatch.setattr( + "hermes_cli.models.probe_api_models", + lambda api_key, base_url: { + "models": ["llm"], + "probed_url": "http://localhost:8000/v1/models", + "resolved_base_url": "http://localhost:8000/v1", + "suggested_base_url": "http://localhost:8000/v1", + "used_fallback": True, + }, + ) + monkeypatch.setattr( + "hermes_cli.config.load_config", + lambda: {"model": {"default": "", "provider": "custom", "base_url": ""}}, + ) + monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None) + + answers = iter(["http://localhost:8000", "local-key", "llm"]) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) + + hermes_main._model_flow_custom({}) + output = capsys.readouterr().out + + assert "Saving the working base URL instead" in output + assert saved_env["OPENAI_BASE_URL"] == "http://localhost:8000/v1" + assert saved_env["OPENAI_API_KEY"] == "local-key" + assert saved_env["MODEL"] == "llm" \ No newline at end of file