fix(discord): route attachment downloads through authenticated bot session (#11568)

Three open issues — #8242, #6587, #11345 — all trace to the same root
cause: the image / audio / document download paths in
`DiscordAdapter._handle_message` used plain, unauthenticated HTTP to
fetch `att.url`. That broke in three independent ways:

  #8242  cdn.discordapp.com attachment URLs increasingly require the
         bot session to download; unauthenticated httpx sees 403
         Forbidden, image/voice analysis fail silently.

  #6587  Some user environments (VPNs, corporate DNS, tunnels) resolve
         cdn.discordapp.com to private-looking IPs. Our is_safe_url()
         guard correctly blocks them as SSRF risks, but the user
         environment is legitimate — image analysis and voice STT die.

  #11345 The document download path skipped is_safe_url() entirely —
         raw aiohttp.ClientSession.get(att.url) with no SSRF check,
         inconsistent with the image/audio branches.

Unified fix: use `discord.Attachment.read()` as the primary download
path on all three branches. `att.read()` routes through discord.py's
own authenticated HTTPClient, so:

  - Discord CDN auth is handled (#8242 resolved).
  - Our is_safe_url() gate isn't consulted for the attachment path at
    all — the bot session handles networking internally (#6587 resolved).
  - All three branches now share the same code path, eliminating the
    document-path SSRF gap (#11345 resolved).

Falls back to the existing cache_*_from_url helpers (image/audio) or an
SSRF-gated aiohttp fetch (documents) when `att.read()` is unavailable
or fails — preserves defense-in-depth for any future payload-schema
drift that could slip a non-CDN URL into att.url.

New helpers on DiscordAdapter:
  - _read_attachment_bytes(att)  — safe att.read() wrapper
  - _cache_discord_image(att, ext)     — primary + URL fallback
  - _cache_discord_audio(att, ext)     — primary + URL fallback
  - _cache_discord_document(att, ext)  — primary + SSRF-gated aiohttp fallback

Tests:
  - tests/gateway/test_discord_attachment_download.py — 12 new cases
    covering all three helpers: primary path, fallback on missing
    .read(), fallback on validator rejection, SSRF guard on document
    fallback, aiohttp fallback happy-path, and an E2E case via
    _handle_message confirming cache_image_from_url is never invoked
    when att.read() succeeds.
  - All 11 existing document-handling tests continue to pass via the
    aiohttp fallback path (their SimpleNamespace attachments have no
    .read(), which triggers the fallback — now SSRF-gated).

Closes #8242, closes #6587, closes #11345.
This commit is contained in:
Teknium 2026-04-17 04:59:03 -07:00 committed by GitHub
parent 24342813fe
commit 53da34a4fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 483 additions and 15 deletions

View file

@ -51,7 +51,9 @@ from gateway.platforms.base import (
ProcessingOutcome,
SendResult,
cache_image_from_url,
cache_image_from_bytes,
cache_audio_from_url,
cache_audio_from_bytes,
cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES,
)
@ -2489,6 +2491,124 @@ class DiscordAdapter(BasePlatformAdapter):
return f"{parent_name} / {thread_name}"
return thread_name
# ------------------------------------------------------------------
# Attachment download helpers
#
# Discord attachments (images / audio / documents) are fetched via the
# authenticated bot session whenever the Attachment object exposes
# ``read()``. That sidesteps two classes of bug that hit the older
# plain-HTTP path:
#
# 1. ``cdn.discordapp.com`` URLs increasingly require bot auth on
# download — unauthenticated httpx sees 403 Forbidden.
# (issue #8242)
# 2. Some user environments (VPNs, corporate DNS, tunnels) resolve
# ``cdn.discordapp.com`` to private-looking IPs that our
# ``is_safe_url`` guard classifies as SSRF risks. Routing the
# fetch through discord.py's own HTTP client handles DNS
# internally so our guard isn't consulted for the attachment
# path. (issue #6587)
#
# If ``att.read()`` is unavailable (unexpected object shape / test
# stub) or the bot session fetch fails, we fall back to the existing
# SSRF-gated URL downloaders. The fallback keeps defense-in-depth
# against any future Discord payload-schema drift that could slip a
# non-CDN URL into the ``att.url`` field. (issue #11345)
# ------------------------------------------------------------------
async def _read_attachment_bytes(self, att) -> Optional[bytes]:
"""Read an attachment via discord.py's authenticated bot session.
Returns the raw bytes on success, or ``None`` if ``att`` doesn't
expose a callable ``read()`` or the read itself fails. Callers
should treat ``None`` as a signal to fall back to the URL-based
downloaders.
"""
reader = getattr(att, "read", None)
if reader is None or not callable(reader):
return None
try:
return await reader()
except Exception as e:
logger.warning(
"[Discord] Authenticated attachment read failed for %s: %s",
getattr(att, "filename", None) or getattr(att, "url", "<unknown>"),
e,
)
return None
async def _cache_discord_image(self, att, ext: str) -> str:
"""Cache a Discord image attachment to local disk.
Primary path: ``att.read()`` + ``cache_image_from_bytes``
(authenticated, no SSRF gate).
Fallback: ``cache_image_from_url`` (plain httpx, SSRF-gated).
"""
raw_bytes = await self._read_attachment_bytes(att)
if raw_bytes is not None:
try:
return cache_image_from_bytes(raw_bytes, ext=ext)
except Exception as e:
logger.debug(
"[Discord] cache_image_from_bytes rejected att.read() data; falling back to URL: %s",
e,
)
return await cache_image_from_url(att.url, ext=ext)
async def _cache_discord_audio(self, att, ext: str) -> str:
"""Cache a Discord audio attachment to local disk.
Primary path: ``att.read()`` + ``cache_audio_from_bytes``
(authenticated, no SSRF gate).
Fallback: ``cache_audio_from_url`` (plain httpx, SSRF-gated).
"""
raw_bytes = await self._read_attachment_bytes(att)
if raw_bytes is not None:
try:
return cache_audio_from_bytes(raw_bytes, ext=ext)
except Exception as e:
logger.debug(
"[Discord] cache_audio_from_bytes failed; falling back to URL: %s",
e,
)
return await cache_audio_from_url(att.url, ext=ext)
async def _cache_discord_document(self, att, ext: str) -> bytes:
"""Download a Discord document attachment and return the raw bytes.
Primary path: ``att.read()`` (authenticated, no SSRF gate).
Fallback: SSRF-gated ``aiohttp`` download. This closes the gap
where the old document path made raw ``aiohttp.ClientSession``
requests with no safety check (#11345). The caller is responsible
for passing the returned bytes to ``cache_document_from_bytes``
(and, where applicable, for injecting text content).
"""
raw_bytes = await self._read_attachment_bytes(att)
if raw_bytes is not None:
return raw_bytes
# Fallback: SSRF-gated URL download.
if not is_safe_url(att.url):
raise ValueError(
f"Blocked unsafe attachment URL (SSRF protection): {att.url}"
)
import aiohttp
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
async with aiohttp.ClientSession(**_sess_kw) as session:
async with session.get(
att.url,
timeout=aiohttp.ClientTimeout(total=30),
**_req_kw,
) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
return await resp.read()
async def _handle_message(self, message: DiscordMessage) -> None:
"""Handle incoming Discord messages."""
# In server channels (not DMs), require the bot to be @mentioned
@ -2642,7 +2762,7 @@ class DiscordAdapter(BasePlatformAdapter):
ext = "." + content_type.split("/")[-1].split(";")[0]
if ext not in (".jpg", ".jpeg", ".png", ".gif", ".webp"):
ext = ".jpg"
cached_path = await cache_image_from_url(att.url, ext=ext)
cached_path = await self._cache_discord_image(att, ext)
media_urls.append(cached_path)
media_types.append(content_type)
print(f"[Discord] Cached user image: {cached_path}", flush=True)
@ -2656,7 +2776,7 @@ class DiscordAdapter(BasePlatformAdapter):
ext = "." + content_type.split("/")[-1].split(";")[0]
if ext not in (".ogg", ".mp3", ".wav", ".webm", ".m4a"):
ext = ".ogg"
cached_path = await cache_audio_from_url(att.url, ext=ext)
cached_path = await self._cache_discord_audio(att, ext)
media_urls.append(cached_path)
media_types.append(content_type)
print(f"[Discord] Cached user audio: {cached_path}", flush=True)
@ -2687,19 +2807,7 @@ class DiscordAdapter(BasePlatformAdapter):
)
else:
try:
import aiohttp
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
async with aiohttp.ClientSession(**_sess_kw) as session:
async with session.get(
att.url,
timeout=aiohttp.ClientTimeout(total=30),
**_req_kw,
) as resp:
if resp.status != 200:
raise Exception(f"HTTP {resp.status}")
raw_bytes = await resp.read()
raw_bytes = await self._cache_discord_document(att, ext)
cached_path = cache_document_from_bytes(
raw_bytes, att.filename or f"document{ext}"
)

View file

@ -0,0 +1,360 @@
"""Tests for Discord attachment downloads via the authenticated bot session.
Covers the three download paths (image / audio / document) in
``DiscordAdapter._handle_message()`` and the shared ``_cache_discord_*``
helpers. Verifies that:
- ``att.read()`` is preferred over the legacy URL-based downloaders so
that Discord's CDN auth (and user-environment DNS quirks) can't block
media caching. (issues #8242 image 403s, #6587 CDN SSRF false-positives)
- Falls back cleanly to the SSRF-gated ``cache_*_from_url`` helpers
(image/audio) or SSRF-gated aiohttp (documents) when ``att.read()``
isn't available or fails.
- The document fallback path now runs through the SSRF gate for
defense-in-depth. (issue #11345)
"""
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
def _ensure_discord_mock():
"""Install a mock discord module when discord.py isn't available."""
if "discord" in sys.modules and hasattr(sys.modules["discord"], "__file__"):
return
discord_mod = MagicMock()
discord_mod.Intents.default.return_value = MagicMock()
discord_mod.Client = MagicMock
discord_mod.File = MagicMock
discord_mod.DMChannel = type("DMChannel", (), {})
discord_mod.Thread = type("Thread", (), {})
discord_mod.ForumChannel = type("ForumChannel", (), {})
discord_mod.ui = SimpleNamespace(View=object, button=lambda *a, **k: (lambda fn: fn), Button=object)
discord_mod.ButtonStyle = SimpleNamespace(success=1, primary=2, secondary=2, danger=3, green=1, grey=2, blurple=2, red=3)
discord_mod.Color = SimpleNamespace(orange=lambda: 1, green=lambda: 2, blue=lambda: 3, red=lambda: 4, purple=lambda: 5)
discord_mod.Interaction = object
discord_mod.Embed = MagicMock
discord_mod.app_commands = SimpleNamespace(
describe=lambda **kwargs: (lambda fn: fn),
choices=lambda **kwargs: (lambda fn: fn),
Choice=lambda **kwargs: SimpleNamespace(**kwargs),
)
ext_mod = MagicMock()
commands_mod = MagicMock()
commands_mod.Bot = MagicMock
ext_mod.commands = commands_mod
sys.modules.setdefault("discord", discord_mod)
sys.modules.setdefault("discord.ext", ext_mod)
sys.modules.setdefault("discord.ext.commands", commands_mod)
_ensure_discord_mock()
from gateway.platforms.discord import DiscordAdapter # noqa: E402
# Minimal valid image / audio / PDF bytes so the cache_*_from_bytes
# validators accept them. cache_image_from_bytes runs _looks_like_image()
# which checks for magic bytes; PNG's magic is sufficient.
_PNG_BYTES = b"\x89PNG\r\n\x1a\n" + b"\x00" * 64
_OGG_BYTES = b"OggS" + b"\x00" * 60
_PDF_BYTES = b"%PDF-1.4\n" + b"fake pdf body" + b"\n%%EOF"
def _make_adapter() -> DiscordAdapter:
return DiscordAdapter(PlatformConfig(enabled=True, token="***"))
def _make_attachment_with_read(payload: bytes) -> SimpleNamespace:
"""Attachment stub that exposes .read() — the happy-path primary."""
return SimpleNamespace(
url="https://cdn.discordapp.com/attachments/fake/file.png",
filename="file.png",
size=len(payload),
read=AsyncMock(return_value=payload),
)
def _make_attachment_without_read() -> SimpleNamespace:
"""Attachment stub that has no .read() — exercises the URL fallback."""
return SimpleNamespace(
url="https://cdn.discordapp.com/attachments/fake/file.png",
filename="file.png",
size=1024,
)
# ---------------------------------------------------------------------------
# _read_attachment_bytes
# ---------------------------------------------------------------------------
class TestReadAttachmentBytes:
"""Unit tests for the low-level att.read() wrapper."""
@pytest.mark.asyncio
async def test_returns_bytes_on_successful_read(self):
adapter = _make_adapter()
att = _make_attachment_with_read(b"hello world")
result = await adapter._read_attachment_bytes(att)
assert result == b"hello world"
att.read.assert_awaited_once()
@pytest.mark.asyncio
async def test_returns_none_when_read_missing(self):
adapter = _make_adapter()
att = _make_attachment_without_read()
result = await adapter._read_attachment_bytes(att)
assert result is None
@pytest.mark.asyncio
async def test_returns_none_when_read_raises(self):
"""Bot-session fetch failures are swallowed so callers fall back."""
adapter = _make_adapter()
att = SimpleNamespace(
url="https://cdn.discordapp.com/attachments/fake/file.png",
filename="file.png",
read=AsyncMock(side_effect=RuntimeError("403 Forbidden")),
)
result = await adapter._read_attachment_bytes(att)
assert result is None
# ---------------------------------------------------------------------------
# _cache_discord_image
# ---------------------------------------------------------------------------
class TestCacheDiscordImage:
@pytest.mark.asyncio
async def test_prefers_att_read_over_url(self):
"""Primary path: att.read() bytes → cache_image_from_bytes, no URL fetch."""
adapter = _make_adapter()
att = _make_attachment_with_read(_PNG_BYTES)
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
return_value="/tmp/cached.png",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_image_from_url",
new_callable=AsyncMock,
) as mock_url:
result = await adapter._cache_discord_image(att, ".png")
assert result == "/tmp/cached.png"
mock_bytes.assert_called_once_with(_PNG_BYTES, ext=".png")
mock_url.assert_not_called()
@pytest.mark.asyncio
async def test_falls_back_to_url_when_no_read(self):
"""No .read() → URL path is used (existing SSRF-gated behavior)."""
adapter = _make_adapter()
att = _make_attachment_without_read()
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/from_url.png",
) as mock_url:
result = await adapter._cache_discord_image(att, ".png")
assert result == "/tmp/from_url.png"
mock_bytes.assert_not_called()
mock_url.assert_awaited_once_with(att.url, ext=".png")
@pytest.mark.asyncio
async def test_falls_back_to_url_when_bytes_validator_rejects(self):
"""If att.read() returns garbage that cache_image_from_bytes rejects
(e.g. an HTML error page), fall back to the URL downloader instead
of surfacing the validation error to the caller."""
adapter = _make_adapter()
att = _make_attachment_with_read(b"<html>forbidden</html>")
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
side_effect=ValueError("not a valid image"),
), patch(
"gateway.platforms.discord.cache_image_from_url",
new_callable=AsyncMock,
return_value="/tmp/fallback.png",
) as mock_url:
result = await adapter._cache_discord_image(att, ".png")
assert result == "/tmp/fallback.png"
mock_url.assert_awaited_once()
# ---------------------------------------------------------------------------
# _cache_discord_audio
# ---------------------------------------------------------------------------
class TestCacheDiscordAudio:
@pytest.mark.asyncio
async def test_prefers_att_read_over_url(self):
adapter = _make_adapter()
att = _make_attachment_with_read(_OGG_BYTES)
with patch(
"gateway.platforms.discord.cache_audio_from_bytes",
return_value="/tmp/voice.ogg",
) as mock_bytes, patch(
"gateway.platforms.discord.cache_audio_from_url",
new_callable=AsyncMock,
) as mock_url:
result = await adapter._cache_discord_audio(att, ".ogg")
assert result == "/tmp/voice.ogg"
mock_bytes.assert_called_once_with(_OGG_BYTES, ext=".ogg")
mock_url.assert_not_called()
@pytest.mark.asyncio
async def test_falls_back_to_url_when_no_read(self):
adapter = _make_adapter()
att = _make_attachment_without_read()
with patch(
"gateway.platforms.discord.cache_audio_from_url",
new_callable=AsyncMock,
return_value="/tmp/from_url.ogg",
) as mock_url:
result = await adapter._cache_discord_audio(att, ".ogg")
assert result == "/tmp/from_url.ogg"
mock_url.assert_awaited_once_with(att.url, ext=".ogg")
# ---------------------------------------------------------------------------
# _cache_discord_document
# ---------------------------------------------------------------------------
class TestCacheDiscordDocument:
@pytest.mark.asyncio
async def test_prefers_att_read_returns_bytes_directly(self):
"""Primary path: att.read() → raw bytes, no aiohttp involvement."""
adapter = _make_adapter()
att = _make_attachment_with_read(_PDF_BYTES)
with patch("aiohttp.ClientSession") as mock_session:
result = await adapter._cache_discord_document(att, ".pdf")
assert result == _PDF_BYTES
mock_session.assert_not_called()
@pytest.mark.asyncio
async def test_fallback_blocked_by_ssrf_guard(self):
"""Document fallback path now honors is_safe_url — was missing before.
Regression guard for #11345: the old aiohttp block skipped the
SSRF check entirely; a non-CDN ``att.url`` could have reached
internal-looking hosts. The fallback must now refuse unsafe URLs.
"""
adapter = _make_adapter()
att = _make_attachment_without_read() # no .read → forces fallback
with patch(
"gateway.platforms.discord.is_safe_url", return_value=False
) as mock_safe, patch("aiohttp.ClientSession") as mock_session:
with pytest.raises(ValueError, match="SSRF"):
await adapter._cache_discord_document(att, ".pdf")
mock_safe.assert_called_once_with(att.url)
# aiohttp must NOT be contacted when the URL is blocked.
mock_session.assert_not_called()
@pytest.mark.asyncio
async def test_fallback_aiohttp_when_safe_url(self):
"""Safe URL + no att.read() → aiohttp fallback executes."""
adapter = _make_adapter()
att = _make_attachment_without_read()
# Build an aiohttp session mock that returns 200 + payload.
resp = AsyncMock()
resp.status = 200
resp.read = AsyncMock(return_value=_PDF_BYTES)
resp.__aenter__ = AsyncMock(return_value=resp)
resp.__aexit__ = AsyncMock(return_value=False)
session = AsyncMock()
session.get = MagicMock(return_value=resp)
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
with patch(
"gateway.platforms.discord.is_safe_url", return_value=True
), patch("aiohttp.ClientSession", return_value=session):
result = await adapter._cache_discord_document(att, ".pdf")
assert result == _PDF_BYTES
# ---------------------------------------------------------------------------
# Integration: end-to-end via _handle_message
# ---------------------------------------------------------------------------
class TestHandleMessageUsesAuthenticatedRead:
"""E2E: verify _handle_message routes image/audio downloads through
att.read() so cdn.discordapp.com 403s (#8242) and SSRF false-positives
on mangled DNS (#6587) no longer block media caching.
"""
@pytest.mark.asyncio
async def test_image_downloads_via_att_read_not_url(self, monkeypatch):
"""Image attachments with .read() never call cache_image_from_url."""
adapter = _make_adapter()
adapter._client = SimpleNamespace(user=SimpleNamespace(id=999))
adapter.handle_message = AsyncMock()
with patch(
"gateway.platforms.discord.cache_image_from_bytes",
return_value="/tmp/img_from_read.png",
), patch(
"gateway.platforms.discord.cache_image_from_url",
new_callable=AsyncMock,
) as mock_url_download:
att = SimpleNamespace(
url="https://cdn.discordapp.com/attachments/fake/file.png",
filename="file.png",
content_type="image/png",
size=len(_PNG_BYTES),
read=AsyncMock(return_value=_PNG_BYTES),
)
# Minimal Discord message stub for _handle_message.
from datetime import datetime, timezone
class _FakeDMChannel:
id = 100
name = "dm"
# Patch the DMChannel isinstance check so our fake counts as DM.
monkeypatch.setattr(
"gateway.platforms.discord.discord.DMChannel",
_FakeDMChannel,
)
chan = _FakeDMChannel()
msg = SimpleNamespace(
id=1, content="", attachments=[att], mentions=[],
reference=None,
created_at=datetime.now(timezone.utc),
channel=chan,
author=SimpleNamespace(id=42, display_name="U", name="U"),
)
await adapter._handle_message(msg)
mock_url_download.assert_not_called()
event = adapter.handle_message.call_args[0][0]
assert event.media_urls == ["/tmp/img_from_read.png"]
assert event.media_types == ["image/png"]