fix(skills-hub): cover remaining SSRF fetch paths after #10029

This commit is contained in:
heathley 2026-05-09 23:18:49 +03:00 committed by Teknium
parent af9df46525
commit 0c5c4d1b8d
3 changed files with 135 additions and 25 deletions

View file

@ -560,6 +560,11 @@ class TestFindSkillInRepoTree:
class TestWellKnownSkillSource: class TestWellKnownSkillSource:
@pytest.fixture(autouse=True)
def _allow_public_skill_fetches(self, monkeypatch):
monkeypatch.setattr("tools.skills_hub.is_safe_url", lambda _url: True)
monkeypatch.setattr("tools.skills_hub.check_website_access", lambda _url: None)
def _source(self): def _source(self):
return WellKnownSkillSource() return WellKnownSkillSource()
@ -675,6 +680,11 @@ class TestWellKnownSkillSource:
class TestUrlSource: class TestUrlSource:
@pytest.fixture(autouse=True)
def _allow_public_skill_fetches(self, monkeypatch):
monkeypatch.setattr("tools.skills_hub.is_safe_url", lambda _url: True)
monkeypatch.setattr("tools.skills_hub.check_website_access", lambda _url: None)
def _source(self): def _source(self):
return UrlSource() return UrlSource()
@ -753,6 +763,13 @@ class TestUrlSource:
mock_get.side_effect = httpx.HTTPError("boom") mock_get.side_effect = httpx.HTTPError("boom")
assert self._source().inspect("https://example.com/SKILL.md") is None assert self._source().inspect("https://example.com/SKILL.md") is None
@patch("tools.skills_hub.httpx.get")
@patch("tools.skills_hub.check_website_access", return_value=None)
@patch("tools.skills_hub.is_safe_url", return_value=False)
def test_inspect_blocks_private_url(self, _mock_safe, _mock_policy, mock_get):
assert self._source().inspect("http://127.0.0.1/SKILL.md") is None
mock_get.assert_not_called()
@patch("tools.skills_hub.httpx.get") @patch("tools.skills_hub.httpx.get")
def test_inspect_flags_awaiting_name_when_unresolvable(self, mock_get): def test_inspect_flags_awaiting_name_when_unresolvable(self, mock_get):
# No frontmatter name + a URL path that can't produce a valid slug # No frontmatter name + a URL path that can't produce a valid slug
@ -855,6 +872,24 @@ class TestUrlSource:
mock_get.return_value = MagicMock(status_code=404) mock_get.return_value = MagicMock(status_code=404)
assert self._source().fetch("https://example.com/SKILL.md") is None assert self._source().fetch("https://example.com/SKILL.md") is None
@patch("tools.skills_hub.httpx.get")
@patch("tools.skills_hub.check_website_access", return_value=None)
@patch("tools.skills_hub.is_safe_url", side_effect=[True, False])
def test_fetch_blocks_redirect_to_private_url(self, _mock_safe, _mock_policy, mock_get):
redirect = MagicMock(status_code=302)
redirect.headers = {"location": "http://127.0.0.1/private/SKILL.md"}
mock_get.return_value = redirect
assert self._source().fetch("https://example.com/SKILL.md") is None
assert mock_get.call_count == 1
@patch("tools.skills_hub.httpx.get")
@patch("tools.skills_hub.check_website_access", return_value=None)
@patch("tools.skills_hub.is_safe_url", return_value=False)
def test_fetch_blocks_private_url(self, _mock_safe, _mock_policy, mock_get):
assert self._source().fetch("http://127.0.0.1/SKILL.md") is None
mock_get.assert_not_called()
@patch("tools.skills_hub.httpx.get") @patch("tools.skills_hub.httpx.get")
def test_fetch_skips_non_matching_identifier(self, mock_get): def test_fetch_skips_non_matching_identifier(self, mock_get):
assert self._source().fetch("owner/repo/skill") is None assert self._source().fetch("owner/repo/skill") is None

View file

@ -7,10 +7,11 @@ from tools.skills_hub import ClawHubSource, SkillMeta
class _MockResponse: class _MockResponse:
def __init__(self, status_code=200, json_data=None, text=""): def __init__(self, status_code=200, json_data=None, text="", headers=None):
self.status_code = status_code self.status_code = status_code
self._json_data = json_data self._json_data = json_data
self.text = text self.text = text
self.headers = headers or {}
def json(self): def json(self):
return self._json_data return self._json_data
@ -19,6 +20,14 @@ class _MockResponse:
class TestClawHubSource(unittest.TestCase): class TestClawHubSource(unittest.TestCase):
def setUp(self): def setUp(self):
self.src = ClawHubSource() self.src = ClawHubSource()
self._safe_patcher = patch("tools.skills_hub.is_safe_url", return_value=True)
self._policy_patcher = patch("tools.skills_hub.check_website_access", return_value=None)
self._safe_patcher.start()
self._policy_patcher.start()
def tearDown(self):
self._policy_patcher.stop()
self._safe_patcher.stop()
@patch("tools.skills_hub._write_index_cache") @patch("tools.skills_hub._write_index_cache")
@patch("tools.skills_hub._read_index_cache", return_value=None) @patch("tools.skills_hub._read_index_cache", return_value=None)
@ -255,6 +264,40 @@ class TestClawHubSource(unittest.TestCase):
self.assertIsNotNone(bundle) self.assertIsNotNone(bundle)
self.assertEqual(bundle.files["SKILL.md"], "# Skill") self.assertEqual(bundle.files["SKILL.md"], "# Skill")
@patch("tools.skills_hub.check_website_access", return_value=None)
@patch("tools.skills_hub.is_safe_url")
@patch("tools.skills_hub.httpx.get")
def test_fetch_blocks_private_raw_url(self, mock_get, mock_safe, _mock_policy):
def side_effect(url, *args, **kwargs):
if url.endswith("/skills/caldav-calendar"):
return _MockResponse(
status_code=200,
json_data={
"slug": "caldav-calendar",
"latestVersion": {"version": "1.0.1"},
},
)
if url.endswith("/download"):
return _MockResponse(status_code=404)
if url.endswith("/skills/caldav-calendar/versions/1.0.1"):
return _MockResponse(
status_code=200,
json_data={
"files": [
{"path": "SKILL.md", "rawUrl": "http://127.0.0.1/private-skill"},
]
},
)
return _MockResponse(status_code=404, json_data={})
mock_get.side_effect = side_effect
mock_safe.side_effect = lambda url: not url.startswith("http://127.0.0.1/")
bundle = self.src.fetch("caldav-calendar")
self.assertIsNone(bundle)
self.assertEqual(mock_get.call_count, 3)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View file

@ -27,7 +27,7 @@ from datetime import datetime, timezone
from pathlib import Path, PurePosixPath from pathlib import Path, PurePosixPath
from hermes_constants import get_hermes_home from hermes_constants import get_hermes_home
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse, urlunparse from urllib.parse import urljoin, urlparse, urlunparse
import httpx import httpx
import yaml import yaml
@ -35,6 +35,8 @@ import yaml
from tools.skills_guard import ( from tools.skills_guard import (
ScanResult, content_hash, TRUSTED_REPOS, ScanResult, content_hash, TRUSTED_REPOS,
) )
from tools.url_safety import is_safe_url
from tools.website_policy import check_website_access
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -55,6 +57,9 @@ INDEX_CACHE_DIR = HUB_DIR / "index-cache"
# Cache duration for remote index fetches # Cache duration for remote index fetches
INDEX_CACHE_TTL = 3600 # 1 hour INDEX_CACHE_TTL = 3600 # 1 hour
_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
_MAX_SKILL_FETCH_REDIRECTS = 5
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Data models # Data models
@ -118,6 +123,43 @@ def _validate_category_name(category: str) -> str:
return _normalize_bundle_path(category, field_name="category", allow_nested=False) return _normalize_bundle_path(category, field_name="category", allow_nested=False)
def _guarded_http_get(url: str, *, timeout: int = 20) -> Optional[httpx.Response]:
"""Fetch a URL with SSRF and redirect-target validation."""
current_url = url
for _ in range(_MAX_SKILL_FETCH_REDIRECTS + 1):
if not is_safe_url(current_url):
logger.warning("Blocked unsafe Skills Hub URL: %s", current_url)
return None
blocked = check_website_access(current_url)
if blocked:
logger.info(
"Blocked Skills Hub fetch for %s by rule %s",
blocked["host"],
blocked["rule"],
)
return None
try:
resp = httpx.get(current_url, timeout=timeout, follow_redirects=False)
except httpx.HTTPError as exc:
logger.debug("Skills Hub fetch failed for %s: %s", current_url, exc)
return None
if resp.status_code in _REDIRECT_STATUS_CODES:
location = getattr(resp, "headers", {}).get("location")
if not location:
return None
current_url = urljoin(current_url, location)
continue
return resp
logger.warning("Skills Hub fetch exceeded redirect limit for %s", url)
return None
def _validate_bundle_rel_path(rel_path: str) -> str: def _validate_bundle_rel_path(rel_path: str) -> str:
return _normalize_bundle_path(rel_path, field_name="bundle file path", allow_nested=True) return _normalize_bundle_path(rel_path, field_name="bundle file path", allow_nested=True)
@ -887,12 +929,12 @@ class WellKnownSkillSource(SkillSource):
if isinstance(cached, dict) and isinstance(cached.get("skills"), list): if isinstance(cached, dict) and isinstance(cached.get("skills"), list):
return cached return cached
resp = _guarded_http_get(index_url, timeout=20)
if resp is None or resp.status_code != 200:
return None
try: try:
resp = httpx.get(index_url, timeout=20, follow_redirects=True)
if resp.status_code != 200:
return None
data = resp.json() data = resp.json()
except (httpx.HTTPError, json.JSONDecodeError): except json.JSONDecodeError:
return None return None
skills = data.get("skills", []) if isinstance(data, dict) else [] skills = data.get("skills", []) if isinstance(data, dict) else []
@ -918,12 +960,9 @@ class WellKnownSkillSource(SkillSource):
@staticmethod @staticmethod
def _fetch_text(url: str) -> Optional[str]: def _fetch_text(url: str) -> Optional[str]:
try: resp = _guarded_http_get(url, timeout=20)
resp = httpx.get(url, timeout=20, follow_redirects=True) if resp is not None and resp.status_code == 200:
if resp.status_code == 200: return resp.text
return resp.text
except httpx.HTTPError:
return None
return None return None
@staticmethod @staticmethod
@ -1045,13 +1084,9 @@ class UrlSource(SkillSource):
@staticmethod @staticmethod
def _fetch_text(url: str) -> Optional[str]: def _fetch_text(url: str) -> Optional[str]:
try: resp = _guarded_http_get(url, timeout=20)
resp = httpx.get(url, timeout=20, follow_redirects=True) if resp is not None and resp.status_code == 200:
if resp.status_code == 200: return resp.text
return resp.text
except httpx.HTTPError as exc:
logger.debug("UrlSource fetch failed for %s: %s", url, exc)
return None
return None return None
# Skill names must look like identifiers: lowercase letters/digits with # Skill names must look like identifiers: lowercase letters/digits with
@ -2051,12 +2086,9 @@ class ClawHubSource(SkillSource):
return files return files
def _fetch_text(self, url: str) -> Optional[str]: def _fetch_text(self, url: str) -> Optional[str]:
try: resp = _guarded_http_get(url, timeout=20)
resp = httpx.get(url, timeout=20) if resp is not None and resp.status_code == 200:
if resp.status_code == 200: return resp.text
return resp.text
except httpx.HTTPError:
return None
return None return None