diff --git a/tests/tools/test_tirith_security.py b/tests/tools/test_tirith_security.py index 6c771c6d482..cb0556cd93c 100644 --- a/tests/tools/test_tirith_security.py +++ b/tests/tools/test_tirith_security.py @@ -1,8 +1,10 @@ """Tests for the tirith security scanning subprocess wrapper.""" +import io import json import os import subprocess +import tarfile import time from unittest.mock import MagicMock, patch @@ -716,6 +718,89 @@ class TestCosignVerification: assert mock_cosign.called # cosign was invoked +class TestInstallArchiveMemberValidation: + def _write_archive(self, tmp_path, member: tarfile.TarInfo, data: bytes | None = None): + archive = tmp_path / "tirith-aarch64-apple-darwin.tar.gz" + checksums = tmp_path / "checksums.txt" + with tarfile.open(archive, "w:gz") as tar: + if data is None: + tar.addfile(member) + else: + tar.addfile(member, io.BytesIO(data)) + checksums.write_text( + "ignored tirith-aarch64-apple-darwin.tar.gz\n", + encoding="utf-8", + ) + return archive, checksums + + def _download_side_effect(self, archive, checksums): + def _download(url, dest, timeout=10): + del timeout + if url.endswith(".tar.gz"): + with open(archive, "rb") as src, open(dest, "wb") as dst: + dst.write(src.read()) + return + if url.endswith("checksums.txt"): + with open(checksums, "rb") as src, open(dest, "wb") as dst: + dst.write(src.read()) + return + raise AssertionError(f"unexpected download URL: {url}") + + return _download + + @patch("tools.tirith_security._verify_checksum", return_value=True) + @patch("tools.tirith_security.shutil.which", return_value=None) + @patch("tools.tirith_security._detect_target", return_value="aarch64-apple-darwin") + def test_install_extracts_regular_tirith_member(self, mock_target, mock_which, + mock_checksum, tmp_path, monkeypatch): + """A valid regular-file tirith member is installed as a plain file.""" + del mock_target, mock_which, mock_checksum + from tools.tirith_security import _install_tirith + + payload = b"#!/bin/sh\nexit 0\n" + member = tarfile.TarInfo("bin/tirith") + member.mode = 0o755 + member.size = len(payload) + archive, checksums = self._write_archive(tmp_path, member, payload) + + hermes_home = tmp_path / "hermes-home" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + with patch("tools.tirith_security._download_file", + side_effect=self._download_side_effect(archive, checksums)): + path, reason = _install_tirith(log_failures=False) + + assert reason == "" + assert path == str(hermes_home / "bin" / "tirith") + assert os.path.isfile(path) + assert not os.path.islink(path) + with open(path, "rb") as f: + assert f.read() == payload + + @patch("tools.tirith_security._verify_checksum", return_value=True) + @patch("tools.tirith_security.shutil.which", return_value=None) + @patch("tools.tirith_security._detect_target", return_value="aarch64-apple-darwin") + def test_install_rejects_non_regular_tirith_member(self, mock_target, mock_which, + mock_checksum, tmp_path, monkeypatch): + """Symlink or hardlink tar members must not be installed as tirith.""" + del mock_target, mock_which, mock_checksum + from tools.tirith_security import _install_tirith + + member = tarfile.TarInfo("bin/tirith") + member.type = tarfile.SYMTYPE + member.linkname = "/bin/sh" + archive, checksums = self._write_archive(tmp_path, member) + + hermes_home = tmp_path / "hermes-home" + monkeypatch.setenv("HERMES_HOME", str(hermes_home)) + with patch("tools.tirith_security._download_file", + side_effect=self._download_side_effect(archive, checksums)): + path, reason = _install_tirith(log_failures=False) + + assert path is None + assert reason == "binary_not_regular_file" + assert not os.path.lexists(hermes_home / "bin" / "tirith") + + # --------------------------------------------------------------------------- # Background install / non-blocking startup (P2) # --------------------------------------------------------------------------- diff --git a/tools/tirith_security.py b/tools/tirith_security.py index 83b222c8887..f40da60e52d 100644 --- a/tools/tirith_security.py +++ b/tools/tirith_security.py @@ -326,6 +326,32 @@ def _verify_checksum(archive_path: str, checksums_path: str, archive_name: str) return True +def _extract_tirith_binary(tar: tarfile.TarFile, dest_dir: str, log) -> tuple[str | None, str]: + """Extract the tirith binary from a release archive into dest_dir.""" + for member in tar.getmembers(): + if member.name == "tirith" or member.name.endswith("/tirith"): + if ".." in member.name: + continue + if not member.isfile(): + log("tirith archive member is not a regular file: %s", member.name) + return None, "binary_not_regular_file" + src_file = tar.extractfile(member) + if src_file is None: + log("tirith binary could not be read from archive") + return None, "binary_extract_failed" + + dest_path = os.path.join(dest_dir, "tirith") + try: + with open(dest_path, "wb") as out: + shutil.copyfileobj(src_file, out) + finally: + src_file.close() + return dest_path, "" + + log("tirith binary not found in archive") + return None, "binary_not_in_archive" + + def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]: """Download and install tirith to $HERMES_HOME/bin/tirith. @@ -394,19 +420,10 @@ def _install_tirith(*, log_failures: bool = True) -> tuple[str | None, str]: return None, "checksum_failed" with tarfile.open(archive_path, "r:gz") as tar: - # Extract only the tirith binary (safety: reject paths with ..) - for member in tar.getmembers(): - if member.name == "tirith" or member.name.endswith("/tirith"): - if ".." in member.name: - continue - member.name = "tirith" - tar.extract(member, tmpdir) - break - else: - log("tirith binary not found in archive") - return None, "binary_not_in_archive" + src, reason = _extract_tirith_binary(tar, tmpdir, log) + if src is None: + return None, reason - src = os.path.join(tmpdir, "tirith") dest = os.path.join(_hermes_bin_dir(), "tirith") try: shutil.move(src, dest)