diff --git a/agent/auxiliary_client.py b/agent/auxiliary_client.py index b802a796d..5632096c4 100644 --- a/agent/auxiliary_client.py +++ b/agent/auxiliary_client.py @@ -183,6 +183,37 @@ def _is_likely_vision_model(model: str) -> bool: lower = model.lower() return any(indicator in lower for indicator in _VISION_MODEL_INDICATORS) + +def _get_named_custom_provider_entry(provider: Optional[str]) -> Optional[Dict[str, Any]]: + """Return the named custom-provider entry for *provider*, if any.""" + if not provider or provider == "custom": + return None + try: + from hermes_cli.runtime_provider import _get_named_custom_provider + entry = _get_named_custom_provider(provider) + except Exception: + return None + return entry if isinstance(entry, dict) else None + + +def _configured_vision_model(provider: Optional[str], model: Optional[str]) -> Optional[str]: + """Read per-model vision routing hints from named custom-provider config.""" + entry = _get_named_custom_provider_entry(provider) + if not entry or not model: + return None + models_cfg = entry.get("models") + if not isinstance(models_cfg, dict): + return None + model_cfg = models_cfg.get(model) + if not isinstance(model_cfg, dict): + return None + explicit = model_cfg.get("vision_model") or model_cfg.get("multimodal_model") + if isinstance(explicit, str) and explicit.strip(): + return explicit.strip() + if model_cfg.get("supports_vision") is True or model_cfg.get("multimodal") is True: + return model + return None + # OpenRouter app attribution headers _OR_HEADERS = { "HTTP-Referer": "https://hermes-agent.nousresearch.com", @@ -1361,6 +1392,25 @@ def _is_connection_error(exc: Exception) -> bool: return False +def _is_vision_capability_error(exc: Exception) -> bool: + """Detect errors from sending image input to a text-only model.""" + status = getattr(exc, "status_code", None) + if status not in (None, 400, 415, 422): + return False + err_lower = str(exc).lower() + return any(phrase in err_lower for phrase in ( + "does not support image", + "doesn't support image", + "images are not supported", + "image input is not supported", + "vision is not supported", + "vision not supported", + "multimodal input is not supported", + "model is not multimodal", + "text-only model", + )) + + def _is_auth_error(exc: Exception) -> bool: """Detect auth failures that should trigger provider-specific refresh.""" status = getattr(exc, "status_code", None) @@ -1479,6 +1529,41 @@ def _try_payment_fallback( return None, None, "" +def _try_vision_fallback( + failed_provider: str, + *, + async_mode: bool = False, + reason: str = "vision capability error", +) -> Tuple[Optional[Any], Optional[str], str]: + """Try the next strict vision backend after a main-provider vision failure.""" + skip = _normalize_vision_provider(failed_provider) + tried = [] + for label in _VISION_AUTO_PROVIDER_ORDER: + if label == skip: + continue + client, model = _resolve_strict_vision_backend(label) + if client is not None: + if async_mode: + client, model = _to_async_client(client, model) + logger.info( + "Auxiliary vision: %s on %s — falling back to %s (%s)", + reason, + failed_provider, + label, + model or "default", + ) + return client, model, label + tried.append(label) + + logger.warning( + "Auxiliary vision: %s on %s and no fallback available (tried: %s)", + reason, + failed_provider, + ", ".join(tried), + ) + return None, None, "" + + def _resolve_auto(main_runtime: Optional[Dict[str, Any]] = None) -> Tuple[Optional[OpenAI], Optional[str]]: """Full auto-detection chain. @@ -2229,12 +2314,25 @@ def resolve_vision_provider_client( ) return _finalize(main_provider, sync_client, default_model) else: - vision_model = _PROVIDER_VISION_MODELS.get(main_provider, main_model) + vision_model = ( + _configured_vision_model(main_provider, main_model) + or _PROVIDER_VISION_MODELS.get(main_provider) + or main_model + ) # When no explicit vision override exists (fell through to # main_model), check if the main model looks vision-capable. + # Trust direct/named custom providers: users may intentionally + # point Hermes at a self-hosted multimodal model whose name is + # unknown to our heuristic. + trusted_main_model = ( + main_provider == "custom" + or _get_named_custom_provider_entry(main_provider) is not None + ) # If not, skip directly to aggregator fallbacks instead of - # sending an image payload to a text-only model (#14744). - if vision_model == main_model and not _is_likely_vision_model(main_model): + # sending an image payload to a clearly text-only model. + if (vision_model == main_model + and not trusted_main_model + and not _is_likely_vision_model(main_model)): logger.info( "Vision auto-detect: skipping main provider %s " "(model %r is not vision-capable)", @@ -2900,6 +2998,7 @@ def call_llm( """ resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( task, provider, model, base_url, api_key) + allow_auto_fallback = resolved_provider in ("auto", "", None) effective_extra_body = _get_task_extra_body(task) effective_extra_body.update(extra_body or {}) @@ -3081,17 +3180,29 @@ def call_llm( # Codex/OAuth tokens that authenticate but whose endpoint is down, # and providers the user never configured that got picked up by # the auto-detection chain. - should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) + vision_capability_error = task == "vision" and _is_vision_capability_error(first_err) + should_fallback = ( + _is_payment_error(first_err) + or _is_connection_error(first_err) + or vision_capability_error + ) # Only try alternative providers when the user didn't explicitly # configure this task's provider. Explicit provider = hard constraint; # auto (the default) = best-effort fallback chain. (#7559) - is_auto = resolved_provider in ("auto", "", None) - if should_fallback and is_auto: - reason = "payment error" if _is_payment_error(first_err) else "connection error" + if should_fallback and allow_auto_fallback: + if vision_capability_error: + reason = "vision capability error" + fb_client, fb_model, fb_label = _try_vision_fallback( + resolved_provider, + async_mode=False, + reason=reason, + ) + else: + reason = "payment error" if _is_payment_error(first_err) else "connection error" + fb_client, fb_model, fb_label = _try_payment_fallback( + resolved_provider, task, reason=reason) logger.info("Auxiliary %s: %s on %s (%s), trying fallback", task or "call", reason, resolved_provider, first_err) - fb_client, fb_model, fb_label = _try_payment_fallback( - resolved_provider, task, reason=reason) if fb_client is not None: fb_kwargs = _build_call_kwargs( fb_label, fb_model, messages, @@ -3180,6 +3291,7 @@ async def async_call_llm( """ resolved_provider, resolved_model, resolved_base_url, resolved_api_key, resolved_api_mode = _resolve_task_provider_model( task, provider, model, base_url, api_key) + allow_auto_fallback = resolved_provider in ("auto", "", None) effective_extra_body = _get_task_extra_body(task) effective_extra_body.update(extra_body or {}) @@ -3332,14 +3444,26 @@ async def async_call_llm( await retry_client.chat.completions.create(**retry_kwargs), task) # ── Payment / connection fallback (mirrors sync call_llm) ───── - should_fallback = _is_payment_error(first_err) or _is_connection_error(first_err) - is_auto = resolved_provider in ("auto", "", None) - if should_fallback and is_auto: - reason = "payment error" if _is_payment_error(first_err) else "connection error" + vision_capability_error = task == "vision" and _is_vision_capability_error(first_err) + should_fallback = ( + _is_payment_error(first_err) + or _is_connection_error(first_err) + or vision_capability_error + ) + if should_fallback and allow_auto_fallback: + if vision_capability_error: + reason = "vision capability error" + fb_client, fb_model, fb_label = _try_vision_fallback( + resolved_provider, + async_mode=False, + reason=reason, + ) + else: + reason = "payment error" if _is_payment_error(first_err) else "connection error" + fb_client, fb_model, fb_label = _try_payment_fallback( + resolved_provider, task, reason=reason) logger.info("Auxiliary %s (async): %s on %s (%s), trying fallback", task or "call", reason, resolved_provider, first_err) - fb_client, fb_model, fb_label = _try_payment_fallback( - resolved_provider, task, reason=reason) if fb_client is not None: fb_kwargs = _build_call_kwargs( fb_label, fb_model, messages, diff --git a/tests/agent/test_vision_non_vision_fallthrough.py b/tests/agent/test_vision_non_vision_fallthrough.py index 0eb8ab2c4..faba0b152 100644 --- a/tests/agent/test_vision_non_vision_fallthrough.py +++ b/tests/agent/test_vision_non_vision_fallthrough.py @@ -6,11 +6,12 @@ the auto-detection should skip directly to aggregator fallbacks instead of sending an image payload to a text-only model. """ +from types import SimpleNamespace from unittest.mock import patch, MagicMock import pytest -from agent.auxiliary_client import _is_likely_vision_model +from agent.auxiliary_client import _is_likely_vision_model, call_llm # -- _is_likely_vision_model heuristic -- @@ -149,6 +150,67 @@ class TestVisionNonVisionFallthrough: assert client is mock_client assert model == "mimo-v2.5" + def test_named_custom_provider_unknown_model_is_trusted(self): + """Named custom providers should not be skipped by the name heuristic.""" + mock_client = MagicMock() + + with patch( + "agent.auxiliary_client._read_main_provider", + return_value="beans", + ), patch( + "agent.auxiliary_client._read_main_model", + return_value="my-company-vlm", + ), patch( + "agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", None, None, None, None), + ), patch( + "agent.auxiliary_client._get_named_custom_provider_entry", + return_value={"name": "beans", "base_url": "http://vlm.test/v1"}, + ), patch( + "agent.auxiliary_client.resolve_provider_client", + return_value=(mock_client, "my-company-vlm"), + ): + from agent.auxiliary_client import resolve_vision_provider_client + + provider, client, model = resolve_vision_provider_client() + + assert provider == "beans" + assert client is mock_client + assert model == "my-company-vlm" + + def test_named_custom_provider_can_declare_vision_model_override(self): + """Named custom providers can route vision to a dedicated model.""" + mock_client = MagicMock() + + with patch( + "agent.auxiliary_client._read_main_provider", + return_value="beans", + ), patch( + "agent.auxiliary_client._read_main_model", + return_value="chat-model", + ), patch( + "agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", None, None, None, None), + ), patch( + "agent.auxiliary_client._get_named_custom_provider_entry", + return_value={ + "name": "beans", + "base_url": "http://vlm.test/v1", + "models": {"chat-model": {"vision_model": "vision-model"}}, + }, + ), patch( + "agent.auxiliary_client.resolve_provider_client", + return_value=(mock_client, "vision-model"), + ) as mock_resolve: + from agent.auxiliary_client import resolve_vision_provider_client + + provider, client, model = resolve_vision_provider_client() + + assert provider == "beans" + assert client is mock_client + assert model == "vision-model" + assert mock_resolve.call_args.args[:2] == ("beans", "vision-model") + def test_non_vision_model_all_aggregators_fail(self): """Non-vision main + no aggregators available must return None.""" with patch( @@ -170,3 +232,49 @@ class TestVisionNonVisionFallthrough: assert client is None assert model is None + + +class VisionUnsupportedError(Exception): + def __init__(self, message, status_code=400): + super().__init__(message) + self.status_code = status_code + + +class TestVisionCapabilityFallback: + def test_call_llm_retries_auto_vision_on_capability_error(self): + """A text-only main provider should fall through to strict vision backends.""" + primary_client = MagicMock() + fallback_client = MagicMock() + response = SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="ok"))] + ) + + primary_client.chat.completions.create.side_effect = VisionUnsupportedError( + "This model does not support image input" + ) + fallback_client.chat.completions.create.return_value = response + + with patch( + "agent.auxiliary_client._resolve_task_provider_model", + return_value=("auto", None, None, None, None), + ), patch( + "agent.auxiliary_client.resolve_vision_provider_client", + return_value=("ollama-cloud", primary_client, "llama3"), + ), patch( + "agent.auxiliary_client._build_call_kwargs", + return_value={ + "model": "llama3", + "messages": [{"role": "user", "content": "analyze"}], + }, + ), patch( + "agent.auxiliary_client._try_vision_fallback", + return_value=(fallback_client, "google/gemini-3-flash-preview", "openrouter"), + ): + result = call_llm( + task="vision", + messages=[{"role": "user", "content": "analyze"}], + ) + + assert result is response + assert primary_client.chat.completions.create.call_count == 1 + assert fallback_client.chat.completions.create.call_count == 1