test: regression coverage for checkpoint dedup and inf/nan coercion

Covers the two bugs salvaged from PR #15161:

- test_batch_runner_checkpoint: TestFinalCheckpointNoDuplicates asserts
  the final aggregated completed_prompts list has no duplicate indices,
  and keeps a sanity anchor test documenting the pre-fix pattern so a
  future refactor that re-introduces it is caught immediately.

- test_model_tools: TestCoerceNumberInfNan asserts _coerce_number
  returns the original string for inf/-inf/nan/Infinity inputs and that
  the result round-trips through strict (allow_nan=False) json.dumps.
This commit is contained in:
Teknium 2026-04-24 14:26:30 -07:00 committed by Teknium
parent dbdefa43c8
commit ef9355455b
2 changed files with 107 additions and 0 deletions

View file

@ -186,3 +186,67 @@ class TestBatchWorkerResumeBehavior:
assert result["discarded_no_reasoning"] == 1
assert result["completed_prompts"] == [0]
assert not batch_file.exists() or batch_file.read_text() == ""
class TestFinalCheckpointNoDuplicates:
"""Regression: the final checkpoint must not contain duplicate prompt
indices.
Before PR #15161, `run()` populated `completed_prompts_set` incrementally
as each batch completed, then at the end built `all_completed_prompts =
list(completed_prompts_set)` AND extended it again with every batch's
`completed_prompts` double-counting every index.
"""
def _simulate_final_aggregation_fixed(self, batch_results):
"""Mirror the fixed code path in batch_runner.run()."""
completed_prompts_set = set()
for result in batch_results:
completed_prompts_set.update(result.get("completed_prompts", []))
# This is what the fixed code now writes to the checkpoint:
return sorted(completed_prompts_set)
def test_no_duplicates_in_final_list(self):
batch_results = [
{"completed_prompts": [0, 1, 2]},
{"completed_prompts": [3, 4]},
{"completed_prompts": [5]},
]
final = self._simulate_final_aggregation_fixed(batch_results)
assert final == [0, 1, 2, 3, 4, 5]
assert len(final) == len(set(final)) # no duplicates
def test_persisted_checkpoint_has_unique_prompts(self, runner):
"""Write what run()'s fixed aggregation produces to disk; the file
must load back with no duplicate indices."""
batch_results = [
{"completed_prompts": [0, 1]},
{"completed_prompts": [2, 3]},
]
final = self._simulate_final_aggregation_fixed(batch_results)
runner._save_checkpoint({
"run_name": runner.run_name,
"completed_prompts": final,
"batch_stats": {},
})
loaded = json.loads(runner.checkpoint_file.read_text())
cp = loaded["completed_prompts"]
assert cp == sorted(set(cp))
assert len(cp) == 4
def test_old_buggy_pattern_would_have_duplicates(self):
"""Document the bug this PR fixes: the old code shape produced
duplicates. Kept as a sanity anchor so a future refactor that
re-introduces the pattern is immediately visible."""
completed_prompts_set = set()
results = []
for batch in ({"completed_prompts": [0, 1, 2]},
{"completed_prompts": [3, 4]}):
completed_prompts_set.update(batch["completed_prompts"])
results.append(batch)
# Buggy aggregation (pre-fix):
buggy = list(completed_prompts_set)
for br in results:
buggy.extend(br.get("completed_prompts", []))
# Every index appears twice
assert len(buggy) == 2 * len(set(buggy))