mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(discord): route attachment downloads through authenticated bot session (#11568)
Three open issues — #8242, #6587, #11345 — all trace to the same root cause: the image / audio / document download paths in `DiscordAdapter._handle_message` used plain, unauthenticated HTTP to fetch `att.url`. That broke in three independent ways: #8242 cdn.discordapp.com attachment URLs increasingly require the bot session to download; unauthenticated httpx sees 403 Forbidden, image/voice analysis fail silently. #6587 Some user environments (VPNs, corporate DNS, tunnels) resolve cdn.discordapp.com to private-looking IPs. Our is_safe_url() guard correctly blocks them as SSRF risks, but the user environment is legitimate — image analysis and voice STT die. #11345 The document download path skipped is_safe_url() entirely — raw aiohttp.ClientSession.get(att.url) with no SSRF check, inconsistent with the image/audio branches. Unified fix: use `discord.Attachment.read()` as the primary download path on all three branches. `att.read()` routes through discord.py's own authenticated HTTPClient, so: - Discord CDN auth is handled (#8242 resolved). - Our is_safe_url() gate isn't consulted for the attachment path at all — the bot session handles networking internally (#6587 resolved). - All three branches now share the same code path, eliminating the document-path SSRF gap (#11345 resolved). Falls back to the existing cache_*_from_url helpers (image/audio) or an SSRF-gated aiohttp fetch (documents) when `att.read()` is unavailable or fails — preserves defense-in-depth for any future payload-schema drift that could slip a non-CDN URL into att.url. New helpers on DiscordAdapter: - _read_attachment_bytes(att) — safe att.read() wrapper - _cache_discord_image(att, ext) — primary + URL fallback - _cache_discord_audio(att, ext) — primary + URL fallback - _cache_discord_document(att, ext) — primary + SSRF-gated aiohttp fallback Tests: - tests/gateway/test_discord_attachment_download.py — 12 new cases covering all three helpers: primary path, fallback on missing .read(), fallback on validator rejection, SSRF guard on document fallback, aiohttp fallback happy-path, and an E2E case via _handle_message confirming cache_image_from_url is never invoked when att.read() succeeds. - All 11 existing document-handling tests continue to pass via the aiohttp fallback path (their SimpleNamespace attachments have no .read(), which triggers the fallback — now SSRF-gated). Closes #8242, closes #6587, closes #11345.
This commit is contained in:
parent
24342813fe
commit
53da34a4fc
2 changed files with 483 additions and 15 deletions
|
|
@ -51,7 +51,9 @@ from gateway.platforms.base import (
|
|||
ProcessingOutcome,
|
||||
SendResult,
|
||||
cache_image_from_url,
|
||||
cache_image_from_bytes,
|
||||
cache_audio_from_url,
|
||||
cache_audio_from_bytes,
|
||||
cache_document_from_bytes,
|
||||
SUPPORTED_DOCUMENT_TYPES,
|
||||
)
|
||||
|
|
@ -2489,6 +2491,124 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
return f"{parent_name} / {thread_name}"
|
||||
return thread_name
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Attachment download helpers
|
||||
#
|
||||
# Discord attachments (images / audio / documents) are fetched via the
|
||||
# authenticated bot session whenever the Attachment object exposes
|
||||
# ``read()``. That sidesteps two classes of bug that hit the older
|
||||
# plain-HTTP path:
|
||||
#
|
||||
# 1. ``cdn.discordapp.com`` URLs increasingly require bot auth on
|
||||
# download — unauthenticated httpx sees 403 Forbidden.
|
||||
# (issue #8242)
|
||||
# 2. Some user environments (VPNs, corporate DNS, tunnels) resolve
|
||||
# ``cdn.discordapp.com`` to private-looking IPs that our
|
||||
# ``is_safe_url`` guard classifies as SSRF risks. Routing the
|
||||
# fetch through discord.py's own HTTP client handles DNS
|
||||
# internally so our guard isn't consulted for the attachment
|
||||
# path. (issue #6587)
|
||||
#
|
||||
# If ``att.read()`` is unavailable (unexpected object shape / test
|
||||
# stub) or the bot session fetch fails, we fall back to the existing
|
||||
# SSRF-gated URL downloaders. The fallback keeps defense-in-depth
|
||||
# against any future Discord payload-schema drift that could slip a
|
||||
# non-CDN URL into the ``att.url`` field. (issue #11345)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
|
||||
"""Read an attachment via discord.py's authenticated bot session.
|
||||
|
||||
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
|
||||
expose a callable ``read()`` or the read itself fails. Callers
|
||||
should treat ``None`` as a signal to fall back to the URL-based
|
||||
downloaders.
|
||||
"""
|
||||
reader = getattr(att, "read", None)
|
||||
if reader is None or not callable(reader):
|
||||
return None
|
||||
try:
|
||||
return await reader()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"[Discord] Authenticated attachment read failed for %s: %s",
|
||||
getattr(att, "filename", None) or getattr(att, "url", "<unknown>"),
|
||||
e,
|
||||
)
|
||||
return None
|
||||
|
||||
async def _cache_discord_image(self, att, ext: str) -> str:
|
||||
"""Cache a Discord image attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_image_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_image_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_image_from_bytes rejected att.read() data; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_image_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_audio(self, att, ext: str) -> str:
|
||||
"""Cache a Discord audio attachment to local disk.
|
||||
|
||||
Primary path: ``att.read()`` + ``cache_audio_from_bytes``
|
||||
(authenticated, no SSRF gate).
|
||||
|
||||
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
try:
|
||||
return cache_audio_from_bytes(raw_bytes, ext=ext)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"[Discord] cache_audio_from_bytes failed; falling back to URL: %s",
|
||||
e,
|
||||
)
|
||||
return await cache_audio_from_url(att.url, ext=ext)
|
||||
|
||||
async def _cache_discord_document(self, att, ext: str) -> bytes:
|
||||
"""Download a Discord document attachment and return the raw bytes.
|
||||
|
||||
Primary path: ``att.read()`` (authenticated, no SSRF gate).
|
||||
|
||||
Fallback: SSRF-gated ``aiohttp`` download. This closes the gap
|
||||
where the old document path made raw ``aiohttp.ClientSession``
|
||||
requests with no safety check (#11345). The caller is responsible
|
||||
for passing the returned bytes to ``cache_document_from_bytes``
|
||||
(and, where applicable, for injecting text content).
|
||||
"""
|
||||
raw_bytes = await self._read_attachment_bytes(att)
|
||||
if raw_bytes is not None:
|
||||
return raw_bytes
|
||||
|
||||
# Fallback: SSRF-gated URL download.
|
||||
if not is_safe_url(att.url):
|
||||
raise ValueError(
|
||||
f"Blocked unsafe attachment URL (SSRF protection): {att.url}"
|
||||
)
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
return await resp.read()
|
||||
|
||||
async def _handle_message(self, message: DiscordMessage) -> None:
|
||||
"""Handle incoming Discord messages."""
|
||||
# In server channels (not DMs), require the bot to be @mentioned
|
||||
|
|
@ -2642,7 +2762,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
|
||||
ext = ".jpg"
|
||||
cached_path = await cache_image_from_url(att.url, ext=ext)
|
||||
cached_path = await self._cache_discord_image(att, ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user image: {cached_path}", flush=True)
|
||||
|
|
@ -2656,7 +2776,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
ext = "." + content_type.split("/")[-1].split(";")[0]
|
||||
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
|
||||
ext = ".ogg"
|
||||
cached_path = await cache_audio_from_url(att.url, ext=ext)
|
||||
cached_path = await self._cache_discord_audio(att, ext)
|
||||
media_urls.append(cached_path)
|
||||
media_types.append(content_type)
|
||||
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
|
||||
|
|
@ -2687,19 +2807,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
)
|
||||
else:
|
||||
try:
|
||||
import aiohttp
|
||||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
async with aiohttp.ClientSession(**_sess_kw) as session:
|
||||
async with session.get(
|
||||
att.url,
|
||||
timeout=aiohttp.ClientTimeout(total=30),
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"HTTP {resp.status}")
|
||||
raw_bytes = await resp.read()
|
||||
raw_bytes = await self._cache_discord_document(att, ext)
|
||||
cached_path = cache_document_from_bytes(
|
||||
raw_bytes, att.filename or f"document{ext}"
|
||||
)
|
||||
|
|
|
|||
360
tests/gateway/test_discord_attachment_download.py
Normal file
360
tests/gateway/test_discord_attachment_download.py
Normal file
|
|
@ -0,0 +1,360 @@
|
|||
"""Tests for Discord attachment downloads via the authenticated bot session.
|
||||
|
||||
Covers the three download paths (image / audio / document) in
|
||||
``DiscordAdapter._handle_message()`` and the shared ``_cache_discord_*``
|
||||
helpers. Verifies that:
|
||||
|
||||
- ``att.read()`` is preferred over the legacy URL-based downloaders so
|
||||
that Discord's CDN auth (and user-environment DNS quirks) can't block
|
||||
media caching. (issues #8242 image 403s, #6587 CDN SSRF false-positives)
|
||||
- Falls back cleanly to the SSRF-gated ``cache_*_from_url`` helpers
|
||||
(image/audio) or SSRF-gated aiohttp (documents) when ``att.read()``
|
||||
isn't available or fails.
|
||||
- The document fallback path now runs through the SSRF gate for
|
||||
defense-in-depth. (issue #11345)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _ensure_discord_mock():
|
||||
"""Install a mock discord module when discord.py isn't available."""
|
||||
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
|
||||
return
|
||||
|
||||
discord_mod = MagicMock()
|
||||
discord_mod.Intents.default.return_value = MagicMock()
|
||||
discord_mod.Client = MagicMock
|
||||
discord_mod.File = MagicMock
|
||||
discord_mod.DMChannel = type("DMChannel", (), {})
|
||||
discord_mod.Thread = type("Thread", (), {})
|
||||
discord_mod.ForumChannel = type("ForumChannel", (), {})
|
||||
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
|
||||
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
|
||||
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
|
||||
discord_mod.Interaction = object
|
||||
discord_mod.Embed = MagicMock
|
||||
discord_mod.app_commands = SimpleNamespace(
|
||||
describe=lambda **kwargs: (lambda fn: fn),
|
||||
choices=lambda **kwargs: (lambda fn: fn),
|
||||
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
|
||||
)
|
||||
|
||||
ext_mod = MagicMock()
|
||||
commands_mod = MagicMock()
|
||||
commands_mod.Bot = MagicMock
|
||||
ext_mod.commands = commands_mod
|
||||
|
||||
sys.modules.setdefault("discord", discord_mod)
|
||||
sys.modules.setdefault("discord.ext", ext_mod)
|
||||
sys.modules.setdefault("discord.ext.commands", commands_mod)
|
||||
|
||||
|
||||
_ensure_discord_mock()
|
||||
|
||||
from gateway.platforms.discord import DiscordAdapter # noqa: E402
|
||||
|
||||
|
||||
# Minimal valid image / audio / PDF bytes so the cache_*_from_bytes
|
||||
# validators accept them. cache_image_from_bytes runs _looks_like_image()
|
||||
# which checks for magic bytes; PNG's magic is sufficient.
|
||||
_PNG_BYTES = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64
|
||||
_OGG_BYTES = b"OggS" + b"\x00" * 60
|
||||
_PDF_BYTES = b"%PDF-1.4\n" + b"fake pdf body" + b"\n%%EOF"
|
||||
|
||||
|
||||
def _make_adapter() -> DiscordAdapter:
|
||||
return DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
|
||||
def _make_attachment_with_read(payload: bytes) -> SimpleNamespace:
|
||||
"""Attachment stub that exposes .read() — the happy-path primary."""
|
||||
return SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
size=len(payload),
|
||||
read=AsyncMock(return_value=payload),
|
||||
)
|
||||
|
||||
|
||||
def _make_attachment_without_read() -> SimpleNamespace:
|
||||
"""Attachment stub that has no .read() — exercises the URL fallback."""
|
||||
return SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
size=1024,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_attachment_bytes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadAttachmentBytes:
|
||||
"""Unit tests for the low-level att.read() wrapper."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_bytes_on_successful_read(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(b"hello world")
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result == b"hello world"
|
||||
att.read.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_read_missing(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_read_raises(self):
|
||||
"""Bot-session fetch failures are swallowed so callers fall back."""
|
||||
adapter = _make_adapter()
|
||||
att = SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
read=AsyncMock(side_effect=RuntimeError("403 Forbidden")),
|
||||
)
|
||||
|
||||
result = await adapter._read_attachment_bytes(att)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_image
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordImage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_over_url(self):
|
||||
"""Primary path: att.read() bytes → cache_image_from_bytes, no URL fetch."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_PNG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
return_value="/tmp/cached.png",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/cached.png"
|
||||
mock_bytes.assert_called_once_with(_PNG_BYTES, ext=".png")
|
||||
mock_url.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_no_read(self):
|
||||
"""No .read() → URL path is used (existing SSRF-gated behavior)."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.png",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/from_url.png"
|
||||
mock_bytes.assert_not_called()
|
||||
mock_url.assert_awaited_once_with(att.url, ext=".png")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_bytes_validator_rejects(self):
|
||||
"""If att.read() returns garbage that cache_image_from_bytes rejects
|
||||
(e.g. an HTML error page), fall back to the URL downloader instead
|
||||
of surfacing the validation error to the caller."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(b"<html>forbidden</html>")
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
side_effect=ValueError("not a valid image"),
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/fallback.png",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_image(att, ".png")
|
||||
|
||||
assert result == "/tmp/fallback.png"
|
||||
mock_url.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_audio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordAudio:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_over_url(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_OGG_BYTES)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_bytes",
|
||||
return_value="/tmp/voice.ogg",
|
||||
) as mock_bytes, patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_audio(att, ".ogg")
|
||||
|
||||
assert result == "/tmp/voice.ogg"
|
||||
mock_bytes.assert_called_once_with(_OGG_BYTES, ext=".ogg")
|
||||
mock_url.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_url_when_no_read(self):
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_audio_from_url",
|
||||
new_callable=AsyncMock,
|
||||
return_value="/tmp/from_url.ogg",
|
||||
) as mock_url:
|
||||
result = await adapter._cache_discord_audio(att, ".ogg")
|
||||
|
||||
assert result == "/tmp/from_url.ogg"
|
||||
mock_url.assert_awaited_once_with(att.url, ext=".ogg")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _cache_discord_document
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCacheDiscordDocument:
|
||||
@pytest.mark.asyncio
|
||||
async def test_prefers_att_read_returns_bytes_directly(self):
|
||||
"""Primary path: att.read() → raw bytes, no aiohttp involvement."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_with_read(_PDF_BYTES)
|
||||
|
||||
with patch("aiohttp.ClientSession") as mock_session:
|
||||
result = await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
assert result == _PDF_BYTES
|
||||
mock_session.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_blocked_by_ssrf_guard(self):
|
||||
"""Document fallback path now honors is_safe_url — was missing before.
|
||||
|
||||
Regression guard for #11345: the old aiohttp block skipped the
|
||||
SSRF check entirely; a non-CDN ``att.url`` could have reached
|
||||
internal-looking hosts. The fallback must now refuse unsafe URLs.
|
||||
"""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read() # no .read → forces fallback
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=False
|
||||
) as mock_safe, patch("aiohttp.ClientSession") as mock_session:
|
||||
with pytest.raises(ValueError, match="SSRF"):
|
||||
await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
mock_safe.assert_called_once_with(att.url)
|
||||
# aiohttp must NOT be contacted when the URL is blocked.
|
||||
mock_session.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_aiohttp_when_safe_url(self):
|
||||
"""Safe URL + no att.read() → aiohttp fallback executes."""
|
||||
adapter = _make_adapter()
|
||||
att = _make_attachment_without_read()
|
||||
|
||||
# Build an aiohttp session mock that returns 200 + payload.
|
||||
resp = AsyncMock()
|
||||
resp.status = 200
|
||||
resp.read = AsyncMock(return_value=_PDF_BYTES)
|
||||
resp.__aenter__ = AsyncMock(return_value=resp)
|
||||
resp.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = AsyncMock()
|
||||
session.get = MagicMock(return_value=resp)
|
||||
session.__aenter__ = AsyncMock(return_value=session)
|
||||
session.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.is_safe_url", return_value=True
|
||||
), patch("aiohttp.ClientSession", return_value=session):
|
||||
result = await adapter._cache_discord_document(att, ".pdf")
|
||||
|
||||
assert result == _PDF_BYTES
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: end-to-end via _handle_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHandleMessageUsesAuthenticatedRead:
|
||||
"""E2E: verify _handle_message routes image/audio downloads through
|
||||
att.read() so cdn.discordapp.com 403s (#8242) and SSRF false-positives
|
||||
on mangled DNS (#6587) no longer block media caching.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_image_downloads_via_att_read_not_url(self, monkeypatch):
|
||||
"""Image attachments with .read() never call cache_image_from_url."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
|
||||
adapter.handle_message = AsyncMock()
|
||||
|
||||
with patch(
|
||||
"gateway.platforms.discord.cache_image_from_bytes",
|
||||
return_value="/tmp/img_from_read.png",
|
||||
), patch(
|
||||
"gateway.platforms.discord.cache_image_from_url",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_url_download:
|
||||
att = SimpleNamespace(
|
||||
url="https://cdn.discordapp.com/attachments/fake/file.png",
|
||||
filename="file.png",
|
||||
content_type="image/png",
|
||||
size=len(_PNG_BYTES),
|
||||
read=AsyncMock(return_value=_PNG_BYTES),
|
||||
)
|
||||
# Minimal Discord message stub for _handle_message.
|
||||
from datetime import datetime, timezone
|
||||
|
||||
class _FakeDMChannel:
|
||||
id = 100
|
||||
name = "dm"
|
||||
|
||||
# Patch the DMChannel isinstance check so our fake counts as DM.
|
||||
monkeypatch.setattr(
|
||||
"gateway.platforms.discord.discord.DMChannel",
|
||||
_FakeDMChannel,
|
||||
)
|
||||
chan = _FakeDMChannel()
|
||||
msg = SimpleNamespace(
|
||||
id=1, content="", attachments=[att], mentions=[],
|
||||
reference=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
channel=chan,
|
||||
author=SimpleNamespace(id=42, display_name="U", name="U"),
|
||||
)
|
||||
await adapter._handle_message(msg)
|
||||
|
||||
mock_url_download.assert_not_called()
|
||||
event = adapter.handle_message.call_args[0][0]
|
||||
assert event.media_urls == ["/tmp/img_from_read.png"]
|
||||
assert event.media_types == ["image/png"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue