fix(security): consolidated security hardening — SSRF, timing attack, tar traversal, credential leakage (#5944)

Salvaged from PRs #5800 (memosr), #5806 (memosr), #5915 (Ruzzgar), #5928 (Awsh1).

Changes:
- Use hmac.compare_digest for API key comparison (timing attack prevention)
- Apply provider env var blocklist to Docker containers (credential leakage)
- Replace tar.extractall() with safe extraction in TerminalBench2 (CVE-2007-4559)
- Add SSRF protection via is_safe_url to ALL platform adapters:
  base.py (cache_image_from_url, cache_audio_from_url),
  discord, slack, telegram, matrix, mattermost, feishu, wecom
  (Signal and WhatsApp protected via base.py helpers)
- Update tests: mock is_safe_url in Mattermost download tests
- Add security tests for tar extraction (traversal, symlinks, safe files)
This commit is contained in:
Teknium 2026-04-07 17:28:37 -07:00 committed by GitHub
parent b1a66d55b4
commit 469cd16fe0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 284 additions and 11 deletions

View file

@ -44,7 +44,7 @@ import tempfile
import time import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
# Ensure repo root is on sys.path for imports # Ensure repo root is on sys.path for imports
@ -148,6 +148,62 @@ MODAL_INCOMPATIBLE_TASKS = {
# Tar extraction helper # Tar extraction helper
# ============================================================================= # =============================================================================
def _normalize_tar_member_parts(member_name: str) -> list:
"""Return safe path components for a tar member or raise ValueError."""
normalized_name = member_name.replace("\\", "/")
posix_path = PurePosixPath(normalized_name)
windows_path = PureWindowsPath(member_name)
if (
not normalized_name
or posix_path.is_absolute()
or windows_path.is_absolute()
or windows_path.drive
):
raise ValueError(f"Unsafe archive member path: {member_name}")
parts = [part for part in posix_path.parts if part not in ("", ".")]
if not parts or any(part == ".." for part in parts):
raise ValueError(f"Unsafe archive member path: {member_name}")
return parts
def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
"""Extract a tar archive without allowing traversal or link entries."""
target_dir.mkdir(parents=True, exist_ok=True)
target_root = target_dir.resolve()
for member in tar.getmembers():
parts = _normalize_tar_member_parts(member.name)
target = target_dir.joinpath(*parts)
target_real = target.resolve(strict=False)
try:
target_real.relative_to(target_root)
except ValueError as exc:
raise ValueError(f"Unsafe archive member path: {member.name}") from exc
if member.isdir():
target_real.mkdir(parents=True, exist_ok=True)
continue
if not member.isfile():
raise ValueError(f"Unsupported archive member type: {member.name}")
target_real.parent.mkdir(parents=True, exist_ok=True)
extracted = tar.extractfile(member)
if extracted is None:
raise ValueError(f"Cannot read archive member: {member.name}")
with extracted, open(target_real, "wb") as dst:
shutil.copyfileobj(extracted, dst)
try:
os.chmod(target_real, member.mode & 0o777)
except OSError:
pass
def _extract_base64_tar(b64_data: str, target_dir: Path): def _extract_base64_tar(b64_data: str, target_dir: Path):
"""Extract a base64-encoded tar.gz archive into target_dir.""" """Extract a base64-encoded tar.gz archive into target_dir."""
if not b64_data: if not b64_data:
@ -155,7 +211,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
raw = base64.b64decode(b64_data) raw = base64.b64decode(b64_data)
buf = io.BytesIO(raw) buf = io.BytesIO(raw)
with tarfile.open(fileobj=buf, mode="r:gz") as tar: with tarfile.open(fileobj=buf, mode="r:gz") as tar:
tar.extractall(path=str(target_dir)) _safe_extract_tar(tar, target_dir)
# ============================================================================= # =============================================================================

View file

@ -20,6 +20,7 @@ Requires:
""" """
import asyncio import asyncio
import hmac
import json import json
import logging import logging
import os import os
@ -370,7 +371,7 @@ class APIServerAdapter(BasePlatformAdapter):
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "): if auth_header.startswith("Bearer "):
token = auth_header[7:].strip() token = auth_header[7:].strip()
if token == self._api_key: if hmac.compare_digest(token, self._api_key):
return None # Auth OK return None # Auth OK
return web.json_response( return web.json_response(

View file

@ -124,7 +124,14 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
Returns: Returns:
Absolute path to the cached image file as a string. Absolute path to the cached image file as a string.
Raises:
ValueError: If the URL targets a private/internal network (SSRF protection).
""" """
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio import asyncio
import httpx import httpx
import logging as _logging import logging as _logging
@ -232,7 +239,14 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
Returns: Returns:
Absolute path to the cached audio file as a string. Absolute path to the cached audio file as a string.
Raises:
ValueError: If the URL targets a private/internal network (SSRF protection).
""" """
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio import asyncio
import httpx import httpx
import logging as _logging import logging as _logging

View file

@ -55,6 +55,7 @@ from gateway.platforms.base import (
cache_document_from_bytes, cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES, SUPPORTED_DOCUMENT_TYPES,
) )
from tools.url_safety import is_safe_url
def _clean_discord_id(entry: str) -> str: def _clean_discord_id(entry: str) -> str:
@ -1285,6 +1286,10 @@ class DiscordAdapter(BasePlatformAdapter):
if not self._client: if not self._client:
return SendResult(success=False, error="Not connected") return SendResult(success=False, error="Not connected")
if not is_safe_url(image_url):
logger.warning("[%s] Blocked unsafe image URL during Discord send_image", self.name)
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try: try:
import aiohttp import aiohttp

View file

@ -2109,6 +2109,10 @@ class FeishuAdapter(BasePlatformAdapter):
default_ext: str, default_ext: str,
preferred_name: str, preferred_name: str,
) -> tuple[str, str]: ) -> tuple[str, str]:
from tools.url_safety import is_safe_url
if not is_safe_url(file_url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {file_url[:80]}")
import httpx import httpx
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:

View file

@ -586,6 +586,11 @@ class MatrixAdapter(BasePlatformAdapter):
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
) -> SendResult: ) -> SendResult:
"""Download an image URL and upload it to Matrix.""" """Download an image URL and upload it to Matrix."""
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("Matrix: blocked unsafe image URL (SSRF protection)")
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try: try:
# Try aiohttp first (always available), fall back to httpx # Try aiohttp first (always available), fall back to httpx
try: try:

View file

@ -407,6 +407,11 @@ class MattermostAdapter(BasePlatformAdapter):
kind: str = "file", kind: str = "file",
) -> SendResult: ) -> SendResult:
"""Download a URL and upload it as a file attachment.""" """Download a URL and upload it as a file attachment."""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
logger.warning("Mattermost: blocked unsafe URL (SSRF protection)")
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
import asyncio import asyncio
import aiohttp import aiohttp

View file

@ -595,6 +595,11 @@ class SlackAdapter(BasePlatformAdapter):
if not self._app: if not self._app:
return SendResult(success=False, error="Not connected") return SendResult(success=False, error="Not connected")
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("[Slack] Blocked unsafe image URL (SSRF protection)")
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try: try:
import httpx import httpx

View file

@ -1633,6 +1633,11 @@ class TelegramAdapter(BasePlatformAdapter):
if not self._bot: if not self._bot:
return SendResult(success=False, error="Not connected") return SendResult(success=False, error="Not connected")
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("[%s] Blocked unsafe image URL (SSRF protection)", self.name)
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try: try:
# Telegram can send photos directly from URLs (up to ~5MB) # Telegram can send photos directly from URLs (up to ~5MB)
_photo_thread = metadata.get("thread_id") if metadata else None _photo_thread = metadata.get("thread_id") if metadata else None

View file

@ -910,6 +910,10 @@ class WeComAdapter(BasePlatformAdapter):
url: str, url: str,
max_bytes: int, max_bytes: int,
) -> Tuple[bytes, Dict[str, str]]: ) -> Tuple[bytes, Dict[str, str]]:
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url[:80]}")
if not HTTPX_AVAILABLE: if not HTTPX_AVAILABLE:
raise RuntimeError("httpx is required for WeCom media download") raise RuntimeError("httpx is required for WeCom media download")

View file

@ -0,0 +1,164 @@
"""Security tests for Terminal-Bench 2 archive extraction."""
import base64
import importlib
import io
import sys
import tarfile
import types
import pytest
def _stub_module(name: str, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _load_terminalbench_module(monkeypatch):
class _EvalHandlingEnum:
STOP_TRAIN = "stop_train"
class _APIServerConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _AgentResult:
pass
class _HermesAgentLoop:
pass
class _HermesAgentBaseEnv:
pass
class _HermesAgentEnvConfig:
pass
class _ToolContext:
pass
stub_modules = {
"atroposlib": _stub_module("atroposlib"),
"atroposlib.envs": _stub_module("atroposlib.envs"),
"atroposlib.envs.base": _stub_module(
"atroposlib.envs.base",
EvalHandlingEnum=_EvalHandlingEnum,
),
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
"atroposlib.envs.server_handling.server_manager": _stub_module(
"atroposlib.envs.server_handling.server_manager",
APIServerConfig=_APIServerConfig,
),
"environments.agent_loop": _stub_module(
"environments.agent_loop",
AgentResult=_AgentResult,
HermesAgentLoop=_HermesAgentLoop,
),
"environments.hermes_base_env": _stub_module(
"environments.hermes_base_env",
HermesAgentBaseEnv=_HermesAgentBaseEnv,
HermesAgentEnvConfig=_HermesAgentEnvConfig,
),
"environments.tool_context": _stub_module(
"environments.tool_context",
ToolContext=_ToolContext,
),
"tools.terminal_tool": _stub_module(
"tools.terminal_tool",
register_task_env_overrides=lambda *args, **kwargs: None,
clear_task_env_overrides=lambda *args, **kwargs: None,
cleanup_vm=lambda *args, **kwargs: None,
),
}
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
"atroposlib.envs.server_handling.server_manager"
]
for name, module in stub_modules.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
sys.modules.pop(module_name, None)
return importlib.import_module(module_name)
def _build_tar_b64(entries):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
for entry in entries:
kind = entry["kind"]
info = tarfile.TarInfo(entry["name"])
if kind == "dir":
info.type = tarfile.DIRTYPE
tar.addfile(info)
continue
if kind == "file":
data = entry["data"].encode("utf-8")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
continue
if kind == "symlink":
info.type = tarfile.SYMTYPE
info.linkname = entry["target"]
tar.addfile(info)
continue
raise ValueError(f"Unknown tar entry kind: {kind}")
return base64.b64encode(buf.getvalue()).decode("ascii")
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "dir", "name": "nested"},
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
]
)
target = tmp_path / "extract"
module._extract_base64_tar(archive, target)
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "file", "name": "../escape.txt", "data": "owned"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsafe archive member path"):
module._extract_base64_tar(archive, target)
assert not (tmp_path / "escape.txt").exists()
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsupported archive member type"):
module._extract_base64_tar(archive, target)
assert not (target / "link").exists()

View file

@ -504,7 +504,8 @@ class TestMattermostFileUpload:
self.adapter._session = MagicMock() self.adapter._session = MagicMock()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self): @patch("tools.url_safety.is_safe_url", return_value=True)
async def test_send_image_downloads_and_uploads(self, _mock_safe):
"""send_image should download the URL, upload via /api/v4/files, then post.""" """send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET) # Mock the download (GET)
mock_dl_resp = AsyncMock() mock_dl_resp = AsyncMock()

View file

@ -596,10 +596,11 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
return resp return resp
@patch("tools.url_safety.is_safe_url", return_value=True)
class TestMattermostSendUrlAsFile: class TestMattermostSendUrlAsFile:
"""Tests for MattermostAdapter._send_url_as_file""" """Tests for MattermostAdapter._send_url_as_file"""
def test_success_on_first_attempt(self): def test_success_on_first_attempt(self, _mock_safe):
"""200 on first attempt → file uploaded and post created.""" """200 on first attempt → file uploaded and post created."""
adapter = _make_mm_adapter() adapter = _make_mm_adapter()
resp = _make_aiohttp_resp(200) resp = _make_aiohttp_resp(200)
@ -616,7 +617,7 @@ class TestMattermostSendUrlAsFile:
adapter._upload_file.assert_called_once() adapter._upload_file.assert_called_once()
adapter._api_post.assert_called_once() adapter._api_post.assert_called_once()
def test_retries_on_429_then_succeeds(self): def test_retries_on_429_then_succeeds(self, _mock_safe):
"""429 on first attempt is retried; 200 on second attempt succeeds.""" """429 on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter() adapter = _make_mm_adapter()
@ -637,7 +638,7 @@ class TestMattermostSendUrlAsFile:
assert adapter._session.get.call_count == 2 assert adapter._session.get.call_count == 2
mock_sleep.assert_called_once() mock_sleep.assert_called_once()
def test_retries_on_500_then_succeeds(self): def test_retries_on_500_then_succeeds(self, _mock_safe):
"""5xx on first attempt is retried; 200 on second attempt succeeds.""" """5xx on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter() adapter = _make_mm_adapter()
@ -655,7 +656,7 @@ class TestMattermostSendUrlAsFile:
assert result.success assert result.success
assert adapter._session.get.call_count == 2 assert adapter._session.get.call_count == 2
def test_falls_back_to_text_after_max_retries_on_5xx(self): def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
"""Three consecutive 500s exhaust retries; falls back to send() with URL text.""" """Three consecutive 500s exhaust retries; falls back to send() with URL text."""
adapter = _make_mm_adapter() adapter = _make_mm_adapter()
@ -674,7 +675,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1] text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg assert "http://cdn.example.com/img.png" in text_arg
def test_falls_back_on_client_error(self): def test_falls_back_on_client_error(self, _mock_safe):
"""aiohttp.ClientError on every attempt falls back to send() with URL.""" """aiohttp.ClientError on every attempt falls back to send() with URL."""
import aiohttp import aiohttp
@ -699,7 +700,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1] text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg assert "http://cdn.example.com/img.png" in text_arg
def test_non_retryable_404_falls_back_immediately(self): def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
"""404 is non-retryable (< 500, != 429); send() is called right away.""" """404 is non-retryable (< 500, != 429); send() is called right away."""
adapter = _make_mm_adapter() adapter = _make_mm_adapter()

View file

@ -18,6 +18,7 @@ import uuid
from typing import Optional from typing import Optional
from tools.environments.base import BaseEnvironment from tools.environments.base import BaseEnvironment
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
from tools.interrupt import is_interrupted from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -510,6 +511,8 @@ class DockerEnvironment(BaseEnvironment):
forward_keys |= get_all_passthrough() forward_keys |= get_all_passthrough()
except Exception: except Exception:
pass pass
# Strip Hermes-managed secrets so they never leak into the container.
forward_keys -= _HERMES_PROVIDER_ENV_BLOCKLIST
hermes_env = _load_hermes_env_vars() if forward_keys else {} hermes_env = _load_hermes_env_vars() if forward_keys else {}
for key in sorted(forward_keys): for key in sorted(forward_keys):
value = os.getenv(key) value = os.getenv(key)