From ef9355455b4141e694a03768a90ebc6ad4691e90 Mon Sep 17 00:00:00 2001 From: Teknium Date: Fri, 24 Apr 2026 14:26:30 -0700 Subject: [PATCH] test: regression coverage for checkpoint dedup and inf/nan coercion Covers the two bugs salvaged from PR #15161: - test_batch_runner_checkpoint: TestFinalCheckpointNoDuplicates asserts the final aggregated completed_prompts list has no duplicate indices, and keeps a sanity anchor test documenting the pre-fix pattern so a future refactor that re-introduces it is caught immediately. - test_model_tools: TestCoerceNumberInfNan asserts _coerce_number returns the original string for inf/-inf/nan/Infinity inputs and that the result round-trips through strict (allow_nan=False) json.dumps. --- tests/test_batch_runner_checkpoint.py | 64 +++++++++++++++++++++++++++ tests/test_model_tools.py | 43 ++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/tests/test_batch_runner_checkpoint.py b/tests/test_batch_runner_checkpoint.py index 440e421cc..526c09556 100644 --- a/tests/test_batch_runner_checkpoint.py +++ b/tests/test_batch_runner_checkpoint.py @@ -186,3 +186,67 @@ class TestBatchWorkerResumeBehavior: assert result["discarded_no_reasoning"] == 1 assert result["completed_prompts"] == [0] assert not batch_file.exists() or batch_file.read_text() == "" + + +class TestFinalCheckpointNoDuplicates: + """Regression: the final checkpoint must not contain duplicate prompt + indices. + + Before PR #15161, `run()` populated `completed_prompts_set` incrementally + as each batch completed, then at the end built `all_completed_prompts = + list(completed_prompts_set)` AND extended it again with every batch's + `completed_prompts` — double-counting every index. + """ + + def _simulate_final_aggregation_fixed(self, batch_results): + """Mirror the fixed code path in batch_runner.run().""" + completed_prompts_set = set() + for result in batch_results: + completed_prompts_set.update(result.get("completed_prompts", [])) + # This is what the fixed code now writes to the checkpoint: + return sorted(completed_prompts_set) + + def test_no_duplicates_in_final_list(self): + batch_results = [ + {"completed_prompts": [0, 1, 2]}, + {"completed_prompts": [3, 4]}, + {"completed_prompts": [5]}, + ] + final = self._simulate_final_aggregation_fixed(batch_results) + assert final == [0, 1, 2, 3, 4, 5] + assert len(final) == len(set(final)) # no duplicates + + def test_persisted_checkpoint_has_unique_prompts(self, runner): + """Write what run()'s fixed aggregation produces to disk; the file + must load back with no duplicate indices.""" + batch_results = [ + {"completed_prompts": [0, 1]}, + {"completed_prompts": [2, 3]}, + ] + final = self._simulate_final_aggregation_fixed(batch_results) + runner._save_checkpoint({ + "run_name": runner.run_name, + "completed_prompts": final, + "batch_stats": {}, + }) + loaded = json.loads(runner.checkpoint_file.read_text()) + cp = loaded["completed_prompts"] + assert cp == sorted(set(cp)) + assert len(cp) == 4 + + def test_old_buggy_pattern_would_have_duplicates(self): + """Document the bug this PR fixes: the old code shape produced + duplicates. Kept as a sanity anchor so a future refactor that + re-introduces the pattern is immediately visible.""" + completed_prompts_set = set() + results = [] + for batch in ({"completed_prompts": [0, 1, 2]}, + {"completed_prompts": [3, 4]}): + completed_prompts_set.update(batch["completed_prompts"]) + results.append(batch) + # Buggy aggregation (pre-fix): + buggy = list(completed_prompts_set) + for br in results: + buggy.extend(br.get("completed_prompts", [])) + # Every index appears twice + assert len(buggy) == 2 * len(set(buggy)) diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index 12654e350..9c2764daf 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -231,3 +231,46 @@ class TestBackwardCompat: def test_tool_to_toolset_map(self): assert isinstance(TOOL_TO_TOOLSET_MAP, dict) assert len(TOOL_TO_TOOLSET_MAP) > 0 + + +# ========================================================================= +# _coerce_number — inf / nan must fall through to the original string +# (regression: fix: eliminate duplicate checkpoint entries and JSON-unsafe coercion) +# ========================================================================= + +class TestCoerceNumberInfNan: + """_coerce_number must honor its documented contract ("Returns original + string on failure") for inf/nan inputs, because float('inf') and + float('nan') are not JSON-compliant under strict serialization.""" + + def test_inf_returns_original_string(self): + from model_tools import _coerce_number + assert _coerce_number("inf") == "inf" + + def test_negative_inf_returns_original_string(self): + from model_tools import _coerce_number + assert _coerce_number("-inf") == "-inf" + + def test_nan_returns_original_string(self): + from model_tools import _coerce_number + assert _coerce_number("nan") == "nan" + + def test_infinity_spelling_returns_original_string(self): + from model_tools import _coerce_number + # Python's float() parses "Infinity" too — still not JSON-safe. + assert _coerce_number("Infinity") == "Infinity" + + def test_coerced_result_is_strict_json_safe(self): + """Whatever _coerce_number returns for inf/nan must round-trip + through strict (allow_nan=False) json.dumps without raising.""" + from model_tools import _coerce_number + for s in ("inf", "-inf", "nan", "Infinity"): + result = _coerce_number(s) + json.dumps({"x": result}, allow_nan=False) # must not raise + + def test_normal_numbers_still_coerce(self): + """Guard against over-correction — real numbers still coerce.""" + from model_tools import _coerce_number + assert _coerce_number("42") == 42 + assert _coerce_number("3.14") == 3.14 + assert _coerce_number("1e3") == 1000