fix(agent): preserve auto vision fallback for custom backends

This commit is contained in:
Tranquil-Flow 2026-04-25 10:09:09 +10:00
parent 1b4be27642
commit 887b3d4e38
2 changed files with 248 additions and 16 deletions

View file

@ -6,11 +6,12 @@ the auto-detection should skip directly to aggregator fallbacks instead of
sending an image payload to a text-only model.
"""
from types import SimpleNamespace
from unittest.mock import patch, MagicMock
import pytest
from agent.auxiliary_client import _is_likely_vision_model
from agent.auxiliary_client import _is_likely_vision_model, call_llm
# -- _is_likely_vision_model heuristic --
@ -149,6 +150,67 @@ class TestVisionNonVisionFallthrough:
assert client is mock_client
assert model == "mimo-v2.5"
def test_named_custom_provider_unknown_model_is_trusted(self):
"""Named custom providers should not be skipped by the name heuristic."""
mock_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="beans",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="my-company-vlm",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._get_named_custom_provider_entry",
return_value={"name": "beans", "base_url": "http://vlm.test/v1"},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(mock_client, "my-company-vlm"),
):
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "beans"
assert client is mock_client
assert model == "my-company-vlm"
def test_named_custom_provider_can_declare_vision_model_override(self):
"""Named custom providers can route vision to a dedicated model."""
mock_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="beans",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="chat-model",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._get_named_custom_provider_entry",
return_value={
"name": "beans",
"base_url": "http://vlm.test/v1",
"models": {"chat-model": {"vision_model": "vision-model"}},
},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(mock_client, "vision-model"),
) as mock_resolve:
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "beans"
assert client is mock_client
assert model == "vision-model"
assert mock_resolve.call_args.args[:2] == ("beans", "vision-model")
def test_non_vision_model_all_aggregators_fail(self):
"""Non-vision main + no aggregators available must return None."""
with patch(
@ -170,3 +232,49 @@ class TestVisionNonVisionFallthrough:
assert client is None
assert model is None
class VisionUnsupportedError(Exception):
def __init__(self, message, status_code=400):
super().__init__(message)
self.status_code = status_code
class TestVisionCapabilityFallback:
def test_call_llm_retries_auto_vision_on_capability_error(self):
"""A text-only main provider should fall through to strict vision backends."""
primary_client = MagicMock()
fallback_client = MagicMock()
response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))]
)
primary_client.chat.completions.create.side_effect = VisionUnsupportedError(
"This model does not support image input"
)
fallback_client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client.resolve_vision_provider_client",
return_value=("ollama-cloud", primary_client, "llama3"),
), patch(
"agent.auxiliary_client._build_call_kwargs",
return_value={
"model": "llama3",
"messages": [{"role": "user", "content": "analyze"}],
},
), patch(
"agent.auxiliary_client._try_vision_fallback",
return_value=(fallback_client, "google/gemini-3-flash-preview", "openrouter"),
):
result = call_llm(
task="vision",
messages=[{"role": "user", "content": "analyze"}],
)
assert result is response
assert primary_client.chat.completions.create.call_count == 1
assert fallback_client.chat.completions.create.call_count == 1