mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
fix(gateway): cap inbound media download size to prevent memory exhaustion
Inbound image/audio/video payloads were buffered fully into process memory before being written to the cache, with no size limit. A large upload (Discord Nitro allows 500 MB) or a remote media URL in an inbound message pointing at a huge file could spike RAM and OOM-kill the gateway. Enforce a configurable cap in the shared cache helpers (gateway/platforms/ base.py) so the protection holds across every platform adapter, not one: - cache_image/audio/video_from_bytes reject oversized payloads before writing (video was the gap in the original report — now covered). - cache_image/audio_from_url stream the body, rejecting on an oversized Content-Length header and re-checking the running total per chunk so an absent/lying header can't smuggle an unbounded body past the cap. - Discord's _read_attachment_bytes checks att.size up front, so an oversized attachment is rejected before any bytes are pulled into memory. Configurable via gateway.max_inbound_media_bytes in config.yaml (default 128 MiB; 0 disables). No new env var — non-secret config lives in config.yaml. Salvaged and extended from @sgaofen's PR #13341 (the original report and the shared-helper approach). Reapplied onto current main (Discord adapter has since moved to plugins/platforms/discord/), the configurable knob moved from an env var to config.yaml, and the video cache helper added. Co-authored-by: Hermes Agent <noreply@nousresearch.com>
This commit is contained in:
parent
16899ae144
commit
93ea9b04af
5 changed files with 308 additions and 123 deletions
|
|
@ -567,6 +567,96 @@ async def _ssrf_redirect_guard(response):
|
|||
# Default location: {HERMES_HOME}/cache/images/ (legacy: image_cache/)
|
||||
IMAGE_CACHE_DIR = get_hermes_dir("cache/images", "image_cache")
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Inbound media size cap (#13145)
|
||||
#
|
||||
# Inbound image / audio / video payloads are buffered fully into process
|
||||
# memory before being written to the cache directory. With no cap, a single
|
||||
# large upload (Discord Nitro allows 500 MB) — or a remote URL in an inbound
|
||||
# message payload pointing at an arbitrarily large file — can spike RAM and
|
||||
# OOM-kill the gateway. The ``cache_*_from_bytes`` helpers (the shared funnel
|
||||
# every platform reaches eventually) and the ``cache_*_from_url`` downloaders
|
||||
# enforce this cap, so the protection holds regardless of which platform
|
||||
# adapter or code path produced the bytes.
|
||||
#
|
||||
# Configurable via ``gateway.max_inbound_media_bytes`` in config.yaml.
|
||||
# ``0`` disables the cap. Default 128 MiB — generous enough for ordinary
|
||||
# photos/voice notes/short clips while still bounding a hostile upload.
|
||||
# ---------------------------------------------------------------------------
|
||||
DEFAULT_INBOUND_MEDIA_MAX_BYTES = 128 * 1024 * 1024
|
||||
|
||||
|
||||
def get_inbound_media_max_bytes() -> int:
|
||||
"""Return the max inbound image/audio/video bytes allowed in memory.
|
||||
|
||||
Reads ``gateway.max_inbound_media_bytes`` from config.yaml. ``0`` (or a
|
||||
negative / unparseable value) disables the cap. Non-fatal if config is
|
||||
unreadable — falls back to the default.
|
||||
"""
|
||||
try:
|
||||
from hermes_cli.config import load_config as _load_config
|
||||
cfg = _load_config()
|
||||
except Exception:
|
||||
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
|
||||
gw = cfg.get("gateway", {}) if isinstance(cfg, dict) else {}
|
||||
if not isinstance(gw, dict) or "max_inbound_media_bytes" not in gw:
|
||||
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
|
||||
try:
|
||||
return int(gw["max_inbound_media_bytes"])
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
|
||||
|
||||
|
||||
def validate_inbound_media_size(
|
||||
size: int,
|
||||
*,
|
||||
media_type: str = "media",
|
||||
max_bytes: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Raise ``ValueError`` if an inbound media payload exceeds the cap.
|
||||
|
||||
A ``max_bytes`` of ``0`` (or the configured cap resolving to ``0``)
|
||||
disables the check entirely. Passing ``max_bytes`` lets callers resolve
|
||||
the limit once and reuse it across an incremental read.
|
||||
"""
|
||||
limit = get_inbound_media_max_bytes() if max_bytes is None else max_bytes
|
||||
if limit and size > limit:
|
||||
raise ValueError(
|
||||
f"Inbound {media_type} payload is too large "
|
||||
f"({size} bytes > {limit} bytes)"
|
||||
)
|
||||
|
||||
|
||||
async def _read_httpx_body_with_limit(response, *, media_type: str) -> bytes:
|
||||
"""Read an httpx streaming response body without exceeding the media cap.
|
||||
|
||||
Rejects early on an oversized ``Content-Length`` header, then re-checks
|
||||
the running total as chunks arrive so a lying/absent header can't smuggle
|
||||
an unbounded body past the cap.
|
||||
"""
|
||||
max_bytes = get_inbound_media_max_bytes()
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length:
|
||||
try:
|
||||
declared_size = int(content_length)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
"Ignoring invalid Content-Length for inbound %s: %r",
|
||||
media_type, content_length,
|
||||
)
|
||||
else:
|
||||
validate_inbound_media_size(
|
||||
declared_size, media_type=media_type, max_bytes=max_bytes,
|
||||
)
|
||||
|
||||
chunks: list[bytes] = []
|
||||
total = 0
|
||||
async for chunk in response.aiter_bytes():
|
||||
total += len(chunk)
|
||||
validate_inbound_media_size(total, media_type=media_type, max_bytes=max_bytes)
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def get_image_cache_dir() -> Path:
|
||||
"""Return the image cache directory, creating it if it doesn't exist."""
|
||||
|
|
@ -606,6 +696,7 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
|
|||
ValueError: If *data* does not look like a valid image (e.g. an HTML
|
||||
error page returned by the upstream server).
|
||||
"""
|
||||
validate_inbound_media_size(len(data), media_type="image")
|
||||
if not _looks_like_image(data):
|
||||
snippet = data[:80].decode("utf-8", errors="replace")
|
||||
raise ValueError(
|
||||
|
|
@ -651,15 +742,19 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
|
|||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
async with client.stream(
|
||||
"GET",
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "image/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_image_from_bytes(response.content, ext)
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
content = await _read_httpx_body_with_limit(
|
||||
response, media_type="image",
|
||||
)
|
||||
return cache_image_from_bytes(content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
|
|
@ -726,6 +821,7 @@ def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str:
|
|||
Returns:
|
||||
Absolute path to the cached audio file as a string.
|
||||
"""
|
||||
validate_inbound_media_size(len(data), media_type="audio")
|
||||
cache_dir = get_audio_cache_dir()
|
||||
filename = f"audio_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
|
|
@ -765,15 +861,19 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
|
|||
) as client:
|
||||
for attempt in range(retries + 1):
|
||||
try:
|
||||
response = await client.get(
|
||||
async with client.stream(
|
||||
"GET",
|
||||
url,
|
||||
headers={
|
||||
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
|
||||
"Accept": "audio/*,*/*;q=0.8",
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return cache_audio_from_bytes(response.content, ext)
|
||||
) as response:
|
||||
response.raise_for_status()
|
||||
content = await _read_httpx_body_with_limit(
|
||||
response, media_type="audio",
|
||||
)
|
||||
return cache_audio_from_bytes(content, ext)
|
||||
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
|
||||
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
|
||||
raise
|
||||
|
|
@ -818,6 +918,7 @@ def get_video_cache_dir() -> Path:
|
|||
|
||||
def cache_video_from_bytes(data: bytes, ext: str = ".mp4") -> str:
|
||||
"""Save raw video bytes to the cache and return the absolute file path."""
|
||||
validate_inbound_media_size(len(data), media_type="video")
|
||||
cache_dir = get_video_cache_dir()
|
||||
filename = f"video_{uuid.uuid4().hex[:12]}{ext}"
|
||||
filepath = cache_dir / filename
|
||||
|
|
|
|||
|
|
@ -2474,6 +2474,16 @@ DEFAULT_CONFIG = {
|
|||
"enabled": False,
|
||||
},
|
||||
|
||||
# Maximum bytes for an inbound image / audio / video payload the
|
||||
# gateway will buffer into memory and cache to disk. Inbound media is
|
||||
# read fully into RAM before being written, so an unbounded upload
|
||||
# (Discord Nitro allows 500 MB) or a remote media URL pointing at a
|
||||
# huge file can spike memory and OOM-kill the gateway on constrained
|
||||
# deployments. Enforced in the shared cache helpers
|
||||
# (gateway/platforms/base.py), so the cap holds across every platform
|
||||
# adapter. ``0`` disables the cap. Default 128 MiB.
|
||||
"max_inbound_media_bytes": 134217728,
|
||||
|
||||
# When false (default), any file path the agent emits is delivered
|
||||
# as a native attachment as long as it isn't under the credential /
|
||||
# system-path denylist (/etc, /proc, ~/.ssh, ~/.aws, ~/.hermes/.env,
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ from gateway.platforms.base import (
|
|||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
validate_inbound_media_size,
|
||||
)
|
||||
from tools.url_safety import is_safe_url
|
||||
|
||||
|
|
@ -5052,19 +5053,32 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# non-CDN URL into the ``att.url`` field. (issue #11345)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
|
||||
async def _read_attachment_bytes(
|
||||
self,
|
||||
att,
|
||||
*,
|
||||
media_type: str = "media",
|
||||
) -> Optional[bytes]:
|
||||
"""Read an attachment via discord.py's authenticated bot session.
|
||||
|
||||
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
|
||||
expose a callable ``read()`` or the read itself fails. Callers
|
||||
should treat ``None`` as a signal to fall back to the URL-based
|
||||
downloaders.
|
||||
|
||||
Oversized attachments (per ``gateway.max_inbound_media_bytes``) raise
|
||||
``ValueError`` BEFORE the bytes are pulled into memory when Discord
|
||||
reports the size up front, so a hostile upload can't OOM the gateway.
|
||||
"""
|
||||
attachment_size = getattr(att, "size", None)
|
||||
if attachment_size:
|
||||
validate_inbound_media_size(int(attachment_size), media_type=media_type)
|
||||
|
||||
reader = getattr(att, "read", None)
|
||||
if reader is None or not callable(reader):
|
||||
return None
|
||||
try:
|
||||
return await reader()
|
||||
raw_bytes = await reader()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Authenticated attachment read failed for %s: %s",
|
||||
|
|
@ -5072,6 +5086,8 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
e,
|
||||
)
|
||||
return None
|
||||
validate_inbound_media_size(len(raw_bytes), media_type=media_type)
|
||||
return raw_bytes
|
||||
|
||||
async def _cache_discord_image(self, att, ext: str) -> str:
|
||||
"""Cache a Discord image attachment to local disk.
|
||||
|
|
@ -5081,7 +5097,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
raw_bytes = await self._read_attachment_bytes(att, media_type="image")
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_image_from_bytes(raw_bytes, ext=ext)
|
||||
|
|
@ -5100,7 +5116,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
raw_bytes = await self._read_attachment_bytes(att, media_type="audio")
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_audio_from_bytes(raw_bytes, ext=ext)
|
||||
|
|
@ -5122,7 +5138,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
for passing the returned bytes to ``cache_document_from_bytes``
|
||||
(and, where applicable, for injecting text content).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
raw_bytes = await self._read_attachment_bytes(att, media_type="document")
|
||||
if raw_bytes is not None:
|
||||
return raw_bytes
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,56 @@ def _make_timeout_error() -> httpx.TimeoutException:
|
|||
return httpx.TimeoutException("timed out")
|
||||
|
||||
|
||||
def _make_stream_response(content: bytes = b"\xff\xd8\xff fake media"):
|
||||
"""Build a mock httpx response suitable for ``client.stream()`` usage.
|
||||
|
||||
Exposes ``raise_for_status``, an empty ``headers`` mapping (no
|
||||
Content-Length), and an ``aiter_bytes`` async iterator yielding the body
|
||||
in one chunk — matching how ``_read_httpx_body_with_limit`` consumes it.
|
||||
"""
|
||||
resp = MagicMock()
|
||||
resp.raise_for_status = MagicMock()
|
||||
resp.headers = {}
|
||||
|
||||
async def _aiter():
|
||||
yield content
|
||||
|
||||
resp.aiter_bytes = lambda: _aiter()
|
||||
return resp
|
||||
|
||||
|
||||
def _make_stream_client(*, responses=None, side_effect=None):
|
||||
"""Build a mock httpx client whose ``.stream()`` is an async CM.
|
||||
|
||||
``responses`` is a list of response objects (or exceptions) returned on
|
||||
successive ``.stream()`` calls; ``side_effect`` is a single exception
|
||||
raised on every call. The returned client also supports being used as an
|
||||
``async with`` context manager (``httpx.AsyncClient(...)``).
|
||||
"""
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
call_state = {"i": 0}
|
||||
|
||||
def _stream(method, url, **kwargs):
|
||||
idx = call_state["i"]
|
||||
call_state["i"] += 1
|
||||
if side_effect is not None:
|
||||
raise side_effect
|
||||
item = responses[idx]
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__ = AsyncMock(return_value=item)
|
||||
cm.__aexit__ = AsyncMock(return_value=False)
|
||||
return cm
|
||||
|
||||
mock_client.stream = MagicMock(side_effect=_stream)
|
||||
mock_client._call_state = call_state
|
||||
return mock_client
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cache_image_from_bytes (base.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -85,14 +135,9 @@ class TestCacheImageFromUrl:
|
|||
"""A clean 200 response caches the image and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_stream_response(b"\xff\xd8\xff fake jpeg")]
|
||||
)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
|
@ -103,23 +148,15 @@ class TestCacheImageFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
mock_client.get.assert_called_once()
|
||||
mock_client.stream.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\xff\xd8\xff image data"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_timeout_error(), _make_stream_response(b"\xff\xd8\xff image data")]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
|
|
@ -132,23 +169,16 @@ class TestCacheImageFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.stream.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"\xff\xd8\xff image data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_http_status_error(429), _make_stream_response(b"\xff\xd8\xff image data")]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -160,16 +190,13 @@ class TestCacheImageFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".jpg")
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.stream.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(side_effect=_make_timeout_error())
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -183,17 +210,14 @@ class TestCacheImageFromUrl:
|
|||
asyncio.run(run())
|
||||
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
assert mock_client.stream.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(side_effect=_make_http_status_error(404))
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -207,7 +231,7 @@ class TestCacheImageFromUrl:
|
|||
asyncio.run(run())
|
||||
|
||||
# Only 1 attempt, no sleep
|
||||
assert mock_client.get.call_count == 1
|
||||
assert mock_client.stream.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
|
|
@ -223,14 +247,9 @@ class TestCacheAudioFromUrl:
|
|||
"""A clean 200 response caches the audio and returns a path."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"\x00\x01 fake audio"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=fake_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_stream_response(b"\x00\x01 fake audio")]
|
||||
)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client):
|
||||
|
|
@ -241,23 +260,15 @@ class TestCacheAudioFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".ogg")
|
||||
mock_client.get.assert_called_once()
|
||||
mock_client.stream.assert_called_once()
|
||||
|
||||
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A timeout on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = b"audio data"
|
||||
fake_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_timeout_error(), fake_response]
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_timeout_error(), _make_stream_response(b"audio data")]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
|
||||
async def run():
|
||||
|
|
@ -270,23 +281,16 @@ class TestCacheAudioFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".ogg")
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.stream.call_count == 2
|
||||
mock_sleep.assert_called_once()
|
||||
|
||||
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A 429 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"audio data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(429), ok_response]
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_http_status_error(429), _make_stream_response(b"audio data")]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -298,22 +302,15 @@ class TestCacheAudioFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".ogg")
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.stream.call_count == 2
|
||||
|
||||
def test_retries_on_500_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A 500 response on the first attempt is retried; second attempt succeeds."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"audio data"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(
|
||||
side_effect=[_make_http_status_error(500), ok_response]
|
||||
mock_client = _make_stream_client(
|
||||
responses=[_make_http_status_error(500), _make_stream_response(b"audio data")]
|
||||
)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -325,16 +322,13 @@ class TestCacheAudioFromUrl:
|
|||
|
||||
path = asyncio.run(run())
|
||||
assert path.endswith(".ogg")
|
||||
assert mock_client.get.call_count == 2
|
||||
assert mock_client.stream.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""Timeout on every attempt raises after all retries are consumed."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(side_effect=_make_timeout_error())
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -348,17 +342,14 @@ class TestCacheAudioFromUrl:
|
|||
asyncio.run(run())
|
||||
|
||||
# 3 total calls: initial + 2 retries
|
||||
assert mock_client.get.call_count == 3
|
||||
assert mock_client.stream.call_count == 3
|
||||
|
||||
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
|
||||
"""A 404 (non-retryable) is raised immediately without any retry."""
|
||||
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
|
||||
|
||||
mock_sleep = AsyncMock()
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_client = _make_stream_client(side_effect=_make_http_status_error(404))
|
||||
|
||||
async def run():
|
||||
with patch("httpx.AsyncClient", return_value=mock_client), \
|
||||
|
|
@ -372,7 +363,7 @@ class TestCacheAudioFromUrl:
|
|||
asyncio.run(run())
|
||||
|
||||
# Only 1 attempt, no sleep
|
||||
assert mock_client.get.call_count == 1
|
||||
assert mock_client.stream.call_count == 1
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
|
||||
|
|
@ -415,12 +406,18 @@ class TestSSRFRedirectGuard:
|
|||
)
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
# Simulate httpx calling the response event hooks
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
def fake_stream(method, _url, **kwargs):
|
||||
async def _aenter(*a):
|
||||
# Simulate httpx invoking the response event hooks on the stream.
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
return redirect_resp
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__ = AsyncMock(side_effect=_aenter)
|
||||
cm.__aexit__ = AsyncMock(return_value=False)
|
||||
return cm
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
mock_client.stream = MagicMock(side_effect=fake_stream)
|
||||
|
||||
def fake_safe(url):
|
||||
return url == "https://public.example.com/image.png"
|
||||
|
|
@ -445,11 +442,17 @@ class TestSSRFRedirectGuard:
|
|||
)
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
def fake_stream(method, _url, **kwargs):
|
||||
async def _aenter(*a):
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp)
|
||||
return redirect_resp
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__ = AsyncMock(side_effect=_aenter)
|
||||
cm.__aexit__ = AsyncMock(return_value=False)
|
||||
return cm
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
mock_client.stream = MagicMock(side_effect=fake_stream)
|
||||
|
||||
def fake_safe(url):
|
||||
return url == "https://public.example.com/voice.ogg"
|
||||
|
|
@ -473,24 +476,24 @@ class TestSSRFRedirectGuard:
|
|||
"https://cdn.example.com/real-image.png"
|
||||
)
|
||||
|
||||
ok_response = MagicMock()
|
||||
ok_response.content = b"\xff\xd8\xff fake jpeg"
|
||||
ok_response.raise_for_status = MagicMock()
|
||||
ok_response = _make_stream_response(b"\xff\xd8\xff fake jpeg")
|
||||
ok_response.is_redirect = False
|
||||
|
||||
mock_client, captured, factory = self._make_client_capturing_hooks()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def fake_get(_url, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# First call triggers redirect hook, second returns data
|
||||
async def _aenter(*a):
|
||||
# Public redirect passes the guard; body then streams normally.
|
||||
for hook in captured["event_hooks"]["response"]:
|
||||
await hook(redirect_resp if call_count == 1 else ok_response)
|
||||
await hook(redirect_resp)
|
||||
return ok_response
|
||||
|
||||
mock_client.get = AsyncMock(side_effect=fake_get)
|
||||
def fake_stream(method, _url, **kwargs):
|
||||
cm = AsyncMock()
|
||||
cm.__aenter__ = AsyncMock(side_effect=_aenter)
|
||||
cm.__aexit__ = AsyncMock(return_value=False)
|
||||
return cm
|
||||
|
||||
mock_client.stream = MagicMock(side_effect=fake_stream)
|
||||
|
||||
async def run():
|
||||
with patch("tools.url_safety.is_safe_url", return_value=True), \
|
||||
|
|
|
|||
|
|
@ -10,13 +10,68 @@ from gateway.platforms.base import (
|
|||
BasePlatformAdapter,
|
||||
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
|
||||
MessageEvent,
|
||||
cache_audio_from_bytes,
|
||||
cache_image_from_bytes,
|
||||
cache_video_from_bytes,
|
||||
safe_url_for_log,
|
||||
utf16_len,
|
||||
validate_inbound_media_size,
|
||||
_log_safe_path,
|
||||
_prefix_within_utf16_limit,
|
||||
)
|
||||
|
||||
|
||||
class TestInboundMediaSizeCap:
|
||||
"""gateway.max_inbound_media_bytes caps inbound media buffered into RAM (#13145)."""
|
||||
|
||||
_PNG = b"\x89PNG\r\n\x1a\n" + b"x" * 64
|
||||
|
||||
def test_default_cap_is_128_mib(self, monkeypatch):
|
||||
# No config override -> default. Patch loader to return empty config.
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: base.DEFAULT_INBOUND_MEDIA_MAX_BYTES)
|
||||
assert base.DEFAULT_INBOUND_MEDIA_MAX_BYTES == 128 * 1024 * 1024
|
||||
|
||||
def test_image_bytes_rejected_when_oversized(self, monkeypatch):
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 16)
|
||||
with pytest.raises(ValueError, match="Inbound image payload is too large"):
|
||||
cache_image_from_bytes(self._PNG, ext=".png")
|
||||
|
||||
def test_audio_bytes_rejected_when_oversized(self, monkeypatch):
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4)
|
||||
with pytest.raises(ValueError, match="Inbound audio payload is too large"):
|
||||
cache_audio_from_bytes(b"x" * 8, ext=".ogg")
|
||||
|
||||
def test_video_bytes_rejected_when_oversized(self, monkeypatch):
|
||||
# Video was the gap in the original report — verify it's covered.
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4)
|
||||
with pytest.raises(ValueError, match="Inbound video payload is too large"):
|
||||
cache_video_from_bytes(b"x" * 8, ext=".mp4")
|
||||
|
||||
def test_legit_image_accepted_under_cap(self, monkeypatch):
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 128 * 1024 * 1024)
|
||||
path = cache_image_from_bytes(self._PNG, ext=".png")
|
||||
assert os.path.exists(path)
|
||||
assert os.path.getsize(path) == len(self._PNG)
|
||||
|
||||
def test_cap_of_zero_disables_check(self, monkeypatch):
|
||||
import gateway.platforms.base as base
|
||||
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 0)
|
||||
# A would-be-oversized video passes through when the cap is disabled.
|
||||
path = cache_video_from_bytes(b"x" * 5000, ext=".mp4")
|
||||
assert os.path.exists(path)
|
||||
|
||||
def test_validate_helper_respects_explicit_max_bytes(self):
|
||||
# max_bytes arg overrides the configured cap.
|
||||
validate_inbound_media_size(100, media_type="image", max_bytes=200) # ok
|
||||
with pytest.raises(ValueError, match="too large"):
|
||||
validate_inbound_media_size(300, media_type="image", max_bytes=200)
|
||||
|
||||
|
||||
class TestSecretCaptureGuidance:
|
||||
def test_gateway_secret_capture_message_points_to_local_setup(self):
|
||||
message = GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue