From 60d3b8cbce5adb1710bf14503c79a9f34fcc67f9 Mon Sep 17 00:00:00 2001 From: LeonSGP43 Date: Wed, 24 Jun 2026 09:23:40 +0800 Subject: [PATCH] fix(docker): restore config backups after failed boot migration --- scripts/docker_config_migrate.py | 38 ++++++++++-- tests/tools/test_docker_config_migrate.py | 70 +++++++++++++++++++++++ 2 files changed, 103 insertions(+), 5 deletions(-) diff --git a/scripts/docker_config_migrate.py b/scripts/docker_config_migrate.py index a0c83ed1247..b563cdb8405 100644 --- a/scripts/docker_config_migrate.py +++ b/scripts/docker_config_migrate.py @@ -28,18 +28,28 @@ def _backup_path(path: Path, stamp: str) -> Path: raise RuntimeError(f"could not choose a backup path for {path}") -def _backup_existing(paths: Iterable[Path]) -> list[Path]: +def _backup_existing(paths: Iterable[Path]) -> dict[Path, Path]: stamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ") - backups: list[Path] = [] + backups: dict[Path, Path] = {} for path in paths: if not path.is_file(): continue dest = _backup_path(path, stamp) shutil.copy2(path, dest) - backups.append(dest) + backups[path] = dest return backups +def _restore_backups(backups: dict[Path, Path]) -> list[Path]: + restored: list[Path] = [] + for original, backup in backups.items(): + if not backup.is_file(): + continue + shutil.copy2(backup, original) + restored.append(original) + return restored + + def main() -> int: if env_var_enabled("HERMES_SKIP_CONFIG_MIGRATION"): print("[config-migrate] HERMES_SKIP_CONFIG_MIGRATION is set; skipping config migration") @@ -50,12 +60,30 @@ def main() -> int: return 0 backups = _backup_existing((get_config_path(), get_env_path())) - backup_text = ", ".join(str(path) for path in backups) if backups else "none" + backup_text = ", ".join(str(path) for path in backups.values()) if backups else "none" print( f"[config-migrate] Migrating config schema {current_ver} -> {latest_ver}; " f"backups: {backup_text}" ) - migrate_config(interactive=False, quiet=False) + try: + migrate_config(interactive=False, quiet=False) + except Exception: + restored = _restore_backups(backups) + if restored: + print( + "[config-migrate] Migration failed; restored " + + ", ".join(str(path) for path in restored) + ) + raise + + post_ver, _ = check_config_version() + if post_ver < latest_ver: + restored = _restore_backups(backups) + restored_text = ", ".join(str(path) for path in restored) if restored else "none" + raise RuntimeError( + f"migration did not advance config version to {latest_ver} " + f"(still {post_ver}); restored: {restored_text}" + ) return 0 diff --git a/tests/tools/test_docker_config_migrate.py b/tests/tools/test_docker_config_migrate.py index a7fe193818d..05e4f38bcc6 100644 --- a/tests/tools/test_docker_config_migrate.py +++ b/tests/tools/test_docker_config_migrate.py @@ -1,10 +1,12 @@ from __future__ import annotations +import importlib.util import os import subprocess import sys from pathlib import Path +import pytest import yaml from hermes_cli.config import DEFAULT_CONFIG @@ -13,6 +15,14 @@ REPO_ROOT = Path(__file__).resolve().parents[2] SCRIPT = REPO_ROOT / "scripts" / "docker_config_migrate.py" +def _load_script_module(): + spec = importlib.util.spec_from_file_location("docker_config_migrate_test_module", SCRIPT) + assert spec and spec.loader + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def _run_migration(hermes_home: Path, **env_overrides: str) -> subprocess.CompletedProcess[str]: env = os.environ.copy() env.update( @@ -132,3 +142,63 @@ def test_docker_config_migrate_skip_env_leaves_config_unchanged(tmp_path: Path) assert "skipping config migration" in proc.stdout assert config_path.read_text(encoding="utf-8") == original assert not list(tmp_path.glob("*.bak-*")) + + +def test_docker_config_migrate_restores_backups_after_failed_migration( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module = _load_script_module() + config_path = tmp_path / "config.yaml" + env_path = tmp_path / ".env" + original_config = yaml.safe_dump({"_config_version": 11, "gateway": {"provider": "telegram"}}) + original_env = "TELEGRAM_BOT_TOKEN=test-token\n" + config_path.write_text(original_config, encoding="utf-8") + env_path.write_text(original_env, encoding="utf-8") + + monkeypatch.setattr(module, "check_config_version", lambda: (11, DEFAULT_CONFIG["_config_version"])) + monkeypatch.setattr(module, "get_config_path", lambda: config_path) + monkeypatch.setattr(module, "get_env_path", lambda: env_path) + + def _failing_migrate(*, interactive: bool, quiet: bool): + config_path.write_text("gateway: {}\n", encoding="utf-8") + env_path.write_text("", encoding="utf-8") + raise RuntimeError("boom") + + monkeypatch.setattr(module, "migrate_config", _failing_migrate) + + with pytest.raises(RuntimeError, match="boom"): + module.main() + + assert config_path.read_text(encoding="utf-8") == original_config + assert env_path.read_text(encoding="utf-8") == original_env + assert list(tmp_path.glob("config.yaml.bak-*")) + assert list(tmp_path.glob(".env.bak-*")) + + +def test_docker_config_migrate_restores_backups_when_version_does_not_advance( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + module = _load_script_module() + config_path = tmp_path / "config.yaml" + env_path = tmp_path / ".env" + original_config = yaml.safe_dump({"_config_version": 11, "gateway": {"provider": "telegram"}}) + original_env = "TELEGRAM_BOT_TOKEN=test-token\n" + config_path.write_text(original_config, encoding="utf-8") + env_path.write_text(original_env, encoding="utf-8") + + calls = iter([(11, DEFAULT_CONFIG["_config_version"]), (11, DEFAULT_CONFIG["_config_version"])]) + monkeypatch.setattr(module, "check_config_version", lambda: next(calls)) + monkeypatch.setattr(module, "get_config_path", lambda: config_path) + monkeypatch.setattr(module, "get_env_path", lambda: env_path) + + def _non_advancing_migrate(*, interactive: bool, quiet: bool): + config_path.write_text("gateway: {}\n", encoding="utf-8") + env_path.write_text("", encoding="utf-8") + + monkeypatch.setattr(module, "migrate_config", _non_advancing_migrate) + + with pytest.raises(RuntimeError, match="did not advance config version"): + module.main() + + assert config_path.read_text(encoding="utf-8") == original_config + assert env_path.read_text(encoding="utf-8") == original_env