diff --git a/batch_runner.py b/batch_runner.py index f1f4f464f..b95a5cc82 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -29,8 +29,6 @@ from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from multiprocessing import Pool, Lock import traceback -import tempfile - from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn from rich.console import Console import fire @@ -703,32 +701,12 @@ class BatchRunner: """ checkpoint_data["last_updated"] = datetime.now().isoformat() - def _atomic_write(): - """Write checkpoint atomically (temp file + replace) to avoid corruption on crash.""" - self.checkpoint_file.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp( - dir=str(self.checkpoint_file.parent), - prefix='.checkpoint_', - suffix='.tmp', - ) - try: - with os.fdopen(fd, 'w', encoding='utf-8') as f: - json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, self.checkpoint_file) - except BaseException: - try: - os.unlink(tmp_path) - except OSError: - pass - raise - + from utils import atomic_json_write if lock: with lock: - _atomic_write() + atomic_json_write(self.checkpoint_file, checkpoint_data) else: - _atomic_write() + atomic_json_write(self.checkpoint_file, checkpoint_data) def _scan_completed_prompts_by_content(self) -> set: """ diff --git a/tests/test_atomic_json_write.py b/tests/test_atomic_json_write.py new file mode 100644 index 000000000..681b7d8a8 --- /dev/null +++ b/tests/test_atomic_json_write.py @@ -0,0 +1,132 @@ +"""Tests for utils.atomic_json_write — crash-safe JSON file writes.""" + +import json +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from utils import atomic_json_write + + +class TestAtomicJsonWrite: + """Core atomic write behavior.""" + + def test_writes_valid_json(self, tmp_path): + target = tmp_path / "data.json" + data = {"key": "value", "nested": {"a": 1}} + atomic_json_write(target, data) + + result = json.loads(target.read_text(encoding="utf-8")) + assert result == data + + def test_creates_parent_directories(self, tmp_path): + target = tmp_path / "deep" / "nested" / "dir" / "data.json" + atomic_json_write(target, {"ok": True}) + + assert target.exists() + assert json.loads(target.read_text())["ok"] is True + + def test_overwrites_existing_file(self, tmp_path): + target = tmp_path / "data.json" + target.write_text('{"old": true}') + + atomic_json_write(target, {"new": True}) + result = json.loads(target.read_text()) + assert result == {"new": True} + + def test_preserves_original_on_serialization_error(self, tmp_path): + target = tmp_path / "data.json" + original = {"preserved": True} + target.write_text(json.dumps(original)) + + # Try to write non-serializable data — should fail + with pytest.raises(TypeError): + atomic_json_write(target, {"bad": object()}) + + # Original file should be untouched + result = json.loads(target.read_text()) + assert result == original + + def test_no_leftover_temp_files_on_success(self, tmp_path): + target = tmp_path / "data.json" + atomic_json_write(target, [1, 2, 3]) + + # No .tmp files should be left behind + tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name] + assert len(tmp_files) == 0 + assert target.exists() + + def test_no_leftover_temp_files_on_failure(self, tmp_path): + target = tmp_path / "data.json" + + with pytest.raises(TypeError): + atomic_json_write(target, {"bad": object()}) + + # No temp files should be left behind + tmp_files = [f for f in tmp_path.iterdir() if ".tmp" in f.name] + assert len(tmp_files) == 0 + + def test_accepts_string_path(self, tmp_path): + target = str(tmp_path / "string_path.json") + atomic_json_write(target, {"string": True}) + + result = json.loads(Path(target).read_text()) + assert result == {"string": True} + + def test_writes_list_data(self, tmp_path): + target = tmp_path / "list.json" + data = [1, "two", {"three": 3}] + atomic_json_write(target, data) + + result = json.loads(target.read_text()) + assert result == data + + def test_empty_list(self, tmp_path): + target = tmp_path / "empty.json" + atomic_json_write(target, []) + + result = json.loads(target.read_text()) + assert result == [] + + def test_custom_indent(self, tmp_path): + target = tmp_path / "custom.json" + atomic_json_write(target, {"a": 1}, indent=4) + + text = target.read_text() + assert ' "a"' in text # 4-space indent + + def test_unicode_content(self, tmp_path): + target = tmp_path / "unicode.json" + data = {"emoji": "🎉", "japanese": "日本語"} + atomic_json_write(target, data) + + result = json.loads(target.read_text(encoding="utf-8")) + assert result["emoji"] == "🎉" + assert result["japanese"] == "日本語" + + def test_concurrent_writes_dont_corrupt(self, tmp_path): + """Multiple rapid writes should each produce valid JSON.""" + import threading + + target = tmp_path / "concurrent.json" + errors = [] + + def writer(n): + try: + atomic_json_write(target, {"writer": n, "data": list(range(100))}) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=writer, args=(i,)) for i in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors + # File should contain valid JSON from one of the writers + result = json.loads(target.read_text()) + assert "writer" in result + assert len(result["data"]) == 100 diff --git a/tests/test_batch_runner_checkpoint.py b/tests/test_batch_runner_checkpoint.py new file mode 100644 index 000000000..ebf9bce7e --- /dev/null +++ b/tests/test_batch_runner_checkpoint.py @@ -0,0 +1,159 @@ +"""Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity.""" + +import json +import os +from pathlib import Path +from multiprocessing import Lock +from unittest.mock import patch, MagicMock + +import pytest + +# batch_runner uses relative imports, ensure project root is on path +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from batch_runner import BatchRunner + + +@pytest.fixture +def runner(tmp_path): + """Create a BatchRunner with all paths pointing at tmp_path.""" + prompts_file = tmp_path / "prompts.jsonl" + prompts_file.write_text("") + output_file = tmp_path / "output.jsonl" + checkpoint_file = tmp_path / "checkpoint.json" + r = BatchRunner.__new__(BatchRunner) + r.run_name = "test_run" + r.checkpoint_file = checkpoint_file + r.output_file = output_file + r.prompts_file = prompts_file + return r + + +class TestSaveCheckpoint: + """Verify _save_checkpoint writes valid, atomic JSON.""" + + def test_writes_valid_json(self, runner): + data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}} + runner._save_checkpoint(data) + + result = json.loads(runner.checkpoint_file.read_text()) + assert result["run_name"] == "test" + assert result["completed_prompts"] == [1, 2, 3] + + def test_adds_last_updated(self, runner): + data = {"run_name": "test", "completed_prompts": []} + runner._save_checkpoint(data) + + result = json.loads(runner.checkpoint_file.read_text()) + assert "last_updated" in result + assert result["last_updated"] is not None + + def test_overwrites_previous_checkpoint(self, runner): + runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]}) + runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]}) + + result = json.loads(runner.checkpoint_file.read_text()) + assert result["completed_prompts"] == [1, 2, 3] + + def test_with_lock(self, runner): + lock = Lock() + data = {"run_name": "test", "completed_prompts": [42]} + runner._save_checkpoint(data, lock=lock) + + result = json.loads(runner.checkpoint_file.read_text()) + assert result["completed_prompts"] == [42] + + def test_without_lock(self, runner): + data = {"run_name": "test", "completed_prompts": [99]} + runner._save_checkpoint(data, lock=None) + + result = json.loads(runner.checkpoint_file.read_text()) + assert result["completed_prompts"] == [99] + + def test_creates_parent_dirs(self, tmp_path): + runner_deep = BatchRunner.__new__(BatchRunner) + runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json" + + data = {"run_name": "test", "completed_prompts": []} + runner_deep._save_checkpoint(data) + + assert runner_deep.checkpoint_file.exists() + + def test_no_temp_files_left(self, runner): + runner._save_checkpoint({"run_name": "test", "completed_prompts": []}) + + tmp_files = [f for f in runner.checkpoint_file.parent.iterdir() + if ".tmp" in f.name] + assert len(tmp_files) == 0 + + +class TestLoadCheckpoint: + """Verify _load_checkpoint reads existing data or returns defaults.""" + + def test_returns_empty_when_no_file(self, runner): + result = runner._load_checkpoint() + assert result.get("completed_prompts", []) == [] + + def test_loads_existing_checkpoint(self, runner): + data = {"run_name": "test_run", "completed_prompts": [5, 10, 15], + "batch_stats": {"0": {"processed": 3}}} + runner.checkpoint_file.write_text(json.dumps(data)) + + result = runner._load_checkpoint() + assert result["completed_prompts"] == [5, 10, 15] + assert result["batch_stats"]["0"]["processed"] == 3 + + def test_handles_corrupt_json(self, runner): + runner.checkpoint_file.write_text("{broken json!!") + + result = runner._load_checkpoint() + # Should return empty/default, not crash + assert isinstance(result, dict) + + +class TestResumePreservesProgress: + """Verify that initializing a run with resume=True loads prior checkpoint.""" + + def test_completed_prompts_loaded_from_checkpoint(self, runner): + # Simulate a prior run that completed prompts 0-4 + prior = { + "run_name": "test_run", + "completed_prompts": [0, 1, 2, 3, 4], + "batch_stats": {"0": {"processed": 5}}, + "last_updated": "2026-01-01T00:00:00", + } + runner.checkpoint_file.write_text(json.dumps(prior)) + + # Load checkpoint like run() does + checkpoint_data = runner._load_checkpoint() + if checkpoint_data.get("run_name") != runner.run_name: + checkpoint_data = { + "run_name": runner.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None, + } + + completed_set = set(checkpoint_data.get("completed_prompts", [])) + assert completed_set == {0, 1, 2, 3, 4} + + def test_different_run_name_starts_fresh(self, runner): + prior = { + "run_name": "different_run", + "completed_prompts": [0, 1, 2], + "batch_stats": {}, + } + runner.checkpoint_file.write_text(json.dumps(prior)) + + checkpoint_data = runner._load_checkpoint() + if checkpoint_data.get("run_name") != runner.run_name: + checkpoint_data = { + "run_name": runner.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None, + } + + assert checkpoint_data["completed_prompts"] == [] + assert checkpoint_data["run_name"] == "test_run" diff --git a/tools/process_registry.py b/tools/process_registry.py index 0d2f88c42..948f2a4f3 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -37,7 +37,6 @@ import shlex import shutil import signal import subprocess -import tempfile import threading import time import uuid @@ -707,25 +706,9 @@ class ProcessRegistry: "session_key": s.session_key, }) - # Atomic write: temp file + os.replace to avoid corruption on crash - CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp( - dir=str(CHECKPOINT_PATH.parent), - prefix='.checkpoint_', - suffix='.tmp', - ) - try: - with os.fdopen(fd, 'w', encoding='utf-8') as f: - json.dump(entries, f, indent=2, ensure_ascii=False) - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, CHECKPOINT_PATH) - except BaseException: - try: - os.unlink(tmp_path) - except OSError: - pass - raise + # Atomic write to avoid corruption on crash + from utils import atomic_json_write + atomic_json_write(CHECKPOINT_PATH, entries) except Exception as e: logger.debug("Failed to write checkpoint file: %s", e, exc_info=True) @@ -774,26 +757,9 @@ class ProcessRegistry: logger.info("Recovered detached process: %s (pid=%d)", session.command[:60], pid) # Clear the checkpoint (will be rewritten as processes finish) - # Use atomic write to avoid corruption try: - CHECKPOINT_PATH.parent.mkdir(parents=True, exist_ok=True) - fd, tmp_path = tempfile.mkstemp( - dir=str(CHECKPOINT_PATH.parent), - prefix='.checkpoint_', - suffix='.tmp', - ) - try: - with os.fdopen(fd, 'w', encoding='utf-8') as f: - f.write("[]") - f.flush() - os.fsync(f.fileno()) - os.replace(tmp_path, CHECKPOINT_PATH) - except BaseException: - try: - os.unlink(tmp_path) - except OSError: - pass - raise + from utils import atomic_json_write + atomic_json_write(CHECKPOINT_PATH, []) except Exception as e: logger.debug("Could not clear checkpoint file: %s", e, exc_info=True) diff --git a/utils.py b/utils.py new file mode 100644 index 000000000..9c8b5e8c6 --- /dev/null +++ b/utils.py @@ -0,0 +1,41 @@ +"""Shared utility functions for hermes-agent.""" + +import json +import os +import tempfile +from pathlib import Path +from typing import Any, Union + + +def atomic_json_write(path: Union[str, Path], data: Any, *, indent: int = 2) -> None: + """Write JSON data to a file atomically. + + Uses temp file + fsync + os.replace to ensure the target file is never + left in a partially-written state. If the process crashes mid-write, + the previous version of the file remains intact. + + Args: + path: Target file path (will be created or overwritten). + data: JSON-serializable data to write. + indent: JSON indentation (default 2). + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + fd, tmp_path = tempfile.mkstemp( + dir=str(path.parent), + prefix=f".{path.stem}_", + suffix=".tmp", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + json.dump(data, f, indent=indent, ensure_ascii=False) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, path) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise