mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-17 09:41:58 +00:00
fix(models): pass model.base_url to fetch_models in /model picker
The /model interactive picker resolved a base_url from user credentials but never passed it to ProviderProfile.fetch_models(), causing the picker to always query the provider's hardcoded default endpoint instead of the user's custom URL (e.g. a company litellm proxy). - providers/base.py: add optional base_url parameter to fetch_models() - hermes_cli/models.py: pass resolved base_url to fetch_models() - Update all subclass overrides for signature compatibility - Add 6 regression tests covering override, fallback, and integration
This commit is contained in:
parent
9137b86a52
commit
1b962f001e
8 changed files with 155 additions and 7 deletions
|
|
@ -2371,7 +2371,7 @@ def provider_model_ids(provider: Optional[str], *, force_refresh: bool = False)
|
|||
if not base_url:
|
||||
base_url = _p.base_url
|
||||
if api_key:
|
||||
live = _p.fetch_models(api_key=api_key)
|
||||
live = _p.fetch_models(api_key=api_key, base_url=base_url or None)
|
||||
if live:
|
||||
# Merge static curated list with live API results so
|
||||
# models that the live endpoint omits (stale cache,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ class AnthropicProfile(ProviderProfile):
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Anthropic uses x-api-key header and anthropic-version."""
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ class BedrockProfile(ProviderProfile):
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Bedrock model listing requires AWS SDK, not a REST call."""
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ class CopilotACPProfile(ProviderProfile):
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Model listing is handled by the ACP subprocess."""
|
||||
|
|
|
|||
|
|
@ -43,12 +43,13 @@ class CustomProfile(ProviderProfile):
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Custom/Ollama: base_url is user-configured; fetch if set."""
|
||||
if not self.base_url:
|
||||
if not (base_url or self.base_url):
|
||||
return None
|
||||
return super().fetch_models(api_key=api_key, timeout=timeout)
|
||||
return super().fetch_models(api_key=api_key, base_url=base_url, timeout=timeout)
|
||||
|
||||
|
||||
custom = CustomProfile(
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class OpenRouterProfile(ProviderProfile):
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Fetch from public OpenRouter catalog — no auth required.
|
||||
|
|
@ -64,7 +65,7 @@ class OpenRouterProfile(ProviderProfile):
|
|||
if _CACHE is not None:
|
||||
return _CACHE
|
||||
try:
|
||||
result = super().fetch_models(api_key=None, timeout=timeout)
|
||||
result = super().fetch_models(api_key=None, base_url=base_url, timeout=timeout)
|
||||
if result is not None:
|
||||
_CACHE = result
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ class ProviderProfile:
|
|||
self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float = 8.0,
|
||||
) -> list[str] | None:
|
||||
"""Fetch the live model list from the provider's models endpoint.
|
||||
|
|
@ -175,7 +176,8 @@ class ProviderProfile:
|
|||
endpoint differs from the inference base URL, e.g. OpenRouter
|
||||
exposes a public catalog at /api/v1/models while inference is
|
||||
at /api/v1)
|
||||
2. self.base_url + "/models" (standard OpenAI-compat fallback)
|
||||
2. base_url (caller override — user-configured model.base_url)
|
||||
3. self.base_url + "/models" (standard OpenAI-compat fallback)
|
||||
|
||||
The default implementation sends Bearer auth when api_key is given
|
||||
and forwards self.default_headers. Override to customise auth, path,
|
||||
|
|
@ -184,11 +186,12 @@ class ProviderProfile:
|
|||
Callers must always fall back to the static _PROVIDER_MODELS list
|
||||
when this returns None.
|
||||
"""
|
||||
effective_base = base_url or self.base_url
|
||||
url = (self.models_url or "").strip()
|
||||
if not url:
|
||||
if not self.base_url:
|
||||
if not effective_base:
|
||||
return None
|
||||
url = self.base_url.rstrip("/") + "/models"
|
||||
url = effective_base.rstrip("/") + "/models"
|
||||
|
||||
import json
|
||||
import urllib.request
|
||||
|
|
|
|||
140
tests/providers/test_fetch_models_base_url.py
Normal file
140
tests/providers/test_fetch_models_base_url.py
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
"""Tests for ProviderProfile.fetch_models base_url override (issue #47009)."""
|
||||
|
||||
import json
|
||||
from http.server import HTTPServer, BaseHTTPRequestHandler
|
||||
from threading import Thread
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from providers.base import ProviderProfile
|
||||
|
||||
|
||||
class _FakeModelHandler(BaseHTTPRequestHandler):
|
||||
"""Serves /models with a configurable model list."""
|
||||
|
||||
models = [{"id": "custom-model-1"}, {"id": "custom-model-2"}]
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.rstrip("/") == "/models":
|
||||
body = json.dumps({"data": self.models}).encode()
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json")
|
||||
self.end_headers()
|
||||
self.wfile.write(body)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
pass # suppress noise
|
||||
|
||||
|
||||
def _start_server(models=None):
|
||||
"""Start a local HTTP server returning given models. Returns (server, port)."""
|
||||
if models is not None:
|
||||
_FakeModelHandler.models = models
|
||||
server = HTTPServer(("127.0.0.1", 0), _FakeModelHandler)
|
||||
port = server.server_address[1]
|
||||
thread = Thread(target=server.serve_forever, daemon=True)
|
||||
thread.start()
|
||||
return server, port
|
||||
|
||||
|
||||
class TestFetchModelsBaseUrlOverride:
|
||||
"""fetch_models() should use caller-provided base_url when given."""
|
||||
|
||||
def test_base_url_override_used(self):
|
||||
"""When base_url is passed, it overrides self.base_url."""
|
||||
server, port = _start_server([{"id": "proxy-model-a"}])
|
||||
try:
|
||||
profile = ProviderProfile(
|
||||
name="test",
|
||||
base_url="http://127.0.0.1:1", # wrong port — should not be used
|
||||
)
|
||||
result = profile.fetch_models(
|
||||
api_key="test-key",
|
||||
base_url=f"http://127.0.0.1:{port}",
|
||||
)
|
||||
assert result == ["proxy-model-a"]
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
def test_fallback_to_self_base_url(self):
|
||||
"""When base_url is None, falls back to self.base_url."""
|
||||
server, port = _start_server([{"id": "default-model"}])
|
||||
try:
|
||||
profile = ProviderProfile(
|
||||
name="test",
|
||||
base_url=f"http://127.0.0.1:{port}",
|
||||
)
|
||||
result = profile.fetch_models(api_key="test-key")
|
||||
assert result == ["default-model"]
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
def test_no_base_url_returns_none(self):
|
||||
"""When both base_url and self.base_url are empty, returns None."""
|
||||
profile = ProviderProfile(name="test", base_url="")
|
||||
result = profile.fetch_models(api_key="test-key", base_url="")
|
||||
assert result is None
|
||||
|
||||
def test_base_url_override_with_models_url_set(self):
|
||||
"""When self.models_url is set, base_url override is ignored (models_url wins)."""
|
||||
server, port = _start_server([{"id": "from-models-url"}])
|
||||
try:
|
||||
profile = ProviderProfile(
|
||||
name="test",
|
||||
base_url="http://127.0.0.1:1",
|
||||
models_url=f"http://127.0.0.1:{port}/models",
|
||||
)
|
||||
# base_url override should NOT be used because models_url takes priority
|
||||
result = profile.fetch_models(
|
||||
api_key="test-key",
|
||||
base_url="http://127.0.0.1:1",
|
||||
)
|
||||
assert result == ["from-models-url"]
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
class TestCustomProviderBaseUrlPassthrough:
|
||||
"""Custom provider (ollama/local) should pass base_url through to super."""
|
||||
|
||||
def test_custom_passes_base_url(self):
|
||||
"""CustomProfile.fetch_models passes base_url to super()."""
|
||||
server, port = _start_server([{"id": "ollama-model"}])
|
||||
try:
|
||||
from plugins.model_providers.custom import CustomProfile
|
||||
profile = CustomProfile(
|
||||
name="custom",
|
||||
base_url="http://127.0.0.1:1", # wrong port
|
||||
)
|
||||
result = profile.fetch_models(
|
||||
api_key="",
|
||||
base_url=f"http://127.0.0.1:{port}",
|
||||
)
|
||||
assert result == ["ollama-model"]
|
||||
finally:
|
||||
server.shutdown()
|
||||
|
||||
|
||||
class TestModelPickerBaseUrlIntegration:
|
||||
"""The /model picker path should pass model.base_url to fetch_models."""
|
||||
|
||||
def test_picker_passes_base_url(self):
|
||||
"""Verify models.py caller passes base_url to fetch_models."""
|
||||
mock_profile = MagicMock()
|
||||
mock_profile.auth_type = "api_key"
|
||||
mock_profile.base_url = "https://default.api.com"
|
||||
mock_profile.fetch_models.return_value = ["model-a"]
|
||||
|
||||
with (
|
||||
patch("providers.get_provider_profile", return_value=mock_profile),
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials",
|
||||
return_value={"api_key": "sk-test", "base_url": "https://custom.proxy.com"}),
|
||||
):
|
||||
from hermes_cli.models import provider_model_ids
|
||||
result = provider_model_ids("test-provider")
|
||||
# Verify fetch_models was called with base_url
|
||||
mock_profile.fetch_models.assert_called_once()
|
||||
call_kwargs = mock_profile.fetch_models.call_args
|
||||
assert call_kwargs.kwargs.get("base_url") == "https://custom.proxy.com"
|
||||
Loading…
Add table
Add a link
Reference in a new issue