From 93ea9b04aff2f1992b31b86a267303fecc227995 Mon Sep 17 00:00:00 2001 From: sgaofen <135070653+sgaofen@users.noreply.github.com> Date: Sun, 21 Jun 2026 11:36:39 -0700 Subject: [PATCH] fix(gateway): cap inbound media download size to prevent memory exhaustion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Inbound image/audio/video payloads were buffered fully into process memory before being written to the cache, with no size limit. A large upload (Discord Nitro allows 500 MB) or a remote media URL in an inbound message pointing at a huge file could spike RAM and OOM-kill the gateway. Enforce a configurable cap in the shared cache helpers (gateway/platforms/ base.py) so the protection holds across every platform adapter, not one: - cache_image/audio/video_from_bytes reject oversized payloads before writing (video was the gap in the original report — now covered). - cache_image/audio_from_url stream the body, rejecting on an oversized Content-Length header and re-checking the running total per chunk so an absent/lying header can't smuggle an unbounded body past the cap. - Discord's _read_attachment_bytes checks att.size up front, so an oversized attachment is rejected before any bytes are pulled into memory. Configurable via gateway.max_inbound_media_bytes in config.yaml (default 128 MiB; 0 disables). No new env var — non-secret config lives in config.yaml. Salvaged and extended from @sgaofen's PR #13341 (the original report and the shared-helper approach). Reapplied onto current main (Discord adapter has since moved to plugins/platforms/discord/), the configurable knob moved from an env var to config.yaml, and the video cache helper added. Co-authored-by: Hermes Agent --- gateway/platforms/base.py | 117 ++++++++++- hermes_cli/config.py | 10 + plugins/platforms/discord/adapter.py | 26 ++- tests/gateway/test_media_download_retry.py | 223 +++++++++++---------- tests/gateway/test_platform_base.py | 55 +++++ 5 files changed, 308 insertions(+), 123 deletions(-) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 8c447a7a2bf..fe1039f2579 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -567,6 +567,96 @@ async def _ssrf_redirect_guard(response): # Default location: {HERMES_HOME}/cache/images/ (legacy: image_cache/) IMAGE_CACHE_DIR = get_hermes_dir("cache/images", "image_cache") +# --------------------------------------------------------------------------- +# Inbound media size cap (#13145) +# +# Inbound image / audio / video payloads are buffered fully into process +# memory before being written to the cache directory. With no cap, a single +# large upload (Discord Nitro allows 500 MB) — or a remote URL in an inbound +# message payload pointing at an arbitrarily large file — can spike RAM and +# OOM-kill the gateway. The ``cache_*_from_bytes`` helpers (the shared funnel +# every platform reaches eventually) and the ``cache_*_from_url`` downloaders +# enforce this cap, so the protection holds regardless of which platform +# adapter or code path produced the bytes. +# +# Configurable via ``gateway.max_inbound_media_bytes`` in config.yaml. +# ``0`` disables the cap. Default 128 MiB — generous enough for ordinary +# photos/voice notes/short clips while still bounding a hostile upload. +# --------------------------------------------------------------------------- +DEFAULT_INBOUND_MEDIA_MAX_BYTES = 128 * 1024 * 1024 + + +def get_inbound_media_max_bytes() -> int: + """Return the max inbound image/audio/video bytes allowed in memory. + + Reads ``gateway.max_inbound_media_bytes`` from config.yaml. ``0`` (or a + negative / unparseable value) disables the cap. Non-fatal if config is + unreadable — falls back to the default. + """ + try: + from hermes_cli.config import load_config as _load_config + cfg = _load_config() + except Exception: + return DEFAULT_INBOUND_MEDIA_MAX_BYTES + gw = cfg.get("gateway", {}) if isinstance(cfg, dict) else {} + if not isinstance(gw, dict) or "max_inbound_media_bytes" not in gw: + return DEFAULT_INBOUND_MEDIA_MAX_BYTES + try: + return int(gw["max_inbound_media_bytes"]) + except (TypeError, ValueError): + return DEFAULT_INBOUND_MEDIA_MAX_BYTES + + +def validate_inbound_media_size( + size: int, + *, + media_type: str = "media", + max_bytes: Optional[int] = None, +) -> None: + """Raise ``ValueError`` if an inbound media payload exceeds the cap. + + A ``max_bytes`` of ``0`` (or the configured cap resolving to ``0``) + disables the check entirely. Passing ``max_bytes`` lets callers resolve + the limit once and reuse it across an incremental read. + """ + limit = get_inbound_media_max_bytes() if max_bytes is None else max_bytes + if limit and size > limit: + raise ValueError( + f"Inbound {media_type} payload is too large " + f"({size} bytes > {limit} bytes)" + ) + + +async def _read_httpx_body_with_limit(response, *, media_type: str) -> bytes: + """Read an httpx streaming response body without exceeding the media cap. + + Rejects early on an oversized ``Content-Length`` header, then re-checks + the running total as chunks arrive so a lying/absent header can't smuggle + an unbounded body past the cap. + """ + max_bytes = get_inbound_media_max_bytes() + content_length = response.headers.get("content-length") + if content_length: + try: + declared_size = int(content_length) + except ValueError: + logger.debug( + "Ignoring invalid Content-Length for inbound %s: %r", + media_type, content_length, + ) + else: + validate_inbound_media_size( + declared_size, media_type=media_type, max_bytes=max_bytes, + ) + + chunks: list[bytes] = [] + total = 0 + async for chunk in response.aiter_bytes(): + total += len(chunk) + validate_inbound_media_size(total, media_type=media_type, max_bytes=max_bytes) + chunks.append(chunk) + return b"".join(chunks) + def get_image_cache_dir() -> Path: """Return the image cache directory, creating it if it doesn't exist.""" @@ -606,6 +696,7 @@ def cache_image_from_bytes(data: bytes, ext: str = ".jpg") -> str: ValueError: If *data* does not look like a valid image (e.g. an HTML error page returned by the upstream server). """ + validate_inbound_media_size(len(data), media_type="image") if not _looks_like_image(data): snippet = data[:80].decode("utf-8", errors="replace") raise ValueError( @@ -651,15 +742,19 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> ) as client: for attempt in range(retries + 1): try: - response = await client.get( + async with client.stream( + "GET", url, headers={ "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", "Accept": "image/*,*/*;q=0.8", }, - ) - response.raise_for_status() - return cache_image_from_bytes(response.content, ext) + ) as response: + response.raise_for_status() + content = await _read_httpx_body_with_limit( + response, media_type="image", + ) + return cache_image_from_bytes(content, ext) except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: raise @@ -726,6 +821,7 @@ def cache_audio_from_bytes(data: bytes, ext: str = ".ogg") -> str: Returns: Absolute path to the cached audio file as a string. """ + validate_inbound_media_size(len(data), media_type="audio") cache_dir = get_audio_cache_dir() filename = f"audio_{uuid.uuid4().hex[:12]}{ext}" filepath = cache_dir / filename @@ -765,15 +861,19 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> ) as client: for attempt in range(retries + 1): try: - response = await client.get( + async with client.stream( + "GET", url, headers={ "User-Agent": "Mozilla/5.0 (compatible; HermesAgent/1.0)", "Accept": "audio/*,*/*;q=0.8", }, - ) - response.raise_for_status() - return cache_audio_from_bytes(response.content, ext) + ) as response: + response.raise_for_status() + content = await _read_httpx_body_with_limit( + response, media_type="audio", + ) + return cache_audio_from_bytes(content, ext) except (httpx.TimeoutException, httpx.HTTPStatusError) as exc: if isinstance(exc, httpx.HTTPStatusError) and exc.response.status_code < 429: raise @@ -818,6 +918,7 @@ def get_video_cache_dir() -> Path: def cache_video_from_bytes(data: bytes, ext: str = ".mp4") -> str: """Save raw video bytes to the cache and return the absolute file path.""" + validate_inbound_media_size(len(data), media_type="video") cache_dir = get_video_cache_dir() filename = f"video_{uuid.uuid4().hex[:12]}{ext}" filepath = cache_dir / filename diff --git a/hermes_cli/config.py b/hermes_cli/config.py index 0605ab83569..b833b94836a 100644 --- a/hermes_cli/config.py +++ b/hermes_cli/config.py @@ -2474,6 +2474,16 @@ DEFAULT_CONFIG = { "enabled": False, }, + # Maximum bytes for an inbound image / audio / video payload the + # gateway will buffer into memory and cache to disk. Inbound media is + # read fully into RAM before being written, so an unbounded upload + # (Discord Nitro allows 500 MB) or a remote media URL pointing at a + # huge file can spike memory and OOM-kill the gateway on constrained + # deployments. Enforced in the shared cache helpers + # (gateway/platforms/base.py), so the cap holds across every platform + # adapter. ``0`` disables the cap. Default 128 MiB. + "max_inbound_media_bytes": 134217728, + # When false (default), any file path the agent emits is delivered # as a native attachment as long as it isn't under the credential / # system-path denylist (/etc, /proc, ~/.ssh, ~/.aws, ~/.hermes/.env, diff --git a/plugins/platforms/discord/adapter.py b/plugins/platforms/discord/adapter.py index accede61a23..1fc6692eac5 100644 --- a/plugins/platforms/discord/adapter.py +++ b/plugins/platforms/discord/adapter.py @@ -116,6 +116,7 @@ from gateway.platforms.base import ( cache_audio_from_bytes, cache_document_from_bytes, SUPPORTED_DOCUMENT_TYPES, + validate_inbound_media_size, ) from tools.url_safety import is_safe_url @@ -5052,19 +5053,32 @@ class DiscordAdapter(BasePlatformAdapter): # non-CDN URL into the ``att.url`` field. (issue #11345) # ------------------------------------------------------------------ - async def _read_attachment_bytes(self, att) -> Optional[bytes]: + async def _read_attachment_bytes( + self, + att, + *, + media_type: str = "media", + ) -> Optional[bytes]: """Read an attachment via discord.py's authenticated bot session. Returns the raw bytes on success, or ``None`` if ``att`` doesn't expose a callable ``read()`` or the read itself fails. Callers should treat ``None`` as a signal to fall back to the URL-based downloaders. + + Oversized attachments (per ``gateway.max_inbound_media_bytes``) raise + ``ValueError`` BEFORE the bytes are pulled into memory when Discord + reports the size up front, so a hostile upload can't OOM the gateway. """ + attachment_size = getattr(att, "size", None) + if attachment_size: + validate_inbound_media_size(int(attachment_size), media_type=media_type) + reader = getattr(att, "read", None) if reader is None or not callable(reader): return None try: - return await reader() + raw_bytes = await reader() except Exception as e: logger.warning( "[Discord] Authenticated attachment read failed for %s: %s", @@ -5072,6 +5086,8 @@ class DiscordAdapter(BasePlatformAdapter): e, ) return None + validate_inbound_media_size(len(raw_bytes), media_type=media_type) + return raw_bytes async def _cache_discord_image(self, att, ext: str) -> str: """Cache a Discord image attachment to local disk. @@ -5081,7 +5097,7 @@ class DiscordAdapter(BasePlatformAdapter): Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated). """ - raw_bytes = await self._read_attachment_bytes(att) + raw_bytes = await self._read_attachment_bytes(att, media_type="image") if raw_bytes is not None: try: return cache_image_from_bytes(raw_bytes, ext=ext) @@ -5100,7 +5116,7 @@ class DiscordAdapter(BasePlatformAdapter): Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated). """ - raw_bytes = await self._read_attachment_bytes(att) + raw_bytes = await self._read_attachment_bytes(att, media_type="audio") if raw_bytes is not None: try: return cache_audio_from_bytes(raw_bytes, ext=ext) @@ -5122,7 +5138,7 @@ class DiscordAdapter(BasePlatformAdapter): for passing the returned bytes to ``cache_document_from_bytes`` (and, where applicable, for injecting text content). """ - raw_bytes = await self._read_attachment_bytes(att) + raw_bytes = await self._read_attachment_bytes(att, media_type="document") if raw_bytes is not None: return raw_bytes diff --git a/tests/gateway/test_media_download_retry.py b/tests/gateway/test_media_download_retry.py index 2cdc8a32b46..a473a049353 100644 --- a/tests/gateway/test_media_download_retry.py +++ b/tests/gateway/test_media_download_retry.py @@ -34,6 +34,56 @@ def _make_timeout_error() -> httpx.TimeoutException: return httpx.TimeoutException("timed out") +def _make_stream_response(content: bytes = b"\xff\xd8\xff fake media"): + """Build a mock httpx response suitable for ``client.stream()`` usage. + + Exposes ``raise_for_status``, an empty ``headers`` mapping (no + Content-Length), and an ``aiter_bytes`` async iterator yielding the body + in one chunk — matching how ``_read_httpx_body_with_limit`` consumes it. + """ + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.headers = {} + + async def _aiter(): + yield content + + resp.aiter_bytes = lambda: _aiter() + return resp + + +def _make_stream_client(*, responses=None, side_effect=None): + """Build a mock httpx client whose ``.stream()`` is an async CM. + + ``responses`` is a list of response objects (or exceptions) returned on + successive ``.stream()`` calls; ``side_effect`` is a single exception + raised on every call. The returned client also supports being used as an + ``async with`` context manager (``httpx.AsyncClient(...)``). + """ + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + call_state = {"i": 0} + + def _stream(method, url, **kwargs): + idx = call_state["i"] + call_state["i"] += 1 + if side_effect is not None: + raise side_effect + item = responses[idx] + if isinstance(item, Exception): + raise item + cm = AsyncMock() + cm.__aenter__ = AsyncMock(return_value=item) + cm.__aexit__ = AsyncMock(return_value=False) + return cm + + mock_client.stream = MagicMock(side_effect=_stream) + mock_client._call_state = call_state + return mock_client + + # --------------------------------------------------------------------------- # cache_image_from_bytes (base.py) # --------------------------------------------------------------------------- @@ -85,14 +135,9 @@ class TestCacheImageFromUrl: """A clean 200 response caches the image and returns a path.""" monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") - fake_response = MagicMock() - fake_response.content = b"\xff\xd8\xff fake jpeg" - fake_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock(return_value=fake_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client( + responses=[_make_stream_response(b"\xff\xd8\xff fake jpeg")] + ) async def run(): with patch("httpx.AsyncClient", return_value=mock_client): @@ -103,23 +148,15 @@ class TestCacheImageFromUrl: path = asyncio.run(run()) assert path.endswith(".jpg") - mock_client.get.assert_called_once() + mock_client.stream.assert_called_once() def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch): """A timeout on the first attempt is retried; second attempt succeeds.""" monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") - fake_response = MagicMock() - fake_response.content = b"\xff\xd8\xff image data" - fake_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock( - side_effect=[_make_timeout_error(), fake_response] + mock_client = _make_stream_client( + responses=[_make_timeout_error(), _make_stream_response(b"\xff\xd8\xff image data")] ) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_sleep = AsyncMock() async def run(): @@ -132,23 +169,16 @@ class TestCacheImageFromUrl: path = asyncio.run(run()) assert path.endswith(".jpg") - assert mock_client.get.call_count == 2 + assert mock_client.stream.call_count == 2 mock_sleep.assert_called_once() def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch): """A 429 response on the first attempt is retried; second attempt succeeds.""" monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") - ok_response = MagicMock() - ok_response.content = b"\xff\xd8\xff image data" - ok_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock( - side_effect=[_make_http_status_error(429), ok_response] + mock_client = _make_stream_client( + responses=[_make_http_status_error(429), _make_stream_response(b"\xff\xd8\xff image data")] ) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -160,16 +190,13 @@ class TestCacheImageFromUrl: path = asyncio.run(run()) assert path.endswith(".jpg") - assert mock_client.get.call_count == 2 + assert mock_client.stream.call_count == 2 def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch): """Timeout on every attempt raises after all retries are consumed.""" monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") - mock_client = AsyncMock() - mock_client.get = AsyncMock(side_effect=_make_timeout_error()) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client(side_effect=_make_timeout_error()) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -183,17 +210,14 @@ class TestCacheImageFromUrl: asyncio.run(run()) # 3 total calls: initial + 2 retries - assert mock_client.get.call_count == 3 + assert mock_client.stream.call_count == 3 def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch): """A 404 (non-retryable) is raised immediately without any retry.""" monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") mock_sleep = AsyncMock() - mock_client = AsyncMock() - mock_client.get = AsyncMock(side_effect=_make_http_status_error(404)) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client(side_effect=_make_http_status_error(404)) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -207,7 +231,7 @@ class TestCacheImageFromUrl: asyncio.run(run()) # Only 1 attempt, no sleep - assert mock_client.get.call_count == 1 + assert mock_client.stream.call_count == 1 mock_sleep.assert_not_called() @@ -223,14 +247,9 @@ class TestCacheAudioFromUrl: """A clean 200 response caches the audio and returns a path.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") - fake_response = MagicMock() - fake_response.content = b"\x00\x01 fake audio" - fake_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock(return_value=fake_response) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client( + responses=[_make_stream_response(b"\x00\x01 fake audio")] + ) async def run(): with patch("httpx.AsyncClient", return_value=mock_client): @@ -241,23 +260,15 @@ class TestCacheAudioFromUrl: path = asyncio.run(run()) assert path.endswith(".ogg") - mock_client.get.assert_called_once() + mock_client.stream.assert_called_once() def test_retries_on_timeout_then_succeeds(self, _mock_safe, tmp_path, monkeypatch): """A timeout on the first attempt is retried; second attempt succeeds.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") - fake_response = MagicMock() - fake_response.content = b"audio data" - fake_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock( - side_effect=[_make_timeout_error(), fake_response] + mock_client = _make_stream_client( + responses=[_make_timeout_error(), _make_stream_response(b"audio data")] ) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) - mock_sleep = AsyncMock() async def run(): @@ -270,23 +281,16 @@ class TestCacheAudioFromUrl: path = asyncio.run(run()) assert path.endswith(".ogg") - assert mock_client.get.call_count == 2 + assert mock_client.stream.call_count == 2 mock_sleep.assert_called_once() def test_retries_on_429_then_succeeds(self, _mock_safe, tmp_path, monkeypatch): """A 429 response on the first attempt is retried; second attempt succeeds.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") - ok_response = MagicMock() - ok_response.content = b"audio data" - ok_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock( - side_effect=[_make_http_status_error(429), ok_response] + mock_client = _make_stream_client( + responses=[_make_http_status_error(429), _make_stream_response(b"audio data")] ) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -298,22 +302,15 @@ class TestCacheAudioFromUrl: path = asyncio.run(run()) assert path.endswith(".ogg") - assert mock_client.get.call_count == 2 + assert mock_client.stream.call_count == 2 def test_retries_on_500_then_succeeds(self, _mock_safe, tmp_path, monkeypatch): """A 500 response on the first attempt is retried; second attempt succeeds.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") - ok_response = MagicMock() - ok_response.content = b"audio data" - ok_response.raise_for_status = MagicMock() - - mock_client = AsyncMock() - mock_client.get = AsyncMock( - side_effect=[_make_http_status_error(500), ok_response] + mock_client = _make_stream_client( + responses=[_make_http_status_error(500), _make_stream_response(b"audio data")] ) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -325,16 +322,13 @@ class TestCacheAudioFromUrl: path = asyncio.run(run()) assert path.endswith(".ogg") - assert mock_client.get.call_count == 2 + assert mock_client.stream.call_count == 2 def test_raises_after_max_retries_exhausted(self, _mock_safe, tmp_path, monkeypatch): """Timeout on every attempt raises after all retries are consumed.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") - mock_client = AsyncMock() - mock_client.get = AsyncMock(side_effect=_make_timeout_error()) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client(side_effect=_make_timeout_error()) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -348,17 +342,14 @@ class TestCacheAudioFromUrl: asyncio.run(run()) # 3 total calls: initial + 2 retries - assert mock_client.get.call_count == 3 + assert mock_client.stream.call_count == 3 def test_non_retryable_4xx_raises_immediately(self, _mock_safe, tmp_path, monkeypatch): """A 404 (non-retryable) is raised immediately without any retry.""" monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") mock_sleep = AsyncMock() - mock_client = AsyncMock() - mock_client.get = AsyncMock(side_effect=_make_http_status_error(404)) - mock_client.__aenter__ = AsyncMock(return_value=mock_client) - mock_client.__aexit__ = AsyncMock(return_value=False) + mock_client = _make_stream_client(side_effect=_make_http_status_error(404)) async def run(): with patch("httpx.AsyncClient", return_value=mock_client), \ @@ -372,7 +363,7 @@ class TestCacheAudioFromUrl: asyncio.run(run()) # Only 1 attempt, no sleep - assert mock_client.get.call_count == 1 + assert mock_client.stream.call_count == 1 mock_sleep.assert_not_called() @@ -415,12 +406,18 @@ class TestSSRFRedirectGuard: ) mock_client, captured, factory = self._make_client_capturing_hooks() - async def fake_get(_url, **kwargs): - # Simulate httpx calling the response event hooks - for hook in captured["event_hooks"]["response"]: - await hook(redirect_resp) + def fake_stream(method, _url, **kwargs): + async def _aenter(*a): + # Simulate httpx invoking the response event hooks on the stream. + for hook in captured["event_hooks"]["response"]: + await hook(redirect_resp) + return redirect_resp + cm = AsyncMock() + cm.__aenter__ = AsyncMock(side_effect=_aenter) + cm.__aexit__ = AsyncMock(return_value=False) + return cm - mock_client.get = AsyncMock(side_effect=fake_get) + mock_client.stream = MagicMock(side_effect=fake_stream) def fake_safe(url): return url == "https://public.example.com/image.png" @@ -445,11 +442,17 @@ class TestSSRFRedirectGuard: ) mock_client, captured, factory = self._make_client_capturing_hooks() - async def fake_get(_url, **kwargs): - for hook in captured["event_hooks"]["response"]: - await hook(redirect_resp) + def fake_stream(method, _url, **kwargs): + async def _aenter(*a): + for hook in captured["event_hooks"]["response"]: + await hook(redirect_resp) + return redirect_resp + cm = AsyncMock() + cm.__aenter__ = AsyncMock(side_effect=_aenter) + cm.__aexit__ = AsyncMock(return_value=False) + return cm - mock_client.get = AsyncMock(side_effect=fake_get) + mock_client.stream = MagicMock(side_effect=fake_stream) def fake_safe(url): return url == "https://public.example.com/voice.ogg" @@ -473,24 +476,24 @@ class TestSSRFRedirectGuard: "https://cdn.example.com/real-image.png" ) - ok_response = MagicMock() - ok_response.content = b"\xff\xd8\xff fake jpeg" - ok_response.raise_for_status = MagicMock() + ok_response = _make_stream_response(b"\xff\xd8\xff fake jpeg") ok_response.is_redirect = False mock_client, captured, factory = self._make_client_capturing_hooks() - call_count = 0 - - async def fake_get(_url, **kwargs): - nonlocal call_count - call_count += 1 - # First call triggers redirect hook, second returns data + async def _aenter(*a): + # Public redirect passes the guard; body then streams normally. for hook in captured["event_hooks"]["response"]: - await hook(redirect_resp if call_count == 1 else ok_response) + await hook(redirect_resp) return ok_response - mock_client.get = AsyncMock(side_effect=fake_get) + def fake_stream(method, _url, **kwargs): + cm = AsyncMock() + cm.__aenter__ = AsyncMock(side_effect=_aenter) + cm.__aexit__ = AsyncMock(return_value=False) + return cm + + mock_client.stream = MagicMock(side_effect=fake_stream) async def run(): with patch("tools.url_safety.is_safe_url", return_value=True), \ diff --git a/tests/gateway/test_platform_base.py b/tests/gateway/test_platform_base.py index 3f8ecd93231..3a4f85a5e41 100644 --- a/tests/gateway/test_platform_base.py +++ b/tests/gateway/test_platform_base.py @@ -10,13 +10,68 @@ from gateway.platforms.base import ( BasePlatformAdapter, GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE, MessageEvent, + cache_audio_from_bytes, + cache_image_from_bytes, + cache_video_from_bytes, safe_url_for_log, utf16_len, + validate_inbound_media_size, _log_safe_path, _prefix_within_utf16_limit, ) +class TestInboundMediaSizeCap: + """gateway.max_inbound_media_bytes caps inbound media buffered into RAM (#13145).""" + + _PNG = b"\x89PNG\r\n\x1a\n" + b"x" * 64 + + def test_default_cap_is_128_mib(self, monkeypatch): + # No config override -> default. Patch loader to return empty config. + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: base.DEFAULT_INBOUND_MEDIA_MAX_BYTES) + assert base.DEFAULT_INBOUND_MEDIA_MAX_BYTES == 128 * 1024 * 1024 + + def test_image_bytes_rejected_when_oversized(self, monkeypatch): + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 16) + with pytest.raises(ValueError, match="Inbound image payload is too large"): + cache_image_from_bytes(self._PNG, ext=".png") + + def test_audio_bytes_rejected_when_oversized(self, monkeypatch): + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4) + with pytest.raises(ValueError, match="Inbound audio payload is too large"): + cache_audio_from_bytes(b"x" * 8, ext=".ogg") + + def test_video_bytes_rejected_when_oversized(self, monkeypatch): + # Video was the gap in the original report — verify it's covered. + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 4) + with pytest.raises(ValueError, match="Inbound video payload is too large"): + cache_video_from_bytes(b"x" * 8, ext=".mp4") + + def test_legit_image_accepted_under_cap(self, monkeypatch): + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 128 * 1024 * 1024) + path = cache_image_from_bytes(self._PNG, ext=".png") + assert os.path.exists(path) + assert os.path.getsize(path) == len(self._PNG) + + def test_cap_of_zero_disables_check(self, monkeypatch): + import gateway.platforms.base as base + monkeypatch.setattr(base, "get_inbound_media_max_bytes", lambda: 0) + # A would-be-oversized video passes through when the cap is disabled. + path = cache_video_from_bytes(b"x" * 5000, ext=".mp4") + assert os.path.exists(path) + + def test_validate_helper_respects_explicit_max_bytes(self): + # max_bytes arg overrides the configured cap. + validate_inbound_media_size(100, media_type="image", max_bytes=200) # ok + with pytest.raises(ValueError, match="too large"): + validate_inbound_media_size(300, media_type="image", max_bytes=200) + + class TestSecretCaptureGuidance: def test_gateway_secret_capture_message_points_to_local_setup(self): message = GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE