diff --git a/batch_runner_threaded.py b/batch_runner_threaded.py new file mode 100644 index 0000000000..c193e6579f --- /dev/null +++ b/batch_runner_threaded.py @@ -0,0 +1,1228 @@ +#!/usr/bin/env python3 +""" +Batch Agent Runner + +This module provides parallel batch processing capabilities for running the agent +across multiple prompts from a dataset. It includes: +- Dataset loading and batching +- Parallel batch processing with threading (shared slot pool for sandbox backends) +- Checkpointing for fault tolerance and resumption +- Trajectory saving in the proper format (from/value pairs) +- Tool usage statistics aggregation across all batches + +Threaded variant: uses ThreadPoolExecutor instead of multiprocessing.Pool so all +workers share a single _SlotPoolManager for N:M sandbox multiplexing. + +When TERMINAL_ENV=slot_pool, all threads share the same sandbox pool โ€” one set of +sandboxes serves all concurrent prompts via slot-based multiplexing. + +Usage: + # With slot pool (shared sandbox backend): + TERMINAL_ENV=slot_pool TERMINAL_SLOT_BACKEND=modal \ + python batch_runner_threaded.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume an interrupted run + python batch_runner_threaded.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use a specific toolset distribution + python batch_runner_threaded.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen +""" + +import json +import logging +import os +import time +from pathlib import Path +from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from concurrent.futures import ThreadPoolExecutor, as_completed +import threading +import traceback + +from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn +from rich.console import Console +import fire + +from run_agent import AIAgent +from toolset_distributions import ( + get_distribution, + list_distributions, + sample_toolsets_from_distribution, + validate_distribution +) +from model_tools import TOOL_TO_TOOLSET_MAP + + +# Global configuration for worker processes +_WORKER_CONFIG = {} + +# All possible tools - auto-derived from the master mapping in model_tools.py. +# This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP. +# Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for +# filtering corrupted entries during trajectory combination. +ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys()) + +# Default stats for tools that weren't used +DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0} + + +def _normalize_tool_stats(tool_stats: Dict[str, Dict[str, int]]) -> Dict[str, Dict[str, int]]: + """ + Normalize tool_stats to include all possible tools with consistent schema. + + This ensures HuggingFace datasets can load the JSONL without schema mismatch errors. + Tools that weren't used get zero counts. + + Args: + tool_stats (Dict): Raw tool statistics from extraction + + Returns: + Dict: Normalized tool statistics with all tools present + """ + normalized = {} + + # Add all possible tools with defaults + for tool in ALL_POSSIBLE_TOOLS: + if tool in tool_stats: + normalized[tool] = tool_stats[tool].copy() + else: + normalized[tool] = DEFAULT_TOOL_STATS.copy() + + # Also include any unexpected tools (in case new tools are added) + for tool, stats in tool_stats.items(): + if tool not in normalized: + normalized[tool] = stats.copy() + + return normalized + + +def _normalize_tool_error_counts(tool_error_counts: Dict[str, int]) -> Dict[str, int]: + """ + Normalize tool_error_counts to include all possible tools. + + Args: + tool_error_counts (Dict): Raw error counts mapping + + Returns: + Dict: Normalized error counts with all tools present + """ + normalized = {} + + # Add all possible tools with zero defaults + for tool in ALL_POSSIBLE_TOOLS: + normalized[tool] = tool_error_counts.get(tool, 0) + + # Also include any unexpected tools + for tool, count in tool_error_counts.items(): + if tool not in normalized: + normalized[tool] = count + + return normalized + + +def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: + """ + Extract tool usage statistics from message history. + + Args: + messages (List[Dict]): Message history + + Returns: + Dict: Tool statistics with counts and success/failure rates + """ + tool_stats = {} + + # Track tool calls and their results + tool_calls_map = {} # Map tool_call_id to tool name + + for msg in messages: + # Track tool calls from assistant messages + if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + for tool_call in msg["tool_calls"]: + tool_name = tool_call["function"]["name"] + tool_call_id = tool_call["id"] + + # Initialize stats for this tool if not exists + if tool_name not in tool_stats: + tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + tool_stats[tool_name]["count"] += 1 + tool_calls_map[tool_call_id] = tool_name + + # Track tool responses + elif msg["role"] == "tool": + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + # Determine if tool call was successful + is_success = True + try: + # Try to parse as JSON and check for actual error values + content_json = json.loads(content) if isinstance(content, str) else content + + if isinstance(content_json, dict): + # Check if error field exists AND has a non-null value + if "error" in content_json and content_json["error"] is not None: + is_success = False + + # Special handling for terminal tool responses + # Terminal wraps its response in a "content" field + if "content" in content_json and isinstance(content_json["content"], dict): + inner_content = content_json["content"] + # Check for actual error (non-null error field) + # Note: non-zero exit codes are not failures - the model can self-correct + if inner_content.get("error") is not None: + is_success = False + + # Check for "success": false pattern used by some tools + if content_json.get("success") is False: + is_success = False + + except: + # If not JSON, check if content is empty or explicitly states an error + # Note: We avoid simple substring matching to prevent false positives + if not content: + is_success = False + # Only mark as failure if it explicitly starts with "Error:" or "ERROR:" + elif content.strip().lower().startswith("error:"): + is_success = False + + # Update success/failure count + if tool_call_id in tool_calls_map: + tool_name = tool_calls_map[tool_call_id] + if is_success: + tool_stats[tool_name]["success"] += 1 + else: + tool_stats[tool_name]["failure"] += 1 + + return tool_stats + + +def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]: + """ + Count how many assistant turns have reasoning vs no reasoning. + + Checks for in content or a non-empty 'reasoning' field + (native thinking tokens). Returns counts for tracking reasoning coverage. + + Args: + messages: Message history + + Returns: + Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning' + """ + total = 0 + with_reasoning = 0 + + for msg in messages: + if msg.get("role") != "assistant": + continue + total += 1 + + content = msg.get("content", "") or "" + has_scratchpad = "" in content + has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False + + if has_scratchpad or has_native_reasoning: + with_reasoning += 1 + + return { + "total_assistant_turns": total, + "turns_with_reasoning": with_reasoning, + "turns_without_reasoning": total - with_reasoning, + "has_any_reasoning": with_reasoning > 0, + } + + +def _process_single_prompt( + prompt_index: int, + prompt_data: Dict[str, Any], + batch_num: int, + config: Dict[str, Any] +) -> Dict[str, Any]: + """ + Process a single prompt with the agent. + + Args: + prompt_index (int): Index of prompt in dataset + prompt_data (Dict): Prompt data containing 'prompt' field + batch_num (int): Batch number + config (Dict): Configuration dict with agent parameters + + Returns: + Dict: Result containing trajectory, stats, and metadata + """ + prompt = prompt_data["prompt"] + + try: + # Sample toolsets from distribution for this prompt + selected_toolsets = sample_toolsets_from_distribution(config["distribution"]) + + if config.get("verbose"): + print(f" Prompt {prompt_index}: Using toolsets {selected_toolsets}") + + # Initialize agent with sampled toolsets and log prefix for identification + log_prefix = f"[B{batch_num}:P{prompt_index}]" + agent = AIAgent( + base_url=config.get("base_url"), + api_key=config.get("api_key"), + model=config["model"], + max_iterations=config["max_iterations"], + enabled_toolsets=selected_toolsets, + save_trajectories=False, # We handle saving ourselves + verbose_logging=config.get("verbose", False), + ephemeral_system_prompt=config.get("ephemeral_system_prompt"), + log_prefix_chars=config.get("log_prefix_chars", 100), + log_prefix=log_prefix, + providers_allowed=config.get("providers_allowed"), + providers_ignored=config.get("providers_ignored"), + providers_order=config.get("providers_order"), + provider_sort=config.get("provider_sort"), + max_tokens=config.get("max_tokens"), + reasoning_config=config.get("reasoning_config"), + prefill_messages=config.get("prefill_messages"), + ) + + # Run the agent with task_id to ensure each task gets its own isolated VM + result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}") + + # Extract tool usage statistics + tool_stats = _extract_tool_stats(result["messages"]) + + # Extract reasoning coverage stats + reasoning_stats = _extract_reasoning_stats(result["messages"]) + + # Convert to trajectory format (using existing method) + trajectory = agent._convert_to_trajectory_format( + result["messages"], + prompt, + result["completed"] + ) + + return { + "success": True, + "prompt_index": prompt_index, + "trajectory": trajectory, + "tool_stats": tool_stats, + "reasoning_stats": reasoning_stats, + "completed": result["completed"], + "partial": result.get("partial", False), + "api_calls": result["api_calls"], + "toolsets_used": selected_toolsets, + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat(), + "model": config["model"] + } + } + + except Exception as e: + print(f"โŒ Error processing prompt {prompt_index}: {e}") + if config.get("verbose"): + traceback.print_exc() + + return { + "success": False, + "prompt_index": prompt_index, + "error": str(e), + "trajectory": None, + "tool_stats": {}, + "toolsets_used": [], + "metadata": { + "batch_num": batch_num, + "timestamp": datetime.now().isoformat() + } + } + + +def _process_batch_worker(args: Tuple) -> Dict[str, Any]: + """ + Worker function to process a single batch of prompts. + + Args: + args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) + + Returns: + Dict: Batch results with statistics + """ + batch_num, batch_data, output_dir, completed_prompts_set, config = args + + output_dir = Path(output_dir) + print(f"\n๐Ÿ”„ Batch {batch_num}: Starting ({len(batch_data)} prompts)") + + # Output file for this batch + batch_output_file = output_dir / f"batch_{batch_num}.jsonl" + + # Filter out already completed prompts + prompts_to_process = [ + (idx, data) for idx, data in batch_data + if idx not in completed_prompts_set + ] + + if not prompts_to_process: + print(f"โœ… Batch {batch_num}: Already completed (skipping)") + return { + "batch_num": batch_num, + "processed": 0, + "skipped": len(batch_data), + "tool_stats": {}, + "completed_prompts": [] + } + + print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") + + # Initialize aggregated stats for this batch + batch_tool_stats = {} + batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} + completed_in_batch = [] + discarded_no_reasoning = 0 + + # Process each prompt sequentially in this batch + for prompt_index, prompt_data in prompts_to_process: + # Process the prompt + result = _process_single_prompt( + prompt_index, + prompt_data, + batch_num, + config + ) + + # Save trajectory if successful + if result["success"] and result["trajectory"]: + # Discard samples with zero reasoning across all turns + reasoning = result.get("reasoning_stats", {}) + if not reasoning.get("has_any_reasoning", True): + print(f" ๐Ÿšซ Prompt {prompt_index} discarded (no reasoning in any turn)") + discarded_no_reasoning += 1 + continue + + # Get and normalize tool stats for consistent schema across all entries + raw_tool_stats = result.get("tool_stats", {}) + tool_stats = _normalize_tool_stats(raw_tool_stats) + + # Create normalized tool_error_counts mapping tool names to their failure counts + raw_error_counts = { + tool_name: stats.get("failure", 0) + for tool_name, stats in raw_tool_stats.items() + } + tool_error_counts = _normalize_tool_error_counts(raw_error_counts) + + trajectory_entry = { + "prompt_index": prompt_index, + "conversations": result["trajectory"], + "metadata": result["metadata"], + "completed": result["completed"], + "partial": result.get("partial", False), # True if stopped due to invalid tool calls + "api_calls": result["api_calls"], + "toolsets_used": result["toolsets_used"], + "tool_stats": tool_stats, # Full stats: {tool: {count, success, failure}} - normalized + "tool_error_counts": tool_error_counts # Simple: {tool: failure_count} - normalized + } + + # Append to batch output file + with open(batch_output_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") + + # Aggregate tool statistics + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in batch_tool_stats: + batch_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + batch_tool_stats[tool_name]["count"] += stats["count"] + batch_tool_stats[tool_name]["success"] += stats["success"] + batch_tool_stats[tool_name]["failure"] += stats["failure"] + + # Aggregate reasoning stats + for key in batch_reasoning_stats: + batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0) + + # Only mark as completed if successfully saved (failed prompts can be retried on resume) + if result["success"] and result["trajectory"]: + completed_in_batch.append(prompt_index) + status = "โš ๏ธ partial" if result.get("partial") else "โœ…" + print(f" {status} Prompt {prompt_index} completed") + else: + print(f" โŒ Prompt {prompt_index} failed (will retry on resume)") + + print(f"โœ… Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") + + return { + "batch_num": batch_num, + "processed": len(prompts_to_process), + "skipped": len(batch_data) - len(prompts_to_process), + "tool_stats": batch_tool_stats, + "reasoning_stats": batch_reasoning_stats, + "discarded_no_reasoning": discarded_no_reasoning, + "completed_prompts": completed_in_batch + } + + +class BatchRunner: + """ + Manages batch processing of agent prompts with checkpointing and statistics. + """ + + def __init__( + self, + dataset_file: str, + batch_size: int, + run_name: str, + distribution: str = "default", + max_iterations: int = 10, + base_url: str = None, + api_key: str = None, + model: str = "claude-opus-4-20250514", + num_workers: int = 4, + verbose: bool = False, + ephemeral_system_prompt: str = None, + log_prefix_chars: int = 100, + providers_allowed: List[str] = None, + providers_ignored: List[str] = None, + providers_order: List[str] = None, + provider_sort: str = None, + max_tokens: int = None, + reasoning_config: Dict[str, Any] = None, + prefill_messages: List[Dict[str, Any]] = None, + max_samples: int = None, + ): + """ + Initialize the batch runner. + + Args: + dataset_file (str): Path to the dataset JSONL file with 'prompt' field + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for checkpointing and output) + distribution (str): Toolset distribution to use (default: "default") + max_iterations (int): Max iterations per agent run + base_url (str): Base URL for model API + api_key (str): API key for model + model (str): Model name to use + num_workers (int): Number of parallel workers + verbose (bool): Enable verbose logging + ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + providers_allowed (List[str]): OpenRouter providers to allow (optional) + providers_ignored (List[str]): OpenRouter providers to ignore (optional) + providers_order (List[str]): OpenRouter providers to try in order (optional) + provider_sort (str): Sort providers by price/throughput/latency (optional) + max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) + reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking) + prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming) + max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) + """ + self.dataset_file = Path(dataset_file) + self.batch_size = batch_size + self.run_name = run_name + self.distribution = distribution + self.max_iterations = max_iterations + self.base_url = base_url + self.api_key = api_key + self.model = model + self.num_workers = num_workers + self.verbose = verbose + self.ephemeral_system_prompt = ephemeral_system_prompt + self.log_prefix_chars = log_prefix_chars + self.providers_allowed = providers_allowed + self.providers_ignored = providers_ignored + self.providers_order = providers_order + self.provider_sort = provider_sort + self.max_tokens = max_tokens + self.reasoning_config = reasoning_config + self.prefill_messages = prefill_messages + self.max_samples = max_samples + + # Validate distribution + if not validate_distribution(distribution): + raise ValueError(f"Unknown distribution: {distribution}. Available: {list(list_distributions().keys())}") + + # Setup output directory + self.output_dir = Path("data") / run_name + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Checkpoint file + self.checkpoint_file = self.output_dir / "checkpoint.json" + + # Statistics file + self.stats_file = self.output_dir / "statistics.json" + + # Load dataset (and optionally truncate to max_samples) + self.dataset = self._load_dataset() + if self.max_samples and self.max_samples < len(self.dataset): + full_count = len(self.dataset) + self.dataset = self.dataset[:self.max_samples] + print(f"โœ‚๏ธ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)") + + # Create batches + self.batches = self._create_batches() + + print(f"๐Ÿ“Š Batch Runner Initialized") + print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") + print(f" Batch size: {self.batch_size}") + print(f" Total batches: {len(self.batches)}") + print(f" Run name: {self.run_name}") + print(f" Distribution: {self.distribution}") + print(f" Output directory: {self.output_dir}") + print(f" Workers: {self.num_workers}") + if self.ephemeral_system_prompt: + prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt + print(f" ๐Ÿ”’ Ephemeral system prompt: '{prompt_preview}'") + + def _load_dataset(self) -> List[Dict[str, Any]]: + """ + Load dataset from JSONL file. + + Returns: + List[Dict]: List of dataset entries + """ + if not self.dataset_file.exists(): + raise FileNotFoundError(f"Dataset file not found: {self.dataset_file}") + + dataset = [] + with open(self.dataset_file, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + + try: + entry = json.loads(line) + if 'prompt' not in entry: + print(f"โš ๏ธ Warning: Line {line_num} missing 'prompt' field, skipping") + continue + dataset.append(entry) + except json.JSONDecodeError as e: + print(f"โš ๏ธ Warning: Invalid JSON on line {line_num}: {e}") + continue + + if not dataset: + raise ValueError(f"No valid entries found in dataset file: {self.dataset_file}") + + return dataset + + def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: + """ + Split dataset into batches with indices. + + Returns: + List of batches, where each batch is a list of (index, entry) tuples + """ + batches = [] + for i in range(0, len(self.dataset), self.batch_size): + batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] + batches.append(batch) + + return batches + + def _load_checkpoint(self) -> Dict[str, Any]: + """ + Load checkpoint data if it exists. + + Returns: + Dict: Checkpoint data with completed prompt indices + """ + if not self.checkpoint_file.exists(): + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + try: + with open(self.checkpoint_file, 'r', encoding='utf-8') as f: + return json.load(f) + except Exception as e: + print(f"โš ๏ธ Warning: Failed to load checkpoint: {e}") + return { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[threading.Lock] = None): + """ + Save checkpoint data. + + Args: + checkpoint_data (Dict): Checkpoint data to save + lock (Lock): Optional lock for thread-safe access + """ + checkpoint_data["last_updated"] = datetime.now().isoformat() + + if lock: + with lock: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + else: + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + + def _scan_completed_prompts_by_content(self) -> set: + """ + Scan all batch files and extract completed prompts by their actual content. + + This provides a more robust resume mechanism that matches on prompt text + rather than indices, allowing recovery even if indices don't match. + + Returns: + set: Set of prompt texts that have been successfully processed + """ + completed_prompts = set() + batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) + + if not batch_files: + return completed_prompts + + print(f"๐Ÿ“‚ Scanning {len(batch_files)} batch files for completed prompts...") + + for batch_file in batch_files: + try: + with open(batch_file, 'r', encoding='utf-8') as f: + for line in f: + try: + entry = json.loads(line.strip()) + + # Skip failed entries - we want to retry these + if entry.get("failed", False): + continue + + # Extract the human/user prompt from conversations + conversations = entry.get("conversations", []) + for msg in conversations: + if msg.get("from") == "human": + prompt_text = msg.get("value", "").strip() + if prompt_text: + completed_prompts.add(prompt_text) + break # Only need the first human message + except json.JSONDecodeError: + continue + except Exception as e: + print(f" โš ๏ธ Warning: Error reading {batch_file.name}: {e}") + + return completed_prompts + + def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]: + """ + Filter the dataset to exclude prompts that have already been completed. + + Args: + completed_prompts: Set of prompt texts that have been completed + + Returns: + Tuple of (filtered_dataset, skipped_indices) + """ + filtered_dataset = [] + skipped_indices = [] + + for idx, entry in enumerate(self.dataset): + # Extract prompt from the dataset entry + prompt_text = entry.get("prompt", "").strip() + + # Also check conversations format + if not prompt_text: + conversations = entry.get("conversations", []) + for msg in conversations: + role = msg.get("role") or msg.get("from") + if role in ("user", "human"): + prompt_text = (msg.get("content") or msg.get("value", "")).strip() + break + + if prompt_text in completed_prompts: + skipped_indices.append(idx) + else: + # Keep original index for tracking + filtered_dataset.append((idx, entry)) + + return filtered_dataset, skipped_indices + + def run(self, resume: bool = False): + """ + Run the batch processing pipeline. + + Args: + resume (bool): Whether to resume from checkpoint + """ + print("\n" + "=" * 70) + print("๐Ÿš€ Starting Batch Processing") + print("=" * 70) + + # Smart resume: scan batch files by content to find completed prompts + completed_prompt_texts = set() + if resume: + completed_prompt_texts = self._scan_completed_prompts_by_content() + if completed_prompt_texts: + print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching") + + # Filter dataset to only include unprocessed prompts + if resume and completed_prompt_texts: + filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts) + + if not filtered_entries: + print("\nโœ… All prompts have already been processed!") + return + + # Recreate batches from filtered entries (keeping original indices for tracking) + batches_to_process = [] + for i in range(0, len(filtered_entries), self.batch_size): + batch = filtered_entries[i:i + self.batch_size] + batches_to_process.append(batch) + + self.batches = batches_to_process + + # Print prominent resume summary + print("\n" + "=" * 70) + print("๐Ÿ“Š RESUME SUMMARY") + print("=" * 70) + print(f" Original dataset size: {len(self.dataset):,} prompts") + print(f" Already completed: {len(skipped_indices):,} prompts") + print(f" โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€") + print(f" ๐ŸŽฏ RESUMING WITH: {len(filtered_entries):,} prompts") + print(f" New batches created: {len(batches_to_process)}") + print("=" * 70 + "\n") + + # Initialize checkpoint data (needed for saving at the end) + checkpoint_data = { + "run_name": self.run_name, + "completed_prompts": [], + "batch_stats": {}, + "last_updated": None + } + + # Prepare configuration for workers + config = { + "distribution": self.distribution, + "model": self.model, + "max_iterations": self.max_iterations, + "base_url": self.base_url, + "api_key": self.api_key, + "verbose": self.verbose, + "ephemeral_system_prompt": self.ephemeral_system_prompt, + "log_prefix_chars": self.log_prefix_chars, + "providers_allowed": self.providers_allowed, + "providers_ignored": self.providers_ignored, + "providers_order": self.providers_order, + "provider_sort": self.provider_sort, + "max_tokens": self.max_tokens, + "reasoning_config": self.reasoning_config, + "prefill_messages": self.prefill_messages, + } + + # For backward compatibility, still track by index (but this is secondary to content matching) + completed_prompts_set = set() + + # Aggregate statistics across all batches + total_tool_stats = {} + + start_time = time.time() + + print(f"\n๐Ÿงต Initializing {self.num_workers} worker threads...") + print(f" (Threads share sandbox pool โ€” all workers use same _SlotPoolManager)") + + # Create tasks for each batch + tasks = [ + ( + batch_num, + batch_data, + self.output_dir, # No pickling needed โ€” threads share memory + completed_prompts_set, + config + ) + for batch_num, batch_data in enumerate(self.batches) + ] + + print(f"โœ… Created {len(tasks)} batch tasks") + print(f"๐Ÿš€ Starting threaded batch processing...\n") + + # Process batches in parallel using threads (shared _SlotPoolManager) + results = [] + console = Console(force_terminal=True) + with Progress( + SpinnerColumn(), + TextColumn("[bold blue]๐Ÿ“ฆ Batches"), + BarColumn(bar_width=40), + MofNCompleteColumn(), + TextColumn("โ€ข"), + TimeRemainingColumn(), + console=console, + refresh_per_second=2, + transient=False, + redirect_stdout=False, + redirect_stderr=False, + ) as progress: + progress_task = progress.add_task("Processing", total=len(tasks)) + + # Temporarily suppress DEBUG logging to avoid bar interference + root_logger = logging.getLogger() + original_level = root_logger.level + root_logger.setLevel(logging.WARNING) + + try: + with ThreadPoolExecutor(max_workers=self.num_workers) as executor: + futures = { + executor.submit(_process_batch_worker, task_args): task_args[0] + for task_args in tasks + } + + for future in as_completed(futures): + batch_num = futures[future] + try: + result = future.result() + results.append(result) + except Exception as e: + print(f"โŒ Batch {batch_num} failed: {e}") + results.append({ + "batch_num": batch_num, + "processed": 0, + "skipped": 0, + "tool_stats": {}, + "completed_prompts": [], + }) + progress.update(progress_task, advance=1) + finally: + root_logger.setLevel(original_level) + + # Aggregate all batch statistics and update checkpoint + all_completed_prompts = list(completed_prompts_set) + total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0} + + for batch_result in results: + # Add newly completed prompts + all_completed_prompts.extend(batch_result.get("completed_prompts", [])) + + # Aggregate tool stats + for tool_name, stats in batch_result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Aggregate reasoning stats + for key in total_reasoning_stats: + total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0) + + # Save final checkpoint + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data) + + # Calculate success rates + for tool_name in total_tool_stats: + stats = total_tool_stats[tool_name] + total_calls = stats["success"] + stats["failure"] + if total_calls > 0: + stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) + stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) + else: + stats["success_rate"] = 0.0 + stats["failure_rate"] = 0.0 + + # Combine ALL batch files in directory into a single trajectories.jsonl file + # This includes both old batches (from previous runs) and new batches (from resume) + # Also filter out corrupted entries (where model generated invalid tool names) + combined_file = self.output_dir / "trajectories.jsonl" + print(f"\n๐Ÿ“ฆ Combining ALL batch files into {combined_file.name}...") + + # Valid tools auto-derived from model_tools.py โ€” no manual updates needed + VALID_TOOLS = ALL_POSSIBLE_TOOLS + + total_entries = 0 + filtered_entries = 0 + batch_files_found = 0 + + # Find ALL batch files in the output directory (handles resume merging old + new) + all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl")) + + with open(combined_file, 'w', encoding='utf-8') as outfile: + for batch_file in all_batch_files: + batch_files_found += 1 + batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging + + with open(batch_file, 'r', encoding='utf-8') as infile: + for line in infile: + total_entries += 1 + try: + data = json.loads(line) + tool_stats = data.get('tool_stats', {}) + + # Check for invalid tool names (model hallucinations) + invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS] + + if invalid_tools: + filtered_entries += 1 + invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0] + print(f" โš ๏ธ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'") + continue + + outfile.write(line) + except json.JSONDecodeError: + filtered_entries += 1 + print(f" โš ๏ธ Filtering invalid JSON entry (batch {batch_num})") + + if filtered_entries > 0: + print(f"โš ๏ธ Filtered {filtered_entries} corrupted entries out of {total_entries} total") + print(f"โœ… Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)") + + # Save final statistics + final_stats = { + "run_name": self.run_name, + "distribution": self.distribution, + "total_prompts": len(self.dataset), + "total_batches": len(self.batches), + "batch_size": self.batch_size, + "model": self.model, + "completed_at": datetime.now().isoformat(), + "duration_seconds": round(time.time() - start_time, 2), + "tool_statistics": total_tool_stats, + "reasoning_statistics": total_reasoning_stats, + } + + with open(self.stats_file, 'w', encoding='utf-8') as f: + json.dump(final_stats, f, indent=2, ensure_ascii=False) + + # Print summary + print("\n" + "=" * 70) + print("๐Ÿ“Š BATCH PROCESSING COMPLETE") + print("=" * 70) + print(f"โœ… Prompts processed this run: {sum(r.get('processed', 0) for r in results)}") + print(f"โœ… Total trajectories in merged file: {total_entries - filtered_entries}") + print(f"โœ… Total batch files merged: {batch_files_found}") + print(f"โฑ๏ธ Total duration: {round(time.time() - start_time, 2)}s") + print(f"\n๐Ÿ“ˆ Tool Usage Statistics:") + print("-" * 70) + + if total_tool_stats: + # Sort by count descending + sorted_tools = sorted( + total_tool_stats.items(), + key=lambda x: x[1]["count"], + reverse=True + ) + + print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") + print("-" * 70) + for tool_name, stats in sorted_tools: + print( + f"{tool_name:<25} " + f"{stats['count']:<10} " + f"{stats['success']:<10} " + f"{stats['failure']:<10} " + f"{stats['success_rate']:.1f}%" + ) + else: + print("No tool calls were made during this run.") + + # Print reasoning coverage stats + total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results) + + print(f"\n๐Ÿง  Reasoning Coverage:") + print("-" * 70) + total_turns = total_reasoning_stats["total_assistant_turns"] + with_reasoning = total_reasoning_stats["turns_with_reasoning"] + without_reasoning = total_reasoning_stats["turns_without_reasoning"] + if total_turns > 0: + pct_with = round(with_reasoning / total_turns * 100, 1) + pct_without = round(without_reasoning / total_turns * 100, 1) + print(f" Total assistant turns: {total_turns:,}") + print(f" With reasoning: {with_reasoning:,} ({pct_with}%)") + print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)") + else: + print(" No assistant turns recorded.") + if total_discarded > 0: + print(f" ๐Ÿšซ Samples discarded (zero reasoning): {total_discarded:,}") + + print(f"\n๐Ÿ’พ Results saved to: {self.output_dir}") + print(f" - Trajectories: trajectories.jsonl (combined)") + print(f" - Individual batches: batch_*.jsonl (for debugging)") + print(f" - Statistics: {self.stats_file.name}") + print(f" - Checkpoint: {self.checkpoint_file.name}") + + +def main( + dataset_file: str = None, + batch_size: int = None, + run_name: str = None, + distribution: str = "default", + model: str = "anthropic/claude-sonnet-4-20250514", + api_key: str = None, + base_url: str = "https://openrouter.ai/api/v1", + max_turns: int = 10, + num_workers: int = 4, + resume: bool = False, + verbose: bool = False, + list_distributions: bool = False, + ephemeral_system_prompt: str = None, + log_prefix_chars: int = 100, + providers_allowed: str = None, + providers_ignored: str = None, + providers_order: str = None, + provider_sort: str = None, + max_tokens: int = None, + reasoning_effort: str = None, + reasoning_disabled: bool = False, + prefill_messages_file: str = None, + max_samples: int = None, +): + """ + Run batch processing of agent prompts from a dataset. + + Args: + dataset_file (str): Path to JSONL file with 'prompt' field in each entry + batch_size (int): Number of prompts per batch + run_name (str): Name for this run (used for output and checkpointing) + distribution (str): Toolset distribution to use (default: "default") + model (str): Model name to use (default: "claude-opus-4-20250514") + api_key (str): API key for model authentication + base_url (str): Base URL for model API + max_turns (int): Maximum number of tool calling iterations per prompt (default: 10) + num_workers (int): Number of parallel worker processes (default: 4) + resume (bool): Resume from checkpoint if run was interrupted (default: False) + verbose (bool): Enable verbose logging (default: False) + list_distributions (bool): List available toolset distributions and exit + ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) + log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + providers_allowed (str): Comma-separated list of OpenRouter providers to allow (e.g. "anthropic,openai") + providers_ignored (str): Comma-separated list of OpenRouter providers to ignore (e.g. "together,deepinfra") + providers_order (str): Comma-separated list of OpenRouter providers to try in order (e.g. "anthropic,openai,google") + provider_sort (str): Sort providers by "price", "throughput", or "latency" (OpenRouter only) + max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set) + reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh") + reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False) + prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts) + max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set) + + Examples: + # Basic usage + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + + # Resume interrupted run + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + + # Use specific distribution + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen + + # With disabled reasoning and max tokens + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --reasoning_disabled --max_tokens=128000 + + # With prefill messages from file + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --prefill_messages_file=configs/prefill_opus.json + + # List available distributions + python batch_runner.py --list_distributions + """ + # Handle list distributions + if list_distributions: + from toolset_distributions import list_distributions as get_all_dists, print_distribution_info + + print("๐Ÿ“Š Available Toolset Distributions") + print("=" * 70) + + all_dists = get_all_dists() + for dist_name in sorted(all_dists.keys()): + print_distribution_info(dist_name) + + print("\n๐Ÿ’ก Usage:") + print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") + print(" --run_name=my_run --distribution=") + return + + # Validate required arguments + if not dataset_file: + print("โŒ Error: --dataset_file is required") + return + + if not batch_size or batch_size < 1: + print("โŒ Error: --batch_size must be a positive integer") + return + + if not run_name: + print("โŒ Error: --run_name is required") + return + + # Parse provider preferences (comma-separated strings to lists) + providers_allowed_list = [p.strip() for p in providers_allowed.split(",")] if providers_allowed else None + providers_ignored_list = [p.strip() for p in providers_ignored.split(",")] if providers_ignored else None + providers_order_list = [p.strip() for p in providers_order.split(",")] if providers_order else None + + # Build reasoning_config from CLI flags + # --reasoning_disabled takes priority, then --reasoning_effort, then default (xhigh) + reasoning_config = None + if reasoning_disabled: + # Completely disable reasoning/thinking tokens + reasoning_config = {"effort": "none"} + print("๐Ÿง  Reasoning: DISABLED (effort=none)") + elif reasoning_effort: + # Use specified effort level + valid_efforts = ["xhigh", "high", "medium", "low", "minimal", "none"] + if reasoning_effort not in valid_efforts: + print(f"โŒ Error: --reasoning_effort must be one of: {', '.join(valid_efforts)}") + return + reasoning_config = {"enabled": True, "effort": reasoning_effort} + print(f"๐Ÿง  Reasoning effort: {reasoning_effort}") + + # Load prefill messages from JSON file if provided + prefill_messages = None + if prefill_messages_file: + try: + with open(prefill_messages_file, 'r', encoding='utf-8') as f: + prefill_messages = json.load(f) + if not isinstance(prefill_messages, list): + print(f"โŒ Error: prefill_messages_file must contain a JSON array of messages") + return + print(f"๐Ÿ’ฌ Loaded {len(prefill_messages)} prefill messages from {prefill_messages_file}") + except Exception as e: + print(f"โŒ Error loading prefill messages: {e}") + return + + # Initialize and run batch runner + try: + runner = BatchRunner( + dataset_file=dataset_file, + batch_size=batch_size, + run_name=run_name, + distribution=distribution, + max_iterations=max_turns, + base_url=base_url, + api_key=api_key, + model=model, + num_workers=num_workers, + verbose=verbose, + ephemeral_system_prompt=ephemeral_system_prompt, + log_prefix_chars=log_prefix_chars, + providers_allowed=providers_allowed_list, + providers_ignored=providers_ignored_list, + providers_order=providers_order_list, + provider_sort=provider_sort, + max_tokens=max_tokens, + reasoning_config=reasoning_config, + prefill_messages=prefill_messages, + max_samples=max_samples, + ) + + runner.run(resume=resume) + + except Exception as e: + print(f"\nโŒ Fatal error: {e}") + if verbose: + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + fire.Fire(main) +