hermes-agent/tests/test_batch_runner_checkpoint.py
Teknium ef9355455b test: regression coverage for checkpoint dedup and inf/nan coercion
Covers the two bugs salvaged from PR #15161:

- test_batch_runner_checkpoint: TestFinalCheckpointNoDuplicates asserts
  the final aggregated completed_prompts list has no duplicate indices,
  and keeps a sanity anchor test documenting the pre-fix pattern so a
  future refactor that re-introduces it is caught immediately.

- test_model_tools: TestCoerceNumberInfNan asserts _coerce_number
  returns the original string for inf/-inf/nan/Infinity inputs and that
  the result round-trips through strict (allow_nan=False) json.dumps.
2026-04-24 14:32:21 -07:00

252 lines
9.2 KiB
Python

"""Tests for batch_runner checkpoint behavior — incremental writes, resume, atomicity."""
import json
import os
from pathlib import Path
from threading import Lock
from unittest.mock import patch, MagicMock
import pytest
# batch_runner uses relative imports, ensure project root is on path
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from batch_runner import BatchRunner, _process_batch_worker
@pytest.fixture
def runner(tmp_path):
"""Create a BatchRunner with all paths pointing at tmp_path."""
prompts_file = tmp_path / "prompts.jsonl"
prompts_file.write_text("")
output_file = tmp_path / "output.jsonl"
checkpoint_file = tmp_path / "checkpoint.json"
r = BatchRunner.__new__(BatchRunner)
r.run_name = "test_run"
r.checkpoint_file = checkpoint_file
r.output_file = output_file
r.prompts_file = prompts_file
return r
class TestSaveCheckpoint:
"""Verify _save_checkpoint writes valid, atomic JSON."""
def test_writes_valid_json(self, runner):
data = {"run_name": "test", "completed_prompts": [1, 2, 3], "batch_stats": {}}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert result["run_name"] == "test"
assert result["completed_prompts"] == [1, 2, 3]
def test_adds_last_updated(self, runner):
data = {"run_name": "test", "completed_prompts": []}
runner._save_checkpoint(data)
result = json.loads(runner.checkpoint_file.read_text())
assert "last_updated" in result
assert result["last_updated"] is not None
def test_overwrites_previous_checkpoint(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1]})
runner._save_checkpoint({"run_name": "test", "completed_prompts": [1, 2, 3]})
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [1, 2, 3]
def test_with_lock(self, runner):
lock = Lock()
data = {"run_name": "test", "completed_prompts": [42]}
runner._save_checkpoint(data, lock=lock)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [42]
def test_without_lock(self, runner):
data = {"run_name": "test", "completed_prompts": [99]}
runner._save_checkpoint(data, lock=None)
result = json.loads(runner.checkpoint_file.read_text())
assert result["completed_prompts"] == [99]
def test_creates_parent_dirs(self, tmp_path):
runner_deep = BatchRunner.__new__(BatchRunner)
runner_deep.checkpoint_file = tmp_path / "deep" / "nested" / "checkpoint.json"
data = {"run_name": "test", "completed_prompts": []}
runner_deep._save_checkpoint(data)
assert runner_deep.checkpoint_file.exists()
def test_no_temp_files_left(self, runner):
runner._save_checkpoint({"run_name": "test", "completed_prompts": []})
tmp_files = [f for f in runner.checkpoint_file.parent.iterdir()
if ".tmp" in f.name]
assert len(tmp_files) == 0
class TestLoadCheckpoint:
"""Verify _load_checkpoint reads existing data or returns defaults."""
def test_returns_empty_when_no_file(self, runner):
result = runner._load_checkpoint()
assert result.get("completed_prompts", []) == []
def test_loads_existing_checkpoint(self, runner):
data = {"run_name": "test_run", "completed_prompts": [5, 10, 15],
"batch_stats": {"0": {"processed": 3}}}
runner.checkpoint_file.write_text(json.dumps(data))
result = runner._load_checkpoint()
assert result["completed_prompts"] == [5, 10, 15]
assert result["batch_stats"]["0"]["processed"] == 3
def test_handles_corrupt_json(self, runner):
runner.checkpoint_file.write_text("{broken json!!")
result = runner._load_checkpoint()
# Should return empty/default, not crash
assert isinstance(result, dict)
class TestResumePreservesProgress:
"""Verify that initializing a run with resume=True loads prior checkpoint."""
def test_completed_prompts_loaded_from_checkpoint(self, runner):
# Simulate a prior run that completed prompts 0-4
prior = {
"run_name": "test_run",
"completed_prompts": [0, 1, 2, 3, 4],
"batch_stats": {"0": {"processed": 5}},
"last_updated": "2026-01-01T00:00:00",
}
runner.checkpoint_file.write_text(json.dumps(prior))
# Load checkpoint like run() does
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
completed_set = set(checkpoint_data.get("completed_prompts", []))
assert completed_set == {0, 1, 2, 3, 4}
def test_different_run_name_starts_fresh(self, runner):
prior = {
"run_name": "different_run",
"completed_prompts": [0, 1, 2],
"batch_stats": {},
}
runner.checkpoint_file.write_text(json.dumps(prior))
checkpoint_data = runner._load_checkpoint()
if checkpoint_data.get("run_name") != runner.run_name:
checkpoint_data = {
"run_name": runner.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None,
}
assert checkpoint_data["completed_prompts"] == []
assert checkpoint_data["run_name"] == "test_run"
class TestBatchWorkerResumeBehavior:
def test_discarded_no_reasoning_prompts_are_marked_completed(self, tmp_path, monkeypatch):
batch_file = tmp_path / "batch_1.jsonl"
prompt_result = {
"success": True,
"trajectory": [{"role": "assistant", "content": "x"}],
"reasoning_stats": {"has_any_reasoning": False},
"tool_stats": {},
"metadata": {},
"completed": True,
"api_calls": 1,
"toolsets_used": [],
}
monkeypatch.setattr("batch_runner._process_single_prompt", lambda *args, **kwargs: prompt_result)
result = _process_batch_worker((
1,
[(0, {"prompt": "hi"})],
tmp_path,
set(),
{"verbose": False},
))
assert result["discarded_no_reasoning"] == 1
assert result["completed_prompts"] == [0]
assert not batch_file.exists() or batch_file.read_text() == ""
class TestFinalCheckpointNoDuplicates:
"""Regression: the final checkpoint must not contain duplicate prompt
indices.
Before PR #15161, `run()` populated `completed_prompts_set` incrementally
as each batch completed, then at the end built `all_completed_prompts =
list(completed_prompts_set)` AND extended it again with every batch's
`completed_prompts` — double-counting every index.
"""
def _simulate_final_aggregation_fixed(self, batch_results):
"""Mirror the fixed code path in batch_runner.run()."""
completed_prompts_set = set()
for result in batch_results:
completed_prompts_set.update(result.get("completed_prompts", []))
# This is what the fixed code now writes to the checkpoint:
return sorted(completed_prompts_set)
def test_no_duplicates_in_final_list(self):
batch_results = [
{"completed_prompts": [0, 1, 2]},
{"completed_prompts": [3, 4]},
{"completed_prompts": [5]},
]
final = self._simulate_final_aggregation_fixed(batch_results)
assert final == [0, 1, 2, 3, 4, 5]
assert len(final) == len(set(final)) # no duplicates
def test_persisted_checkpoint_has_unique_prompts(self, runner):
"""Write what run()'s fixed aggregation produces to disk; the file
must load back with no duplicate indices."""
batch_results = [
{"completed_prompts": [0, 1]},
{"completed_prompts": [2, 3]},
]
final = self._simulate_final_aggregation_fixed(batch_results)
runner._save_checkpoint({
"run_name": runner.run_name,
"completed_prompts": final,
"batch_stats": {},
})
loaded = json.loads(runner.checkpoint_file.read_text())
cp = loaded["completed_prompts"]
assert cp == sorted(set(cp))
assert len(cp) == 4
def test_old_buggy_pattern_would_have_duplicates(self):
"""Document the bug this PR fixes: the old code shape produced
duplicates. Kept as a sanity anchor so a future refactor that
re-introduces the pattern is immediately visible."""
completed_prompts_set = set()
results = []
for batch in ({"completed_prompts": [0, 1, 2]},
{"completed_prompts": [3, 4]}):
completed_prompts_set.update(batch["completed_prompts"])
results.append(batch)
# Buggy aggregation (pre-fix):
buggy = list(completed_prompts_set)
for br in results:
buggy.extend(br.get("completed_prompts", []))
# Every index appears twice
assert len(buggy) == 2 * len(set(buggy))