fix(computer_use): honor custom vision routing

This commit is contained in:
helix4u 2026-06-03 21:03:31 -06:00 committed by Teknium
parent ffe665277c
commit 591e6fb8f4
6 changed files with 207 additions and 7 deletions

View file

@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
import base64
import json import json
import os import os
import sys import sys
@ -360,7 +361,9 @@ class TestCaptureResponse:
def focus_app(self, app, raise_window=False): ... def focus_app(self, app, raise_window=False): ...
cu_tool.reset_backend_for_tests() cu_tool.reset_backend_for_tests()
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()): with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
patch.object(cu_tool, "_should_route_through_aux_vision",
return_value=False):
out = cu_tool.handle_computer_use({"action": "capture", "mode": "vision"}) out = cu_tool.handle_computer_use({"action": "capture", "mode": "vision"})
assert isinstance(out, dict) assert isinstance(out, dict)
@ -398,7 +401,9 @@ class TestCaptureResponse:
def focus_app(self, app, raise_window=False): ... def focus_app(self, app, raise_window=False): ...
cu_tool.reset_backend_for_tests() cu_tool.reset_backend_for_tests()
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()): with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
patch.object(cu_tool, "_should_route_through_aux_vision",
return_value=False):
out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"}) out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"})
assert isinstance(out, dict) assert isinstance(out, dict)
text_part = next(p for p in out["content"] if p.get("type") == "text") text_part = next(p for p in out["content"] if p.get("type") == "text")
@ -436,6 +441,7 @@ class TestCaptureResponse:
return FakeBackend() return FakeBackend()
def test_capture_ax_caps_elements_at_default_for_dense_trees(self): def test_capture_ax_caps_elements_at_default_for_dense_trees(self):
"""Regression for #22865: an Electron-style 600-element AX tree must """Regression for #22865: an Electron-style 600-element AX tree must
not emit the entire array verbatim into the tool result. not emit the entire array verbatim into the tool result.
@ -582,7 +588,9 @@ class TestCaptureResponse:
def focus_app(self, app, raise_window=False): ... def focus_app(self, app, raise_window=False): ...
cu_tool.reset_backend_for_tests() cu_tool.reset_backend_for_tests()
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()): with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
patch.object(cu_tool, "_should_route_through_aux_vision",
return_value=False):
out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"}) out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"})
assert isinstance(out, dict) and out["_multimodal"] is True assert isinstance(out, dict) and out["_multimodal"] is True
@ -594,6 +602,32 @@ class TestCaptureResponse:
assert "truncated to" not in out["text_summary"] assert "truncated to" not in out["text_summary"]
class TestCuaCaptureImageDimensions:
def test_png_dimensions_are_sniffed_from_image_bytes(self):
from tools.computer_use.cua_backend import _image_dimensions_from_bytes
raw_png = base64.b64decode(
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42m"
"NkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=",
validate=False,
)
assert _image_dimensions_from_bytes(raw_png) == (1, 1)
def test_jpeg_dimensions_are_sniffed_from_sof_segment(self):
from tools.computer_use.cua_backend import _image_dimensions_from_bytes
raw_jpeg = (
b"\xff\xd8" +
b"\xff\xe0\x00\x10" + (b"0" * 14)
+ b"\xff\xc0\x00\x11\x08"
+ b"\x01\x2c" # height: 300
+ b"\x01\x90" # width: 400
+ b"\x03\x01\x11\x00\x02\x11\x00\x03\x11\x00"
+ b"\xff\xd9"
)
assert _image_dimensions_from_bytes(raw_jpeg) == (400, 300)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Anthropic adapter: multimodal tool-result conversion # Anthropic adapter: multimodal tool-result conversion
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View file

@ -241,6 +241,39 @@ class TestCaptureResponseRoutedToAuxVision:
assert observed_path["path"] assert observed_path["path"]
assert not os.path.exists(observed_path["path"]) assert not os.path.exists(observed_path["path"])
def test_aux_route_creates_missing_cache_dir(self, tmp_path):
from tools.computer_use import tool as cu_tool
cache_dir = tmp_path / "missing" / "cache_vision"
cap = _make_capture(mode="som")
observed_path = {}
def _fake_get(*_args, **_kw):
return cache_dir
def _fake_run_async(_coro):
return _stub_aux_analysis("description goes here")
def _fake_vat(image_path, _prompt):
observed_path["path"] = image_path
assert os.path.exists(image_path)
return "<coro>"
fake_vat = MagicMock(side_effect=_fake_vat)
with patch.object(cu_tool, "_should_route_through_aux_vision",
return_value=True), \
patch("hermes_constants.get_hermes_dir", _fake_get), \
patch("model_tools._run_async", side_effect=_fake_run_async), \
patch("tools.vision_tools.vision_analyze_tool",
new_callable=lambda: fake_vat):
resp = cu_tool._capture_response(cap)
assert isinstance(resp, str)
assert cache_dir.is_dir()
assert observed_path["path"]
assert not os.path.exists(observed_path["path"])
def test_temp_file_cleaned_up_even_when_aux_call_raises( def test_temp_file_cleaned_up_even_when_aux_call_raises(
self, tmp_cache_dir, self, tmp_cache_dir,
): ):

View file

@ -160,6 +160,42 @@ class TestRouteDecision:
"some-aggregator", "some-vision-model", {} "some-aggregator", "some-vision-model", {}
) is True ) is True
def test_user_declared_vision_support_keeps_custom_provider_native(self):
"""Local/custom VLMs use config as their tool-result image escape hatch."""
from tools.computer_use import vision_routing
cfg = {
"model": {
"default": "Qwen3.6-35B-A3B-local-vlm",
"provider": "omlx",
"supports_vision": True,
}
}
with patch.object(vision_routing,
"_provider_accepts_multimodal_tool_result",
return_value=False):
assert vision_routing.should_route_capture_to_aux_vision(
"custom", "Qwen3.6-35B-A3B-local-vlm", cfg
) is False
def test_user_declared_no_vision_routes_custom_provider_to_aux(self):
"""An explicit false override should not fall through to native routing."""
from tools.computer_use import vision_routing
cfg = {
"model": {
"default": "local-text-model",
"provider": "omlx",
"supports_vision": False,
}
}
with patch.object(vision_routing,
"_provider_accepts_multimodal_tool_result",
return_value=True):
assert vision_routing.should_route_capture_to_aux_vision(
"custom", "local-text-model", cfg
) is True
def test_unknown_provider_capabilities_fail_closed(self): def test_unknown_provider_capabilities_fail_closed(self):
"""When tool-result lookup returns None, route to aux (safe default).""" """When tool-result lookup returns None, route to aux (safe default)."""
from tools.computer_use import vision_routing from tools.computer_use import vision_routing

View file

@ -126,6 +126,45 @@ def _parse_elements_from_tree(markdown: str) -> List[UIElement]:
return elements return elements
def _image_dimensions_from_bytes(raw: bytes) -> Tuple[int, int]:
"""Best-effort PNG/JPEG dimension sniffing without extra dependencies."""
if raw.startswith(b"\x89PNG\r\n\x1a\n") and len(raw) >= 24:
width = int.from_bytes(raw[16:20], "big")
height = int.from_bytes(raw[20:24], "big")
if width > 0 and height > 0:
return width, height
if raw.startswith(b"\xff\xd8"):
i = 2
n = len(raw)
while i + 9 < n:
if raw[i] != 0xFF:
i += 1
continue
marker = raw[i + 1]
i += 2
if marker in {0xD8, 0xD9} or 0xD0 <= marker <= 0xD7:
continue
if i + 2 > n:
break
segment_len = int.from_bytes(raw[i:i + 2], "big")
if segment_len < 2 or i + segment_len > n:
break
if marker in {
0xC0, 0xC1, 0xC2, 0xC3, 0xC5, 0xC6, 0xC7,
0xC9, 0xCA, 0xCB, 0xCD, 0xCE, 0xCF,
}:
if segment_len >= 7:
height = int.from_bytes(raw[i + 3:i + 5], "big")
width = int.from_bytes(raw[i + 5:i + 7], "big")
if width > 0 and height > 0:
return width, height
break
i += segment_len
return 0, 0
def _split_tree_text(full_text: str) -> Tuple[str, str]: def _split_tree_text(full_text: str) -> Tuple[str, str]:
"""Split get_window_state text into (summary_line, tree_markdown).""" """Split get_window_state text into (summary_line, tree_markdown)."""
lines = full_text.split("\n", 1) lines = full_text.split("\n", 1)
@ -491,7 +530,12 @@ class CuaDriverBackend(ComputerUseBackend):
png_bytes_len = 0 png_bytes_len = 0
if png_b64: if png_b64:
try: try:
png_bytes_len = len(base64.b64decode(png_b64, validate=False)) raw = base64.b64decode(png_b64, validate=False)
png_bytes_len = len(raw)
detected_width, detected_height = _image_dimensions_from_bytes(raw)
if detected_width and detected_height:
width = detected_width
height = detected_height
except Exception: except Exception:
png_bytes_len = len(png_b64) * 3 // 4 png_bytes_len = len(png_b64) * 3 // 4

View file

@ -615,6 +615,7 @@ def _route_capture_through_aux_vision(
# MIME sniffing returns the right content-type. # MIME sniffing returns the right content-type.
ext = ".jpg" if cap.png_b64[:8].startswith("/9j/") else ".png" ext = ".jpg" if cap.png_b64[:8].startswith("/9j/") else ".png"
cache_dir = get_hermes_dir("cache/vision", "temp_vision_images") cache_dir = get_hermes_dir("cache/vision", "temp_vision_images")
cache_dir.mkdir(parents=True, exist_ok=True)
temp_image_path = cache_dir / f"computer_use_{_uuid.uuid4().hex}{ext}" temp_image_path = cache_dir / f"computer_use_{_uuid.uuid4().hex}{ext}"
temp_image_path.write_bytes(raw) temp_image_path.write_bytes(raw)

View file

@ -28,6 +28,10 @@ Behaviour (mirrors ``vision_analyze`` for consistency)
``provider``, ``model``, or ``base_url`` non-empty / not ``"auto"``), ``provider``, ``model``, or ``base_url`` non-empty / not ``"auto"``),
the screenshot is routed through the aux vision pipeline. Users who the screenshot is routed through the aux vision pipeline. Users who
pay for a dedicated vision model usually want it used. pay for a dedicated vision model usually want it used.
* Otherwise, if the user explicitly declared the active model vision-capable
via ``model.supports_vision`` / provider model config, return ``False``.
This is the escape hatch for custom/local OpenAI-compatible VLM routes that
are absent from models.dev and provider allowlists.
* Otherwise, if the active main model+provider can carry an image inside * Otherwise, if the active main model+provider can carry an image inside
a tool-result message AND the model reports ``supports_vision=True`` a tool-result message AND the model reports ``supports_vision=True``
in models.dev metadata, return ``False`` (use the multimodal path). in models.dev metadata, return ``False`` (use the multimodal path).
@ -76,10 +80,52 @@ def _explicit_aux_vision_override(cfg: Optional[Dict[str, Any]]) -> bool:
return True return True
def _lookup_supports_vision(provider: str, model: str) -> Optional[bool]: def _lookup_user_declared_supports_vision(
"""Return models.dev ``supports_vision`` for *(provider, model)* or None.""" provider: str,
model: str,
cfg: Optional[Dict[str, Any]],
) -> Optional[bool]:
"""Return config-declared ``supports_vision`` for the active route."""
try:
from agent.image_routing import _supports_vision_override
except Exception as exc: # pragma: no cover - defensive
logger.debug(
"computer_use vision_routing: config override lookup import failed: %s",
exc,
)
return None
try:
return _supports_vision_override(cfg, provider, model)
except Exception as exc: # pragma: no cover - defensive
logger.debug(
"computer_use vision_routing: config override lookup failed: %s",
exc,
)
return None
def _lookup_supports_vision(
provider: str,
model: str,
cfg: Optional[Dict[str, Any]] = None,
) -> Optional[bool]:
"""Return config/models.dev ``supports_vision`` for *(provider, model)*."""
if not provider or not model: if not provider or not model:
return None return None
try:
from agent.image_routing import _lookup_supports_vision as _lookup_image_supports
except Exception:
_lookup_image_supports = None
if _lookup_image_supports is not None:
try:
return _lookup_image_supports(provider, model, cfg)
except Exception as exc: # pragma: no cover - defensive
logger.debug(
"computer_use vision_routing: image-routing caps lookup failed "
"for %s:%s%s",
provider, model, exc,
)
return None
try: try:
from agent.models_dev import get_model_capabilities from agent.models_dev import get_model_capabilities
caps = get_model_capabilities(provider, model) caps = get_model_capabilities(provider, model)
@ -137,11 +183,17 @@ def should_route_capture_to_aux_vision(
if _explicit_aux_vision_override(cfg): if _explicit_aux_vision_override(cfg):
return True return True
user_declared = _lookup_user_declared_supports_vision(provider, model, cfg)
if user_declared is True:
return False
if user_declared is False:
return True
accepts_tool_image = _provider_accepts_multimodal_tool_result(provider, model) accepts_tool_image = _provider_accepts_multimodal_tool_result(provider, model)
if accepts_tool_image is None or accepts_tool_image is False: if accepts_tool_image is None or accepts_tool_image is False:
return True return True
supports_vision = _lookup_supports_vision(provider, model) supports_vision = _lookup_supports_vision(provider, model, cfg)
if supports_vision is True: if supports_vision is True:
return False return False
return True return True