diff --git a/batch_runner.py b/batch_runner.py index ba3109f9317..30c8cfad3c2 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -4,25 +4,25 @@ Batch Agent Runner This module provides parallel batch processing capabilities for running the agent across multiple prompts from a dataset. It includes: -- Dataset loading and batching -- Parallel batch processing with multiprocessing +- Dataset loading +- Concurrent processing with asyncio (Producer-Consumer pattern) - Checkpointing for fault tolerance and resumption - Trajectory saving in the proper format (from/value pairs) -- Tool usage statistics aggregation across all batches +- Tool usage statistics aggregation across all prompts - Cluster failure detection and graceful shutdown (morph, firecrawl, API errors) - Configurable failure thresholds with automatic data consolidation Usage: - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run # Resume an interrupted run - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume # Use a specific toolset distribution - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --distribution=image_gen # Configure tool failure thresholds - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run \\ --max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 """ @@ -30,10 +30,10 @@ import json import logging import os import time +import asyncio from pathlib import Path -from typing import List, Dict, Any, Optional, Tuple +from typing import List, Dict, Any, Optional, Tuple, Set from datetime import datetime -from multiprocessing import Pool, Manager, Lock import traceback import re @@ -49,9 +49,6 @@ from toolset_distributions import ( from safe_print import safe_print -# Global configuration for worker processes -_WORKER_CONFIG = {} - # Canonical names for the terminal tool (old & new implementations) _TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"} @@ -295,10 +292,9 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i return tool_stats -def _process_single_prompt( +async def _process_single_prompt( prompt_index: int, prompt_data: Dict[str, Any], - batch_num: int, config: Dict[str, Any] ) -> Dict[str, Any]: """ @@ -307,7 +303,6 @@ def _process_single_prompt( Args: prompt_index (int): Index of prompt in dataset prompt_data (Dict): Prompt data containing 'prompt' field - batch_num (int): Batch number config (Dict): Configuration dict with agent parameters Returns: @@ -336,7 +331,7 @@ def _process_single_prompt( ) # Run the agent with task_id to ensure each task gets its own isolated VM - result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}") + result = await agent.run_conversation(prompt, task_id=f"task_{prompt_index}") # Extract tool usage statistics tool_stats = _extract_tool_stats(result["messages"]) @@ -365,7 +360,6 @@ def _process_single_prompt( "api_calls": result["api_calls"], "toolsets_used": selected_toolsets, "metadata": { - "batch_num": batch_num, "timestamp": datetime.now().isoformat(), "model": config["model"] } @@ -389,132 +383,38 @@ def _process_single_prompt( "tool_stats": {}, "toolsets_used": [], "metadata": { - "batch_num": batch_num, "timestamp": datetime.now().isoformat() } } -def _process_batch_worker(args: Tuple) -> Dict[str, Any]: +async def worker( + work_queue: asyncio.Queue, + result_queue: asyncio.Queue, + config: Dict[str, Any] +): """ - Worker function to process a single batch of prompts. - - Args: - args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config) - - Returns: - Dict: Batch results with statistics + Consumer worker that processes prompts from the work queue. """ - batch_num, batch_data, output_dir, completed_prompts_set, config = args - - output_dir = Path(output_dir) - print(f"\nšŸ”„ Batch {batch_num}: Starting ({len(batch_data)} prompts)") - - # Output file for this batch - batch_output_file = output_dir / f"batch_{batch_num}.jsonl" - - # Filter out already completed prompts - prompts_to_process = [ - (idx, data) for idx, data in batch_data - if idx not in completed_prompts_set - ] - - if not prompts_to_process: - print(f"āœ… Batch {batch_num}: Already completed (skipping)") - return { - "batch_num": batch_num, - "processed": 0, - "skipped": len(batch_data), - "tool_stats": {}, - "completed_prompts": [] - } - - print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)") - - # Initialize aggregated stats for this batch - batch_tool_stats = {} - batch_profiling_stats = [] # Collect profiling stats from each prompt - completed_in_batch = [] - all_tool_errors = [] # Track all tool errors in this batch - exception_errors = [] # Track top-level exceptions - - # Process each prompt sequentially in this batch - for prompt_index, prompt_data in prompts_to_process: - # Process the prompt - result = _process_single_prompt( - prompt_index, - prompt_data, - batch_num, - config - ) - - # Track tool errors from the conversation - if result.get("tool_errors"): - for tool_error in result["tool_errors"]: - all_tool_errors.append({ - "prompt_index": prompt_index, - "tool_name": tool_error["tool_name"], - "error_message": tool_error["error_message"], - "full_content": tool_error.get("full_content", ""), - "error_type": tool_error.get("error_type", "Other") - }) - - # Track top-level exceptions (not tool errors) - if not result["success"]: - exception_errors.append({ - "prompt_index": prompt_index, - "error": result.get("error", "Unknown error"), - "traceback": result.get("traceback", "") - }) - safe_print(f"[bold red]āŒ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}") - - # Save trajectory if successful - if result["success"] and result["trajectory"]: - trajectory_entry = { - "prompt_index": prompt_index, - "conversations": result["trajectory"], - "metadata": result["metadata"], - "completed": result["completed"], - "api_calls": result["api_calls"], - "toolsets_used": result["toolsets_used"] - } - - # Append to batch output file - with open(batch_output_file, 'a', encoding='utf-8') as f: - f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") - - # Aggregate tool statistics - for tool_name, stats in result.get("tool_stats", {}).items(): - if tool_name not in batch_tool_stats: - batch_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - - batch_tool_stats[tool_name]["count"] += stats["count"] - batch_tool_stats[tool_name]["success"] += stats["success"] - batch_tool_stats[tool_name]["failure"] += stats["failure"] - - # Collect profiling statistics - if result.get("profiling_stats"): - batch_profiling_stats.append(result["profiling_stats"]) - - completed_in_batch.append(prompt_index) - print(f" āœ… Prompt {prompt_index} completed") - - print(f"āœ… Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") - - return { - "batch_num": batch_num, - "processed": len(prompts_to_process), - "skipped": len(batch_data) - len(prompts_to_process), - "tool_stats": batch_tool_stats, - "profiling_stats": batch_profiling_stats, - "completed_prompts": completed_in_batch, - "tool_errors": all_tool_errors, - "exception_errors": exception_errors - } + while True: + try: + task = await work_queue.get() + if task is None: + # Sentinel to stop worker + work_queue.task_done() + break + + prompt_index, prompt_data = task + + result = await _process_single_prompt(prompt_index, prompt_data, config) + + await result_queue.put(result) + work_queue.task_done() + + except Exception as e: + print(f"Error in worker: {e}") + if 'task' in locals() and task is not None: + work_queue.task_done() class BatchRunner: @@ -525,7 +425,6 @@ class BatchRunner: def __init__( self, dataset_file: str, - batch_size: int, run_name: str, distribution: str = "default", max_iterations: int = 10, @@ -546,14 +445,13 @@ class BatchRunner: Args: dataset_file (str): Path to the dataset JSONL file with 'prompt' field - batch_size (int): Number of prompts per batch run_name (str): Name for this run (used for checkpointing and output) distribution (str): Toolset distribution to use (default: "default") max_iterations (int): Max iterations per agent run base_url (str): Base URL for model API api_key (str): API key for model model (str): Model name to use - num_workers (int): Number of parallel workers + num_workers (int): Number of parallel workers (default: 4) verbose (bool): Enable verbose logging ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) @@ -563,7 +461,6 @@ class BatchRunner: min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10) """ self.dataset_file = Path(dataset_file) - self.batch_size = batch_size self.run_name = run_name self.distribution = distribution self.max_iterations = max_iterations @@ -596,16 +493,14 @@ class BatchRunner: # Errors file self.errors_file = self.output_dir / "errors.json" + # Trajectories file + self.trajectories_file = self.output_dir / "trajectories.jsonl" + # Load dataset self.dataset = self._load_dataset() - # Create batches - self.batches = self._create_batches() - safe_print("[bold cyan]šŸ“Š Batch Runner Initialized[/bold cyan]") safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") - safe_print(f" Batch size: {self.batch_size}") - safe_print(f" Total batches: {len(self.batches)}") safe_print(f" Run name: {self.run_name}") safe_print(f" Distribution: {self.distribution}") safe_print(f" Output directory: {self.output_dir}") @@ -651,20 +546,6 @@ class BatchRunner: return dataset - def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]: - """ - Split dataset into batches with indices. - - Returns: - List of batches, where each batch is a list of (index, entry) tuples - """ - batches = [] - for i in range(0, len(self.dataset), self.batch_size): - batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)] - batches.append(batch) - - return batches - def _load_checkpoint(self) -> Dict[str, Any]: """ Load checkpoint data if it exists. @@ -676,7 +557,6 @@ class BatchRunner: return { "run_name": self.run_name, "completed_prompts": [], - "batch_stats": {}, "last_updated": None } @@ -688,61 +568,34 @@ class BatchRunner: return { "run_name": self.run_name, "completed_prompts": [], - "batch_stats": {}, "last_updated": None } - def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): + def _save_checkpoint(self, checkpoint_data: Dict[str, Any]): """ Save checkpoint data. Args: checkpoint_data (Dict): Checkpoint data to save - lock (Lock): Optional lock for thread-safe access """ checkpoint_data["last_updated"] = datetime.now().isoformat() + with open(self.checkpoint_file, 'w', encoding='utf-8') as f: + json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) - if lock: - with lock: - with open(self.checkpoint_file, 'w', encoding='utf-8') as f: - json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) - else: - with open(self.checkpoint_file, 'w', encoding='utf-8') as f: - json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) - - def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]], - start_time: float, tool_errors_by_tool: Dict[str, List[Dict]], - exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None, - profiling_stats_list: List[Dict] = None): + def _save_final_stats( + self, + num_processed: int, + tool_stats: Dict[str, Dict[str, int]], + start_time: float, + tool_errors_by_tool: Dict[str, List[Dict]], + exception_errors: List[Dict], + early_exit: bool = False, + exit_reason: str = None, + profiling_stats_list: List[Dict] = None + ): """ - Consolidate batch data into trajectories.jsonl and save statistics. - - Args: - num_batches (int): Number of batches processed - tool_stats (Dict): Aggregated tool statistics - start_time (float): Start time of the run - tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent - exception_errors (List): Top-level exceptions - early_exit (bool): Whether this is an early exit - exit_reason (str): Reason for early exit - profiling_stats_list (List[Dict]): List of profiling statistics from each conversation + Save final statistics and errors. """ - # Combine all batch files into a single trajectories.jsonl file - combined_file = self.output_dir / "trajectories.jsonl" - safe_print(f"\n[cyan]šŸ“¦ Combining batch files into {combined_file.name}...[/cyan]") - - entries_written = 0 - with open(combined_file, 'w', encoding='utf-8') as outfile: - for batch_num in range(num_batches): - batch_file = self.output_dir / f"batch_{batch_num}.jsonl" - if batch_file.exists(): - with open(batch_file, 'r', encoding='utf-8') as infile: - for line in infile: - outfile.write(line) - entries_written += 1 - - safe_print(f"[green]āœ… Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]") - # Calculate success rates for tool stats for tool_name in tool_stats: stats = tool_stats[tool_name] @@ -794,17 +647,21 @@ class BatchRunner: # Aggregate profiling statistics if available aggregated_profiling_stats = None if profiling_stats_list: - from profiling import aggregate_profiling_stats - aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list) + try: + from profiling import aggregate_profiling_stats, print_aggregated_statistics + aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list) + + # Display aggregated profiling statistics + print_aggregated_statistics(aggregated_profiling_stats, detailed=True) + except ImportError: + pass - # Save final statistics (without detailed errors) + # Save final statistics final_stats = { "run_name": self.run_name, "distribution": self.distribution, "total_prompts": len(self.dataset), - "total_batches": len(self.batches), - "batches_processed": num_batches, - "batch_size": self.batch_size, + "processed": num_processed, "model": self.model, "completed_at": datetime.now().isoformat(), "duration_seconds": round(time.time() - start_time, 2), @@ -817,18 +674,9 @@ class BatchRunner: with open(self.stats_file, 'w', encoding='utf-8') as f: json.dump(final_stats, f, indent=2, ensure_ascii=False) - # Display aggregated profiling statistics - if aggregated_profiling_stats: - from profiling import print_aggregated_statistics - print_aggregated_statistics(aggregated_profiling_stats, detailed=True) - - - def run(self, resume: bool = False): + async def _run_async(self, resume: bool = False): """ - Run the batch processing pipeline. - - Args: - resume (bool): Whether to resume from checkpoint + Async implementation of the batch runner pipeline. """ print("\n" + "=" * 70) print("šŸš€ Starting Batch Processing") @@ -838,15 +686,32 @@ class BatchRunner: checkpoint_data = self._load_checkpoint() if resume else { "run_name": self.run_name, "completed_prompts": [], - "batch_stats": {}, "last_updated": None } if resume and checkpoint_data.get("completed_prompts"): print(f"šŸ“‚ Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)") - # Prepare configuration for workers - config = { + completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) + + # Prepare queues + work_queue = asyncio.Queue() + result_queue = asyncio.Queue() + + # Enqueue prompts to process + prompts_to_process = [] + for idx, entry in enumerate(self.dataset): + if idx not in completed_prompts_set: + prompts_to_process.append((idx, entry)) + work_queue.put_nowait((idx, entry)) + + total_to_process = len(prompts_to_process) + if total_to_process == 0: + print("āœ… All prompts already completed.") + return + + # Worker configuration + worker_config = { "distribution": self.distribution, "model": self.model, "max_iterations": self.max_iterations, @@ -857,120 +722,119 @@ class BatchRunner: "log_prefix_chars": self.log_prefix_chars } - # Get completed prompts set - completed_prompts_set = set(checkpoint_data.get("completed_prompts", [])) + # Start workers + workers = [] + for _ in range(min(self.num_workers, total_to_process)): + w = asyncio.create_task(worker(work_queue, result_queue, worker_config)) + workers.append(w) + + print(f" Processing {total_to_process} prompts with {len(workers)} workers...") - # Aggregate statistics across all batches + # Aggregate statistics total_tool_stats = {} - all_profiling_stats = [] # Collect all profiling stats for aggregation - tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]} + all_profiling_stats = [] + tool_errors_by_tool = {} all_exception_errors = [] - all_completed_prompts = list(completed_prompts_set) - total_processed = len(completed_prompts_set) total_tool_errors = 0 early_exit = False exit_reason = None - + processed_count = 0 + start_time = time.time() - - # Process batches in parallel - with Pool(processes=self.num_workers) as pool: - # Create tasks for each batch - tasks = [ - ( - batch_num, - batch_data, - str(self.output_dir), # Convert Path to string for pickling - completed_prompts_set, - config - ) - for batch_num, batch_data in enumerate(self.batches) - ] - - # Process batches in parallel and check tool failure threshold as results come in - # imap_unordered allows parallel processing while getting results as they complete - batch_num = 0 - try: - for result in pool.imap_unordered(_process_batch_worker, tasks): - # Update statistics - all_completed_prompts.extend(result.get("completed_prompts", [])) - total_processed += result.get("processed", 0) - - # Aggregate tool stats - for tool_name, stats in result.get("tool_stats", {}).items(): - if tool_name not in total_tool_stats: - total_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - - total_tool_stats[tool_name]["count"] += stats["count"] - total_tool_stats[tool_name]["success"] += stats["success"] - total_tool_stats[tool_name]["failure"] += stats["failure"] - - # Collect profiling stats from this batch - if result.get("profiling_stats"): - all_profiling_stats.extend(result["profiling_stats"]) - - # Aggregate tool errors (keep k most recent per tool) - for tool_error in result.get("tool_errors", []): - tool_name = tool_error["tool_name"] - if tool_name not in tool_errors_by_tool: - tool_errors_by_tool[tool_name] = [] - - # Add error and keep only k most recent - tool_errors_by_tool[tool_name].append(tool_error) - if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors: - tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:] - - total_tool_errors += 1 - - # Track exception errors - all_exception_errors.extend(result.get("exception_errors", [])) - - # Check tool failure thresholds - # Calculate total tool calls (not prompts) - total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values()) - - # Check absolute count threshold - if total_tool_errors >= self.max_tool_failures: + + # Process results as they arrive + try: + while processed_count < total_to_process: + result = await result_queue.get() + processed_count += 1 + + prompt_index = result["prompt_index"] + + # Track exceptions + if not result["success"]: + safe_print(f"[bold red]āŒ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}") + all_exception_errors.append({ + "prompt_index": prompt_index, + "error": result.get("error", "Unknown error"), + "traceback": result.get("traceback", "") + }) + else: + print(f" āœ… Prompt {prompt_index} completed") + + # Save trajectory immediately + if result.get("trajectory"): + trajectory_entry = { + "prompt_index": prompt_index, + "conversations": result["trajectory"], + "metadata": result["metadata"], + "completed": result["completed"], + "api_calls": result["api_calls"], + "toolsets_used": result["toolsets_used"] + } + with open(self.trajectories_file, 'a', encoding='utf-8') as f: + f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") + + # Aggregate tool stats + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0} + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Collect profiling stats + if result.get("profiling_stats"): + all_profiling_stats.append(result["profiling_stats"]) + + # Aggregate tool errors + for tool_error in result.get("tool_errors", []): + tool_name = tool_error["tool_name"] + if tool_name not in tool_errors_by_tool: + tool_errors_by_tool[tool_name] = [] + + tool_errors_by_tool[tool_name].append(tool_error) + # Keep only k most recent + if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors: + tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:] + + total_tool_errors += 1 + + # Update checkpoint + completed_prompts_set.add(prompt_index) + checkpoint_data["completed_prompts"] = list(completed_prompts_set) + self._save_checkpoint(checkpoint_data) + + # Check failure thresholds + total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values()) + + if total_tool_errors >= self.max_tool_failures: + early_exit = True + exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})" + break + + if total_tool_calls >= self.min_tool_calls_for_rate: + tool_failure_rate = total_tool_errors / total_tool_calls + if tool_failure_rate >= self.max_tool_failure_rate: early_exit = True - exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})" - safe_print(f"\n[bold red]šŸ›‘ STOPPING: {exit_reason}[/bold red]") - pool.terminate() # Stop all workers immediately + exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%})" break - - # Check rate threshold (only if we have enough tool calls to trust the rate) - if total_tool_calls >= self.min_tool_calls_for_rate: - tool_failure_rate = total_tool_errors / total_tool_calls - - if tool_failure_rate >= self.max_tool_failure_rate: - early_exit = True - exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)" - safe_print(f"\n[bold red]šŸ›‘ STOPPING: {exit_reason}[/bold red]") - pool.terminate() # Stop all workers immediately - break - - # Update checkpoint after each batch completes - checkpoint_data["completed_prompts"] = all_completed_prompts - self._save_checkpoint(checkpoint_data) - - batch_num += 1 - except KeyboardInterrupt: - safe_print("\n[bold yellow]āš ļø Interrupted by user, stopping workers...[/bold yellow]") - pool.terminate() - early_exit = True - exit_reason = "Interrupted by user" - - # Save final checkpoint - checkpoint_data["completed_prompts"] = all_completed_prompts - self._save_checkpoint(checkpoint_data) - - # Consolidate data and save statistics - num_batches_processed = batch_num + 1 if early_exit else len(self.batches) - self._consolidate_data( - num_batches_processed, + + except asyncio.CancelledError: + early_exit = True + exit_reason = "Run cancelled" + finally: + # Stop all workers + for _ in range(len(workers)): + work_queue.put_nowait(None) + await asyncio.gather(*workers, return_exceptions=True) + + if early_exit: + safe_print(f"\n[bold red]šŸ›‘ STOPPING: {exit_reason}[/bold red]") + + # Save final statistics + self._save_final_stats( + processed_count, total_tool_stats, start_time, tool_errors_by_tool, @@ -980,164 +844,28 @@ class BatchRunner: all_profiling_stats ) - # Print summary + # Summary output safe_print("\n" + "=" * 70) - if early_exit: - safe_print("[bold yellow]āš ļø BATCH PROCESSING STOPPED EARLY[/bold yellow]") - safe_print(f"[yellow]Reason: {exit_reason}[/yellow]") - else: - safe_print("[bold green]šŸ“Š BATCH PROCESSING COMPLETE[/bold green]") - safe_print("=" * 70) - - safe_print(f"āœ… Total prompts processed: {total_processed}") - safe_print(f"āœ… Batches completed: {num_batches_processed}/{len(self.batches)}") + safe_print(f"āœ… Total prompts processed: {processed_count}/{total_to_process}") safe_print(f"ā±ļø Total duration: {round(time.time() - start_time, 2)}s") - - # Tool error summary + if tool_errors_by_tool: - total_errors = sum(len(errors) for errors in tool_errors_by_tool.values()) - safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]") - safe_print("[red]-[/red]" * 70) - - # Sort tools by error count - sorted_tools = sorted( - tool_errors_by_tool.items(), - key=lambda x: len(x[1]), - reverse=True - ) - - for tool_name, errors in sorted_tools: - # Count unique error messages - unique_errors = {} - for error in errors: - error_msg = error["error_message"][:100] # Truncate for grouping - if error_msg not in unique_errors: - unique_errors[error_msg] = [] - unique_errors[error_msg].append(error) - - safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)") - - # Show up to 3 most recent unique error types - for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]): - error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..." - safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})") - - # Show one example with prompt index and full content prefix - example = instances[-1] # Most recent - safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]") - - # Show full content prefix (first 200 chars) - full_content = example.get('full_content', '') - if full_content and full_content != error_preview: - content_preview = full_content[:200] - if len(full_content) > 200: - content_preview += "..." - # Show with prefix indicator - safe_print(f" [dim]Content: {content_preview}[/dim]") - - if len(unique_errors) > 3: - safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]") - - tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0 - safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]") - - # Exception errors - if all_exception_errors: - safe_print(f"\n[bold red]šŸ’„ Top-level Exceptions: {len(all_exception_errors)}[/bold red]") - safe_print("[red]-[/red]" * 70) - for error in all_exception_errors[:self.keep_recent_errors]: - error_msg = error["error"] - error_preview = error_msg[:150] - if len(error_msg) > 150: - error_preview += "..." - safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]") - - # Show traceback prefix if available - traceback_text = error.get("traceback", "") - if traceback_text: - # Show last 3 lines of traceback for context - tb_lines = traceback_text.strip().split('\n') - relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines - for line in relevant_lines: - safe_print(f" [dim]{line}[/dim]") - - safe_print(f"\n[cyan]šŸ“ˆ Tool Usage Statistics:[/cyan]") - safe_print("-" * 70) - - if total_tool_stats: - # Sort by count descending - sorted_tools = sorted( - total_tool_stats.items(), - key=lambda x: x[1]["count"], - reverse=True - ) - - safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") - safe_print("-" * 70) - for tool_name, stats in sorted_tools: - safe_print( - f"{tool_name:<25} " - f"{stats['count']:<10} " - f"{stats['success']:<10} " - f"{stats['failure']:<10} " - f"{stats.get('success_rate', 0):.1f}%" - ) - else: - safe_print("No tool calls were made during this run.") - - # Display failure type breakdown for tools with failures - if tool_errors_by_tool: - safe_print(f"\n[cyan]šŸ“Š Failure Type Breakdown:[/cyan]") - safe_print("-" * 70) - - # Sort tools by total error count - sorted_tools = sorted( - tool_errors_by_tool.items(), - key=lambda x: len(x[1]), - reverse=True - ) - - for tool_name, errors in sorted_tools: - # Count failure types for this tool - failure_types = {} - for error in errors: - error_type = error.get("error_type", "Other") - if error_type not in failure_types: - failure_types[error_type] = 0 - failure_types[error_type] += 1 - - # Display tool name and total failures - total_failures = len(errors) - safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):") - - # Sort failure types by count - sorted_types = sorted( - failure_types.items(), - key=lambda x: x[1], - reverse=True - ) - - # Display each failure type with count and percentage - for failure_type, count in sorted_types: - percentage = (count / total_failures) * 100 - safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)") - + safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total[/bold red]") + # Simplified error printing here, full detail is in json + for tool_name, errors in tool_errors_by_tool.items(): + safe_print(f" {tool_name}: {len(errors)} errors") + safe_print(f"\n[cyan]šŸ’¾ Results saved to:[/cyan] {self.output_dir}") - safe_print(f" - Trajectories: trajectories.jsonl (combined)") - safe_print(f" - Individual batches: batch_*.jsonl (for debugging)") - safe_print(f" - Statistics: {self.stats_file.name}") - safe_print(f" - Errors: {self.errors_file.name}") - safe_print(f" - Checkpoint: {self.checkpoint_file.name}") - if early_exit: - safe_print(f"\n[bold yellow]ā„¹ļø Run was stopped early due to tool failures.[/bold yellow]") - safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]") - safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]") + def run(self, resume: bool = False): + """ + Run the batch processing pipeline (sync wrapper). + """ + asyncio.run(self._run_async(resume)) def main( dataset_file: str = None, - batch_size: int = None, run_name: str = None, distribution: str = "default", model: str = "claude-opus-4-20250514", @@ -1160,7 +888,6 @@ def main( Args: dataset_file (str): Path to JSONL file with 'prompt' field in each entry - batch_size (int): Number of prompts per batch run_name (str): Name for this run (used for output and checkpointing) distribution (str): Toolset distribution to use (default: "default") model (str): Model name to use (default: "claude-opus-4-20250514") @@ -1180,24 +907,13 @@ def main( Examples: # Basic usage - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run # Resume interrupted run - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume + python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume # Use specific distribution - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen - - # With ephemeral system prompt (not saved to dataset) - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ - --ephemeral_system_prompt="You are a helpful assistant focused on image generation." - - # With custom tool failure thresholds - python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ - --max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10 - - # List available distributions - python batch_runner.py --list_distributions + python batch_runner.py --dataset_file=data.jsonl --run_name=image_test --distribution=image_gen """ # Handle list distributions if list_distributions: @@ -1209,10 +925,6 @@ def main( all_dists = get_all_dists() for dist_name in sorted(all_dists.keys()): print_distribution_info(dist_name) - - print("\nšŸ’” Usage:") - print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\") - print(" --run_name=my_run --distribution=") return # Validate required arguments @@ -1220,10 +932,6 @@ def main( print("āŒ Error: --dataset_file is required") return - if not batch_size or batch_size < 1: - print("āŒ Error: --batch_size must be a positive integer") - return - if not run_name: print("āŒ Error: --run_name is required") return @@ -1232,7 +940,6 @@ def main( try: runner = BatchRunner( dataset_file=dataset_file, - batch_size=batch_size, run_name=run_name, distribution=distribution, max_iterations=max_turns, diff --git a/model_tools.py b/model_tools.py index eb27b753534..93f15d30fec 100644 --- a/model_tools.py +++ b/model_tools.py @@ -23,7 +23,7 @@ Usage: web_tools = get_tool_definitions(enabled_toolsets=['web_tools']) # Handle function calls from model - result = handle_function_call("web_search", {"query": "Python"}) + result = await handle_function_call("web_search", {"query": "Python"}) """ import json @@ -439,7 +439,7 @@ def get_tool_definitions( return filtered_tools -def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +async def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str: """ Handle function calls for web tools. @@ -454,25 +454,25 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) query = function_args.get("query", "") # Always use fixed limit of 5 limit = 5 - return web_search_tool(query, limit) + return await web_search_tool(query, limit) elif function_name == "web_extract": urls = function_args.get("urls", []) # Limit URLs to prevent abuse urls = urls[:5] if isinstance(urls, list) else [] - # Run async function in event loop - return asyncio.run(web_extract_tool(urls, "markdown")) + # Run async function + return await web_extract_tool(urls, "markdown") elif function_name == "web_crawl": url = function_args.get("url", "") instructions = function_args.get("instructions") - # Run async function in event loop - return asyncio.run(web_crawl_tool(url, instructions, "basic")) + # Run async function + return await web_crawl_tool(url, instructions, "basic") else: return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False) -def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: +async def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: """ Handle function calls for terminal tools. @@ -489,13 +489,20 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A background = function_args.get("background", False) timeout = function_args.get("timeout") - return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id) + # Run sync terminal tool in a thread to avoid blocking + return await asyncio.to_thread( + simple_terminal_tool, + command=command, + background=background, + timeout=timeout, + task_id=task_id + ) else: return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False) -def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +async def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str: """ Handle function calls for vision tools. @@ -512,14 +519,14 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}" - # Run async function in event loop - return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")) + # Run async function + return await vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash") else: return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False) -def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +async def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str: """ Handle function calls for Mixture-of-Agents tools. @@ -536,14 +543,14 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) if not user_prompt: return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False) - # Run async function in event loop - return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt)) + # Run async function + return await mixture_of_agents_tool(user_prompt=user_prompt) else: return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False) -def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str: +async def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str: """ Handle function calls for image generation tools. @@ -572,21 +579,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any] allow_nsfw_images = True seed = None - # Run async function in event loop with proper handling for multiprocessing - try: - # Try to get existing event loop - loop = asyncio.get_event_loop() - if loop.is_closed(): - # If closed, create a new one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - except RuntimeError: - # No event loop in current thread, create one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Run the coroutine in the event loop - result = loop.run_until_complete(image_generate_tool( + # Run async function + return await image_generate_tool( prompt=prompt, image_size=image_size, num_inference_steps=num_inference_steps, @@ -597,15 +591,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any] acceleration=acceleration, allow_nsfw_images=allow_nsfw_images, seed=seed - )) - - return result + ) else: return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False) -def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: +async def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str: """ Main function call dispatcher that routes calls to appropriate toolsets. @@ -627,23 +619,23 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task try: # Route web tools if function_name in ["web_search", "web_extract", "web_crawl"]: - return handle_web_function_call(function_name, function_args) + return await handle_web_function_call(function_name, function_args) # Route terminal tools elif function_name in ["terminal"]: - return handle_terminal_function_call(function_name, function_args, task_id) + return await handle_terminal_function_call(function_name, function_args, task_id) # Route vision tools elif function_name in ["vision_analyze"]: - return handle_vision_function_call(function_name, function_args) + return await handle_vision_function_call(function_name, function_args) # Route MoA tools elif function_name in ["mixture_of_agents"]: - return handle_moa_function_call(function_name, function_args) + return await handle_moa_function_call(function_name, function_args) # Route image generation tools elif function_name in ["image_generate"]: - return handle_image_function_call(function_name, function_args) + return await handle_image_function_call(function_name, function_args) else: error_msg = f"Unknown function: {function_name}" @@ -773,4 +765,4 @@ if __name__ == "__main__": if "terminal" in all_tool_names: no_terminal = get_tool_definitions(disabled_tools=["terminal"]) - print(f" All except terminal: {len(no_terminal)} tools") + print(f" All except terminal: {len(no_terminal)} tools") \ No newline at end of file diff --git a/run_agent.py b/run_agent.py index ef3f28648b6..27e16a0c607 100644 --- a/run_agent.py +++ b/run_agent.py @@ -24,121 +24,17 @@ import json import logging import os import time +import asyncio import sys from typing import List, Dict, Any, Optional -from openai import OpenAI +from openai import AsyncOpenAI import fire from datetime import datetime from pathlib import Path from rich import print from prokletor.formatters.hermes_formatter import HermesToolFormatterWithReasoning - -class SyncCustomToolCompletions: - def __init__(self, completions, formatter): - self._completions = completions - self._formatter = formatter - - def create(self, *args, messages, tools=None, **kwargs): - if not tools: - return self._completions.create(*args, messages=messages, **kwargs) - - # 1. Format system message with tools - system_prompt = self._formatter.format_system_message(tools) - - new_messages = list(messages) - if new_messages and new_messages[0]["role"] == "system": - # Append to existing system message - existing_content = new_messages[0]["content"] - if isinstance(existing_content, str): - new_messages[0] = { - "role": "system", - "content": existing_content + "\n\n" + system_prompt - } - else: - # Insert new system message - new_messages.insert(0, { - "role": "system", - "content": system_prompt - }) - - # 2. Call the API without the 'tools' parameter - kwargs.pop("tool_choice", None) - - # Process messages (e.g. convert roles and add reasoning prompt) - # use_tool_role=False for API compatibility - new_messages = self._formatter.process_messages(new_messages, use_tool_role=False) - - response = self._completions.create( - *args, - messages=new_messages, - # tools=tools, # Do NOT pass tools to the model - **kwargs - ) - - # 3. Parse the response - choice = response.choices[0] - if choice.message.content: - tool_calls = self._formatter.parse_response(choice.message.content, tools=tools) - if tool_calls: - choice.message.tool_calls = tool_calls - - # Clean the content if the formatter supports it - if hasattr(self._formatter, "extract_text_from_content"): - cleaned_content = self._formatter.extract_text_from_content(choice.message.content) - choice.message.content = cleaned_content - - if not choice.message.content: - choice.message.content = None - - return response - - def __getattr__(self, name): - return getattr(self._completions, name) - -class SyncCustomChat: - def __init__(self, chat, formatter): - self._chat = chat - self.completions = SyncCustomToolCompletions(chat.completions, formatter) - - def __getattr__(self, name): - return getattr(self._chat, name) - -class SyncHermesToolClientWithReasoning: - def __init__(self, client): - self._client = client - self.formatter = HermesToolFormatterWithReasoning() - self.chat = SyncCustomChat(client.chat, self.formatter) - - def format(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], use_tool_role: bool = True) -> List[Dict[str, Any]]: - """ - Format messages and tools into the Hermes XML format. - Useful for debugging or manual inspection of what will be sent to the model. - """ - # 1. Format system message with tools - system_prompt = self.formatter.format_system_message(tools) - - new_messages = list(messages) - if new_messages and new_messages[0]["role"] == "system": - # Append to existing system message - existing_content = new_messages[0]["content"] - if isinstance(existing_content, str): - new_messages[0] = { - "role": "system", - "content": existing_content + "\n\n" + system_prompt - } - else: - # Insert new system message - new_messages.insert(0, { - "role": "system", - "content": system_prompt - }) - - # 2. Process messages (convert tool calls and results to XML) - return self.formatter.process_messages(new_messages, use_tool_role=use_tool_role) - - def __getattr__(self, name): - return getattr(self._client, name) +from prokletor.clients.hermes import HermesToolClientWithReasoning # Load environment variables from .env file from dotenv import load_dotenv @@ -242,10 +138,10 @@ class AIAgent: client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key") try: - oai_client = OpenAI(**client_kwargs) + oai_client = AsyncOpenAI(**client_kwargs) # self.client = oai_client - self.client = SyncHermesToolClientWithReasoning(oai_client) - print(f"🧠 Wrapped OpenAI client with SyncHermesToolClientWithReasoning") + self.client = HermesToolClientWithReasoning(oai_client) + print(f"🧠 Wrapped OpenAI client with AsyncHermesToolClientWithReasoning") print(f"šŸ¤– AI Agent initialized with model: {self.model}") if base_url: @@ -494,7 +390,7 @@ class AIAgent: except Exception as e: print(f"āš ļø Failed to save trajectory: {e}") - def run_conversation( + async def run_conversation( self, user_message: str, system_message: str = None, @@ -567,7 +463,7 @@ class AIAgent: # Make API call with tools - response = self.client.chat.completions.create( + response = await self.client.chat.completions.create( model=self.model, messages=api_messages, tools=self.tools if self.tools else None, @@ -603,7 +499,7 @@ class AIAgent: print(f"āš ļø OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}") print(f"ā³ Retrying in {wait_time}s...") logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}") - time.sleep(wait_time) + await asyncio.sleep(wait_time) if response is None: raise last_api_error if last_api_error else RuntimeError("OpenAI-compatible API call failed without a response") @@ -694,7 +590,7 @@ class AIAgent: tool_start_time = time.time() # Execute the tool with task_id to isolate VMs between concurrent tasks - function_result = handle_function_call(function_name, function_args, effective_task_id) + function_result = await handle_function_call(function_name, function_args, effective_task_id) tool_duration = time.time() - tool_start_time result_preview = function_result[:200] if len(function_result) > 200 else function_result @@ -720,7 +616,7 @@ class AIAgent: # Delay between tool calls if self.tool_delay > 0 and i < len(assistant_message.tool_calls): - time.sleep(self.tool_delay) + await asyncio.sleep(self.tool_delay) # Continue loop for next response continue @@ -838,7 +734,7 @@ class AIAgent: # Clean up VM for this task after conversation completes try: - cleanup_vm(effective_task_id) + await asyncio.to_thread(cleanup_vm, effective_task_id) except Exception as e: if self.verbose_logging: logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}") @@ -854,7 +750,7 @@ class AIAgent: "profiling_stats": profiling_stats } - def chat(self, message: str) -> str: + async def chat(self, message: str) -> str: """ Simple chat interface that returns just the final response. @@ -864,7 +760,7 @@ class AIAgent: Returns: str: Final assistant response """ - result = self.run_conversation(message) + result = await self.run_conversation(message) return result["final_response"] @@ -1037,7 +933,7 @@ def main( print("\n" + "=" * 50) # Run conversation - result = agent.run_conversation(user_query) + result = asyncio.run(agent.run_conversation(user_query)) print("\n" + "=" * 50) print("šŸ“‹ CONVERSATION SUMMARY") diff --git a/tools/web_tools.py b/tools/web_tools.py index 3f7df9f43ec..df7765b5239 100644 --- a/tools/web_tools.py +++ b/tools/web_tools.py @@ -48,11 +48,11 @@ import uuid import datetime from pathlib import Path from typing import List, Dict, Any, Optional -from firecrawl import Firecrawl +from firecrawl import AsyncFirecrawl from openai import AsyncOpenAI # Initialize Firecrawl client once at module level -firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY")) +firecrawl_client = AsyncFirecrawl(api_key=os.getenv("FIRECRAWL_API_KEY")) # Initialize Nous Research API client for LLM processing (async) nous_client = AsyncOpenAI( @@ -261,7 +261,7 @@ def clean_base64_images(text: str) -> str: return cleaned_text -def web_search_tool(query: str, limit: int = 5) -> str: +async def web_search_tool(query: str, limit: int = 5) -> str: """ Search the web for information using available search API backend. @@ -312,7 +312,7 @@ def web_search_tool(query: str, limit: int = 5) -> str: # Use Firecrawl's v2 search functionality WITHOUT scraping # We only want search result metadata, not scraped content # Docs: https://docs.firecrawl.dev/features/search - response = firecrawl_client.search( + response = await firecrawl_client.search( query=query, limit=limit ) @@ -446,7 +446,7 @@ async def web_extract_tool( for url in urls: try: print(f" šŸ“„ Scraping: {url}") - scrape_result = firecrawl_client.scrape( + scrape_result = await firecrawl_client.scrape( url=url, formats=formats ) @@ -703,7 +703,7 @@ async def web_crawl_tool( # Use the crawl method which waits for completion automatically try: - crawl_result = firecrawl_client.crawl( + crawl_result = await firecrawl_client.crawl( url=url, **crawl_params )