mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-26 06:01:49 +00:00
fix(skills-hub): cover remaining SSRF fetch paths after #10029
This commit is contained in:
parent
af9df46525
commit
0c5c4d1b8d
3 changed files with 135 additions and 25 deletions
|
|
@ -560,6 +560,11 @@ class TestFindSkillInRepoTree:
|
||||||
|
|
||||||
|
|
||||||
class TestWellKnownSkillSource:
|
class TestWellKnownSkillSource:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _allow_public_skill_fetches(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("tools.skills_hub.is_safe_url", lambda _url: True)
|
||||||
|
monkeypatch.setattr("tools.skills_hub.check_website_access", lambda _url: None)
|
||||||
|
|
||||||
def _source(self):
|
def _source(self):
|
||||||
return WellKnownSkillSource()
|
return WellKnownSkillSource()
|
||||||
|
|
||||||
|
|
@ -675,6 +680,11 @@ class TestWellKnownSkillSource:
|
||||||
|
|
||||||
|
|
||||||
class TestUrlSource:
|
class TestUrlSource:
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _allow_public_skill_fetches(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("tools.skills_hub.is_safe_url", lambda _url: True)
|
||||||
|
monkeypatch.setattr("tools.skills_hub.check_website_access", lambda _url: None)
|
||||||
|
|
||||||
def _source(self):
|
def _source(self):
|
||||||
return UrlSource()
|
return UrlSource()
|
||||||
|
|
||||||
|
|
@ -753,6 +763,13 @@ class TestUrlSource:
|
||||||
mock_get.side_effect = httpx.HTTPError("boom")
|
mock_get.side_effect = httpx.HTTPError("boom")
|
||||||
assert self._source().inspect("https://example.com/SKILL.md") is None
|
assert self._source().inspect("https://example.com/SKILL.md") is None
|
||||||
|
|
||||||
|
@patch("tools.skills_hub.httpx.get")
|
||||||
|
@patch("tools.skills_hub.check_website_access", return_value=None)
|
||||||
|
@patch("tools.skills_hub.is_safe_url", return_value=False)
|
||||||
|
def test_inspect_blocks_private_url(self, _mock_safe, _mock_policy, mock_get):
|
||||||
|
assert self._source().inspect("http://127.0.0.1/SKILL.md") is None
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
|
||||||
@patch("tools.skills_hub.httpx.get")
|
@patch("tools.skills_hub.httpx.get")
|
||||||
def test_inspect_flags_awaiting_name_when_unresolvable(self, mock_get):
|
def test_inspect_flags_awaiting_name_when_unresolvable(self, mock_get):
|
||||||
# No frontmatter name + a URL path that can't produce a valid slug
|
# No frontmatter name + a URL path that can't produce a valid slug
|
||||||
|
|
@ -855,6 +872,24 @@ class TestUrlSource:
|
||||||
mock_get.return_value = MagicMock(status_code=404)
|
mock_get.return_value = MagicMock(status_code=404)
|
||||||
assert self._source().fetch("https://example.com/SKILL.md") is None
|
assert self._source().fetch("https://example.com/SKILL.md") is None
|
||||||
|
|
||||||
|
@patch("tools.skills_hub.httpx.get")
|
||||||
|
@patch("tools.skills_hub.check_website_access", return_value=None)
|
||||||
|
@patch("tools.skills_hub.is_safe_url", side_effect=[True, False])
|
||||||
|
def test_fetch_blocks_redirect_to_private_url(self, _mock_safe, _mock_policy, mock_get):
|
||||||
|
redirect = MagicMock(status_code=302)
|
||||||
|
redirect.headers = {"location": "http://127.0.0.1/private/SKILL.md"}
|
||||||
|
mock_get.return_value = redirect
|
||||||
|
|
||||||
|
assert self._source().fetch("https://example.com/SKILL.md") is None
|
||||||
|
assert mock_get.call_count == 1
|
||||||
|
|
||||||
|
@patch("tools.skills_hub.httpx.get")
|
||||||
|
@patch("tools.skills_hub.check_website_access", return_value=None)
|
||||||
|
@patch("tools.skills_hub.is_safe_url", return_value=False)
|
||||||
|
def test_fetch_blocks_private_url(self, _mock_safe, _mock_policy, mock_get):
|
||||||
|
assert self._source().fetch("http://127.0.0.1/SKILL.md") is None
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
|
||||||
@patch("tools.skills_hub.httpx.get")
|
@patch("tools.skills_hub.httpx.get")
|
||||||
def test_fetch_skips_non_matching_identifier(self, mock_get):
|
def test_fetch_skips_non_matching_identifier(self, mock_get):
|
||||||
assert self._source().fetch("owner/repo/skill") is None
|
assert self._source().fetch("owner/repo/skill") is None
|
||||||
|
|
|
||||||
|
|
@ -7,10 +7,11 @@ from tools.skills_hub import ClawHubSource, SkillMeta
|
||||||
|
|
||||||
|
|
||||||
class _MockResponse:
|
class _MockResponse:
|
||||||
def __init__(self, status_code=200, json_data=None, text=""):
|
def __init__(self, status_code=200, json_data=None, text="", headers=None):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self._json_data = json_data
|
self._json_data = json_data
|
||||||
self.text = text
|
self.text = text
|
||||||
|
self.headers = headers or {}
|
||||||
|
|
||||||
def json(self):
|
def json(self):
|
||||||
return self._json_data
|
return self._json_data
|
||||||
|
|
@ -19,6 +20,14 @@ class _MockResponse:
|
||||||
class TestClawHubSource(unittest.TestCase):
|
class TestClawHubSource(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.src = ClawHubSource()
|
self.src = ClawHubSource()
|
||||||
|
self._safe_patcher = patch("tools.skills_hub.is_safe_url", return_value=True)
|
||||||
|
self._policy_patcher = patch("tools.skills_hub.check_website_access", return_value=None)
|
||||||
|
self._safe_patcher.start()
|
||||||
|
self._policy_patcher.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self._policy_patcher.stop()
|
||||||
|
self._safe_patcher.stop()
|
||||||
|
|
||||||
@patch("tools.skills_hub._write_index_cache")
|
@patch("tools.skills_hub._write_index_cache")
|
||||||
@patch("tools.skills_hub._read_index_cache", return_value=None)
|
@patch("tools.skills_hub._read_index_cache", return_value=None)
|
||||||
|
|
@ -255,6 +264,40 @@ class TestClawHubSource(unittest.TestCase):
|
||||||
self.assertIsNotNone(bundle)
|
self.assertIsNotNone(bundle)
|
||||||
self.assertEqual(bundle.files["SKILL.md"], "# Skill")
|
self.assertEqual(bundle.files["SKILL.md"], "# Skill")
|
||||||
|
|
||||||
|
@patch("tools.skills_hub.check_website_access", return_value=None)
|
||||||
|
@patch("tools.skills_hub.is_safe_url")
|
||||||
|
@patch("tools.skills_hub.httpx.get")
|
||||||
|
def test_fetch_blocks_private_raw_url(self, mock_get, mock_safe, _mock_policy):
|
||||||
|
def side_effect(url, *args, **kwargs):
|
||||||
|
if url.endswith("/skills/caldav-calendar"):
|
||||||
|
return _MockResponse(
|
||||||
|
status_code=200,
|
||||||
|
json_data={
|
||||||
|
"slug": "caldav-calendar",
|
||||||
|
"latestVersion": {"version": "1.0.1"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if url.endswith("/download"):
|
||||||
|
return _MockResponse(status_code=404)
|
||||||
|
if url.endswith("/skills/caldav-calendar/versions/1.0.1"):
|
||||||
|
return _MockResponse(
|
||||||
|
status_code=200,
|
||||||
|
json_data={
|
||||||
|
"files": [
|
||||||
|
{"path": "SKILL.md", "rawUrl": "http://127.0.0.1/private-skill"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return _MockResponse(status_code=404, json_data={})
|
||||||
|
|
||||||
|
mock_get.side_effect = side_effect
|
||||||
|
mock_safe.side_effect = lambda url: not url.startswith("http://127.0.0.1/")
|
||||||
|
|
||||||
|
bundle = self.src.fetch("caldav-calendar")
|
||||||
|
|
||||||
|
self.assertIsNone(bundle)
|
||||||
|
self.assertEqual(mock_get.call_count, 3)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ from datetime import datetime, timezone
|
||||||
from pathlib import Path, PurePosixPath
|
from pathlib import Path, PurePosixPath
|
||||||
from hermes_constants import get_hermes_home
|
from hermes_constants import get_hermes_home
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse, urlunparse
|
from urllib.parse import urljoin, urlparse, urlunparse
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import yaml
|
import yaml
|
||||||
|
|
@ -35,6 +35,8 @@ import yaml
|
||||||
from tools.skills_guard import (
|
from tools.skills_guard import (
|
||||||
ScanResult, content_hash, TRUSTED_REPOS,
|
ScanResult, content_hash, TRUSTED_REPOS,
|
||||||
)
|
)
|
||||||
|
from tools.url_safety import is_safe_url
|
||||||
|
from tools.website_policy import check_website_access
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -55,6 +57,9 @@ INDEX_CACHE_DIR = HUB_DIR / "index-cache"
|
||||||
# Cache duration for remote index fetches
|
# Cache duration for remote index fetches
|
||||||
INDEX_CACHE_TTL = 3600 # 1 hour
|
INDEX_CACHE_TTL = 3600 # 1 hour
|
||||||
|
|
||||||
|
_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||||
|
_MAX_SKILL_FETCH_REDIRECTS = 5
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Data models
|
# Data models
|
||||||
|
|
@ -118,6 +123,43 @@ def _validate_category_name(category: str) -> str:
|
||||||
return _normalize_bundle_path(category, field_name="category", allow_nested=False)
|
return _normalize_bundle_path(category, field_name="category", allow_nested=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _guarded_http_get(url: str, *, timeout: int = 20) -> Optional[httpx.Response]:
|
||||||
|
"""Fetch a URL with SSRF and redirect-target validation."""
|
||||||
|
current_url = url
|
||||||
|
|
||||||
|
for _ in range(_MAX_SKILL_FETCH_REDIRECTS + 1):
|
||||||
|
if not is_safe_url(current_url):
|
||||||
|
logger.warning("Blocked unsafe Skills Hub URL: %s", current_url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
blocked = check_website_access(current_url)
|
||||||
|
if blocked:
|
||||||
|
logger.info(
|
||||||
|
"Blocked Skills Hub fetch for %s by rule %s",
|
||||||
|
blocked["host"],
|
||||||
|
blocked["rule"],
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = httpx.get(current_url, timeout=timeout, follow_redirects=False)
|
||||||
|
except httpx.HTTPError as exc:
|
||||||
|
logger.debug("Skills Hub fetch failed for %s: %s", current_url, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if resp.status_code in _REDIRECT_STATUS_CODES:
|
||||||
|
location = getattr(resp, "headers", {}).get("location")
|
||||||
|
if not location:
|
||||||
|
return None
|
||||||
|
current_url = urljoin(current_url, location)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
logger.warning("Skills Hub fetch exceeded redirect limit for %s", url)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _validate_bundle_rel_path(rel_path: str) -> str:
|
def _validate_bundle_rel_path(rel_path: str) -> str:
|
||||||
return _normalize_bundle_path(rel_path, field_name="bundle file path", allow_nested=True)
|
return _normalize_bundle_path(rel_path, field_name="bundle file path", allow_nested=True)
|
||||||
|
|
||||||
|
|
@ -887,12 +929,12 @@ class WellKnownSkillSource(SkillSource):
|
||||||
if isinstance(cached, dict) and isinstance(cached.get("skills"), list):
|
if isinstance(cached, dict) and isinstance(cached.get("skills"), list):
|
||||||
return cached
|
return cached
|
||||||
|
|
||||||
|
resp = _guarded_http_get(index_url, timeout=20)
|
||||||
|
if resp is None or resp.status_code != 200:
|
||||||
|
return None
|
||||||
try:
|
try:
|
||||||
resp = httpx.get(index_url, timeout=20, follow_redirects=True)
|
|
||||||
if resp.status_code != 200:
|
|
||||||
return None
|
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
except (httpx.HTTPError, json.JSONDecodeError):
|
except json.JSONDecodeError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
skills = data.get("skills", []) if isinstance(data, dict) else []
|
skills = data.get("skills", []) if isinstance(data, dict) else []
|
||||||
|
|
@ -918,12 +960,9 @@ class WellKnownSkillSource(SkillSource):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_text(url: str) -> Optional[str]:
|
def _fetch_text(url: str) -> Optional[str]:
|
||||||
try:
|
resp = _guarded_http_get(url, timeout=20)
|
||||||
resp = httpx.get(url, timeout=20, follow_redirects=True)
|
if resp is not None and resp.status_code == 200:
|
||||||
if resp.status_code == 200:
|
return resp.text
|
||||||
return resp.text
|
|
||||||
except httpx.HTTPError:
|
|
||||||
return None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -1045,13 +1084,9 @@ class UrlSource(SkillSource):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fetch_text(url: str) -> Optional[str]:
|
def _fetch_text(url: str) -> Optional[str]:
|
||||||
try:
|
resp = _guarded_http_get(url, timeout=20)
|
||||||
resp = httpx.get(url, timeout=20, follow_redirects=True)
|
if resp is not None and resp.status_code == 200:
|
||||||
if resp.status_code == 200:
|
return resp.text
|
||||||
return resp.text
|
|
||||||
except httpx.HTTPError as exc:
|
|
||||||
logger.debug("UrlSource fetch failed for %s: %s", url, exc)
|
|
||||||
return None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Skill names must look like identifiers: lowercase letters/digits with
|
# Skill names must look like identifiers: lowercase letters/digits with
|
||||||
|
|
@ -2051,12 +2086,9 @@ class ClawHubSource(SkillSource):
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def _fetch_text(self, url: str) -> Optional[str]:
|
def _fetch_text(self, url: str) -> Optional[str]:
|
||||||
try:
|
resp = _guarded_http_get(url, timeout=20)
|
||||||
resp = httpx.get(url, timeout=20)
|
if resp is not None and resp.status_code == 200:
|
||||||
if resp.status_code == 200:
|
return resp.text
|
||||||
return resp.text
|
|
||||||
except httpx.HTTPError:
|
|
||||||
return None
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue