fix(agent): preserve auto vision fallback for custom backends

This commit is contained in:
Tranquil-Flow 2026-04-25 10:09:09 +10:00
parent 1b4be27642
commit 887b3d4e38
2 changed files with 248 additions and 16 deletions

View file

@ -183,6 +183,37 @@ def _is_likely_vision_model(model: str) -> bool:
lower = model.lower() lower = model.lower()
return any(indicator in lower for indicator in _VISION_MODEL_INDICATORS) return any(indicator in lower for indicator in _VISION_MODEL_INDICATORS)
def _get_named_custom_provider_entry(provider: Optional[str]) -> Optional[Dict[str, Any]]:
"""Return the named custom-provider entry for *provider*, if any."""
if not provider or provider == "custom":
return None
try:
from hermes_cli.runtime_provider import _get_named_custom_provider
entry = _get_named_custom_provider(provider)
except Exception:
return None
return entry if isinstance(entry, dict) else None
def _configured_vision_model(provider: Optional[str], model: Optional[str]) -> Optional[str]:
"""Read per-model vision routing hints from named custom-provider config."""
entry = _get_named_custom_provider_entry(provider)
if not entry or not model:
return None
models_cfg = entry.get("models")
if not isinstance(models_cfg, dict):
return None
model_cfg = models_cfg.get(model)
if not isinstance(model_cfg, dict):
return None
explicit = model_cfg.get("vision_model") or model_cfg.get("multimodal_model")
if isinstance(explicit, str) and explicit.strip():
return explicit.strip()
if model_cfg.get("supports_vision") is True or model_cfg.get("multimodal") is True:
return model
return None
# OpenRouter app attribution headers # OpenRouter app attribution headers
_OR_HEADERS = { _OR_HEADERS = {
"HTTP-Referer": "https://hermes-agent.nousresearch.com", "HTTP-Referer": "https://hermes-agent.nousresearch.com",
@ -1361,6 +1392,25 @@ def _is_connection_error(exc: Exception) -> bool:
return False return False
def _is_vision_capability_error(exc: Exception) -> bool:
"""Detect errors from sending image input to a text-only model."""
status = getattr(exc, "status_code", None)
if status not in (None, 400, 415, 422):
return False
err_lower = str(exc).lower()
return any(phrase in err_lower for phrase in (
"does not support image",
"doesn't support image",
"images are not supported",
"image input is not supported",
"vision is not supported",
"vision not supported",
"multimodal input is not supported",
"model is not multimodal",
"text-only model",
))
def _is_auth_error(exc: Exception) -> bool: def _is_auth_error(exc: Exception) -> bool:
"""Detect auth failures that should trigger provider-specific refresh.""" """Detect auth failures that should trigger provider-specific refresh."""
status = getattr(exc, "status_code", None) status = getattr(exc, "status_code", None)
@ -1479,6 +1529,41 @@ def _try_payment_fallback(
return None, None, "" return None, None, ""
def _try_vision_fallback(
failed_provider: str,
*,
async_mode: bool = False,
reason: str = "vision capability error",
) -> Tuple[Optional[Any], Optional[str], str]:
"""Try the next strict vision backend after a main-provider vision failure."""
skip = _normalize_vision_provider(failed_provider)
tried = []
for label in _VISION_AUTO_PROVIDER_ORDER:
if label == skip:
continue
client, model = _resolve_strict_vision_backend(label)
if client is not None:
if async_mode:
client, model = _to_async_client(client, model)
logger.info(
"Auxiliary vision: %s on %s — falling back to %s (%s)",
reason,
failed_provider,
label,
model or "default",
)
return client, model, label
tried.append(label)
logger.warning(
"Auxiliary vision: %s on %s and no fallback available (tried: %s)",
reason,
failed_provider,
", ".join(tried),
)
return None, None, ""
def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]: def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]:
"""Full auto-detection chain. """Full auto-detection chain.
@ -2229,12 +2314,25 @@ def resolve_vision_provider_client(
) )
return _finalize(main_provider, sync_client, default_model) return _finalize(main_provider, sync_client, default_model)
else: else:
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model) vision_model = (
_configured_vision_model(main_provider, main_model)
or _PROVIDER_VISION_MODELS.get(main_provider)
or main_model
)
# When no explicit vision override exists (fell through to # When no explicit vision override exists (fell through to
# main_model), check if the main model looks vision-capable. # main_model), check if the main model looks vision-capable.
# Trust direct/named custom providers: users may intentionally
# point Hermes at a self-hosted multimodal model whose name is
# unknown to our heuristic.
trusted_main_model = (
main_provider == "custom"
or _get_named_custom_provider_entry(main_provider) is not None
)
# If not, skip directly to aggregator fallbacks instead of # If not, skip directly to aggregator fallbacks instead of
# sending an image payload to a text-only model (#14744). # sending an image payload to a clearly text-only model.
if vision_model == main_model and not _is_likely_vision_model(main_model): if (vision_model == main_model
and not trusted_main_model
and not _is_likely_vision_model(main_model)):
logger.info( logger.info(
"Vision auto-detect: skipping main provider %s " "Vision auto-detect: skipping main provider %s "
"(model %r is not vision-capable)", "(model %r is not vision-capable)",
@ -2900,6 +2998,7 @@ def call_llm(
""" """
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
task, provider, model, base_url, api_key) task, provider, model, base_url, api_key)
allow_auto_fallback = resolved_provider in ("auto", "", None)
effective_extra_body = _get_task_extra_body(task) effective_extra_body = _get_task_extra_body(task)
effective_extra_body.update(extra_body or {}) effective_extra_body.update(extra_body or {})
@ -3081,17 +3180,29 @@ def call_llm(
# Codex/OAuth tokens that authenticate but whose endpoint is down, # Codex/OAuth tokens that authenticate but whose endpoint is down,
# and providers the user never configured that got picked up by # and providers the user never configured that got picked up by
# the auto-detection chain. # the auto-detection chain.
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) vision_capability_error = task == "vision" and _is_vision_capability_error(first_err)
should_fallback = (
_is_payment_error(first_err)
or _is_connection_error(first_err)
or vision_capability_error
)
# Only try alternative providers when the user didn't explicitly # Only try alternative providers when the user didn't explicitly
# configure this task's provider. Explicit provider = hard constraint; # configure this task's provider. Explicit provider = hard constraint;
# auto (the default) = best-effort fallback chain. (#7559) # auto (the default) = best-effort fallback chain. (#7559)
is_auto = resolved_provider in ("auto", "", None) if should_fallback and allow_auto_fallback:
if should_fallback and is_auto: if vision_capability_error:
reason = "payment error" if _is_payment_error(first_err) else "connection error" reason = "vision capability error"
fb_client, fb_model, fb_label = _try_vision_fallback(
resolved_provider,
async_mode=False,
reason=reason,
)
else:
reason = "payment error" if _is_payment_error(first_err) else "connection error"
fb_client, fb_model, fb_label = _try_payment_fallback(
resolved_provider, task, reason=reason)
logger.info("Auxiliary %s: %s on %s (%s), trying fallback", logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
task or "call", reason, resolved_provider, first_err) task or "call", reason, resolved_provider, first_err)
fb_client, fb_model, fb_label = _try_payment_fallback(
resolved_provider, task, reason=reason)
if fb_client is not None: if fb_client is not None:
fb_kwargs = _build_call_kwargs( fb_kwargs = _build_call_kwargs(
fb_label, fb_model, messages, fb_label, fb_model, messages,
@ -3180,6 +3291,7 @@ async def async_call_llm(
""" """
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
task, provider, model, base_url, api_key) task, provider, model, base_url, api_key)
allow_auto_fallback = resolved_provider in ("auto", "", None)
effective_extra_body = _get_task_extra_body(task) effective_extra_body = _get_task_extra_body(task)
effective_extra_body.update(extra_body or {}) effective_extra_body.update(extra_body or {})
@ -3332,14 +3444,26 @@ async def async_call_llm(
await retry_client.chat.completions.create(**retry_kwargs), task) await retry_client.chat.completions.create(**retry_kwargs), task)
# ── Payment / connection fallback (mirrors sync call_llm) ───── # ── Payment / connection fallback (mirrors sync call_llm) ─────
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) vision_capability_error = task == "vision" and _is_vision_capability_error(first_err)
is_auto = resolved_provider in ("auto", "", None) should_fallback = (
if should_fallback and is_auto: _is_payment_error(first_err)
reason = "payment error" if _is_payment_error(first_err) else "connection error" or _is_connection_error(first_err)
or vision_capability_error
)
if should_fallback and allow_auto_fallback:
if vision_capability_error:
reason = "vision capability error"
fb_client, fb_model, fb_label = _try_vision_fallback(
resolved_provider,
async_mode=False,
reason=reason,
)
else:
reason = "payment error" if _is_payment_error(first_err) else "connection error"
fb_client, fb_model, fb_label = _try_payment_fallback(
resolved_provider, task, reason=reason)
logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback", logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback",
task or "call", reason, resolved_provider, first_err) task or "call", reason, resolved_provider, first_err)
fb_client, fb_model, fb_label = _try_payment_fallback(
resolved_provider, task, reason=reason)
if fb_client is not None: if fb_client is not None:
fb_kwargs = _build_call_kwargs( fb_kwargs = _build_call_kwargs(
fb_label, fb_model, messages, fb_label, fb_model, messages,

View file

@ -6,11 +6,12 @@ the auto-detection should skip directly to aggregator fallbacks instead of
sending an image payload to a text-only model. sending an image payload to a text-only model.
""" """
from types import SimpleNamespace
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
import pytest import pytest
from agent.auxiliary_client import _is_likely_vision_model from agent.auxiliary_client import _is_likely_vision_model, call_llm
# -- _is_likely_vision_model heuristic -- # -- _is_likely_vision_model heuristic --
@ -149,6 +150,67 @@ class TestVisionNonVisionFallthrough:
assert client is mock_client assert client is mock_client
assert model == "mimo-v2.5" assert model == "mimo-v2.5"
def test_named_custom_provider_unknown_model_is_trusted(self):
"""Named custom providers should not be skipped by the name heuristic."""
mock_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="beans",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="my-company-vlm",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._get_named_custom_provider_entry",
return_value={"name": "beans", "base_url": "http://vlm.test/v1"},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(mock_client, "my-company-vlm"),
):
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "beans"
assert client is mock_client
assert model == "my-company-vlm"
def test_named_custom_provider_can_declare_vision_model_override(self):
"""Named custom providers can route vision to a dedicated model."""
mock_client = MagicMock()
with patch(
"agent.auxiliary_client._read_main_provider",
return_value="beans",
), patch(
"agent.auxiliary_client._read_main_model",
return_value="chat-model",
), patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client._get_named_custom_provider_entry",
return_value={
"name": "beans",
"base_url": "http://vlm.test/v1",
"models": {"chat-model": {"vision_model": "vision-model"}},
},
), patch(
"agent.auxiliary_client.resolve_provider_client",
return_value=(mock_client, "vision-model"),
) as mock_resolve:
from agent.auxiliary_client import resolve_vision_provider_client
provider, client, model = resolve_vision_provider_client()
assert provider == "beans"
assert client is mock_client
assert model == "vision-model"
assert mock_resolve.call_args.args[:2] == ("beans", "vision-model")
def test_non_vision_model_all_aggregators_fail(self): def test_non_vision_model_all_aggregators_fail(self):
"""Non-vision main + no aggregators available must return None.""" """Non-vision main + no aggregators available must return None."""
with patch( with patch(
@ -170,3 +232,49 @@ class TestVisionNonVisionFallthrough:
assert client is None assert client is None
assert model is None assert model is None
class VisionUnsupportedError(Exception):
def __init__(self, message, status_code=400):
super().__init__(message)
self.status_code = status_code
class TestVisionCapabilityFallback:
def test_call_llm_retries_auto_vision_on_capability_error(self):
"""A text-only main provider should fall through to strict vision backends."""
primary_client = MagicMock()
fallback_client = MagicMock()
response = SimpleNamespace(
choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))]
)
primary_client.chat.completions.create.side_effect = VisionUnsupportedError(
"This model does not support image input"
)
fallback_client.chat.completions.create.return_value = response
with patch(
"agent.auxiliary_client._resolve_task_provider_model",
return_value=("auto", None, None, None, None),
), patch(
"agent.auxiliary_client.resolve_vision_provider_client",
return_value=("ollama-cloud", primary_client, "llama3"),
), patch(
"agent.auxiliary_client._build_call_kwargs",
return_value={
"model": "llama3",
"messages": [{"role": "user", "content": "analyze"}],
},
), patch(
"agent.auxiliary_client._try_vision_fallback",
return_value=(fallback_client, "google/gemini-3-flash-preview", "openrouter"),
):
result = call_llm(
task="vision",
messages=[{"role": "user", "content": "analyze"}],
)
assert result is response
assert primary_client.chat.completions.create.call_count == 1
assert fallback_client.chat.completions.create.call_count == 1