mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Three safety gaps in vision_analyze_tool: 1. Local files accepted without checking if they're actually images — a renamed text file would get base64-encoded and sent to the model. Now validates magic bytes (PNG, JPEG, GIF, BMP, WebP, SVG). 2. No website policy enforcement on image URLs — blocked domains could be fetched via the vision tool. Now checks before download. 3. No redirect check — if an allowed URL redirected to a blocked domain, the download would proceed. Now re-checks the final URL. Fixed one test that needed _validate_image_url mocked to bypass DNS resolution on the fake blocked.test domain (is_safe_url does DNS checks that were added after the original PR). Co-authored-by: GutSlabs <GutSlabs@users.noreply.github.com>
This commit is contained in:
parent
b60cfd6ce6
commit
5e67fc8c40
2 changed files with 113 additions and 1 deletions
|
|
@ -354,6 +354,78 @@ class TestErrorLoggingExcInfo:
|
|||
assert warning_records[0].exc_info is not None
|
||||
|
||||
|
||||
class TestVisionSafetyGuards:
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_non_image_file_rejected_before_llm_call(self, tmp_path):
|
||||
secret = tmp_path / "secret.txt"
|
||||
secret.write_text("TOP-SECRET=1\n", encoding="utf-8")
|
||||
|
||||
with patch("tools.vision_tools.async_call_llm", new_callable=AsyncMock) as mock_llm:
|
||||
result = json.loads(await vision_analyze_tool(str(secret), "extract text"))
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Only real image files are supported" in result["error"]
|
||||
mock_llm.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_remote_url_short_circuits_before_download(self):
|
||||
blocked = {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools.check_website_access", return_value=blocked),
|
||||
patch("tools.vision_tools._validate_image_url", return_value=True),
|
||||
patch("tools.vision_tools._download_image", new_callable=AsyncMock) as mock_download,
|
||||
):
|
||||
result = json.loads(await vision_analyze_tool("https://blocked.test/cat.png", "describe"))
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Blocked by website policy" in result["error"]
|
||||
mock_download.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_blocks_redirected_final_url(self, tmp_path):
|
||||
from tools.vision_tools import _download_image
|
||||
|
||||
def fake_check(url):
|
||||
if url == "https://allowed.test/cat.png":
|
||||
return None
|
||||
if url == "https://blocked.test/final.png":
|
||||
return {
|
||||
"host": "blocked.test",
|
||||
"rule": "blocked.test",
|
||||
"source": "config",
|
||||
"message": "Blocked by website policy",
|
||||
}
|
||||
raise AssertionError(f"unexpected URL checked: {url}")
|
||||
|
||||
class FakeResponse:
|
||||
url = "https://blocked.test/final.png"
|
||||
content = b"\x89PNG\r\n\x1a\n" + b"\x00" * 16
|
||||
|
||||
def raise_for_status(self):
|
||||
return None
|
||||
|
||||
with (
|
||||
patch("tools.vision_tools.check_website_access", side_effect=fake_check),
|
||||
patch("tools.vision_tools.httpx.AsyncClient") as mock_client_cls,
|
||||
pytest.raises(PermissionError, match="Blocked by website policy"),
|
||||
):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client.get = AsyncMock(return_value=FakeResponse())
|
||||
mock_client_cls.return_value = mock_client
|
||||
|
||||
await _download_image("https://allowed.test/cat.png", tmp_path / "cat.png", max_retries=1)
|
||||
|
||||
assert not (tmp_path / "cat.png").exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# check_vision_requirements & get_debug_session_info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from urllib.parse import urlparse
|
|||
import httpx
|
||||
from agent.auxiliary_client import async_call_llm, extract_content_or_reasoning
|
||||
from tools.debug_helpers import DebugSession
|
||||
from tools.website_policy import check_website_access
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -76,6 +77,28 @@ def _validate_image_url(url: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
def _detect_image_mime_type(image_path: Path) -> Optional[str]:
|
||||
"""Return a MIME type when the file looks like a supported image."""
|
||||
with image_path.open("rb") as f:
|
||||
header = f.read(64)
|
||||
|
||||
if header.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if header.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if header.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if header.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if len(header) >= 12 and header[:4] == b"RIFF" and header[8:12] == b"WEBP":
|
||||
return "image/webp"
|
||||
if image_path.suffix.lower() == ".svg":
|
||||
head = image_path.read_text(encoding="utf-8", errors="ignore")[:4096].lower()
|
||||
if "<svg" in head:
|
||||
return "image/svg+xml"
|
||||
return None
|
||||
|
||||
|
||||
async def _download_image(image_url: str, destination: Path, max_retries: int = 3) -> Path:
|
||||
"""
|
||||
Download an image from a URL to a local destination (async) with retry logic.
|
||||
|
|
@ -115,6 +138,10 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
|||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
blocked = check_website_access(image_url)
|
||||
if blocked:
|
||||
raise PermissionError(blocked["message"])
|
||||
|
||||
# Download the image with appropriate headers using async httpx
|
||||
# Enable follow_redirects to handle image CDNs that redirect (e.g., Imgur, Picsum)
|
||||
# SSRF: event_hooks validates each redirect target against private IP ranges
|
||||
|
|
@ -131,6 +158,11 @@ async def _download_image(image_url: str, destination: Path, max_retries: int =
|
|||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
final_url = str(response.url)
|
||||
blocked = check_website_access(final_url)
|
||||
if blocked:
|
||||
raise PermissionError(blocked["message"])
|
||||
|
||||
# Save the image content
|
||||
destination.write_bytes(response.content)
|
||||
|
|
@ -257,6 +289,7 @@ async def vision_analyze_tool(
|
|||
# Track whether we should clean up the file after processing.
|
||||
# Local files (e.g. from the image cache) should NOT be deleted.
|
||||
should_cleanup = True
|
||||
detected_mime_type = None
|
||||
|
||||
try:
|
||||
from tools.interrupt import is_interrupted
|
||||
|
|
@ -275,6 +308,9 @@ async def vision_analyze_tool(
|
|||
should_cleanup = False # Don't delete cached/local files
|
||||
elif _validate_image_url(image_url):
|
||||
# Remote URL -- download to a temporary location
|
||||
blocked = check_website_access(image_url)
|
||||
if blocked:
|
||||
raise PermissionError(blocked["message"])
|
||||
logger.info("Downloading image from URL...")
|
||||
temp_dir = Path("./temp_vision_images")
|
||||
temp_image_path = temp_dir / f"temp_image_{uuid.uuid4()}.jpg"
|
||||
|
|
@ -289,10 +325,14 @@ async def vision_analyze_tool(
|
|||
image_size_bytes = temp_image_path.stat().st_size
|
||||
image_size_kb = image_size_bytes / 1024
|
||||
logger.info("Image ready (%.1f KB)", image_size_kb)
|
||||
|
||||
detected_mime_type = _detect_image_mime_type(temp_image_path)
|
||||
if not detected_mime_type:
|
||||
raise ValueError("Only real image files are supported for vision analysis.")
|
||||
|
||||
# Convert image to base64 data URL
|
||||
logger.info("Converting image to base64...")
|
||||
image_data_url = _image_to_base64_data_url(temp_image_path)
|
||||
image_data_url = _image_to_base64_data_url(temp_image_path, mime_type=detected_mime_type)
|
||||
# Calculate size in KB for better readability
|
||||
data_size_kb = len(image_data_url) / 1024
|
||||
logger.info("Image converted to base64 (%.1f KB)", data_size_kb)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue