mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(agent): preserve auto vision fallback for custom backends
This commit is contained in:
parent
1b4be27642
commit
887b3d4e38
2 changed files with 248 additions and 16 deletions
|
|
@ -183,6 +183,37 @@ def _is_likely_vision_model(model: str) -> bool:
|
|||
lower = model.lower()
|
||||
return any(indicator in lower for indicator in _VISION_MODEL_INDICATORS)
|
||||
|
||||
|
||||
def _get_named_custom_provider_entry(provider: Optional[str]) -> Optional[Dict[str, Any]]:
|
||||
"""Return the named custom-provider entry for *provider*, if any."""
|
||||
if not provider or provider == "custom":
|
||||
return None
|
||||
try:
|
||||
from hermes_cli.runtime_provider import _get_named_custom_provider
|
||||
entry = _get_named_custom_provider(provider)
|
||||
except Exception:
|
||||
return None
|
||||
return entry if isinstance(entry, dict) else None
|
||||
|
||||
|
||||
def _configured_vision_model(provider: Optional[str], model: Optional[str]) -> Optional[str]:
|
||||
"""Read per-model vision routing hints from named custom-provider config."""
|
||||
entry = _get_named_custom_provider_entry(provider)
|
||||
if not entry or not model:
|
||||
return None
|
||||
models_cfg = entry.get("models")
|
||||
if not isinstance(models_cfg, dict):
|
||||
return None
|
||||
model_cfg = models_cfg.get(model)
|
||||
if not isinstance(model_cfg, dict):
|
||||
return None
|
||||
explicit = model_cfg.get("vision_model") or model_cfg.get("multimodal_model")
|
||||
if isinstance(explicit, str) and explicit.strip():
|
||||
return explicit.strip()
|
||||
if model_cfg.get("supports_vision") is True or model_cfg.get("multimodal") is True:
|
||||
return model
|
||||
return None
|
||||
|
||||
# OpenRouter app attribution headers
|
||||
_OR_HEADERS = {
|
||||
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
|
||||
|
|
@ -1361,6 +1392,25 @@ def _is_connection_error(exc: Exception) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _is_vision_capability_error(exc: Exception) -> bool:
|
||||
"""Detect errors from sending image input to a text-only model."""
|
||||
status = getattr(exc, "status_code", None)
|
||||
if status not in (None, 400, 415, 422):
|
||||
return False
|
||||
err_lower = str(exc).lower()
|
||||
return any(phrase in err_lower for phrase in (
|
||||
"does not support image",
|
||||
"doesn't support image",
|
||||
"images are not supported",
|
||||
"image input is not supported",
|
||||
"vision is not supported",
|
||||
"vision not supported",
|
||||
"multimodal input is not supported",
|
||||
"model is not multimodal",
|
||||
"text-only model",
|
||||
))
|
||||
|
||||
|
||||
def _is_auth_error(exc: Exception) -> bool:
|
||||
"""Detect auth failures that should trigger provider-specific refresh."""
|
||||
status = getattr(exc, "status_code", None)
|
||||
|
|
@ -1479,6 +1529,41 @@ def _try_payment_fallback(
|
|||
return None, None, ""
|
||||
|
||||
|
||||
def _try_vision_fallback(
|
||||
failed_provider: str,
|
||||
*,
|
||||
async_mode: bool = False,
|
||||
reason: str = "vision capability error",
|
||||
) -> Tuple[Optional[Any], Optional[str], str]:
|
||||
"""Try the next strict vision backend after a main-provider vision failure."""
|
||||
skip = _normalize_vision_provider(failed_provider)
|
||||
tried = []
|
||||
for label in _VISION_AUTO_PROVIDER_ORDER:
|
||||
if label == skip:
|
||||
continue
|
||||
client, model = _resolve_strict_vision_backend(label)
|
||||
if client is not None:
|
||||
if async_mode:
|
||||
client, model = _to_async_client(client, model)
|
||||
logger.info(
|
||||
"Auxiliary vision: %s on %s — falling back to %s (%s)",
|
||||
reason,
|
||||
failed_provider,
|
||||
label,
|
||||
model or "default",
|
||||
)
|
||||
return client, model, label
|
||||
tried.append(label)
|
||||
|
||||
logger.warning(
|
||||
"Auxiliary vision: %s on %s and no fallback available (tried: %s)",
|
||||
reason,
|
||||
failed_provider,
|
||||
", ".join(tried),
|
||||
)
|
||||
return None, None, ""
|
||||
|
||||
|
||||
def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
"""Full auto-detection chain.
|
||||
|
||||
|
|
@ -2229,12 +2314,25 @@ def resolve_vision_provider_client(
|
|||
)
|
||||
return _finalize(main_provider, sync_client, default_model)
|
||||
else:
|
||||
vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model)
|
||||
vision_model = (
|
||||
_configured_vision_model(main_provider, main_model)
|
||||
or _PROVIDER_VISION_MODELS.get(main_provider)
|
||||
or main_model
|
||||
)
|
||||
# When no explicit vision override exists (fell through to
|
||||
# main_model), check if the main model looks vision-capable.
|
||||
# Trust direct/named custom providers: users may intentionally
|
||||
# point Hermes at a self-hosted multimodal model whose name is
|
||||
# unknown to our heuristic.
|
||||
trusted_main_model = (
|
||||
main_provider == "custom"
|
||||
or _get_named_custom_provider_entry(main_provider) is not None
|
||||
)
|
||||
# If not, skip directly to aggregator fallbacks instead of
|
||||
# sending an image payload to a text-only model (#14744).
|
||||
if vision_model == main_model and not _is_likely_vision_model(main_model):
|
||||
# sending an image payload to a clearly text-only model.
|
||||
if (vision_model == main_model
|
||||
and not trusted_main_model
|
||||
and not _is_likely_vision_model(main_model)):
|
||||
logger.info(
|
||||
"Vision auto-detect: skipping main provider %s "
|
||||
"(model %r is not vision-capable)",
|
||||
|
|
@ -2900,6 +2998,7 @@ def call_llm(
|
|||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
allow_auto_fallback = resolved_provider in ("auto", "", None)
|
||||
effective_extra_body = _get_task_extra_body(task)
|
||||
effective_extra_body.update(extra_body or {})
|
||||
|
||||
|
|
@ -3081,17 +3180,29 @@ def call_llm(
|
|||
# Codex/OAuth tokens that authenticate but whose endpoint is down,
|
||||
# and providers the user never configured that got picked up by
|
||||
# the auto-detection chain.
|
||||
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
vision_capability_error = task == "vision" and _is_vision_capability_error(first_err)
|
||||
should_fallback = (
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or vision_capability_error
|
||||
)
|
||||
# Only try alternative providers when the user didn't explicitly
|
||||
# configure this task's provider. Explicit provider = hard constraint;
|
||||
# auto (the default) = best-effort fallback chain. (#7559)
|
||||
is_auto = resolved_provider in ("auto", "", None)
|
||||
if should_fallback and is_auto:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
if should_fallback and allow_auto_fallback:
|
||||
if vision_capability_error:
|
||||
reason = "vision capability error"
|
||||
fb_client, fb_model, fb_label = _try_vision_fallback(
|
||||
resolved_provider,
|
||||
async_mode=False,
|
||||
reason=reason,
|
||||
)
|
||||
else:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
logger.info("Auxiliary %s: %s on %s (%s), trying fallback",
|
||||
task or "call", reason, resolved_provider, first_err)
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
if fb_client is not None:
|
||||
fb_kwargs = _build_call_kwargs(
|
||||
fb_label, fb_model, messages,
|
||||
|
|
@ -3180,6 +3291,7 @@ async def async_call_llm(
|
|||
"""
|
||||
resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model(
|
||||
task, provider, model, base_url, api_key)
|
||||
allow_auto_fallback = resolved_provider in ("auto", "", None)
|
||||
effective_extra_body = _get_task_extra_body(task)
|
||||
effective_extra_body.update(extra_body or {})
|
||||
|
||||
|
|
@ -3332,14 +3444,26 @@ async def async_call_llm(
|
|||
await retry_client.chat.completions.create(**retry_kwargs), task)
|
||||
|
||||
# ── Payment / connection fallback (mirrors sync call_llm) ─────
|
||||
should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err)
|
||||
is_auto = resolved_provider in ("auto", "", None)
|
||||
if should_fallback and is_auto:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
vision_capability_error = task == "vision" and _is_vision_capability_error(first_err)
|
||||
should_fallback = (
|
||||
_is_payment_error(first_err)
|
||||
or _is_connection_error(first_err)
|
||||
or vision_capability_error
|
||||
)
|
||||
if should_fallback and allow_auto_fallback:
|
||||
if vision_capability_error:
|
||||
reason = "vision capability error"
|
||||
fb_client, fb_model, fb_label = _try_vision_fallback(
|
||||
resolved_provider,
|
||||
async_mode=False,
|
||||
reason=reason,
|
||||
)
|
||||
else:
|
||||
reason = "payment error" if _is_payment_error(first_err) else "connection error"
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback",
|
||||
task or "call", reason, resolved_provider, first_err)
|
||||
fb_client, fb_model, fb_label = _try_payment_fallback(
|
||||
resolved_provider, task, reason=reason)
|
||||
if fb_client is not None:
|
||||
fb_kwargs = _build_call_kwargs(
|
||||
fb_label, fb_model, messages,
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ the auto-detection should skip directly to aggregator fallbacks instead of
|
|||
sending an image payload to a text-only model.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.auxiliary_client import _is_likely_vision_model
|
||||
from agent.auxiliary_client import _is_likely_vision_model, call_llm
|
||||
|
||||
|
||||
# -- _is_likely_vision_model heuristic --
|
||||
|
|
@ -149,6 +150,67 @@ class TestVisionNonVisionFallthrough:
|
|||
assert client is mock_client
|
||||
assert model == "mimo-v2.5"
|
||||
|
||||
def test_named_custom_provider_unknown_model_is_trusted(self):
|
||||
"""Named custom providers should not be skipped by the name heuristic."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="beans",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="my-company-vlm",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._get_named_custom_provider_entry",
|
||||
return_value={"name": "beans", "base_url": "http://vlm.test/v1"},
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "my-company-vlm"),
|
||||
):
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "beans"
|
||||
assert client is mock_client
|
||||
assert model == "my-company-vlm"
|
||||
|
||||
def test_named_custom_provider_can_declare_vision_model_override(self):
|
||||
"""Named custom providers can route vision to a dedicated model."""
|
||||
mock_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._read_main_provider",
|
||||
return_value="beans",
|
||||
), patch(
|
||||
"agent.auxiliary_client._read_main_model",
|
||||
return_value="chat-model",
|
||||
), patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client._get_named_custom_provider_entry",
|
||||
return_value={
|
||||
"name": "beans",
|
||||
"base_url": "http://vlm.test/v1",
|
||||
"models": {"chat-model": {"vision_model": "vision-model"}},
|
||||
},
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_provider_client",
|
||||
return_value=(mock_client, "vision-model"),
|
||||
) as mock_resolve:
|
||||
from agent.auxiliary_client import resolve_vision_provider_client
|
||||
|
||||
provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert provider == "beans"
|
||||
assert client is mock_client
|
||||
assert model == "vision-model"
|
||||
assert mock_resolve.call_args.args[:2] == ("beans", "vision-model")
|
||||
|
||||
def test_non_vision_model_all_aggregators_fail(self):
|
||||
"""Non-vision main + no aggregators available must return None."""
|
||||
with patch(
|
||||
|
|
@ -170,3 +232,49 @@ class TestVisionNonVisionFallthrough:
|
|||
|
||||
assert client is None
|
||||
assert model is None
|
||||
|
||||
|
||||
class VisionUnsupportedError(Exception):
|
||||
def __init__(self, message, status_code=400):
|
||||
super().__init__(message)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class TestVisionCapabilityFallback:
|
||||
def test_call_llm_retries_auto_vision_on_capability_error(self):
|
||||
"""A text-only main provider should fall through to strict vision backends."""
|
||||
primary_client = MagicMock()
|
||||
fallback_client = MagicMock()
|
||||
response = SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))]
|
||||
)
|
||||
|
||||
primary_client.chat.completions.create.side_effect = VisionUnsupportedError(
|
||||
"This model does not support image input"
|
||||
)
|
||||
fallback_client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client._resolve_task_provider_model",
|
||||
return_value=("auto", None, None, None, None),
|
||||
), patch(
|
||||
"agent.auxiliary_client.resolve_vision_provider_client",
|
||||
return_value=("ollama-cloud", primary_client, "llama3"),
|
||||
), patch(
|
||||
"agent.auxiliary_client._build_call_kwargs",
|
||||
return_value={
|
||||
"model": "llama3",
|
||||
"messages": [{"role": "user", "content": "analyze"}],
|
||||
},
|
||||
), patch(
|
||||
"agent.auxiliary_client._try_vision_fallback",
|
||||
return_value=(fallback_client, "google/gemini-3-flash-preview", "openrouter"),
|
||||
):
|
||||
result = call_llm(
|
||||
task="vision",
|
||||
messages=[{"role": "user", "content": "analyze"}],
|
||||
)
|
||||
|
||||
assert result is response
|
||||
assert primary_client.chat.completions.create.call_count == 1
|
||||
assert fallback_client.chat.completions.create.call_count == 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue