mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-18 09:51:59 +00:00
The /model interactive picker resolved a base_url from user credentials but never passed it to ProviderProfile.fetch_models(), causing the picker to always query the provider's hardcoded default endpoint instead of the user's custom URL (e.g. a company litellm proxy). - providers/base.py: add optional base_url parameter to fetch_models() - hermes_cli/models.py: pass resolved base_url to fetch_models() - Update all subclass overrides for signature compatibility - Add 6 regression tests covering override, fallback, and integration
140 lines
5.2 KiB
Python
140 lines
5.2 KiB
Python
"""Tests for ProviderProfile.fetch_models base_url override (issue #47009)."""
|
|
|
|
import json
|
|
from http.server import HTTPServer, BaseHTTPRequestHandler
|
|
from threading import Thread
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from providers.base import ProviderProfile
|
|
|
|
|
|
class _FakeModelHandler(BaseHTTPRequestHandler):
|
|
"""Serves /models with a configurable model list."""
|
|
|
|
models = [{"id": "custom-model-1"}, {"id": "custom-model-2"}]
|
|
|
|
def do_GET(self):
|
|
if self.path.rstrip("/") == "/models":
|
|
body = json.dumps({"data": self.models}).encode()
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/json")
|
|
self.end_headers()
|
|
self.wfile.write(body)
|
|
else:
|
|
self.send_response(404)
|
|
self.end_headers()
|
|
|
|
def log_message(self, format, *args):
|
|
pass # suppress noise
|
|
|
|
|
|
def _start_server(models=None):
|
|
"""Start a local HTTP server returning given models. Returns (server, port)."""
|
|
if models is not None:
|
|
_FakeModelHandler.models = models
|
|
server = HTTPServer(("127.0.0.1", 0), _FakeModelHandler)
|
|
port = server.server_address[1]
|
|
thread = Thread(target=server.serve_forever, daemon=True)
|
|
thread.start()
|
|
return server, port
|
|
|
|
|
|
class TestFetchModelsBaseUrlOverride:
|
|
"""fetch_models() should use caller-provided base_url when given."""
|
|
|
|
def test_base_url_override_used(self):
|
|
"""When base_url is passed, it overrides self.base_url."""
|
|
server, port = _start_server([{"id": "proxy-model-a"}])
|
|
try:
|
|
profile = ProviderProfile(
|
|
name="test",
|
|
base_url="http://127.0.0.1:1", # wrong port — should not be used
|
|
)
|
|
result = profile.fetch_models(
|
|
api_key="test-key",
|
|
base_url=f"http://127.0.0.1:{port}",
|
|
)
|
|
assert result == ["proxy-model-a"]
|
|
finally:
|
|
server.shutdown()
|
|
|
|
def test_fallback_to_self_base_url(self):
|
|
"""When base_url is None, falls back to self.base_url."""
|
|
server, port = _start_server([{"id": "default-model"}])
|
|
try:
|
|
profile = ProviderProfile(
|
|
name="test",
|
|
base_url=f"http://127.0.0.1:{port}",
|
|
)
|
|
result = profile.fetch_models(api_key="test-key")
|
|
assert result == ["default-model"]
|
|
finally:
|
|
server.shutdown()
|
|
|
|
def test_no_base_url_returns_none(self):
|
|
"""When both base_url and self.base_url are empty, returns None."""
|
|
profile = ProviderProfile(name="test", base_url="")
|
|
result = profile.fetch_models(api_key="test-key", base_url="")
|
|
assert result is None
|
|
|
|
def test_base_url_override_with_models_url_set(self):
|
|
"""When self.models_url is set, base_url override is ignored (models_url wins)."""
|
|
server, port = _start_server([{"id": "from-models-url"}])
|
|
try:
|
|
profile = ProviderProfile(
|
|
name="test",
|
|
base_url="http://127.0.0.1:1",
|
|
models_url=f"http://127.0.0.1:{port}/models",
|
|
)
|
|
# base_url override should NOT be used because models_url takes priority
|
|
result = profile.fetch_models(
|
|
api_key="test-key",
|
|
base_url="http://127.0.0.1:1",
|
|
)
|
|
assert result == ["from-models-url"]
|
|
finally:
|
|
server.shutdown()
|
|
|
|
|
|
class TestCustomProviderBaseUrlPassthrough:
|
|
"""Custom provider (ollama/local) should pass base_url through to super."""
|
|
|
|
def test_custom_passes_base_url(self):
|
|
"""CustomProfile.fetch_models passes base_url to super()."""
|
|
server, port = _start_server([{"id": "ollama-model"}])
|
|
try:
|
|
from plugins.model_providers.custom import CustomProfile
|
|
profile = CustomProfile(
|
|
name="custom",
|
|
base_url="http://127.0.0.1:1", # wrong port
|
|
)
|
|
result = profile.fetch_models(
|
|
api_key="",
|
|
base_url=f"http://127.0.0.1:{port}",
|
|
)
|
|
assert result == ["ollama-model"]
|
|
finally:
|
|
server.shutdown()
|
|
|
|
|
|
class TestModelPickerBaseUrlIntegration:
|
|
"""The /model picker path should pass model.base_url to fetch_models."""
|
|
|
|
def test_picker_passes_base_url(self):
|
|
"""Verify models.py caller passes base_url to fetch_models."""
|
|
mock_profile = MagicMock()
|
|
mock_profile.auth_type = "api_key"
|
|
mock_profile.base_url = "https://default.api.com"
|
|
mock_profile.fetch_models.return_value = ["model-a"]
|
|
|
|
with (
|
|
patch("providers.get_provider_profile", return_value=mock_profile),
|
|
patch("hermes_cli.auth.resolve_api_key_provider_credentials",
|
|
return_value={"api_key": "sk-test", "base_url": "https://custom.proxy.com"}),
|
|
):
|
|
from hermes_cli.models import provider_model_ids
|
|
result = provider_model_ids("test-provider")
|
|
# Verify fetch_models was called with base_url
|
|
mock_profile.fetch_models.assert_called_once()
|
|
call_kwargs = mock_profile.fetch_models.call_args
|
|
assert call_kwargs.kwargs.get("base_url") == "https://custom.proxy.com"
|