From 1b962f001e7855a7cecf7e7db98ac08d71be0578 Mon Sep 17 00:00:00 2001 From: liuhao1024 Date: Tue, 16 Jun 2026 18:33:21 +0800 Subject: [PATCH] fix(models): pass model.base_url to fetch_models in /model picker The /model interactive picker resolved a base_url from user credentials but never passed it to ProviderProfile.fetch_models(), causing the picker to always query the provider's hardcoded default endpoint instead of the user's custom URL (e.g. a company litellm proxy). - providers/base.py: add optional base_url parameter to fetch_models() - hermes_cli/models.py: pass resolved base_url to fetch_models() - Update all subclass overrides for signature compatibility - Add 6 regression tests covering override, fallback, and integration --- hermes_cli/models.py | 2 +- plugins/model-providers/anthropic/__init__.py | 1 + plugins/model-providers/bedrock/__init__.py | 1 + .../model-providers/copilot-acp/__init__.py | 1 + plugins/model-providers/custom/__init__.py | 5 +- .../model-providers/openrouter/__init__.py | 3 +- providers/base.py | 9 +- tests/providers/test_fetch_models_base_url.py | 140 ++++++++++++++++++ 8 files changed, 155 insertions(+), 7 deletions(-) create mode 100644 tests/providers/test_fetch_models_base_url.py diff --git a/hermes_cli/models.py b/hermes_cli/models.py index a80a54e84f9..dc605ab001f 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -2371,7 +2371,7 @@ def provider_model_ids(provider: Optional[str], *, force_refresh: bool = False) if not base_url: base_url = _p.base_url if api_key: - live = _p.fetch_models(api_key=api_key) + live = _p.fetch_models(api_key=api_key, base_url=base_url or None) if live: # Merge static curated list with live API results so # models that the live endpoint omits (stale cache, diff --git a/plugins/model-providers/anthropic/__init__.py b/plugins/model-providers/anthropic/__init__.py index f1f45eb82c7..3acf6e02fa3 100644 --- a/plugins/model-providers/anthropic/__init__.py +++ b/plugins/model-providers/anthropic/__init__.py @@ -17,6 +17,7 @@ class AnthropicProfile(ProviderProfile): self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Anthropic uses x-api-key header and anthropic-version.""" diff --git a/plugins/model-providers/bedrock/__init__.py b/plugins/model-providers/bedrock/__init__.py index 6fdbbe834da..d0ee99c58cf 100644 --- a/plugins/model-providers/bedrock/__init__.py +++ b/plugins/model-providers/bedrock/__init__.py @@ -11,6 +11,7 @@ class BedrockProfile(ProviderProfile): self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Bedrock model listing requires AWS SDK, not a REST call.""" diff --git a/plugins/model-providers/copilot-acp/__init__.py b/plugins/model-providers/copilot-acp/__init__.py index 21ec7da2e99..6e452706c3d 100644 --- a/plugins/model-providers/copilot-acp/__init__.py +++ b/plugins/model-providers/copilot-acp/__init__.py @@ -16,6 +16,7 @@ class CopilotACPProfile(ProviderProfile): self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Model listing is handled by the ACP subprocess.""" diff --git a/plugins/model-providers/custom/__init__.py b/plugins/model-providers/custom/__init__.py index 6b7b13d5bdb..e893be95b27 100644 --- a/plugins/model-providers/custom/__init__.py +++ b/plugins/model-providers/custom/__init__.py @@ -43,12 +43,13 @@ class CustomProfile(ProviderProfile): self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Custom/Ollama: base_url is user-configured; fetch if set.""" - if not self.base_url: + if not (base_url or self.base_url): return None - return super().fetch_models(api_key=api_key, timeout=timeout) + return super().fetch_models(api_key=api_key, base_url=base_url, timeout=timeout) custom = CustomProfile( diff --git a/plugins/model-providers/openrouter/__init__.py b/plugins/model-providers/openrouter/__init__.py index 6566d604fc5..11c814591ec 100644 --- a/plugins/model-providers/openrouter/__init__.py +++ b/plugins/model-providers/openrouter/__init__.py @@ -51,6 +51,7 @@ class OpenRouterProfile(ProviderProfile): self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Fetch from public OpenRouter catalog — no auth required. @@ -64,7 +65,7 @@ class OpenRouterProfile(ProviderProfile): if _CACHE is not None: return _CACHE try: - result = super().fetch_models(api_key=None, timeout=timeout) + result = super().fetch_models(api_key=None, base_url=base_url, timeout=timeout) if result is not None: _CACHE = result return result diff --git a/providers/base.py b/providers/base.py index 07100a3b52a..4a045a6765d 100644 --- a/providers/base.py +++ b/providers/base.py @@ -163,6 +163,7 @@ class ProviderProfile: self, *, api_key: str | None = None, + base_url: str | None = None, timeout: float = 8.0, ) -> list[str] | None: """Fetch the live model list from the provider's models endpoint. @@ -175,7 +176,8 @@ class ProviderProfile: endpoint differs from the inference base URL, e.g. OpenRouter exposes a public catalog at /api/v1/models while inference is at /api/v1) - 2. self.base_url + "/models" (standard OpenAI-compat fallback) + 2. base_url (caller override — user-configured model.base_url) + 3. self.base_url + "/models" (standard OpenAI-compat fallback) The default implementation sends Bearer auth when api_key is given and forwards self.default_headers. Override to customise auth, path, @@ -184,11 +186,12 @@ class ProviderProfile: Callers must always fall back to the static _PROVIDER_MODELS list when this returns None. """ + effective_base = base_url or self.base_url url = (self.models_url or "").strip() if not url: - if not self.base_url: + if not effective_base: return None - url = self.base_url.rstrip("/") + "/models" + url = effective_base.rstrip("/") + "/models" import json import urllib.request diff --git a/tests/providers/test_fetch_models_base_url.py b/tests/providers/test_fetch_models_base_url.py new file mode 100644 index 00000000000..5db1f61d1cd --- /dev/null +++ b/tests/providers/test_fetch_models_base_url.py @@ -0,0 +1,140 @@ +"""Tests for ProviderProfile.fetch_models base_url override (issue #47009).""" + +import json +from http.server import HTTPServer, BaseHTTPRequestHandler +from threading import Thread +from unittest.mock import patch, MagicMock + +from providers.base import ProviderProfile + + +class _FakeModelHandler(BaseHTTPRequestHandler): + """Serves /models with a configurable model list.""" + + models = [{"id": "custom-model-1"}, {"id": "custom-model-2"}] + + def do_GET(self): + if self.path.rstrip("/") == "/models": + body = json.dumps({"data": self.models}).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(body) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + pass # suppress noise + + +def _start_server(models=None): + """Start a local HTTP server returning given models. Returns (server, port).""" + if models is not None: + _FakeModelHandler.models = models + server = HTTPServer(("127.0.0.1", 0), _FakeModelHandler) + port = server.server_address[1] + thread = Thread(target=server.serve_forever, daemon=True) + thread.start() + return server, port + + +class TestFetchModelsBaseUrlOverride: + """fetch_models() should use caller-provided base_url when given.""" + + def test_base_url_override_used(self): + """When base_url is passed, it overrides self.base_url.""" + server, port = _start_server([{"id": "proxy-model-a"}]) + try: + profile = ProviderProfile( + name="test", + base_url="http://127.0.0.1:1", # wrong port — should not be used + ) + result = profile.fetch_models( + api_key="test-key", + base_url=f"http://127.0.0.1:{port}", + ) + assert result == ["proxy-model-a"] + finally: + server.shutdown() + + def test_fallback_to_self_base_url(self): + """When base_url is None, falls back to self.base_url.""" + server, port = _start_server([{"id": "default-model"}]) + try: + profile = ProviderProfile( + name="test", + base_url=f"http://127.0.0.1:{port}", + ) + result = profile.fetch_models(api_key="test-key") + assert result == ["default-model"] + finally: + server.shutdown() + + def test_no_base_url_returns_none(self): + """When both base_url and self.base_url are empty, returns None.""" + profile = ProviderProfile(name="test", base_url="") + result = profile.fetch_models(api_key="test-key", base_url="") + assert result is None + + def test_base_url_override_with_models_url_set(self): + """When self.models_url is set, base_url override is ignored (models_url wins).""" + server, port = _start_server([{"id": "from-models-url"}]) + try: + profile = ProviderProfile( + name="test", + base_url="http://127.0.0.1:1", + models_url=f"http://127.0.0.1:{port}/models", + ) + # base_url override should NOT be used because models_url takes priority + result = profile.fetch_models( + api_key="test-key", + base_url="http://127.0.0.1:1", + ) + assert result == ["from-models-url"] + finally: + server.shutdown() + + +class TestCustomProviderBaseUrlPassthrough: + """Custom provider (ollama/local) should pass base_url through to super.""" + + def test_custom_passes_base_url(self): + """CustomProfile.fetch_models passes base_url to super().""" + server, port = _start_server([{"id": "ollama-model"}]) + try: + from plugins.model_providers.custom import CustomProfile + profile = CustomProfile( + name="custom", + base_url="http://127.0.0.1:1", # wrong port + ) + result = profile.fetch_models( + api_key="", + base_url=f"http://127.0.0.1:{port}", + ) + assert result == ["ollama-model"] + finally: + server.shutdown() + + +class TestModelPickerBaseUrlIntegration: + """The /model picker path should pass model.base_url to fetch_models.""" + + def test_picker_passes_base_url(self): + """Verify models.py caller passes base_url to fetch_models.""" + mock_profile = MagicMock() + mock_profile.auth_type = "api_key" + mock_profile.base_url = "https://default.api.com" + mock_profile.fetch_models.return_value = ["model-a"] + + with ( + patch("providers.get_provider_profile", return_value=mock_profile), + patch("hermes_cli.auth.resolve_api_key_provider_credentials", + return_value={"api_key": "sk-test", "base_url": "https://custom.proxy.com"}), + ): + from hermes_cli.models import provider_model_ids + result = provider_model_ids("test-provider") + # Verify fetch_models was called with base_url + mock_profile.fetch_models.assert_called_once() + call_kwargs = mock_profile.fetch_models.call_args + assert call_kwargs.kwargs.get("base_url") == "https://custom.proxy.com"