diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 0decffa68..ebe15b880 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -160,7 +160,7 @@ GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE = ( ) -def _safe_url_for_log(url: str, max_len: int = 80) -> str: +def safe_url_for_log(url: str, max_len: int = 80) -> str: """Return a URL string safe for logs (no query/fragment/userinfo).""" if max_len <= 0: return "" @@ -197,6 +197,23 @@ def _safe_url_for_log(url: str, max_len: int = 80) -> str: return f"{safe[:max_len - 3]}..." +async def _ssrf_redirect_guard(response): + """Re-validate each redirect target to prevent redirect-based SSRF. + + Without this, an attacker can host a public URL that 302-redirects to + http://169.254.169.254/ and bypass the pre-flight is_safe_url() check. + + Must be async because httpx.AsyncClient awaits response event hooks. + """ + if response.is_redirect and response.next_request: + redirect_url = str(response.next_request.url) + from tools.url_safety import is_safe_url + if not is_safe_url(redirect_url): + raise ValueError( + f"Blocked redirect to private/internal address: {safe_url_for_log(redirect_url)}" + ) + + # --------------------------------------------------------------------------- # Image cache utilities # @@ -281,7 +298,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> """ from tools.url_safety import is_safe_url if not is_safe_url(url): - raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}") + raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}") import asyncio import httpx @@ -289,7 +306,11 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> _log = _logging.getLogger(__name__) last_exc = None - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_ssrf_redirect_guard]}, + ) as client: for attempt in range(retries + 1): try: response = await client.get( @@ -311,7 +332,7 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) -> "Media cache retry %d/%d for %s (%.1fs): %s", attempt + 1, retries, - _safe_url_for_log(url), + safe_url_for_log(url), wait, exc, ) @@ -396,7 +417,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> """ from tools.url_safety import is_safe_url if not is_safe_url(url): - raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}") + raise ValueError(f"Blocked unsafe URL (SSRF protection): {safe_url_for_log(url)}") import asyncio import httpx @@ -404,7 +425,11 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> _log = _logging.getLogger(__name__) last_exc = None - async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: + async with httpx.AsyncClient( + timeout=30.0, + follow_redirects=True, + event_hooks={"response": [_ssrf_redirect_guard]}, + ) as client: for attempt in range(retries + 1): try: response = await client.get( @@ -426,7 +451,7 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) -> "Audio cache retry %d/%d for %s (%.1fs): %s", attempt + 1, retries, - _safe_url_for_log(url), + safe_url_for_log(url), wait, exc, ) @@ -1525,7 +1550,7 @@ class BasePlatformAdapter(ABC): logger.info( "[%s] Sending image: %s (alt=%s)", self.name, - _safe_url_for_log(image_url), + safe_url_for_log(image_url), alt_text[:30] if alt_text else "", ) # Route animated GIFs through send_animation for proper playback diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index f45d87050..361f74882 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -39,7 +39,7 @@ from gateway.platforms.base import ( MessageType, SendResult, SUPPORTED_DOCUMENT_TYPES, - _safe_url_for_log, + safe_url_for_log, cache_document_from_bytes, ) @@ -686,7 +686,7 @@ class SlackAdapter(BasePlatformAdapter): except Exception as e: # pragma: no cover - defensive logging logger.warning( "[Slack] Failed to upload image from URL %s, falling back to text: %s", - _safe_url_for_log(image_url), + safe_url_for_log(image_url), e, exc_info=True, ) diff --git a/tests/gateway/test_media_download_retry.py b/tests/gateway/test_media_download_retry.py index 8a5e16953..5b5add26c 100644 --- a/tests/gateway/test_media_download_retry.py +++ b/tests/gateway/test_media_download_retry.py @@ -376,6 +376,134 @@ class TestCacheAudioFromUrl: mock_sleep.assert_not_called() +# --------------------------------------------------------------------------- +# SSRF redirect guard tests (base.py) +# --------------------------------------------------------------------------- + + +class TestSSRFRedirectGuard: + """cache_image_from_url / cache_audio_from_url must reject redirects + that land on private/internal hosts (e.g. cloud metadata endpoint).""" + + def _make_redirect_response(self, target_url: str): + """Build a mock httpx response that looks like a redirect.""" + resp = MagicMock() + resp.is_redirect = True + resp.next_request = MagicMock(url=target_url) + return resp + + def _make_client_capturing_hooks(self): + """Return (mock_client, captured_kwargs dict) where captured_kwargs + will contain the kwargs passed to httpx.AsyncClient().""" + captured = {} + mock_client = AsyncMock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=False) + + def factory(*args, **kwargs): + captured.update(kwargs) + return mock_client + + return mock_client, captured, factory + + def test_image_blocks_private_redirect(self, tmp_path, monkeypatch): + """cache_image_from_url rejects a redirect to a private IP.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + redirect_resp = self._make_redirect_response( + "http://169.254.169.254/latest/meta-data" + ) + mock_client, captured, factory = self._make_client_capturing_hooks() + + async def fake_get(_url, **kwargs): + # Simulate httpx calling the response event hooks + for hook in captured["event_hooks"]["response"]: + await hook(redirect_resp) + + mock_client.get = AsyncMock(side_effect=fake_get) + + def fake_safe(url): + return url == "https://public.example.com/image.png" + + async def run(): + with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \ + patch("httpx.AsyncClient", side_effect=factory): + from gateway.platforms.base import cache_image_from_url + await cache_image_from_url( + "https://public.example.com/image.png", ext=".png" + ) + + with pytest.raises(ValueError, match="Blocked redirect"): + asyncio.run(run()) + + def test_audio_blocks_private_redirect(self, tmp_path, monkeypatch): + """cache_audio_from_url rejects a redirect to a private IP.""" + monkeypatch.setattr("gateway.platforms.base.AUDIO_CACHE_DIR", tmp_path / "audio") + + redirect_resp = self._make_redirect_response( + "http://10.0.0.1/internal/secrets" + ) + mock_client, captured, factory = self._make_client_capturing_hooks() + + async def fake_get(_url, **kwargs): + for hook in captured["event_hooks"]["response"]: + await hook(redirect_resp) + + mock_client.get = AsyncMock(side_effect=fake_get) + + def fake_safe(url): + return url == "https://public.example.com/voice.ogg" + + async def run(): + with patch("tools.url_safety.is_safe_url", side_effect=fake_safe), \ + patch("httpx.AsyncClient", side_effect=factory): + from gateway.platforms.base import cache_audio_from_url + await cache_audio_from_url( + "https://public.example.com/voice.ogg", ext=".ogg" + ) + + with pytest.raises(ValueError, match="Blocked redirect"): + asyncio.run(run()) + + def test_safe_redirect_allowed(self, tmp_path, monkeypatch): + """A redirect to a public IP is allowed through.""" + monkeypatch.setattr("gateway.platforms.base.IMAGE_CACHE_DIR", tmp_path / "img") + + redirect_resp = self._make_redirect_response( + "https://cdn.example.com/real-image.png" + ) + + ok_response = MagicMock() + ok_response.content = b"\xff\xd8\xff fake jpeg" + ok_response.raise_for_status = MagicMock() + ok_response.is_redirect = False + + mock_client, captured, factory = self._make_client_capturing_hooks() + + call_count = 0 + + async def fake_get(_url, **kwargs): + nonlocal call_count + call_count += 1 + # First call triggers redirect hook, second returns data + for hook in captured["event_hooks"]["response"]: + await hook(redirect_resp if call_count == 1 else ok_response) + return ok_response + + mock_client.get = AsyncMock(side_effect=fake_get) + + async def run(): + with patch("tools.url_safety.is_safe_url", return_value=True), \ + patch("httpx.AsyncClient", side_effect=factory): + from gateway.platforms.base import cache_image_from_url + return await cache_image_from_url( + "https://public.example.com/image.png", ext=".jpg" + ) + + path = asyncio.run(run()) + assert path.endswith(".jpg") + + # --------------------------------------------------------------------------- # Slack mock setup (mirrors existing test_slack.py approach) # --------------------------------------------------------------------------- diff --git a/tests/gateway/test_platform_base.py b/tests/gateway/test_platform_base.py index 43dd17bd8..f2d133ea2 100644 --- a/tests/gateway/test_platform_base.py +++ b/tests/gateway/test_platform_base.py @@ -8,7 +8,7 @@ from gateway.platforms.base import ( GATEWAY_SECRET_CAPTURE_UNSUPPORTED_MESSAGE, MessageEvent, MessageType, - _safe_url_for_log, + safe_url_for_log, ) @@ -25,7 +25,7 @@ class TestSafeUrlForLog: "https://user:pass@example.com/private/path/image.png" "?X-Amz-Signature=supersecret&token=abc#frag" ) - result = _safe_url_for_log(url) + result = safe_url_for_log(url) assert result == "https://example.com/.../image.png" assert "supersecret" not in result assert "token=abc" not in result @@ -33,15 +33,15 @@ class TestSafeUrlForLog: def test_truncates_long_values(self): long_url = "https://example.com/" + ("a" * 300) - result = _safe_url_for_log(long_url, max_len=40) + result = safe_url_for_log(long_url, max_len=40) assert len(result) == 40 assert result.endswith("...") def test_handles_small_and_non_positive_max_len(self): url = "https://example.com/very/long/path/file.png?token=secret" - assert _safe_url_for_log(url, max_len=3) == "..." - assert _safe_url_for_log(url, max_len=2) == ".." - assert _safe_url_for_log(url, max_len=0) == "" + assert safe_url_for_log(url, max_len=3) == "..." + assert safe_url_for_log(url, max_len=2) == ".." + assert safe_url_for_log(url, max_len=0) == "" # --------------------------------------------------------------------------- diff --git a/tools/url_safety.py b/tools/url_safety.py index ae610d0f7..3dc57ca45 100644 --- a/tools/url_safety.py +++ b/tools/url_safety.py @@ -10,9 +10,10 @@ Limitations (documented, not fixable at pre-flight level): can return a public IP for the check, then a private IP for the actual connection. Fixing this requires connection-level validation (e.g. Python's Champion library or an egress proxy like Stripe's Smokescreen). - - Redirect-based bypass in vision_tools is mitigated by an httpx event - hook that re-validates each redirect target. Web tools use third-party - SDKs (Firecrawl/Tavily) where redirect handling is on their servers. + - Redirect-based bypass is mitigated by httpx event hooks that re-validate + each redirect target in vision_tools, gateway platform adapters, and + media cache helpers. Web tools use third-party SDKs (Firecrawl/Tavily) + where redirect handling is on their servers. """ import ipaddress