fix(models): pass model.base_url to fetch_models in /model picker

The /model interactive picker resolved a base_url from user credentials
but never passed it to ProviderProfile.fetch_models(), causing the
picker to always query the provider's hardcoded default endpoint
instead of the user's custom URL (e.g. a company litellm proxy).

- providers/base.py: add optional base_url parameter to fetch_models()
- hermes_cli/models.py: pass resolved base_url to fetch_models()
- Update all subclass overrides for signature compatibility
- Add 6 regression tests covering override, fallback, and integration
This commit is contained in:
liuhao1024 2026-06-16 18:33:21 +08:00 committed by Teknium
parent 9137b86a52
commit 1b962f001e
8 changed files with 155 additions and 7 deletions

View file

@ -2371,7 +2371,7 @@ def provider_model_ids(provider: Optional[str], *, force_refresh: bool = False)
if not base_url:
base_url = _p.base_url
if api_key:
live = _p.fetch_models(api_key=api_key)
live = _p.fetch_models(api_key=api_key, base_url=base_url or None)
if live:
# Merge static curated list with live API results so
# models that the live endpoint omits (stale cache,

View file

@ -17,6 +17,7 @@ class AnthropicProfile(ProviderProfile):
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Anthropic uses x-api-key header and anthropic-version."""

View file

@ -11,6 +11,7 @@ class BedrockProfile(ProviderProfile):
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Bedrock model listing requires AWS SDK, not a REST call."""

View file

@ -16,6 +16,7 @@ class CopilotACPProfile(ProviderProfile):
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Model listing is handled by the ACP subprocess."""

View file

@ -43,12 +43,13 @@ class CustomProfile(ProviderProfile):
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Custom/Ollama: base_url is user-configured; fetch if set."""
if not self.base_url:
if not (base_url or self.base_url):
return None
return super().fetch_models(api_key=api_key, timeout=timeout)
return super().fetch_models(api_key=api_key, base_url=base_url, timeout=timeout)
custom = CustomProfile(

View file

@ -51,6 +51,7 @@ class OpenRouterProfile(ProviderProfile):
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Fetch from public OpenRouter catalog — no auth required.
@ -64,7 +65,7 @@ class OpenRouterProfile(ProviderProfile):
if _CACHE is not None:
return _CACHE
try:
result = super().fetch_models(api_key=None, timeout=timeout)
result = super().fetch_models(api_key=None, base_url=base_url, timeout=timeout)
if result is not None:
_CACHE = result
return result

View file

@ -163,6 +163,7 @@ class ProviderProfile:
self,
*,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 8.0,
) -> list[str] | None:
"""Fetch the live model list from the provider's models endpoint.
@ -175,7 +176,8 @@ class ProviderProfile:
endpoint differs from the inference base URL, e.g. OpenRouter
exposes a public catalog at /api/v1/models while inference is
at /api/v1)
2. self.base_url + "/models" (standard OpenAI-compat fallback)
2. base_url (caller override user-configured model.base_url)
3. self.base_url + "/models" (standard OpenAI-compat fallback)
The default implementation sends Bearer auth when api_key is given
and forwards self.default_headers. Override to customise auth, path,
@ -184,11 +186,12 @@ class ProviderProfile:
Callers must always fall back to the static _PROVIDER_MODELS list
when this returns None.
"""
effective_base = base_url or self.base_url
url = (self.models_url or "").strip()
if not url:
if not self.base_url:
if not effective_base:
return None
url = self.base_url.rstrip("/") + "/models"
url = effective_base.rstrip("/") + "/models"
import json
import urllib.request

View file

@ -0,0 +1,140 @@
"""Tests for ProviderProfile.fetch_models base_url override (issue #47009)."""
import json
from http.server import HTTPServer, BaseHTTPRequestHandler
from threading import Thread
from unittest.mock import patch, MagicMock
from providers.base import ProviderProfile
class _FakeModelHandler(BaseHTTPRequestHandler):
"""Serves /models with a configurable model list."""
models = [{"id": "custom-model-1"}, {"id": "custom-model-2"}]
def do_GET(self):
if self.path.rstrip("/") == "/models":
body = json.dumps({"data": self.models}).encode()
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.end_headers()
self.wfile.write(body)
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
pass # suppress noise
def _start_server(models=None):
"""Start a local HTTP server returning given models. Returns (server, port)."""
if models is not None:
_FakeModelHandler.models = models
server = HTTPServer(("127.0.0.1", 0), _FakeModelHandler)
port = server.server_address[1]
thread = Thread(target=server.serve_forever, daemon=True)
thread.start()
return server, port
class TestFetchModelsBaseUrlOverride:
"""fetch_models() should use caller-provided base_url when given."""
def test_base_url_override_used(self):
"""When base_url is passed, it overrides self.base_url."""
server, port = _start_server([{"id": "proxy-model-a"}])
try:
profile = ProviderProfile(
name="test",
base_url="http://127.0.0.1:1", # wrong port — should not be used
)
result = profile.fetch_models(
api_key="test-key",
base_url=f"http://127.0.0.1:{port}",
)
assert result == ["proxy-model-a"]
finally:
server.shutdown()
def test_fallback_to_self_base_url(self):
"""When base_url is None, falls back to self.base_url."""
server, port = _start_server([{"id": "default-model"}])
try:
profile = ProviderProfile(
name="test",
base_url=f"http://127.0.0.1:{port}",
)
result = profile.fetch_models(api_key="test-key")
assert result == ["default-model"]
finally:
server.shutdown()
def test_no_base_url_returns_none(self):
"""When both base_url and self.base_url are empty, returns None."""
profile = ProviderProfile(name="test", base_url="")
result = profile.fetch_models(api_key="test-key", base_url="")
assert result is None
def test_base_url_override_with_models_url_set(self):
"""When self.models_url is set, base_url override is ignored (models_url wins)."""
server, port = _start_server([{"id": "from-models-url"}])
try:
profile = ProviderProfile(
name="test",
base_url="http://127.0.0.1:1",
models_url=f"http://127.0.0.1:{port}/models",
)
# base_url override should NOT be used because models_url takes priority
result = profile.fetch_models(
api_key="test-key",
base_url="http://127.0.0.1:1",
)
assert result == ["from-models-url"]
finally:
server.shutdown()
class TestCustomProviderBaseUrlPassthrough:
"""Custom provider (ollama/local) should pass base_url through to super."""
def test_custom_passes_base_url(self):
"""CustomProfile.fetch_models passes base_url to super()."""
server, port = _start_server([{"id": "ollama-model"}])
try:
from plugins.model_providers.custom import CustomProfile
profile = CustomProfile(
name="custom",
base_url="http://127.0.0.1:1", # wrong port
)
result = profile.fetch_models(
api_key="",
base_url=f"http://127.0.0.1:{port}",
)
assert result == ["ollama-model"]
finally:
server.shutdown()
class TestModelPickerBaseUrlIntegration:
"""The /model picker path should pass model.base_url to fetch_models."""
def test_picker_passes_base_url(self):
"""Verify models.py caller passes base_url to fetch_models."""
mock_profile = MagicMock()
mock_profile.auth_type = "api_key"
mock_profile.base_url = "https://default.api.com"
mock_profile.fetch_models.return_value = ["model-a"]
with (
patch("providers.get_provider_profile", return_value=mock_profile),
patch("hermes_cli.auth.resolve_api_key_provider_credentials",
return_value={"api_key": "sk-test", "base_url": "https://custom.proxy.com"}),
):
from hermes_cli.models import provider_model_ids
result = provider_model_ids("test-provider")
# Verify fetch_models was called with base_url
mock_profile.fetch_models.assert_called_once()
call_kwargs = mock_profile.fetch_models.call_args
assert call_kwargs.kwargs.get("base_url") == "https://custom.proxy.com"