change(ci): slice files in matrix job

avoid duplicating work, avoid file discovery on each job
This commit is contained in:
ethernet 2026-06-26 21:52:44 -04:00
parent 1a75387fa8
commit dd0e4ab81a
2 changed files with 144 additions and 92 deletions

View file

@ -493,38 +493,27 @@ def _save_durations(
path.write_text(json.dumps(data, indent=2, sort_keys=True) + "\n")
def _slice_files(
def _compute_lpt_slices(
files: List[Path],
slice_index: int,
slice_count: int,
durations: dict[str, float],
repo_root: Path,
) -> List[Path]:
"""Return the subset of *files* belonging to slice *slice_index*.
) -> List[List[Path]]:
"""Distribute files across N slices using LPT (Longest Processing Time first).
Uses **Longest Processing Time first** (LPT) distribution: sort files
by estimated duration descending, then greedily assign each file to
the slice with the smallest accumulated time so far. This minimizes
the makespan (max slice duration) and keeps CI jobs balanced.
Sorts files by estimated duration descending, then greedily assigns each
file to the slice with the smallest accumulated time so far. This
minimizes the makespan (max slice duration) and keeps CI jobs balanced.
Files with no cached duration get a default estimate of 2.0s (roughly
the P50 from profiling). This means first-time ``--slice`` runs
(no cache) still get reasonable distribution, and new files don't
all land in one slice.
the P50 from profiling). This means first-time runs (no cache) still
get reasonable distribution, and new files don't all land in one slice.
``slice_index`` is 1-indexed (1..slice_count) for ergonomics
``--slice 1/4`` reads more naturally than ``--slice 0/4``.
Returns a list of N file-lists, one per slice (0-indexed).
"""
if slice_count < 2:
return files
if not (1 <= slice_index <= slice_count):
print(
f"error: --slice index must be 1..{slice_count}, got {slice_index}",
file=sys.stderr,
)
sys.exit(2)
return [files]
# Build (file, estimated_duration) pairs.
default_dur = 2.0
file_durs: List[Tuple[Path, float]] = []
for f in files:
@ -541,15 +530,47 @@ def _slice_files(
bucket_totals: List[float] = [0.0] * slice_count
for f, dur in file_durs:
# Find the least-loaded bucket.
min_idx = min(range(slice_count), key=lambda i: bucket_totals[i])
bucket_files[min_idx].append(f)
bucket_totals[min_idx] += dur
# Print slice summary for visibility.
return bucket_files
def _slice_files(
files: List[Path],
slice_index: int,
slice_count: int,
durations: dict[str, float],
repo_root: Path,
) -> List[Path]:
"""Return the subset of *files* belonging to slice *slice_index*.
Uses :func:`_compute_lpt_slices` for LPT distribution.
``slice_index`` is 1-indexed (1..slice_count) for ergonomics
``--slice 1/4`` reads more naturally than ``--slice 0/4``.
"""
if slice_count < 2:
return files
if not (1 <= slice_index <= slice_count):
print(
f"error: --slice index must be 1..{slice_count}, got {slice_index}",
file=sys.stderr,
)
sys.exit(2)
bucket_files = _compute_lpt_slices(files, slice_count, durations, repo_root)
target = bucket_files[slice_index - 1]
target_dur = bucket_totals[slice_index - 1]
total_dur = sum(bucket_totals)
target_dur = sum(
durations.get(_format_file(f, repo_root), 2.0) for f in target
)
total_dur = sum(
durations.get(_format_file(f, repo_root), 2.0)
for bucket in bucket_files
for f in bucket
)
print(
f"Slice {slice_index}/{slice_count}: {len(target)} files "
f"(~{target_dur:.0f}s estimated of {total_dur:.0f}s total)",
@ -604,6 +625,27 @@ def main() -> int:
"Env: HERMES_TEST_SLICE (format: I/N)."
),
)
parser.add_argument(
"--generate-slices",
metavar="N",
type=int,
help=(
"Discover test files, distribute them across N slices using "
"LPT on cached durations, and print a JSON matrix to stdout "
"then exit (no tests run). The JSON has the shape "
"'{\"slices\": [{\"index\": 1, \"files\": [\"tests/foo.py\", ...]}, ...]}' "
"so the CI generate job can feed it directly into a matrix."
),
)
parser.add_argument(
"--files",
metavar="LIST",
help=(
"Explicit colon-separated list of test files to run. Bypasses "
"discovery entirely — used by CI matrix jobs that receive their "
"file list from the generate job."
),
)
parser.add_argument(
"paths_positional",
nargs="*",
@ -642,26 +684,48 @@ def main() -> int:
repo_root = Path(__file__).resolve().parent.parent
# Resolve discovery roots: positional path args override --paths if any
# were supplied, otherwise --paths (which itself defaults to 'tests').
if args.paths_positional:
# Positionals can be directories OR explicit .py files. Either is
# fine — _discover_files handles both via rglob('test_*.py') for
# dirs and direct inclusion for files.
roots = [repo_root / p for p in args.paths_positional]
# --files: explicit file list from the CI generate job — skip discovery.
if args.files:
files = [repo_root / f for f in args.files.split(":") if f.strip()]
roots = []
else:
roots = [repo_root / p for p in args.paths.split(":") if p]
# Resolve discovery roots: positional path args override --paths if any
# were supplied, otherwise --paths (which itself defaults to 'tests').
if args.paths_positional:
roots = [repo_root / p for p in args.paths_positional]
else:
roots = [repo_root / p for p in args.paths.split(":") if p]
if args.include_integration:
# Caller takes responsibility — typically used via explicit -k filter.
global _SKIP_PARTS # noqa: PLW0603 — config knob
_SKIP_PARTS = set()
if args.include_integration:
# Caller takes responsibility — typically used via explicit -k filter.
global _SKIP_PARTS # noqa: PLW0603 — config knob
_SKIP_PARTS = set()
files = _discover_files(roots)
files = _discover_files(roots)
if not files:
print(f"No test files discovered under {[str(r) for r in roots]}", file=sys.stderr)
print(f"No test files to run", file=sys.stderr)
return 1
# --generate-slices: compute LPT distribution and emit JSON, then exit.
if args.generate_slices is not None:
durations = _load_durations(repo_root)
slices = _compute_lpt_slices(
files, args.generate_slices, durations, repo_root
)
matrix = {
"slice": [
{
"index": i + 1,
"files": ":".join(_format_file(f, repo_root) for f in bucket),
}
for i, bucket in enumerate(slices)
]
}
# Print to stdout so the CI step can capture it with $().
print(json.dumps(matrix))
return 0
# Count individual tests per file
test_counts = _approximately_count_tests(files, repo_root)
approx_total_tests = sum(test_counts.values())
@ -675,12 +739,19 @@ def main() -> int:
test_counts = {f: test_counts[f] for f in files if f in test_counts}
approx_total_tests = sum(test_counts.values())
print(
f"Discovered {len(files)} test files (~{approx_total_tests} tests) under "
f"{[str(r.relative_to(repo_root)) if r.is_relative_to(repo_root) else str(r) for r in roots]}; "
f"running with -j {args.jobs}",
flush=True,
)
if roots:
roots_str = [str(r.relative_to(repo_root)) if r.is_relative_to(repo_root) else str(r) for r in roots]
print(
f"Discovered {len(files)} test files (~{approx_total_tests} tests) under "
f"{roots_str}; running with -j {args.jobs}",
flush=True,
)
else:
print(
f"Running {len(files)} test files (~{approx_total_tests} tests) "
f"with -j {args.jobs}",
flush=True,
)
# Capture and print on completion (out-of-order is fine — keeps the
# terminal clean rather than interleaving N parallel pytest outputs).