mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix(computer_use): honor custom vision routing
This commit is contained in:
parent
ffe665277c
commit
591e6fb8f4
6 changed files with 207 additions and 7 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
|
@ -360,7 +361,9 @@ class TestCaptureResponse:
|
|||
def focus_app(self, app, raise_window=False): ...
|
||||
|
||||
cu_tool.reset_backend_for_tests()
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()):
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
|
||||
patch.object(cu_tool, "_should_route_through_aux_vision",
|
||||
return_value=False):
|
||||
out = cu_tool.handle_computer_use({"action": "capture", "mode": "vision"})
|
||||
|
||||
assert isinstance(out, dict)
|
||||
|
|
@ -398,7 +401,9 @@ class TestCaptureResponse:
|
|||
def focus_app(self, app, raise_window=False): ...
|
||||
|
||||
cu_tool.reset_backend_for_tests()
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()):
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
|
||||
patch.object(cu_tool, "_should_route_through_aux_vision",
|
||||
return_value=False):
|
||||
out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"})
|
||||
assert isinstance(out, dict)
|
||||
text_part = next(p for p in out["content"] if p.get("type") == "text")
|
||||
|
|
@ -436,6 +441,7 @@ class TestCaptureResponse:
|
|||
|
||||
return FakeBackend()
|
||||
|
||||
|
||||
def test_capture_ax_caps_elements_at_default_for_dense_trees(self):
|
||||
"""Regression for #22865: an Electron-style 600-element AX tree must
|
||||
not emit the entire array verbatim into the tool result.
|
||||
|
|
@ -582,7 +588,9 @@ class TestCaptureResponse:
|
|||
def focus_app(self, app, raise_window=False): ...
|
||||
|
||||
cu_tool.reset_backend_for_tests()
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()):
|
||||
with patch.object(cu_tool, "_get_backend", return_value=FakeBackend()), \
|
||||
patch.object(cu_tool, "_should_route_through_aux_vision",
|
||||
return_value=False):
|
||||
out = cu_tool.handle_computer_use({"action": "capture", "mode": "som"})
|
||||
|
||||
assert isinstance(out, dict) and out["_multimodal"] is True
|
||||
|
|
@ -594,6 +602,32 @@ class TestCaptureResponse:
|
|||
assert "truncated to" not in out["text_summary"]
|
||||
|
||||
|
||||
class TestCuaCaptureImageDimensions:
|
||||
def test_png_dimensions_are_sniffed_from_image_bytes(self):
|
||||
from tools.computer_use.cua_backend import _image_dimensions_from_bytes
|
||||
|
||||
raw_png = base64.b64decode(
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42m"
|
||||
"NkYAAAAAYAAjCB0C8AAAAASUVORK5CYII=",
|
||||
validate=False,
|
||||
)
|
||||
assert _image_dimensions_from_bytes(raw_png) == (1, 1)
|
||||
|
||||
def test_jpeg_dimensions_are_sniffed_from_sof_segment(self):
|
||||
from tools.computer_use.cua_backend import _image_dimensions_from_bytes
|
||||
|
||||
raw_jpeg = (
|
||||
b"\xff\xd8" +
|
||||
b"\xff\xe0\x00\x10" + (b"0" * 14)
|
||||
+ b"\xff\xc0\x00\x11\x08"
|
||||
+ b"\x01\x2c" # height: 300
|
||||
+ b"\x01\x90" # width: 400
|
||||
+ b"\x03\x01\x11\x00\x02\x11\x00\x03\x11\x00"
|
||||
+ b"\xff\xd9"
|
||||
)
|
||||
assert _image_dimensions_from_bytes(raw_jpeg) == (400, 300)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anthropic adapter: multimodal tool-result conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -241,6 +241,39 @@ class TestCaptureResponseRoutedToAuxVision:
|
|||
assert observed_path["path"]
|
||||
assert not os.path.exists(observed_path["path"])
|
||||
|
||||
def test_aux_route_creates_missing_cache_dir(self, tmp_path):
|
||||
from tools.computer_use import tool as cu_tool
|
||||
|
||||
cache_dir = tmp_path / "missing" / "cache_vision"
|
||||
cap = _make_capture(mode="som")
|
||||
observed_path = {}
|
||||
|
||||
def _fake_get(*_args, **_kw):
|
||||
return cache_dir
|
||||
|
||||
def _fake_run_async(_coro):
|
||||
return _stub_aux_analysis("description goes here")
|
||||
|
||||
def _fake_vat(image_path, _prompt):
|
||||
observed_path["path"] = image_path
|
||||
assert os.path.exists(image_path)
|
||||
return "<coro>"
|
||||
|
||||
fake_vat = MagicMock(side_effect=_fake_vat)
|
||||
|
||||
with patch.object(cu_tool, "_should_route_through_aux_vision",
|
||||
return_value=True), \
|
||||
patch("hermes_constants.get_hermes_dir", _fake_get), \
|
||||
patch("model_tools._run_async", side_effect=_fake_run_async), \
|
||||
patch("tools.vision_tools.vision_analyze_tool",
|
||||
new_callable=lambda: fake_vat):
|
||||
resp = cu_tool._capture_response(cap)
|
||||
|
||||
assert isinstance(resp, str)
|
||||
assert cache_dir.is_dir()
|
||||
assert observed_path["path"]
|
||||
assert not os.path.exists(observed_path["path"])
|
||||
|
||||
def test_temp_file_cleaned_up_even_when_aux_call_raises(
|
||||
self, tmp_cache_dir,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -160,6 +160,42 @@ class TestRouteDecision:
|
|||
"some-aggregator", "some-vision-model", {}
|
||||
) is True
|
||||
|
||||
def test_user_declared_vision_support_keeps_custom_provider_native(self):
|
||||
"""Local/custom VLMs use config as their tool-result image escape hatch."""
|
||||
from tools.computer_use import vision_routing
|
||||
|
||||
cfg = {
|
||||
"model": {
|
||||
"default": "Qwen3.6-35B-A3B-local-vlm",
|
||||
"provider": "omlx",
|
||||
"supports_vision": True,
|
||||
}
|
||||
}
|
||||
with patch.object(vision_routing,
|
||||
"_provider_accepts_multimodal_tool_result",
|
||||
return_value=False):
|
||||
assert vision_routing.should_route_capture_to_aux_vision(
|
||||
"custom", "Qwen3.6-35B-A3B-local-vlm", cfg
|
||||
) is False
|
||||
|
||||
def test_user_declared_no_vision_routes_custom_provider_to_aux(self):
|
||||
"""An explicit false override should not fall through to native routing."""
|
||||
from tools.computer_use import vision_routing
|
||||
|
||||
cfg = {
|
||||
"model": {
|
||||
"default": "local-text-model",
|
||||
"provider": "omlx",
|
||||
"supports_vision": False,
|
||||
}
|
||||
}
|
||||
with patch.object(vision_routing,
|
||||
"_provider_accepts_multimodal_tool_result",
|
||||
return_value=True):
|
||||
assert vision_routing.should_route_capture_to_aux_vision(
|
||||
"custom", "local-text-model", cfg
|
||||
) is True
|
||||
|
||||
def test_unknown_provider_capabilities_fail_closed(self):
|
||||
"""When tool-result lookup returns None, route to aux (safe default)."""
|
||||
from tools.computer_use import vision_routing
|
||||
|
|
|
|||
|
|
@ -126,6 +126,45 @@ def _parse_elements_from_tree(markdown: str) -> List[UIElement]:
|
|||
return elements
|
||||
|
||||
|
||||
def _image_dimensions_from_bytes(raw: bytes) -> Tuple[int, int]:
|
||||
"""Best-effort PNG/JPEG dimension sniffing without extra dependencies."""
|
||||
if raw.startswith(b"\x89PNG\r\n\x1a\n") and len(raw) >= 24:
|
||||
width = int.from_bytes(raw[16:20], "big")
|
||||
height = int.from_bytes(raw[20:24], "big")
|
||||
if width > 0 and height > 0:
|
||||
return width, height
|
||||
|
||||
if raw.startswith(b"\xff\xd8"):
|
||||
i = 2
|
||||
n = len(raw)
|
||||
while i + 9 < n:
|
||||
if raw[i] != 0xFF:
|
||||
i += 1
|
||||
continue
|
||||
marker = raw[i + 1]
|
||||
i += 2
|
||||
if marker in {0xD8, 0xD9} or 0xD0 <= marker <= 0xD7:
|
||||
continue
|
||||
if i + 2 > n:
|
||||
break
|
||||
segment_len = int.from_bytes(raw[i:i + 2], "big")
|
||||
if segment_len < 2 or i + segment_len > n:
|
||||
break
|
||||
if marker in {
|
||||
0xC0, 0xC1, 0xC2, 0xC3, 0xC5, 0xC6, 0xC7,
|
||||
0xC9, 0xCA, 0xCB, 0xCD, 0xCE, 0xCF,
|
||||
}:
|
||||
if segment_len >= 7:
|
||||
height = int.from_bytes(raw[i + 3:i + 5], "big")
|
||||
width = int.from_bytes(raw[i + 5:i + 7], "big")
|
||||
if width > 0 and height > 0:
|
||||
return width, height
|
||||
break
|
||||
i += segment_len
|
||||
|
||||
return 0, 0
|
||||
|
||||
|
||||
def _split_tree_text(full_text: str) -> Tuple[str, str]:
|
||||
"""Split get_window_state text into (summary_line, tree_markdown)."""
|
||||
lines = full_text.split("\n", 1)
|
||||
|
|
@ -491,7 +530,12 @@ class CuaDriverBackend(ComputerUseBackend):
|
|||
png_bytes_len = 0
|
||||
if png_b64:
|
||||
try:
|
||||
png_bytes_len = len(base64.b64decode(png_b64, validate=False))
|
||||
raw = base64.b64decode(png_b64, validate=False)
|
||||
png_bytes_len = len(raw)
|
||||
detected_width, detected_height = _image_dimensions_from_bytes(raw)
|
||||
if detected_width and detected_height:
|
||||
width = detected_width
|
||||
height = detected_height
|
||||
except Exception:
|
||||
png_bytes_len = len(png_b64) * 3 // 4
|
||||
|
||||
|
|
|
|||
|
|
@ -615,6 +615,7 @@ def _route_capture_through_aux_vision(
|
|||
# MIME sniffing returns the right content-type.
|
||||
ext = ".jpg" if cap.png_b64[:8].startswith("/9j/") else ".png"
|
||||
cache_dir = get_hermes_dir("cache/vision", "temp_vision_images")
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
temp_image_path = cache_dir / f"computer_use_{_uuid.uuid4().hex}{ext}"
|
||||
temp_image_path.write_bytes(raw)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ Behaviour (mirrors ``vision_analyze`` for consistency)
|
|||
``provider``, ``model``, or ``base_url`` non-empty / not ``"auto"``),
|
||||
the screenshot is routed through the aux vision pipeline. Users who
|
||||
pay for a dedicated vision model usually want it used.
|
||||
* Otherwise, if the user explicitly declared the active model vision-capable
|
||||
via ``model.supports_vision`` / provider model config, return ``False``.
|
||||
This is the escape hatch for custom/local OpenAI-compatible VLM routes that
|
||||
are absent from models.dev and provider allowlists.
|
||||
* Otherwise, if the active main model+provider can carry an image inside
|
||||
a tool-result message AND the model reports ``supports_vision=True``
|
||||
in models.dev metadata, return ``False`` (use the multimodal path).
|
||||
|
|
@ -76,10 +80,52 @@ def _explicit_aux_vision_override(cfg: Optional[Dict[str, Any]]) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _lookup_supports_vision(provider: str, model: str) -> Optional[bool]:
|
||||
"""Return models.dev ``supports_vision`` for *(provider, model)* or None."""
|
||||
def _lookup_user_declared_supports_vision(
|
||||
provider: str,
|
||||
model: str,
|
||||
cfg: Optional[Dict[str, Any]],
|
||||
) -> Optional[bool]:
|
||||
"""Return config-declared ``supports_vision`` for the active route."""
|
||||
try:
|
||||
from agent.image_routing import _supports_vision_override
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"computer_use vision_routing: config override lookup import failed: %s",
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
return _supports_vision_override(cfg, provider, model)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"computer_use vision_routing: config override lookup failed: %s",
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _lookup_supports_vision(
|
||||
provider: str,
|
||||
model: str,
|
||||
cfg: Optional[Dict[str, Any]] = None,
|
||||
) -> Optional[bool]:
|
||||
"""Return config/models.dev ``supports_vision`` for *(provider, model)*."""
|
||||
if not provider or not model:
|
||||
return None
|
||||
try:
|
||||
from agent.image_routing import _lookup_supports_vision as _lookup_image_supports
|
||||
except Exception:
|
||||
_lookup_image_supports = None
|
||||
if _lookup_image_supports is not None:
|
||||
try:
|
||||
return _lookup_image_supports(provider, model, cfg)
|
||||
except Exception as exc: # pragma: no cover - defensive
|
||||
logger.debug(
|
||||
"computer_use vision_routing: image-routing caps lookup failed "
|
||||
"for %s:%s — %s",
|
||||
provider, model, exc,
|
||||
)
|
||||
return None
|
||||
try:
|
||||
from agent.models_dev import get_model_capabilities
|
||||
caps = get_model_capabilities(provider, model)
|
||||
|
|
@ -137,11 +183,17 @@ def should_route_capture_to_aux_vision(
|
|||
if _explicit_aux_vision_override(cfg):
|
||||
return True
|
||||
|
||||
user_declared = _lookup_user_declared_supports_vision(provider, model, cfg)
|
||||
if user_declared is True:
|
||||
return False
|
||||
if user_declared is False:
|
||||
return True
|
||||
|
||||
accepts_tool_image = _provider_accepts_multimodal_tool_result(provider, model)
|
||||
if accepts_tool_image is None or accepts_tool_image is False:
|
||||
return True
|
||||
|
||||
supports_vision = _lookup_supports_vision(provider, model)
|
||||
supports_vision = _lookup_supports_vision(provider, model, cfg)
|
||||
if supports_vision is True:
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue