fix(security): consolidated security hardening — SSRF, timing attack, tar traversal, credential leakage (#5944)

Salvaged from PRs #5800 (memosr), #5806 (memosr), #5915 (Ruzzgar), #5928 (Awsh1).

Changes:
- Use hmac.compare_digest for API key comparison (timing attack prevention)
- Apply provider env var blocklist to Docker containers (credential leakage)
- Replace tar.extractall() with safe extraction in TerminalBench2 (CVE-2007-4559)
- Add SSRF protection via is_safe_url to ALL platform adapters:
  base.py (cache_image_from_url, cache_audio_from_url),
  discord, slack, telegram, matrix, mattermost, feishu, wecom
  (Signal and WhatsApp protected via base.py helpers)
- Update tests: mock is_safe_url in Mattermost download tests
- Add security tests for tar extraction (traversal, symlinks, safe files)
This commit is contained in:
Teknium 2026-04-07 17:28:37 -07:00 committed by GitHub
parent b1a66d55b4
commit 469cd16fe0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 284 additions and 11 deletions

View file

@ -44,7 +44,7 @@ import tempfile
import time
import uuid
from collections import defaultdict
from pathlib import Path
from pathlib import Path, PurePosixPath, PureWindowsPath
from typing import Any, Dict, List, Optional, Tuple, Union
# Ensure repo root is on sys.path for imports
@ -148,6 +148,62 @@ MODAL_INCOMPATIBLE_TASKS = {
# Tar extraction helper
# =============================================================================
def _normalize_tar_member_parts(member_name: str) -> list:
"""Return safe path components for a tar member or raise ValueError."""
normalized_name = member_name.replace("\\", "/")
posix_path = PurePosixPath(normalized_name)
windows_path = PureWindowsPath(member_name)
if (
not normalized_name
or posix_path.is_absolute()
or windows_path.is_absolute()
or windows_path.drive
):
raise ValueError(f"Unsafe archive member path: {member_name}")
parts = [part for part in posix_path.parts if part not in ("", ".")]
if not parts or any(part == ".." for part in parts):
raise ValueError(f"Unsafe archive member path: {member_name}")
return parts
def _safe_extract_tar(tar: tarfile.TarFile, target_dir: Path) -> None:
"""Extract a tar archive without allowing traversal or link entries."""
target_dir.mkdir(parents=True, exist_ok=True)
target_root = target_dir.resolve()
for member in tar.getmembers():
parts = _normalize_tar_member_parts(member.name)
target = target_dir.joinpath(*parts)
target_real = target.resolve(strict=False)
try:
target_real.relative_to(target_root)
except ValueError as exc:
raise ValueError(f"Unsafe archive member path: {member.name}") from exc
if member.isdir():
target_real.mkdir(parents=True, exist_ok=True)
continue
if not member.isfile():
raise ValueError(f"Unsupported archive member type: {member.name}")
target_real.parent.mkdir(parents=True, exist_ok=True)
extracted = tar.extractfile(member)
if extracted is None:
raise ValueError(f"Cannot read archive member: {member.name}")
with extracted, open(target_real, "wb") as dst:
shutil.copyfileobj(extracted, dst)
try:
os.chmod(target_real, member.mode & 0o777)
except OSError:
pass
def _extract_base64_tar(b64_data: str, target_dir: Path):
"""Extract a base64-encoded tar.gz archive into target_dir."""
if not b64_data:
@ -155,7 +211,7 @@ def _extract_base64_tar(b64_data: str, target_dir: Path):
raw = base64.b64decode(b64_data)
buf = io.BytesIO(raw)
with tarfile.open(fileobj=buf, mode="r:gz") as tar:
tar.extractall(path=str(target_dir))
_safe_extract_tar(tar, target_dir)
# =============================================================================

View file

@ -20,6 +20,7 @@ Requires:
"""
import asyncio
import hmac
import json
import logging
import os
@ -370,7 +371,7 @@ class APIServerAdapter(BasePlatformAdapter):
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[7:].strip()
if token == self._api_key:
if hmac.compare_digest(token, self._api_key):
return None # Auth OK
return web.json_response(

View file

@ -124,7 +124,14 @@ async def cache_image_from_url(url: str, ext: str = ".jpg", retries: int = 2) ->
Returns:
Absolute path to the cached image file as a string.
Raises:
ValueError: If the URL targets a private/internal network (SSRF protection).
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio
import httpx
import logging as _logging
@ -232,7 +239,14 @@ async def cache_audio_from_url(url: str, ext: str = ".ogg", retries: int = 2) ->
Returns:
Absolute path to the cached audio file as a string.
Raises:
ValueError: If the URL targets a private/internal network (SSRF protection).
"""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {_safe_url_for_log(url)}")
import asyncio
import httpx
import logging as _logging

View file

@ -55,6 +55,7 @@ from gateway.platforms.base import (
cache_document_from_bytes,
SUPPORTED_DOCUMENT_TYPES,
)
from tools.url_safety import is_safe_url
def _clean_discord_id(entry: str) -> str:
@ -1285,6 +1286,10 @@ class DiscordAdapter(BasePlatformAdapter):
if not self._client:
return SendResult(success=False, error="Not connected")
if not is_safe_url(image_url):
logger.warning("[%s] Blocked unsafe image URL during Discord send_image", self.name)
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try:
import aiohttp

View file

@ -2109,6 +2109,10 @@ class FeishuAdapter(BasePlatformAdapter):
default_ext: str,
preferred_name: str,
) -> tuple[str, str]:
from tools.url_safety import is_safe_url
if not is_safe_url(file_url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {file_url[:80]}")
import httpx
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:

View file

@ -586,6 +586,11 @@ class MatrixAdapter(BasePlatformAdapter):
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Download an image URL and upload it to Matrix."""
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("Matrix: blocked unsafe image URL (SSRF protection)")
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try:
# Try aiohttp first (always available), fall back to httpx
try:

View file

@ -407,6 +407,11 @@ class MattermostAdapter(BasePlatformAdapter):
kind: str = "file",
) -> SendResult:
"""Download a URL and upload it as a file attachment."""
from tools.url_safety import is_safe_url
if not is_safe_url(url):
logger.warning("Mattermost: blocked unsafe URL (SSRF protection)")
return await self.send(chat_id, f"{caption or ''}\n{url}".strip(), reply_to)
import asyncio
import aiohttp

View file

@ -595,6 +595,11 @@ class SlackAdapter(BasePlatformAdapter):
if not self._app:
return SendResult(success=False, error="Not connected")
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("[Slack] Blocked unsafe image URL (SSRF protection)")
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try:
import httpx

View file

@ -1632,7 +1632,12 @@ class TelegramAdapter(BasePlatformAdapter):
"""
if not self._bot:
return SendResult(success=False, error="Not connected")
from tools.url_safety import is_safe_url
if not is_safe_url(image_url):
logger.warning("[%s] Blocked unsafe image URL (SSRF protection)", self.name)
return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata)
try:
# Telegram can send photos directly from URLs (up to ~5MB)
_photo_thread = metadata.get("thread_id") if metadata else None

View file

@ -910,6 +910,10 @@ class WeComAdapter(BasePlatformAdapter):
url: str,
max_bytes: int,
) -> Tuple[bytes, Dict[str, str]]:
from tools.url_safety import is_safe_url
if not is_safe_url(url):
raise ValueError(f"Blocked unsafe URL (SSRF protection): {url[:80]}")
if not HTTPX_AVAILABLE:
raise RuntimeError("httpx is required for WeCom media download")

View file

@ -0,0 +1,164 @@
"""Security tests for Terminal-Bench 2 archive extraction."""
import base64
import importlib
import io
import sys
import tarfile
import types
import pytest
def _stub_module(name: str, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
def _load_terminalbench_module(monkeypatch):
class _EvalHandlingEnum:
STOP_TRAIN = "stop_train"
class _APIServerConfig:
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
class _AgentResult:
pass
class _HermesAgentLoop:
pass
class _HermesAgentBaseEnv:
pass
class _HermesAgentEnvConfig:
pass
class _ToolContext:
pass
stub_modules = {
"atroposlib": _stub_module("atroposlib"),
"atroposlib.envs": _stub_module("atroposlib.envs"),
"atroposlib.envs.base": _stub_module(
"atroposlib.envs.base",
EvalHandlingEnum=_EvalHandlingEnum,
),
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
"atroposlib.envs.server_handling.server_manager": _stub_module(
"atroposlib.envs.server_handling.server_manager",
APIServerConfig=_APIServerConfig,
),
"environments.agent_loop": _stub_module(
"environments.agent_loop",
AgentResult=_AgentResult,
HermesAgentLoop=_HermesAgentLoop,
),
"environments.hermes_base_env": _stub_module(
"environments.hermes_base_env",
HermesAgentBaseEnv=_HermesAgentBaseEnv,
HermesAgentEnvConfig=_HermesAgentEnvConfig,
),
"environments.tool_context": _stub_module(
"environments.tool_context",
ToolContext=_ToolContext,
),
"tools.terminal_tool": _stub_module(
"tools.terminal_tool",
register_task_env_overrides=lambda *args, **kwargs: None,
clear_task_env_overrides=lambda *args, **kwargs: None,
cleanup_vm=lambda *args, **kwargs: None,
),
}
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
"atroposlib.envs.server_handling.server_manager"
]
for name, module in stub_modules.items():
monkeypatch.setitem(sys.modules, name, module)
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
sys.modules.pop(module_name, None)
return importlib.import_module(module_name)
def _build_tar_b64(entries):
buf = io.BytesIO()
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
for entry in entries:
kind = entry["kind"]
info = tarfile.TarInfo(entry["name"])
if kind == "dir":
info.type = tarfile.DIRTYPE
tar.addfile(info)
continue
if kind == "file":
data = entry["data"].encode("utf-8")
info.size = len(data)
tar.addfile(info, io.BytesIO(data))
continue
if kind == "symlink":
info.type = tarfile.SYMTYPE
info.linkname = entry["target"]
tar.addfile(info)
continue
raise ValueError(f"Unknown tar entry kind: {kind}")
return base64.b64encode(buf.getvalue()).decode("ascii")
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "dir", "name": "nested"},
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
]
)
target = tmp_path / "extract"
module._extract_base64_tar(archive, target)
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "file", "name": "../escape.txt", "data": "owned"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsafe archive member path"):
module._extract_base64_tar(archive, target)
assert not (tmp_path / "escape.txt").exists()
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
module = _load_terminalbench_module(monkeypatch)
archive = _build_tar_b64(
[
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
]
)
target = tmp_path / "extract"
with pytest.raises(ValueError, match="Unsupported archive member type"):
module._extract_base64_tar(archive, target)
assert not (target / "link").exists()

View file

@ -504,7 +504,8 @@ class TestMattermostFileUpload:
self.adapter._session = MagicMock()
@pytest.mark.asyncio
async def test_send_image_downloads_and_uploads(self):
@patch("tools.url_safety.is_safe_url", return_value=True)
async def test_send_image_downloads_and_uploads(self, _mock_safe):
"""send_image should download the URL, upload via /api/v4/files, then post."""
# Mock the download (GET)
mock_dl_resp = AsyncMock()

View file

@ -596,10 +596,11 @@ def _make_aiohttp_resp(status: int, content: bytes = b"file bytes",
return resp
@patch("tools.url_safety.is_safe_url", return_value=True)
class TestMattermostSendUrlAsFile:
"""Tests for MattermostAdapter._send_url_as_file"""
def test_success_on_first_attempt(self):
def test_success_on_first_attempt(self, _mock_safe):
"""200 on first attempt → file uploaded and post created."""
adapter = _make_mm_adapter()
resp = _make_aiohttp_resp(200)
@ -616,7 +617,7 @@ class TestMattermostSendUrlAsFile:
adapter._upload_file.assert_called_once()
adapter._api_post.assert_called_once()
def test_retries_on_429_then_succeeds(self):
def test_retries_on_429_then_succeeds(self, _mock_safe):
"""429 on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter()
@ -637,7 +638,7 @@ class TestMattermostSendUrlAsFile:
assert adapter._session.get.call_count == 2
mock_sleep.assert_called_once()
def test_retries_on_500_then_succeeds(self):
def test_retries_on_500_then_succeeds(self, _mock_safe):
"""5xx on first attempt is retried; 200 on second attempt succeeds."""
adapter = _make_mm_adapter()
@ -655,7 +656,7 @@ class TestMattermostSendUrlAsFile:
assert result.success
assert adapter._session.get.call_count == 2
def test_falls_back_to_text_after_max_retries_on_5xx(self):
def test_falls_back_to_text_after_max_retries_on_5xx(self, _mock_safe):
"""Three consecutive 500s exhaust retries; falls back to send() with URL text."""
adapter = _make_mm_adapter()
@ -674,7 +675,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg
def test_falls_back_on_client_error(self):
def test_falls_back_on_client_error(self, _mock_safe):
"""aiohttp.ClientError on every attempt falls back to send() with URL."""
import aiohttp
@ -699,7 +700,7 @@ class TestMattermostSendUrlAsFile:
text_arg = adapter.send.call_args[0][1]
assert "http://cdn.example.com/img.png" in text_arg
def test_non_retryable_404_falls_back_immediately(self):
def test_non_retryable_404_falls_back_immediately(self, _mock_safe):
"""404 is non-retryable (< 500, != 429); send() is called right away."""
adapter = _make_mm_adapter()

View file

@ -18,6 +18,7 @@ import uuid
from typing import Optional
from tools.environments.base import BaseEnvironment
from tools.environments.local import _HERMES_PROVIDER_ENV_BLOCKLIST
from tools.interrupt import is_interrupted
logger = logging.getLogger(__name__)
@ -510,6 +511,8 @@ class DockerEnvironment(BaseEnvironment):
forward_keys |= get_all_passthrough()
except Exception:
pass
# Strip Hermes-managed secrets so they never leak into the container.
forward_keys -= _HERMES_PROVIDER_ENV_BLOCKLIST
hermes_env = _load_hermes_env_vars() if forward_keys else {}
for key in sorted(forward_keys):
value = os.getenv(key)