fix(profiles): stage profile imports to prevent directory clobbering

This commit is contained in:
Aslaaen 2026-04-23 06:41:34 +03:00 committed by Teknium
parent 08cb345e24
commit 51c1d2de16
2 changed files with 82 additions and 15 deletions

View file

@ -863,19 +863,15 @@ def _safe_extract_profile_archive(archive: Path, destination: Path) -> None:
pass
def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
"""Import a profile from a tar.gz archive.
def _inspect_profile_archive_roots(archive: Path) -> set[str]:
"""Return the archive's top-level directory names.
If *name* is not given, infers it from the archive's top-level directory.
Returns the imported profile directory.
Profile imports expect exactly one root directory. Inspecting the archive
before extraction lets us stage the import safely instead of mutating a
live profile tree first and reconciling names later.
"""
import tarfile
archive = Path(archive_path)
if not archive.exists():
raise FileNotFoundError(f"Archive not found: {archive}")
# Peek at the archive to find the top-level directory name
with tarfile.open(archive, "r:gz") as tf:
top_dirs = {
parts[0]
@ -889,13 +885,33 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
for member in tf.getmembers()
if member.isdir()
}
return top_dirs
inferred_name = name or (top_dirs.pop() if len(top_dirs) == 1 else None)
def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
"""Import a profile from a tar.gz archive.
If *name* is not given, infers it from the archive's top-level directory.
Returns the imported profile directory.
"""
import tempfile
archive = Path(archive_path)
if not archive.exists():
raise FileNotFoundError(f"Archive not found: {archive}")
top_dirs = _inspect_profile_archive_roots(archive)
archive_root = top_dirs.pop() if len(top_dirs) == 1 else None
inferred_name = name or archive_root
if not inferred_name:
raise ValueError(
"Cannot determine profile name from archive. "
"Specify it explicitly: hermes profile import <archive> --name <name>"
)
if archive_root is None:
raise ValueError(
"Profile archive must contain exactly one top-level directory."
)
# Archives exported from the default profile have "default/" as top-level
# dir. Importing as "default" would target ~/.hermes itself — disallow
@ -914,12 +930,22 @@ def import_profile(archive_path: str, name: Optional[str] = None) -> Path:
profiles_root = _get_profiles_root()
profiles_root.mkdir(parents=True, exist_ok=True)
_safe_extract_profile_archive(archive, profiles_root)
with tempfile.TemporaryDirectory(prefix="hermes_profile_import_") as tmpdir:
staging_root = Path(tmpdir)
_safe_extract_profile_archive(archive, staging_root)
# If the archive extracted under a different name, rename
extracted = profiles_root / (top_dirs.pop() if top_dirs else inferred_name)
if extracted != profile_dir and extracted.exists():
extracted.rename(profile_dir)
extracted = staging_root / archive_root
if not extracted.is_dir():
raise ValueError(
f"Profile archive root is missing or invalid: {archive_root}"
)
final_source = extracted
if archive_root != inferred_name:
final_source = staging_root / inferred_name
extracted.rename(final_source)
shutil.move(str(final_source), str(profile_dir))
return profile_dir