fix(gateway): cap inbound media download size to prevent memory exhaustion

Inbound image/audio/video payloads were buffered fully into process memory
before being written to the cache, with no size limit. A large upload
(Discord Nitro allows 500 MB) or a remote media URL in an inbound message
pointing at a huge file could spike RAM and OOM-kill the gateway.

Enforce a configurable cap in the shared cache helpers (gateway/platforms/
base.py) so the protection holds across every platform adapter, not one:

- cache_image/audio/video_from_bytes reject oversized payloads before writing
  (video was the gap in the original report — now covered).
- cache_image/audio_from_url stream the body, rejecting on an oversized
  Content-Length header and re-checking the running total per chunk so an
  absent/lying header can't smuggle an unbounded body past the cap.
- Discord's _read_attachment_bytes checks att.size up front, so an oversized
  attachment is rejected before any bytes are pulled into memory.

Configurable via gateway.max_inbound_media_bytes in config.yaml (default
128 MiB; 0 disables). No new env var — non-secret config lives in config.yaml.

Salvaged and extended from @sgaofen's PR #13341 (the original report and the
shared-helper approach). Reapplied onto current main (Discord adapter has
since moved to plugins/platforms/discord/), the configurable knob moved from
an env var to config.yaml, and the video cache helper added.

Co-authored-by: Hermes Agent <noreply@nousresearch.com>
This commit is contained in:
sgaofen 2026-06-21 11:36:39 -07:00 committed by Teknium
parent 16899ae144
commit 93ea9b04af
5 changed files with 308 additions and 123 deletions

View file

@ -567,6 +567,96 @@ async def _ssrf_redirect_guard(response):
# Default location: {HERMES_HOME}/cache/images/ (legacy: image_cache/)
IMAGE_CACHE_DIR = get_hermes_dir("cache/images", "image_cache")
# ---------------------------------------------------------------------------
# Inbound media size cap (#13145)
#
# Inbound image / audio / video payloads are buffered fully into process
# memory before being written to the cache directory. With no cap, a single
# large upload (Discord Nitro allows 500 MB) — or a remote URL in an inbound
# message payload pointing at an arbitrarily large file — can spike RAM and
# OOM-kill the gateway. The ``cache_*_from_bytes`` helpers (the shared funnel
# every platform reaches eventually) and the ``cache_*_from_url`` downloaders
# enforce this cap, so the protection holds regardless of which platform
# adapter or code path produced the bytes.
#
# Configurable via ``gateway.max_inbound_media_bytes`` in config.yaml.
# ``0`` disables the cap. Default 128 MiB — generous enough for ordinary
# photos/voice notes/short clips while still bounding a hostile upload.
# ---------------------------------------------------------------------------
DEFAULT_INBOUND_MEDIA_MAX_BYTES = 128 * 1024 * 1024
def get_inbound_media_max_bytes() -> int:
"""Return the max inbound image/audio/video bytes allowed in memory.
Reads ``gateway.max_inbound_media_bytes`` from config.yaml. ``0`` (or a
negative / unparseable value) disables the cap. Non-fatal if config is
unreadable falls back to the default.
"""
try:
from hermes_cli.config import load_config as _load_config
cfg = _load_config()
except Exception:
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
gw = cfg.get("gateway", {}) if isinstance(cfg, dict) else {}
if not isinstance(gw, dict) or "max_inbound_media_bytes" not in gw:
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
try:
return int(gw["max_inbound_media_bytes"])
except (TypeError, ValueError):
return DEFAULT_INBOUND_MEDIA_MAX_BYTES
def validate_inbound_media_size(
size: int,
*,
media_type: str = "media",
max_bytes: Optional[int] = None,
) -> None:
"""Raise ``ValueError`` if an inbound media payload exceeds the cap.
A ``max_bytes`` of ``0`` (or the configured cap resolving to ``0``)
disables the check entirely. Passing ``max_bytes`` lets callers resolve
the limit once and reuse it across an incremental read.
"""
limit = get_inbound_media_max_bytes() if max_bytes is None else max_bytes
if limit and size > limit:
raise ValueError(
f"Inbound {media_type} payload is too large "
f"({size} bytes > {limit} bytes)"
)
async def _read_httpx_body_with_limit(response, *, media_type: str) -> bytes:
"""Read an httpx streaming response body without exceeding the media cap.
Rejects early on an oversized ``Content-Length`` header, then re-checks
the running total as chunks arrive so a lying/absent header can't smuggle
an unbounded body past the cap.
"""
max_bytes = get_inbound_media_max_bytes()
content_length = response.headers.get("content-length")
if content_length:
try:
declared_size = int(content_length)
except ValueError:
logger.debug(
"Ignoring invalid Content-Length for inbound %s: %r",
media_type, content_length,
)
else:
validate_inbound_media_size(
declared_size, media_type=media_type, max_bytes=max_bytes,
)
chunks: list[bytes] = []
total = 0
async for chunk in response.aiter_bytes():
total += len(chunk)
validate_inbound_media_size(total, media_type=media_type, max_bytes=max_bytes)
chunks.append(chunk)
return b"".join(chunks)
def get_image_cache_dir() -> Path:
"""Return the image cache directory, creating it if it doesn't exist."""
@ -606,6 +696,7 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str:
ValueError: If *data* does not look like a valid image (e.g. an HTML
error page returned by the upstream server).
"""
validate_inbound_media_size(len(data), media_type="image")
if not _looks_like_image(data):
snippet = data[:80].decode("utf-8", errors="replace")
raise ValueError(
@ -651,15 +742,19 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
async with client.stream(
"GET",
url,
headers={
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
"Accept": "image/*,*/*;q=0.8",
},
)
response.raise_for_status()
return cache_image_from_bytes(response.content, ext)
) as response:
response.raise_for_status()
content = await _read_httpx_body_with_limit(
response, media_type="image",
)
return cache_image_from_bytes(content, ext)
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
raise
@ -726,6 +821,7 @@ def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str:
Returns:
Absolute path to the cached audio file as a string.
"""
validate_inbound_media_size(len(data), media_type="audio")
cache_dir = get_audio_cache_dir()
filename = f"audio_{uuid.uuid4().hex[:12]}{ext}"
filepath = cache_dir / filename
@ -765,15 +861,19 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
async with client.stream(
"GET",
url,
headers={
"User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)",
"Accept": "audio/*,*/*;q=0.8",
},
)
response.raise_for_status()
return cache_audio_from_bytes(response.content, ext)
) as response:
response.raise_for_status()
content = await _read_httpx_body_with_limit(
response, media_type="audio",
)
return cache_audio_from_bytes(content, ext)
except (httpx.TimeoutException, httpx.HTTPStatusError) as exc:
if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429:
raise
@ -818,6 +918,7 @@ def get_video_cache_dir() -> Path:
def cache_video_from_bytes(data: bytes, ext: str = ".mp4") -> str:
"""Save raw video bytes to the cache and return the absolute file path."""
validate_inbound_media_size(len(data), media_type="video")
cache_dir = get_video_cache_dir()
filename = f"video_{uuid.uuid4().hex[:12]}{ext}"
filepath = cache_dir / filename

View file

@ -2474,6 +2474,16 @@ DEFAULT_CONFIG = {
"enabled": False,
},
# Maximum bytes for an inbound image / audio / video payload the
# gateway will buffer into memory and cache to disk. Inbound media is
# read fully into RAM before being written, so an unbounded upload
# (Discord Nitro allows 500 MB) or a remote media URL pointing at a
# huge file can spike memory and OOM-kill the gateway on constrained
# deployments. Enforced in the shared cache helpers
# (gateway/platforms/base.py), so the cap holds across every platform
# adapter. ``0`` disables the cap. Default 128 MiB.
"max_inbound_media_bytes": 134217728,
# When false (default), any file path the agent emits is delivered
# as a native attachment as long as it isn't under the credential /
# system-path denylist (/etc, /proc, ~/.ssh, ~/.aws, ~/.hermes/.env,

View file

@ -116,6 +116,7 @@ from gateway.platforms.base import (
cache_audio_from_bytes,
cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES,
validate_inbound_media_size,
)
from tools.url_safety import is_safe_url
@ -5052,19 +5053,32 @@ class DiscordAdapter(BasePlatformAdapter):
# non-CDN URL into the ``att.url`` field. (issue #11345)
# ------------------------------------------------------------------
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
async def _read_attachment_bytes(
self,
att,
*,
media_type: str = "media",
) -> Optional[bytes]:
"""Read an attachment via discord.py's authenticated bot session.
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
expose a callable ``read()`` or the read itself fails. Callers
should treat ``None`` as a signal to fall back to the URL-based
downloaders.
Oversized attachments (per ``gateway.max_inbound_media_bytes``) raise
``ValueError`` BEFORE the bytes are pulled into memory when Discord
reports the size up front, so a hostile upload can't OOM the gateway.
"""
attachment_size = getattr(att, "size", None)
if attachment_size:
validate_inbound_media_size(int(attachment_size), media_type=media_type)
reader = getattr(att, "read", None)
if reader is None or not callable(reader):
return None
try:
return await reader()
raw_bytes = await reader()
except Exception as e:
logger.warning(
"[Discord] Authenticated attachment read failed for %s: %s",
@ -5072,6 +5086,8 @@ class DiscordAdapter(BasePlatformAdapter):
e,
)
return None
validate_inbound_media_size(len(raw_bytes), media_type=media_type)
return raw_bytes
async def _cache_discord_image(self, att, ext: str) -> str:
"""Cache a Discord image attachment to local disk.
@ -5081,7 +5097,7 @@ class DiscordAdapter(BasePlatformAdapter):
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
"""
raw_bytes = await self._read_attachment_bytes(att)
raw_bytes = await self._read_attachment_bytes(att, media_type="image")
if raw_bytes is not None:
try:
return cache_image_from_bytes(raw_bytes, ext=ext)
@ -5100,7 +5116,7 @@ class DiscordAdapter(BasePlatformAdapter):
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
"""
raw_bytes = await self._read_attachment_bytes(att)
raw_bytes = await self._read_attachment_bytes(att, media_type="audio")
if raw_bytes is not None:
try:
return cache_audio_from_bytes(raw_bytes, ext=ext)
@ -5122,7 +5138,7 @@ class DiscordAdapter(BasePlatformAdapter):
for passing the returned bytes to ``cache_document_from_bytes``
(and, where applicable, for injecting text content).
"""
raw_bytes = await self._read_attachment_bytes(att)
raw_bytes = await self._read_attachment_bytes(att, media_type="document")
if raw_bytes is not None:
return raw_bytes

View file

@ -34,6 +34,56 @@ def _make_timeout_error() -> httpx.TimeoutException:
return httpx.TimeoutException("timed out")
def _make_stream_response(content: bytes = b"\xff\xd8\xff fake media"):
"""Build a mock httpx response suitable for ``client.stream()`` usage.
Exposes ``raise_for_status``, an empty ``headers`` mapping (no
Content-Length), and an ``aiter_bytes`` async iterator yielding the body
in one chunk matching how ``_read_httpx_body_with_limit`` consumes it.
"""
resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.headers = {}
async def _aiter():
yield content
resp.aiter_bytes = lambda: _aiter()
return resp
def _make_stream_client(*, responses=None, side_effect=None):
"""Build a mock httpx client whose ``.stream()`` is an async CM.
``responses`` is a list of response objects (or exceptions) returned on
successive ``.stream()`` calls; ``side_effect`` is a single exception
raised on every call. The returned client also supports being used as an
``async with`` context manager (``httpx.AsyncClient(...)``).
"""
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
call_state = {"i": 0}
def _stream(method, url, **kwargs):
idx = call_state["i"]
call_state["i"] += 1
if side_effect is not None:
raise side_effect
item = responses[idx]
if isinstance(item, Exception):
raise item
cm = AsyncMock()
cm.__aenter__ = AsyncMock(return_value=item)
cm.__aexit__ = AsyncMock(return_value=False)
return cm
mock_client.stream = MagicMock(side_effect=_stream)
mock_client._call_state = call_state
return mock_client
# ---------------------------------------------------------------------------
# cache_image_from_bytes (base.py)
# ---------------------------------------------------------------------------
@ -85,14 +135,9 @@ class TestCacheImageFromUrl:
"""A clean 200 response caches the image and returns a path."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
fake_response = MagicMock()
fake_response.content = b"\xff\xd8\xff fake jpeg"
fake_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=fake_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(
responses=[_make_stream_response(b"\xff\xd8\xff fake jpeg")]
)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client):
@ -103,23 +148,15 @@ class TestCacheImageFromUrl:
path = asyncio.run(run())
assert path.endswith(".jpg")
mock_client.get.assert_called_once()
mock_client.stream.assert_called_once()
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
"""A timeout on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
fake_response = MagicMock()
fake_response.content = b"\xff\xd8\xff image data"
fake_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[_make_timeout_error(), fake_response]
mock_client = _make_stream_client(
responses=[_make_timeout_error(), _make_stream_response(b"\xff\xd8\xff image data")]
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_sleep = AsyncMock()
async def run():
@ -132,23 +169,16 @@ class TestCacheImageFromUrl:
path = asyncio.run(run())
assert path.endswith(".jpg")
assert mock_client.get.call_count == 2
assert mock_client.stream.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
"""A 429 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
ok_response = MagicMock()
ok_response.content = b"\xff\xd8\xff image data"
ok_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[_make_http_status_error(429), ok_response]
mock_client = _make_stream_client(
responses=[_make_http_status_error(429), _make_stream_response(b"\xff\xd8\xff image data")]
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -160,16 +190,13 @@ class TestCacheImageFromUrl:
path = asyncio.run(run())
assert path.endswith(".jpg")
assert mock_client.get.call_count == 2
assert mock_client.stream.call_count == 2
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
"""Timeout on every attempt raises after all retries are consumed."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(side_effect=_make_timeout_error())
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -183,17 +210,14 @@ class TestCacheImageFromUrl:
asyncio.run(run())
# 3 total calls: initial + 2 retries
assert mock_client.get.call_count == 3
assert mock_client.stream.call_count == 3
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
"""A 404 (non-retryable) is raised immediately without any retry."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
mock_sleep = AsyncMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(side_effect=_make_http_status_error(404))
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -207,7 +231,7 @@ class TestCacheImageFromUrl:
asyncio.run(run())
# Only 1 attempt, no sleep
assert mock_client.get.call_count == 1
assert mock_client.stream.call_count == 1
mock_sleep.assert_not_called()
@ -223,14 +247,9 @@ class TestCacheAudioFromUrl:
"""A clean 200 response caches the audio and returns a path."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
fake_response = MagicMock()
fake_response.content = b"\x00\x01 fake audio"
fake_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=fake_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(
responses=[_make_stream_response(b"\x00\x01 fake audio")]
)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client):
@ -241,23 +260,15 @@ class TestCacheAudioFromUrl:
path = asyncio.run(run())
assert path.endswith(".ogg")
mock_client.get.assert_called_once()
mock_client.stream.assert_called_once()
def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
"""A timeout on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
fake_response = MagicMock()
fake_response.content = b"audio data"
fake_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[_make_timeout_error(), fake_response]
mock_client = _make_stream_client(
responses=[_make_timeout_error(), _make_stream_response(b"audio data")]
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_sleep = AsyncMock()
async def run():
@ -270,23 +281,16 @@ class TestCacheAudioFromUrl:
path = asyncio.run(run())
assert path.endswith(".ogg")
assert mock_client.get.call_count == 2
assert mock_client.stream.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
"""A 429 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
ok_response = MagicMock()
ok_response.content = b"audio data"
ok_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[_make_http_status_error(429), ok_response]
mock_client = _make_stream_client(
responses=[_make_http_status_error(429), _make_stream_response(b"audio data")]
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -298,22 +302,15 @@ class TestCacheAudioFromUrl:
path = asyncio.run(run())
assert path.endswith(".ogg")
assert mock_client.get.call_count == 2
assert mock_client.stream.call_count == 2
def test_retries_on_500_then_succeeds(self, _mock_safe, tmp_path, monkeypatch):
"""A 500 response on the first attempt is retried; second attempt succeeds."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
ok_response = MagicMock()
ok_response.content = b"audio data"
ok_response.raise_for_status = MagicMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(
side_effect=[_make_http_status_error(500), ok_response]
mock_client = _make_stream_client(
responses=[_make_http_status_error(500), _make_stream_response(b"audio data")]
)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -325,16 +322,13 @@ class TestCacheAudioFromUrl:
path = asyncio.run(run())
assert path.endswith(".ogg")
assert mock_client.get.call_count == 2
assert mock_client.stream.call_count == 2
def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch):
"""Timeout on every attempt raises after all retries are consumed."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=_make_timeout_error())
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(side_effect=_make_timeout_error())
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -348,17 +342,14 @@ class TestCacheAudioFromUrl:
asyncio.run(run())
# 3 total calls: initial + 2 retries
assert mock_client.get.call_count == 3
assert mock_client.stream.call_count == 3
def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch):
"""A 404 (non-retryable) is raised immediately without any retry."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
mock_sleep = AsyncMock()
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=_make_http_status_error(404))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
mock_client = _make_stream_client(side_effect=_make_http_status_error(404))
async def run():
with patch("httpx.AsyncClient", return_value=mock_client), \
@ -372,7 +363,7 @@ class TestCacheAudioFromUrl:
asyncio.run(run())
# Only 1 attempt, no sleep
assert mock_client.get.call_count == 1
assert mock_client.stream.call_count == 1
mock_sleep.assert_not_called()
@ -415,12 +406,18 @@ class TestSSRFRedirectGuard:
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
# Simulate httpx calling the response event hooks
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
def fake_stream(method, _url, **kwargs):
async def _aenter(*a):
# Simulate httpx invoking the response event hooks on the stream.
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
return redirect_resp
cm = AsyncMock()
cm.__aenter__ = AsyncMock(side_effect=_aenter)
cm.__aexit__ = AsyncMock(return_value=False)
return cm
mock_client.get = AsyncMock(side_effect=fake_get)
mock_client.stream = MagicMock(side_effect=fake_stream)
def fake_safe(url):
return url == "https://public.example.com/image.png"
@ -445,11 +442,17 @@ class TestSSRFRedirectGuard:
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
def fake_stream(method, _url, **kwargs):
async def _aenter(*a):
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
return redirect_resp
cm = AsyncMock()
cm.__aenter__ = AsyncMock(side_effect=_aenter)
cm.__aexit__ = AsyncMock(return_value=False)
return cm
mock_client.get = AsyncMock(side_effect=fake_get)
mock_client.stream = MagicMock(side_effect=fake_stream)
def fake_safe(url):
return url == "https://public.example.com/voice.ogg"
@ -473,24 +476,24 @@ class TestSSRFRedirectGuard:
"https://cdn.example.com/real-image.png"
)
ok_response = MagicMock()
ok_response.content = b"\xff\xd8\xff fake jpeg"
ok_response.raise_for_status = MagicMock()
ok_response = _make_stream_response(b"\xff\xd8\xff fake jpeg")
ok_response.is_redirect = False
mock_client, captured, factory = self._make_client_capturing_hooks()
call_count = 0
async def fake_get(_url, **kwargs):
nonlocal call_count
call_count += 1
# First call triggers redirect hook, second returns data
async def _aenter(*a):
# Public redirect passes the guard; body then streams normally.
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp if call_count == 1 else ok_response)
await hook(redirect_resp)
return ok_response
mock_client.get = AsyncMock(side_effect=fake_get)
def fake_stream(method, _url, **kwargs):
cm = AsyncMock()
cm.__aenter__ = AsyncMock(side_effect=_aenter)
cm.__aexit__ = AsyncMock(return_value=False)
return cm
mock_client.stream = MagicMock(side_effect=fake_stream)
async def run():
with patch("tools.url_safety.is_safe_url", return_value=True), \

View file

@ -10,13 +10,68 @@ from gateway.platforms.base import (
BasePlatformAdapter,
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
MessageEvent,
cache_audio_from_bytes,
cache_image_from_bytes,
cache_video_from_bytes,
safe_url_for_log,
utf16_len,
validate_inbound_media_size,
_log_safe_path,
_prefix_within_utf16_limit,
)
class TestInboundMediaSizeCap:
"""gateway.max_inbound_media_bytes caps inbound media buffered into RAM (#13145)."""
_PNG = b"\x89PNG\r\n\x1a\n" + b"x" * 64
def test_default_cap_is_128_mib(self, monkeypatch):
# No config override -> default. Patch loader to return empty config.
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: base.DEFAULT_INBOUND_MEDIA_MAX_BYTES)
assert base.DEFAULT_INBOUND_MEDIA_MAX_BYTES == 128 * 1024 * 1024
def test_image_bytes_rejected_when_oversized(self, monkeypatch):
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 16)
with pytest.raises(ValueError, match="Inbound image payload is too large"):
cache_image_from_bytes(self._PNG, ext=".png")
def test_audio_bytes_rejected_when_oversized(self, monkeypatch):
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4)
with pytest.raises(ValueError, match="Inbound audio payload is too large"):
cache_audio_from_bytes(b"x" * 8, ext=".ogg")
def test_video_bytes_rejected_when_oversized(self, monkeypatch):
# Video was the gap in the original report — verify it's covered.
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4)
with pytest.raises(ValueError, match="Inbound video payload is too large"):
cache_video_from_bytes(b"x" * 8, ext=".mp4")
def test_legit_image_accepted_under_cap(self, monkeypatch):
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 128 * 1024 * 1024)
path = cache_image_from_bytes(self._PNG, ext=".png")
assert os.path.exists(path)
assert os.path.getsize(path) == len(self._PNG)
def test_cap_of_zero_disables_check(self, monkeypatch):
import gateway.platforms.base as base
monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 0)
# A would-be-oversized video passes through when the cap is disabled.
path = cache_video_from_bytes(b"x" * 5000, ext=".mp4")
assert os.path.exists(path)
def test_validate_helper_respects_explicit_max_bytes(self):
# max_bytes arg overrides the configured cap.
validate_inbound_media_size(100, media_type="image", max_bytes=200) # ok
with pytest.raises(ValueError, match="too large"):
validate_inbound_media_size(300, media_type="image", max_bytes=200)
class TestSecretCaptureGuidance:
def test_gateway_secret_capture_message_points_to_local_setup(self):
message = GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE