fix(profiles): stage profile imports to prevent directory clobbering

This commit is contained in:
Aslaaen 2026-04-23 06:41:34 +03:00 committed by Teknium
parent 08cb345e24
commit 51c1d2de16
2 changed files with 82 additions and 15 deletions

View file

@ -863,19 +863,15 @@ def _safe_extract_profile_archive(archive: Path, destination: Path) -> None:
pass pass
def import_profile(archive_path: str, name: Optional[str] = None) -> Path: def _inspect_profile_archive_roots(archive: Path) -> set[str]:
"""Import a profile from a tar.gz archive. """Return the archive's top-level directory names.
If *name* is not given, infers it from the archive's top-level directory. Profile imports expect exactly one root directory. Inspecting the archive
Returns the imported profile directory. before extraction lets us stage the import safely instead of mutating a
live profile tree first and reconciling names later.
""" """
import tarfile import tarfile
archive = Path(archive_path)
if not archive.exists():
raise FileNotFoundError(f"Archive not found: {archive}")
# Peek at the archive to find the top-level directory name
with tarfile.open(archive, "r:gz") as tf: with tarfile.open(archive, "r:gz") as tf:
top_dirs = { top_dirs = {
parts[0] parts[0]
@ -889,13 +885,33 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
for member in tf.getmembers() for member in tf.getmembers()
if member.isdir() if member.isdir()
} }
return top_dirs
inferred_name = name or (top_dirs.pop() if len(top_dirs) == 1 else None)
def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
"""Import a profile from a tar.gz archive.
If *name* is not given, infers it from the archive's top-level directory.
Returns the imported profile directory.
"""
import tempfile
archive = Path(archive_path)
if not archive.exists():
raise FileNotFoundError(f"Archive not found: {archive}")
top_dirs = _inspect_profile_archive_roots(archive)
archive_root = top_dirs.pop() if len(top_dirs) == 1 else None
inferred_name = name or archive_root
if not inferred_name: if not inferred_name:
raise ValueError( raise ValueError(
"Cannot determine profile name from archive. " "Cannot determine profile name from archive. "
"Specify it explicitly: hermes profile import <archive> --name <name>" "Specify it explicitly: hermes profile import <archive> --name <name>"
) )
if archive_root is None:
raise ValueError(
"Profile archive must contain exactly one top-level directory."
)
# Archives exported from the default profile have "default/" as top-level # Archives exported from the default profile have "default/" as top-level
# dir. Importing as "default" would target ~/.hermes itself — disallow # dir. Importing as "default" would target ~/.hermes itself — disallow
@ -914,12 +930,22 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
profiles_root = _get_profiles_root() profiles_root = _get_profiles_root()
profiles_root.mkdir(parents=True, exist_ok=True) profiles_root.mkdir(parents=True, exist_ok=True)
_safe_extract_profile_archive(archive, profiles_root) with tempfile.TemporaryDirectory(prefix="hermes_profile_import_") as tmpdir:
staging_root = Path(tmpdir)
_safe_extract_profile_archive(archive, staging_root)
# If the archive extracted under a different name, rename extracted = staging_root / archive_root
extracted = profiles_root / (top_dirs.pop() if top_dirs else inferred_name) if not extracted.is_dir():
if extracted != profile_dir and extracted.exists(): raise ValueError(
extracted.rename(profile_dir) f"Profile archive root is missing or invalid: {archive_root}"
)
final_source = extracted
if archive_root != inferred_name:
final_source = staging_root / inferred_name
extracted.rename(final_source)
shutil.move(str(final_source), str(profile_dir))
return profile_dir return profile_dir

View file

@ -455,6 +455,47 @@ class TestExportImport:
with pytest.raises(FileExistsError): with pytest.raises(FileExistsError):
import_profile(str(archive_path), name="coder") import_profile(str(archive_path), name="coder")
def test_import_with_explicit_name_does_not_mutate_existing_archive_root_profile(
self, profile_env, tmp_path
):
create_profile("victim", no_alias=True)
victim_dir = get_profile_dir("victim")
(victim_dir / "marker.txt").write_text("original")
archive_path = tmp_path / "export" / "victim.tar.gz"
archive_path.parent.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive_path, "w:gz") as tf:
data = b"imported"
info = tarfile.TarInfo("victim/marker.txt")
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
imported = import_profile(str(archive_path), name="renamed")
assert imported == get_profile_dir("renamed")
assert (imported / "marker.txt").read_text() == "imported"
assert (victim_dir / "marker.txt").read_text() == "original"
def test_import_rejects_archive_with_multiple_top_level_directories(
self, profile_env, tmp_path
):
archive_path = tmp_path / "export" / "multi-root.tar.gz"
archive_path.parent.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive_path, "w:gz") as tf:
for member_name, data in (
("alpha/marker.txt", b"a"),
("beta/marker.txt", b"b"),
):
info = tarfile.TarInfo(member_name)
info.size = len(data)
tf.addfile(info, io.BytesIO(data))
with pytest.raises(ValueError, match="exactly one top-level directory"):
import_profile(str(archive_path), name="coder")
assert not get_profile_dir("coder").exists()
def test_import_rejects_traversal_archive_member(self, profile_env, tmp_path): def test_import_rejects_traversal_archive_member(self, profile_env, tmp_path):
archive_path = tmp_path / "export" / "evil.tar.gz" archive_path = tmp_path / "export" / "evil.tar.gz"
archive_path.parent.mkdir(parents=True, exist_ok=True) archive_path.parent.mkdir(parents=True, exist_ok=True)