fix: make safe_url_for_log public, add SSRF redirect guards to base.py cache helpers

Follow-up to Dusk1e's PR #7120 (Slack send_image redirect guard):
- Rename _safe_url_for_log -> safe_url_for_log (drop underscore) since
  it is now imported cross-module by the Slack adapter
- Add _ssrf_redirect_guard httpx event hook to cache_image_from_url()
  and cache_audio_from_url() in base.py — same pattern as vision_tools
  and the Slack adapter fix
- Update url_safety.py docstring to reflect broader coverage
- Add regression tests for image/audio redirect blocking + safe passthrough
This commit is contained in:
Teknium 2026-04-10 05:02:17 -07:00 committed by Teknium
parent 714809634f
commit 7663c98c1e
5 changed files with 173 additions and 19 deletions

View file

@ -160,7 +160,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = (
)
def _safe_url_for_log(url: str, max_len: int = 80) -> str:
def safe_url_for_log(url: str, max_len: int = 80) -> str:
"""Return a URL string safe for logs (no query/fragment/userinfo)."""
if max_len <= 0:
return ""
@ -197,6 +197,23 @@ def _safe_url_for_log(url: str, max_len: int = 80) -> str:
return f"{safe[:max_len - 3]}..."
async def _ssrf_redirect_guard(response):
"""Re-validate each redirect target to prevent redirect-based SSRF.
Without this, an attacker can host a public URL that 302-redirects to
http://169.254.169.254/ and bypass the pre-flight is_safe_url() check.
Must be async because httpx.AsyncClient awaits response event hooks.
"""
if response.is_redirect and response.next_request:
redirect_url = str(response.next_request.url)
from tools.url_safety import is_safe_url
if not is_safe_url(redirect_url):
raise ValueError(
f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}"
)
# ---------------------------------------------------------------------------
# Image cache utilities
#
@ -281,7 +298,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
import asyncio
import httpx
@ -289,7 +306,11 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
_log = _logging.getLogger(__name__)
last_exc = None
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
@ -311,7 +332,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
"Media cache retry %d/%d for %s (%.1fs): %s",
attempt + 1,
retries,
_safe_url_for_log(url),
safe_url_for_log(url),
wait,
exc,
)
@ -396,7 +417,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}")
import asyncio
import httpx
@ -404,7 +425,11 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
_log = _logging.getLogger(__name__)
last_exc = None
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
async with httpx.AsyncClient(
timeout=30.0,
follow_redirects=True,
event_hooks={"response": [_ssrf_redirect_guard]},
) as client:
for attempt in range(retries + 1):
try:
response = await client.get(
@ -426,7 +451,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
"Audio cache retry %d/%d for %s (%.1fs): %s",
attempt + 1,
retries,
_safe_url_for_log(url),
safe_url_for_log(url),
wait,
exc,
)
@ -1525,7 +1550,7 @@ class BasePlatformAdapter(ABC):
logger.info(
"[%s] Sending image: %s (alt=%s)",
self.name,
_safe_url_for_log(image_url),
safe_url_for_log(image_url),
alt_text[:30] if alt_text else "",
)
# Route animated GIFs through send_animation for proper playback

View file

@ -39,7 +39,7 @@ from gateway.platforms.base import (
MessageType,
SendResult,
SUPPORTED_DOCUMENT_TYPES,
_safe_url_for_log,
safe_url_for_log,
cache_document_from_bytes,
)
@ -686,7 +686,7 @@ class SlackAdapter(BasePlatformAdapter):
except Exception as e: # pragma: no cover - defensive logging
logger.warning(
"[Slack] Failed to upload image from URL %s, falling back to text: %s",
_safe_url_for_log(image_url),
safe_url_for_log(image_url),
e,
exc_info=True,
)

View file

@ -376,6 +376,134 @@ class TestCacheAudioFromUrl:
mock_sleep.assert_not_called()
# ---------------------------------------------------------------------------
# SSRF redirect guard tests (base.py)
# ---------------------------------------------------------------------------
class TestSSRFRedirectGuard:
"""cache_image_from_url / cache_audio_from_url must reject redirects
that land on private/internal hosts (e.g. cloud metadata endpoint)."""
def _make_redirect_response(self, target_url: str):
"""Build a mock httpx response that looks like a redirect."""
resp = MagicMock()
resp.is_redirect = True
resp.next_request = MagicMock(url=target_url)
return resp
def _make_client_capturing_hooks(self):
"""Return (mock_client, captured_kwargs dict) where captured_kwargs
will contain the kwargs passed to httpx.AsyncClient()."""
captured = {}
mock_client = AsyncMock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
def factory(*args, **kwargs):
captured.update(kwargs)
return mock_client
return mock_client, captured, factory
def test_image_blocks_private_redirect(self, tmp_path, monkeypatch):
"""cache_image_from_url rejects a redirect to a private IP."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
redirect_resp = self._make_redirect_response(
"http://169.254.169.254/latest/meta-data"
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
# Simulate httpx calling the response event hooks
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
mock_client.get = AsyncMock(side_effect=fake_get)
def fake_safe(url):
return url == "https://public.example.com/image.png"
async def run():
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_image_from_url
await cache_image_from_url(
"https://public.example.com/image.png", ext=".png"
)
with pytest.raises(ValueError, match="Blocked redirect"):
asyncio.run(run())
def test_audio_blocks_private_redirect(self, tmp_path, monkeypatch):
"""cache_audio_from_url rejects a redirect to a private IP."""
monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio")
redirect_resp = self._make_redirect_response(
"http://10.0.0.1/internal/secrets"
)
mock_client, captured, factory = self._make_client_capturing_hooks()
async def fake_get(_url, **kwargs):
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp)
mock_client.get = AsyncMock(side_effect=fake_get)
def fake_safe(url):
return url == "https://public.example.com/voice.ogg"
async def run():
with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_audio_from_url
await cache_audio_from_url(
"https://public.example.com/voice.ogg", ext=".ogg"
)
with pytest.raises(ValueError, match="Blocked redirect"):
asyncio.run(run())
def test_safe_redirect_allowed(self, tmp_path, monkeypatch):
"""A redirect to a public IP is allowed through."""
monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img")
redirect_resp = self._make_redirect_response(
"https://cdn.example.com/real-image.png"
)
ok_response = MagicMock()
ok_response.content = b"\xff\xd8\xff fake jpeg"
ok_response.raise_for_status = MagicMock()
ok_response.is_redirect = False
mock_client, captured, factory = self._make_client_capturing_hooks()
call_count = 0
async def fake_get(_url, **kwargs):
nonlocal call_count
call_count += 1
# First call triggers redirect hook, second returns data
for hook in captured["event_hooks"]["response"]:
await hook(redirect_resp if call_count == 1 else ok_response)
return ok_response
mock_client.get = AsyncMock(side_effect=fake_get)
async def run():
with patch("tools.url_safety.is_safe_url", return_value=True), \
patch("httpx.AsyncClient", side_effect=factory):
from gateway.platforms.base import cache_image_from_url
return await cache_image_from_url(
"https://public.example.com/image.png", ext=".jpg"
)
path = asyncio.run(run())
assert path.endswith(".jpg")
# ---------------------------------------------------------------------------
# Slack mock setup (mirrors existing test_slack.py approach)
# ---------------------------------------------------------------------------

View file

@ -8,7 +8,7 @@ from gateway.platforms.base import (
GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE,
MessageEvent,
MessageType,
_safe_url_for_log,
safe_url_for_log,
)
@ -25,7 +25,7 @@ class TestSafeUrlForLog:
"https://user:pass@example.com/private/path/image.png"
"?X-Amz-Signature=supersecret&token=abc#frag"
)
result = _safe_url_for_log(url)
result = safe_url_for_log(url)
assert result == "https://example.com/.../image.png"
assert "supersecret" not in result
assert "token=abc" not in result
@ -33,15 +33,15 @@ class TestSafeUrlForLog:
def test_truncates_long_values(self):
long_url = "https://example.com/" + ("a" * 300)
result = _safe_url_for_log(long_url, max_len=40)
result = safe_url_for_log(long_url, max_len=40)
assert len(result) == 40
assert result.endswith("...")
def test_handles_small_and_non_positive_max_len(self):
url = "https://example.com/very/long/path/file.png?token=secret"
assert _safe_url_for_log(url, max_len=3) == "..."
assert _safe_url_for_log(url, max_len=2) == ".."
assert _safe_url_for_log(url, max_len=0) == ""
assert safe_url_for_log(url, max_len=3) == "..."
assert safe_url_for_log(url, max_len=2) == ".."
assert safe_url_for_log(url, max_len=0) == ""
# ---------------------------------------------------------------------------

View file

@ -10,9 +10,10 @@ Limitations (documented, not fixable at pre-flight level):
can return a public IP for the check, then a private IP for the actual
connection. Fixing this requires connection-level validation (e.g.
Python's Champion library or an egress proxy like Stripe's Smokescreen).
- Redirect-based bypass in vision_tools is mitigated by an httpx event
hook that re-validates each redirect target. Web tools use third-party
SDKs (Firecrawl/Tavily) where redirect handling is on their servers.
- Redirect-based bypass is mitigated by httpx event hooks that re-validate
each redirect target in vision_tools, gateway platform adapters, and
media cache helpers. Web tools use third-party SDKs (Firecrawl/Tavily)
where redirect handling is on their servers.
"""
import ipaddress