diff --git a/hermes_cli/profiles.py b/hermes_cli/profiles.py index 779728adc..bf6de16df 100644 --- a/hermes_cli/profiles.py +++ b/hermes_cli/profiles.py @@ -863,19 +863,15 @@ def _safe_extract_profile_archive(archive: Path, destination: Path) -> None: pass -def import_profile(archive_path: str, name: Optional[str] = None) -> Path: - """Import a profile from a tar.gz archive. +def _inspect_profile_archive_roots(archive: Path) -> set[str]: + """Return the archive's top-level directory names. - If *name* is not given, infers it from the archive's top-level directory. - Returns the imported profile directory. + Profile imports expect exactly one root directory. Inspecting the archive + before extraction lets us stage the import safely instead of mutating a + live profile tree first and reconciling names later. """ import tarfile - archive = Path(archive_path) - if not archive.exists(): - raise FileNotFoundError(f"Archive not found: {archive}") - - # Peek at the archive to find the top-level directory name with tarfile.open(archive, "r:gz") as tf: top_dirs = { parts[0] @@ -889,13 +885,33 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path: for member in tf.getmembers() if member.isdir() } + return top_dirs - inferred_name = name or (top_dirs.pop() if len(top_dirs) == 1 else None) + +def import_profile(archive_path: str, name: Optional[str] = None) -> Path: + """Import a profile from a tar.gz archive. + + If *name* is not given, infers it from the archive's top-level directory. + Returns the imported profile directory. + """ + import tempfile + + archive = Path(archive_path) + if not archive.exists(): + raise FileNotFoundError(f"Archive not found: {archive}") + + top_dirs = _inspect_profile_archive_roots(archive) + archive_root = top_dirs.pop() if len(top_dirs) == 1 else None + inferred_name = name or archive_root if not inferred_name: raise ValueError( "Cannot determine profile name from archive. " "Specify it explicitly: hermes profile import --name " ) + if archive_root is None: + raise ValueError( + "Profile archive must contain exactly one top-level directory." + ) # Archives exported from the default profile have "default/" as top-level # dir. Importing as "default" would target ~/.hermes itself — disallow @@ -914,12 +930,22 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path: profiles_root = _get_profiles_root() profiles_root.mkdir(parents=True, exist_ok=True) - _safe_extract_profile_archive(archive, profiles_root) + with tempfile.TemporaryDirectory(prefix="hermes_profile_import_") as tmpdir: + staging_root = Path(tmpdir) + _safe_extract_profile_archive(archive, staging_root) - # If the archive extracted under a different name, rename - extracted = profiles_root / (top_dirs.pop() if top_dirs else inferred_name) - if extracted != profile_dir and extracted.exists(): - extracted.rename(profile_dir) + extracted = staging_root / archive_root + if not extracted.is_dir(): + raise ValueError( + f"Profile archive root is missing or invalid: {archive_root}" + ) + + final_source = extracted + if archive_root != inferred_name: + final_source = staging_root / inferred_name + extracted.rename(final_source) + + shutil.move(str(final_source), str(profile_dir)) return profile_dir diff --git a/tests/hermes_cli/test_profiles.py b/tests/hermes_cli/test_profiles.py index 9c2dafb97..7e181c1a8 100644 --- a/tests/hermes_cli/test_profiles.py +++ b/tests/hermes_cli/test_profiles.py @@ -455,6 +455,47 @@ class TestExportImport: with pytest.raises(FileExistsError): import_profile(str(archive_path), name="coder") + def test_import_with_explicit_name_does_not_mutate_existing_archive_root_profile( + self, profile_env, tmp_path + ): + create_profile("victim", no_alias=True) + victim_dir = get_profile_dir("victim") + (victim_dir / "marker.txt").write_text("original") + + archive_path = tmp_path / "export" / "victim.tar.gz" + archive_path.parent.mkdir(parents=True, exist_ok=True) + with tarfile.open(archive_path, "w:gz") as tf: + data = b"imported" + info = tarfile.TarInfo("victim/marker.txt") + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + imported = import_profile(str(archive_path), name="renamed") + + assert imported == get_profile_dir("renamed") + assert (imported / "marker.txt").read_text() == "imported" + assert (victim_dir / "marker.txt").read_text() == "original" + + def test_import_rejects_archive_with_multiple_top_level_directories( + self, profile_env, tmp_path + ): + archive_path = tmp_path / "export" / "multi-root.tar.gz" + archive_path.parent.mkdir(parents=True, exist_ok=True) + + with tarfile.open(archive_path, "w:gz") as tf: + for member_name, data in ( + ("alpha/marker.txt", b"a"), + ("beta/marker.txt", b"b"), + ): + info = tarfile.TarInfo(member_name) + info.size = len(data) + tf.addfile(info, io.BytesIO(data)) + + with pytest.raises(ValueError, match="exactly one top-level directory"): + import_profile(str(archive_path), name="coder") + + assert not get_profile_dir("coder").exists() + def test_import_rejects_traversal_archive_member(self, profile_env, tmp_path): archive_path = tmp_path / "export" / "evil.tar.gz" archive_path.parent.mkdir(parents=True, exist_ok=True)