diff --git a/tests/tools/test_osv_check.py b/tests/tools/test_osv_check.py new file mode 100644 index 000000000..f99fd39ee --- /dev/null +++ b/tests/tools/test_osv_check.py @@ -0,0 +1,170 @@ +"""Tests for OSV malware check on MCP extension packages.""" + +import json +import pytest +from unittest.mock import patch, MagicMock + +from tools.osv_check import ( + check_package_for_malware, + _infer_ecosystem, + _parse_package_from_args, + _parse_npm_package, + _parse_pypi_package, + _query_osv, +) + + +class TestInferEcosystem: + def test_npx(self): + assert _infer_ecosystem("npx") == "npm" + assert _infer_ecosystem("/usr/bin/npx") == "npm" + + def test_uvx(self): + assert _infer_ecosystem("uvx") == "PyPI" + assert _infer_ecosystem("/home/user/.local/bin/uvx") == "PyPI" + + def test_pipx(self): + assert _infer_ecosystem("pipx") == "PyPI" + + def test_unknown(self): + assert _infer_ecosystem("node") is None + assert _infer_ecosystem("python") is None + assert _infer_ecosystem("/bin/bash") is None + + +class TestParseNpmPackage: + def test_simple(self): + assert _parse_npm_package("react") == ("react", None) + + def test_with_version(self): + assert _parse_npm_package("react@18.3.1") == ("react", "18.3.1") + + def test_scoped(self): + assert _parse_npm_package("@modelcontextprotocol/server-filesystem") == ( + "@modelcontextprotocol/server-filesystem", None + ) + + def test_scoped_with_version(self): + assert _parse_npm_package("@scope/pkg@1.2.3") == ("@scope/pkg", "1.2.3") + + def test_latest_ignored(self): + assert _parse_npm_package("react@latest") == ("react", None) + + +class TestParsePypiPackage: + def test_simple(self): + assert _parse_pypi_package("requests") == ("requests", None) + + def test_with_version(self): + assert _parse_pypi_package("requests==2.32.3") == ("requests", "2.32.3") + + def test_with_extras(self): + assert _parse_pypi_package("mcp[cli]==1.2.3") == ("mcp", "1.2.3") + + def test_extras_no_version(self): + assert _parse_pypi_package("mcp[cli]") == ("mcp", None) + + +class TestParsePackageFromArgs: + def test_npm_skips_flags(self): + name, ver = _parse_package_from_args(["-y", "@scope/pkg@1.0"], "npm") + assert name == "@scope/pkg" + assert ver == "1.0" + + def test_pypi_skips_flags(self): + name, ver = _parse_package_from_args(["--from", "mcp[cli]"], "PyPI") + # --from is a flag, mcp[cli] is the package + # Actually --from is a flag so it gets skipped, mcp[cli] is found + assert name == "mcp" + + def test_empty_args(self): + assert _parse_package_from_args([], "npm") == (None, None) + + def test_only_flags(self): + assert _parse_package_from_args(["-y", "--yes"], "npm") == (None, None) + + +class TestCheckPackageForMalware: + def test_clean_package(self): + """Clean package returns None (allow).""" + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({"vulns": []}).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + + with patch("tools.osv_check.urllib.request.urlopen", return_value=mock_response): + result = check_package_for_malware("npx", ["-y", "@modelcontextprotocol/server-filesystem"]) + assert result is None + + def test_malware_blocked(self): + """Known malware package returns error string.""" + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({ + "vulns": [ + {"id": "MAL-2023-7938", "summary": "Malicious code in evil-pkg"}, + {"id": "CVE-2023-1234", "summary": "Regular vulnerability"}, # should be filtered + ] + }).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + + with patch("tools.osv_check.urllib.request.urlopen", return_value=mock_response): + result = check_package_for_malware("npx", ["evil-pkg"]) + assert result is not None + assert "BLOCKED" in result + assert "MAL-2023-7938" in result + assert "CVE-2023-1234" not in result # regular CVEs filtered + + def test_network_error_fails_open(self): + """Network errors allow the package (fail-open).""" + with patch("tools.osv_check.urllib.request.urlopen", side_effect=ConnectionError("timeout")): + result = check_package_for_malware("npx", ["some-package"]) + assert result is None + + def test_non_npx_skipped(self): + """Non-npx/uvx commands are skipped entirely.""" + result = check_package_for_malware("node", ["server.js"]) + assert result is None + + def test_uvx_pypi(self): + """uvx commands check PyPI ecosystem.""" + mock_response = MagicMock() + mock_response.read.return_value = json.dumps({"vulns": []}).encode() + mock_response.__enter__ = lambda s: s + mock_response.__exit__ = MagicMock(return_value=False) + + with patch("tools.osv_check.urllib.request.urlopen", return_value=mock_response) as mock_url: + check_package_for_malware("uvx", ["mcp-server-fetch"]) + # Verify PyPI ecosystem was sent + call_data = json.loads(mock_url.call_args[0][0].data) + assert call_data["package"]["ecosystem"] == "PyPI" + assert call_data["package"]["name"] == "mcp-server-fetch" + + +class TestLiveOsvQuery: + """Live integration test against the real OSV API. Skipped if offline.""" + + @pytest.mark.skipif( + not pytest.importorskip("urllib.request", reason="no network"), + reason="network required", + ) + def test_known_malware_package(self): + """node-hide-console-windows has a real MAL- advisory.""" + try: + result = _query_osv("node-hide-console-windows", "npm") + assert len(result) >= 1 + assert result[0]["id"].startswith("MAL-") + except Exception: + pytest.skip("OSV API unreachable") + + @pytest.mark.skipif( + not pytest.importorskip("urllib.request", reason="no network"), + reason="network required", + ) + def test_clean_package(self): + """react should have zero MAL- advisories.""" + try: + result = _query_osv("react", "npm") + assert len(result) == 0 + except Exception: + pytest.skip("OSV API unreachable") diff --git a/tools/mcp_tool.py b/tools/mcp_tool.py index 88bb6fd73..2e1b9217f 100644 --- a/tools/mcp_tool.py +++ b/tools/mcp_tool.py @@ -833,6 +833,15 @@ class MCPServerTask: safe_env = _build_safe_env(user_env) command, safe_env = _resolve_stdio_command(command, safe_env) + + # Check package against OSV malware database before spawning + from tools.osv_check import check_package_for_malware + malware_error = check_package_for_malware(command, args) + if malware_error: + raise ValueError( + f"MCP server '{self.name}': {malware_error}" + ) + server_params = StdioServerParameters( command=command, args=args, diff --git a/tools/osv_check.py b/tools/osv_check.py new file mode 100644 index 000000000..52458fdd3 --- /dev/null +++ b/tools/osv_check.py @@ -0,0 +1,155 @@ +"""OSV malware check for MCP extension packages. + +Before launching an MCP server via npx/uvx, queries the OSV (Open Source +Vulnerabilities) API to check if the package has any known malware advisories +(MAL-* IDs). Regular CVEs are ignored — only confirmed malware is blocked. + +The API is free, public, and maintained by Google. Typical latency is ~300ms. +Fail-open: network errors allow the package to proceed. + +Inspired by Block/goose's extension malware check. +""" + +import json +import logging +import os +import re +import urllib.request +from typing import Optional, Tuple + +logger = logging.getLogger(__name__) + +_OSV_ENDPOINT = os.getenv("OSV_ENDPOINT", "https://api.osv.dev/v1/query") +_TIMEOUT = 10 # seconds + + +def check_package_for_malware( + command: str, args: list +) -> Optional[str]: + """Check if an MCP server package has known malware advisories. + + Inspects the *command* (e.g. ``npx``, ``uvx``) and *args* to infer the + package name and ecosystem. Queries the OSV API for MAL-* advisories. + + Returns: + An error message string if malware is found, or None if clean/unknown. + Returns None (allow) on network errors or unrecognized commands. + """ + ecosystem = _infer_ecosystem(command) + if not ecosystem: + return None # not npx/uvx — skip + + package, version = _parse_package_from_args(args, ecosystem) + if not package: + return None + + try: + malware = _query_osv(package, ecosystem, version) + except Exception as exc: + # Fail-open: network errors, timeouts, parse failures → allow + logger.debug("OSV check failed for %s/%s (allowing): %s", ecosystem, package, exc) + return None + + if malware: + ids = ", ".join(m["id"] for m in malware[:3]) + summaries = "; ".join( + m.get("summary", m["id"])[:100] for m in malware[:3] + ) + return ( + f"BLOCKED: Package '{package}' ({ecosystem}) has known malware " + f"advisories: {ids}. Details: {summaries}" + ) + return None + + +def _infer_ecosystem(command: str) -> Optional[str]: + """Infer package ecosystem from the command name.""" + base = os.path.basename(command).lower() + if base in ("npx", "npx.cmd"): + return "npm" + if base in ("uvx", "uvx.cmd", "pipx"): + return "PyPI" + return None + + +def _parse_package_from_args( + args: list, ecosystem: str +) -> Tuple[Optional[str], Optional[str]]: + """Extract package name and optional version from command args. + + Returns (package_name, version) or (None, None) if not parseable. + """ + if not args: + return None, None + + # Skip flags to find the package token + package_token = None + for arg in args: + if not isinstance(arg, str): + continue + if arg.startswith("-"): + continue + package_token = arg + break + + if not package_token: + return None, None + + if ecosystem == "npm": + return _parse_npm_package(package_token) + elif ecosystem == "PyPI": + return _parse_pypi_package(package_token) + return package_token, None + + +def _parse_npm_package(token: str) -> Tuple[Optional[str], Optional[str]]: + """Parse npm package: @scope/name@version or name@version.""" + if token.startswith("@"): + # Scoped: @scope/name@version + match = re.match(r"^(@[^/]+/[^@]+)(?:@(.+))?$", token) + if match: + return match.group(1), match.group(2) + return token, None + # Unscoped: name@version + if "@" in token: + parts = token.rsplit("@", 1) + name = parts[0] + version = parts[1] if len(parts) > 1 and parts[1] != "latest" else None + return name, version + return token, None + + +def _parse_pypi_package(token: str) -> Tuple[Optional[str], Optional[str]]: + """Parse PyPI package: name==version or name[extras]==version.""" + # Strip extras: name[extra1,extra2]==version + match = re.match(r"^([a-zA-Z0-9._-]+)(?:\[[^\]]*\])?(?:==(.+))?$", token) + if match: + return match.group(1), match.group(2) + return token, None + + +def _query_osv( + package: str, ecosystem: str, version: Optional[str] = None +) -> list: + """Query the OSV API for MAL-* advisories. Returns list of malware vulns.""" + payload = {"package": {"name": package, "ecosystem": ecosystem}} + if version: + payload["version"] = version + + data = json.dumps(payload).encode("utf-8") + req = urllib.request.Request( + _OSV_ENDPOINT, + data=data, + headers={ + "Content-Type": "application/json", + "User-Agent": "hermes-agent-osv-check/1.0", + }, + method="POST", + ) + + with urllib.request.urlopen(req, timeout=_TIMEOUT) as resp: + result = json.loads(resp.read()) + + vulns = result.get("vulns", []) + # Only malware advisories — ignore regular CVEs + return [v for v in vulns if v.get("id", "").startswith("MAL-")]