switch to asyncio

This commit is contained in:
hjc-puro 2025-11-22 11:25:23 -05:00
parent 98321be8b0
commit e91d9e839a
4 changed files with 261 additions and 666 deletions

View file

@ -4,25 +4,25 @@ Batch Agent Runner
This module provides parallel batch processing capabilities for running the agent
across multiple prompts from a dataset. It includes:
- Dataset loading and batching
- Parallel batch processing with multiprocessing
- Dataset loading
- Concurrent processing with asyncio (Producer-Consumer pattern)
- Checkpointing for fault tolerance and resumption
- Trajectory saving in the proper format (from/value pairs)
- Tool usage statistics aggregation across all batches
- Tool usage statistics aggregation across all prompts
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
- Configurable failure thresholds with automatic data consolidation
Usage:
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
# Resume an interrupted run
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
# Use a specific toolset distribution
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --distribution=image_gen
# Configure tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
"""
@ -30,10 +30,10 @@ import json
import logging
import os
import time
import asyncio
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from typing import List, Dict, Any, Optional, Tuple, Set
from datetime import datetime
from multiprocessing import Pool, Manager, Lock
import traceback
import re
@ -49,9 +49,6 @@ from toolset_distributions import (
from safe_print import safe_print
# Global configuration for worker processes
_WORKER_CONFIG = {}
# Canonical names for the terminal tool (old & new implementations)
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
@ -295,10 +292,9 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
return tool_stats
def _process_single_prompt(
async def _process_single_prompt(
prompt_index: int,
prompt_data: Dict[str, Any],
batch_num: int,
config: Dict[str, Any]
) -> Dict[str, Any]:
"""
@ -307,7 +303,6 @@ def _process_single_prompt(
Args:
prompt_index (int): Index of prompt in dataset
prompt_data (Dict): Prompt data containing 'prompt' field
batch_num (int): Batch number
config (Dict): Configuration dict with agent parameters
Returns:
@ -336,7 +331,7 @@ def _process_single_prompt(
)
# Run the agent with task_id to ensure each task gets its own isolated VM
result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
result = await agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
# Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"])
@ -365,7 +360,6 @@ def _process_single_prompt(
"api_calls": result["api_calls"],
"toolsets_used": selected_toolsets,
"metadata": {
"batch_num": batch_num,
"timestamp": datetime.now().isoformat(),
"model": config["model"]
}
@ -389,132 +383,38 @@ def _process_single_prompt(
"tool_stats": {},
"toolsets_used": [],
"metadata": {
"batch_num": batch_num,
"timestamp": datetime.now().isoformat()
}
}
def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
async def worker(
work_queue: asyncio.Queue,
result_queue: asyncio.Queue,
config: Dict[str, Any]
):
"""
Worker function to process a single batch of prompts.
Args:
args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config)
Returns:
Dict: Batch results with statistics
Consumer worker that processes prompts from the work queue.
"""
batch_num, batch_data, output_dir, completed_prompts_set, config = args
output_dir = Path(output_dir)
print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)")
# Output file for this batch
batch_output_file = output_dir / f"batch_{batch_num}.jsonl"
# Filter out already completed prompts
prompts_to_process = [
(idx, data) for idx, data in batch_data
if idx not in completed_prompts_set
]
if not prompts_to_process:
print(f"✅ Batch {batch_num}: Already completed (skipping)")
return {
"batch_num": batch_num,
"processed": 0,
"skipped": len(batch_data),
"tool_stats": {},
"completed_prompts": []
}
print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)")
# Initialize aggregated stats for this batch
batch_tool_stats = {}
batch_profiling_stats = [] # Collect profiling stats from each prompt
completed_in_batch = []
all_tool_errors = [] # Track all tool errors in this batch
exception_errors = [] # Track top-level exceptions
# Process each prompt sequentially in this batch
for prompt_index, prompt_data in prompts_to_process:
# Process the prompt
result = _process_single_prompt(
prompt_index,
prompt_data,
batch_num,
config
)
# Track tool errors from the conversation
if result.get("tool_errors"):
for tool_error in result["tool_errors"]:
all_tool_errors.append({
"prompt_index": prompt_index,
"tool_name": tool_error["tool_name"],
"error_message": tool_error["error_message"],
"full_content": tool_error.get("full_content", ""),
"error_type": tool_error.get("error_type", "Other")
})
# Track top-level exceptions (not tool errors)
if not result["success"]:
exception_errors.append({
"prompt_index": prompt_index,
"error": result.get("error", "Unknown error"),
"traceback": result.get("traceback", "")
})
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
# Save trajectory if successful
if result["success"] and result["trajectory"]:
trajectory_entry = {
"prompt_index": prompt_index,
"conversations": result["trajectory"],
"metadata": result["metadata"],
"completed": result["completed"],
"api_calls": result["api_calls"],
"toolsets_used": result["toolsets_used"]
}
# Append to batch output file
with open(batch_output_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
# Aggregate tool statistics
for tool_name, stats in result.get("tool_stats", {}).items():
if tool_name not in batch_tool_stats:
batch_tool_stats[tool_name] = {
"count": 0,
"success": 0,
"failure": 0
}
batch_tool_stats[tool_name]["count"] += stats["count"]
batch_tool_stats[tool_name]["success"] += stats["success"]
batch_tool_stats[tool_name]["failure"] += stats["failure"]
# Collect profiling statistics
if result.get("profiling_stats"):
batch_profiling_stats.append(result["profiling_stats"])
completed_in_batch.append(prompt_index)
print(f" ✅ Prompt {prompt_index} completed")
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
return {
"batch_num": batch_num,
"processed": len(prompts_to_process),
"skipped": len(batch_data) - len(prompts_to_process),
"tool_stats": batch_tool_stats,
"profiling_stats": batch_profiling_stats,
"completed_prompts": completed_in_batch,
"tool_errors": all_tool_errors,
"exception_errors": exception_errors
}
while True:
try:
task = await work_queue.get()
if task is None:
# Sentinel to stop worker
work_queue.task_done()
break
prompt_index, prompt_data = task
result = await _process_single_prompt(prompt_index, prompt_data, config)
await result_queue.put(result)
work_queue.task_done()
except Exception as e:
print(f"Error in worker: {e}")
if 'task' in locals() and task is not None:
work_queue.task_done()
class BatchRunner:
@ -525,7 +425,6 @@ class BatchRunner:
def __init__(
self,
dataset_file: str,
batch_size: int,
run_name: str,
distribution: str = "default",
max_iterations: int = 10,
@ -546,14 +445,13 @@ class BatchRunner:
Args:
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
batch_size (int): Number of prompts per batch
run_name (str): Name for this run (used for checkpointing and output)
distribution (str): Toolset distribution to use (default: "default")
max_iterations (int): Max iterations per agent run
base_url (str): Base URL for model API
api_key (str): API key for model
model (str): Model name to use
num_workers (int): Number of parallel workers
num_workers (int): Number of parallel workers (default: 4)
verbose (bool): Enable verbose logging
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
@ -563,7 +461,6 @@ class BatchRunner:
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
"""
self.dataset_file = Path(dataset_file)
self.batch_size = batch_size
self.run_name = run_name
self.distribution = distribution
self.max_iterations = max_iterations
@ -596,16 +493,14 @@ class BatchRunner:
# Errors file
self.errors_file = self.output_dir / "errors.json"
# Trajectories file
self.trajectories_file = self.output_dir / "trajectories.jsonl"
# Load dataset
self.dataset = self._load_dataset()
# Create batches
self.batches = self._create_batches()
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
safe_print(f" Batch size: {self.batch_size}")
safe_print(f" Total batches: {len(self.batches)}")
safe_print(f" Run name: {self.run_name}")
safe_print(f" Distribution: {self.distribution}")
safe_print(f" Output directory: {self.output_dir}")
@ -651,20 +546,6 @@ class BatchRunner:
return dataset
def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]:
"""
Split dataset into batches with indices.
Returns:
List of batches, where each batch is a list of (index, entry) tuples
"""
batches = []
for i in range(0, len(self.dataset), self.batch_size):
batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)]
batches.append(batch)
return batches
def _load_checkpoint(self) -> Dict[str, Any]:
"""
Load checkpoint data if it exists.
@ -676,7 +557,6 @@ class BatchRunner:
return {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
@ -688,61 +568,34 @@ class BatchRunner:
return {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
def _save_checkpoint(self, checkpoint_data: Dict[str, Any]):
"""
Save checkpoint data.
Args:
checkpoint_data (Dict): Checkpoint data to save
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None,
profiling_stats_list: List[Dict] = None):
def _save_final_stats(
self,
num_processed: int,
tool_stats: Dict[str, Dict[str, int]],
start_time: float,
tool_errors_by_tool: Dict[str, List[Dict]],
exception_errors: List[Dict],
early_exit: bool = False,
exit_reason: str = None,
profiling_stats_list: List[Dict] = None
):
"""
Consolidate batch data into trajectories.jsonl and save statistics.
Args:
num_batches (int): Number of batches processed
tool_stats (Dict): Aggregated tool statistics
start_time (float): Start time of the run
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
exception_errors (List): Top-level exceptions
early_exit (bool): Whether this is an early exit
exit_reason (str): Reason for early exit
profiling_stats_list (List[Dict]): List of profiling statistics from each conversation
Save final statistics and errors.
"""
# Combine all batch files into a single trajectories.jsonl file
combined_file = self.output_dir / "trajectories.jsonl"
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
entries_written = 0
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(num_batches):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
outfile.write(line)
entries_written += 1
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
# Calculate success rates for tool stats
for tool_name in tool_stats:
stats = tool_stats[tool_name]
@ -794,17 +647,21 @@ class BatchRunner:
# Aggregate profiling statistics if available
aggregated_profiling_stats = None
if profiling_stats_list:
from profiling import aggregate_profiling_stats
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
try:
from profiling import aggregate_profiling_stats, print_aggregated_statistics
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
# Display aggregated profiling statistics
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
except ImportError:
pass
# Save final statistics (without detailed errors)
# Save final statistics
final_stats = {
"run_name": self.run_name,
"distribution": self.distribution,
"total_prompts": len(self.dataset),
"total_batches": len(self.batches),
"batches_processed": num_batches,
"batch_size": self.batch_size,
"processed": num_processed,
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
@ -817,18 +674,9 @@ class BatchRunner:
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
# Display aggregated profiling statistics
if aggregated_profiling_stats:
from profiling import print_aggregated_statistics
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
def run(self, resume: bool = False):
async def _run_async(self, resume: bool = False):
"""
Run the batch processing pipeline.
Args:
resume (bool): Whether to resume from checkpoint
Async implementation of the batch runner pipeline.
"""
print("\n" + "=" * 70)
print("🚀 Starting Batch Processing")
@ -838,15 +686,32 @@ class BatchRunner:
checkpoint_data = self._load_checkpoint() if resume else {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
if resume and checkpoint_data.get("completed_prompts"):
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
# Prepare configuration for workers
config = {
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Prepare queues
work_queue = asyncio.Queue()
result_queue = asyncio.Queue()
# Enqueue prompts to process
prompts_to_process = []
for idx, entry in enumerate(self.dataset):
if idx not in completed_prompts_set:
prompts_to_process.append((idx, entry))
work_queue.put_nowait((idx, entry))
total_to_process = len(prompts_to_process)
if total_to_process == 0:
print("✅ All prompts already completed.")
return
# Worker configuration
worker_config = {
"distribution": self.distribution,
"model": self.model,
"max_iterations": self.max_iterations,
@ -857,120 +722,119 @@ class BatchRunner:
"log_prefix_chars": self.log_prefix_chars
}
# Get completed prompts set
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# Start workers
workers = []
for _ in range(min(self.num_workers, total_to_process)):
w = asyncio.create_task(worker(work_queue, result_queue, worker_config))
workers.append(w)
print(f" Processing {total_to_process} prompts with {len(workers)} workers...")
# Aggregate statistics across all batches
# Aggregate statistics
total_tool_stats = {}
all_profiling_stats = [] # Collect all profiling stats for aggregation
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
all_profiling_stats = []
tool_errors_by_tool = {}
all_exception_errors = []
all_completed_prompts = list(completed_prompts_set)
total_processed = len(completed_prompts_set)
total_tool_errors = 0
early_exit = False
exit_reason = None
processed_count = 0
start_time = time.time()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
tasks = [
(
batch_num,
batch_data,
str(self.output_dir), # Convert Path to string for pickling
completed_prompts_set,
config
)
for batch_num, batch_data in enumerate(self.batches)
]
# Process batches in parallel and check tool failure threshold as results come in
# imap_unordered allows parallel processing while getting results as they complete
batch_num = 0
try:
for result in pool.imap_unordered(_process_batch_worker, tasks):
# Update statistics
all_completed_prompts.extend(result.get("completed_prompts", []))
total_processed += result.get("processed", 0)
# Aggregate tool stats
for tool_name, stats in result.get("tool_stats", {}).items():
if tool_name not in total_tool_stats:
total_tool_stats[tool_name] = {
"count": 0,
"success": 0,
"failure": 0
}
total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"]
# Collect profiling stats from this batch
if result.get("profiling_stats"):
all_profiling_stats.extend(result["profiling_stats"])
# Aggregate tool errors (keep k most recent per tool)
for tool_error in result.get("tool_errors", []):
tool_name = tool_error["tool_name"]
if tool_name not in tool_errors_by_tool:
tool_errors_by_tool[tool_name] = []
# Add error and keep only k most recent
tool_errors_by_tool[tool_name].append(tool_error)
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
total_tool_errors += 1
# Track exception errors
all_exception_errors.extend(result.get("exception_errors", []))
# Check tool failure thresholds
# Calculate total tool calls (not prompts)
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
# Check absolute count threshold
if total_tool_errors >= self.max_tool_failures:
# Process results as they arrive
try:
while processed_count < total_to_process:
result = await result_queue.get()
processed_count += 1
prompt_index = result["prompt_index"]
# Track exceptions
if not result["success"]:
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
all_exception_errors.append({
"prompt_index": prompt_index,
"error": result.get("error", "Unknown error"),
"traceback": result.get("traceback", "")
})
else:
print(f" ✅ Prompt {prompt_index} completed")
# Save trajectory immediately
if result.get("trajectory"):
trajectory_entry = {
"prompt_index": prompt_index,
"conversations": result["trajectory"],
"metadata": result["metadata"],
"completed": result["completed"],
"api_calls": result["api_calls"],
"toolsets_used": result["toolsets_used"]
}
with open(self.trajectories_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
# Aggregate tool stats
for tool_name, stats in result.get("tool_stats", {}).items():
if tool_name not in total_tool_stats:
total_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0}
total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"]
# Collect profiling stats
if result.get("profiling_stats"):
all_profiling_stats.append(result["profiling_stats"])
# Aggregate tool errors
for tool_error in result.get("tool_errors", []):
tool_name = tool_error["tool_name"]
if tool_name not in tool_errors_by_tool:
tool_errors_by_tool[tool_name] = []
tool_errors_by_tool[tool_name].append(tool_error)
# Keep only k most recent
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
total_tool_errors += 1
# Update checkpoint
completed_prompts_set.add(prompt_index)
checkpoint_data["completed_prompts"] = list(completed_prompts_set)
self._save_checkpoint(checkpoint_data)
# Check failure thresholds
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
if total_tool_errors >= self.max_tool_failures:
early_exit = True
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
break
if total_tool_calls >= self.min_tool_calls_for_rate:
tool_failure_rate = total_tool_errors / total_tool_calls
if tool_failure_rate >= self.max_tool_failure_rate:
early_exit = True
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
pool.terminate() # Stop all workers immediately
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%})"
break
# Check rate threshold (only if we have enough tool calls to trust the rate)
if total_tool_calls >= self.min_tool_calls_for_rate:
tool_failure_rate = total_tool_errors / total_tool_calls
if tool_failure_rate >= self.max_tool_failure_rate:
early_exit = True
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
pool.terminate() # Stop all workers immediately
break
# Update checkpoint after each batch completes
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
batch_num += 1
except KeyboardInterrupt:
safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]")
pool.terminate()
early_exit = True
exit_reason = "Interrupted by user"
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Consolidate data and save statistics
num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
self._consolidate_data(
num_batches_processed,
except asyncio.CancelledError:
early_exit = True
exit_reason = "Run cancelled"
finally:
# Stop all workers
for _ in range(len(workers)):
work_queue.put_nowait(None)
await asyncio.gather(*workers, return_exceptions=True)
if early_exit:
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
# Save final statistics
self._save_final_stats(
processed_count,
total_tool_stats,
start_time,
tool_errors_by_tool,
@ -980,164 +844,28 @@ class BatchRunner:
all_profiling_stats
)
# Print summary
# Summary output
safe_print("\n" + "=" * 70)
if early_exit:
safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
else:
safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
safe_print("=" * 70)
safe_print(f"✅ Total prompts processed: {total_processed}")
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
safe_print(f"✅ Total prompts processed: {processed_count}/{total_to_process}")
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
# Tool error summary
if tool_errors_by_tool:
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
safe_print("[red]-[/red]" * 70)
# Sort tools by error count
sorted_tools = sorted(
tool_errors_by_tool.items(),
key=lambda x: len(x[1]),
reverse=True
)
for tool_name, errors in sorted_tools:
# Count unique error messages
unique_errors = {}
for error in errors:
error_msg = error["error_message"][:100] # Truncate for grouping
if error_msg not in unique_errors:
unique_errors[error_msg] = []
unique_errors[error_msg].append(error)
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
# Show up to 3 most recent unique error types
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
# Show one example with prompt index and full content prefix
example = instances[-1] # Most recent
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
# Show full content prefix (first 200 chars)
full_content = example.get('full_content', '')
if full_content and full_content != error_preview:
content_preview = full_content[:200]
if len(full_content) > 200:
content_preview += "..."
# Show with prefix indicator
safe_print(f" [dim]Content: {content_preview}[/dim]")
if len(unique_errors) > 3:
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
# Exception errors
if all_exception_errors:
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
safe_print("[red]-[/red]" * 70)
for error in all_exception_errors[:self.keep_recent_errors]:
error_msg = error["error"]
error_preview = error_msg[:150]
if len(error_msg) > 150:
error_preview += "..."
safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]")
# Show traceback prefix if available
traceback_text = error.get("traceback", "")
if traceback_text:
# Show last 3 lines of traceback for context
tb_lines = traceback_text.strip().split('\n')
relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines
for line in relevant_lines:
safe_print(f" [dim]{line}[/dim]")
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
safe_print("-" * 70)
if total_tool_stats:
# Sort by count descending
sorted_tools = sorted(
total_tool_stats.items(),
key=lambda x: x[1]["count"],
reverse=True
)
safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
safe_print("-" * 70)
for tool_name, stats in sorted_tools:
safe_print(
f"{tool_name:<25} "
f"{stats['count']:<10} "
f"{stats['success']:<10} "
f"{stats['failure']:<10} "
f"{stats.get('success_rate', 0):.1f}%"
)
else:
safe_print("No tool calls were made during this run.")
# Display failure type breakdown for tools with failures
if tool_errors_by_tool:
safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]")
safe_print("-" * 70)
# Sort tools by total error count
sorted_tools = sorted(
tool_errors_by_tool.items(),
key=lambda x: len(x[1]),
reverse=True
)
for tool_name, errors in sorted_tools:
# Count failure types for this tool
failure_types = {}
for error in errors:
error_type = error.get("error_type", "Other")
if error_type not in failure_types:
failure_types[error_type] = 0
failure_types[error_type] += 1
# Display tool name and total failures
total_failures = len(errors)
safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):")
# Sort failure types by count
sorted_types = sorted(
failure_types.items(),
key=lambda x: x[1],
reverse=True
)
# Display each failure type with count and percentage
for failure_type, count in sorted_types:
percentage = (count / total_failures) * 100
safe_print(f"{failure_type:<20} {count:>4} ({percentage:>5.1f}%)")
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total[/bold red]")
# Simplified error printing here, full detail is in json
for tool_name, errors in tool_errors_by_tool.items():
safe_print(f" {tool_name}: {len(errors)} errors")
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
safe_print(f" - Statistics: {self.stats_file.name}")
safe_print(f" - Errors: {self.errors_file.name}")
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
if early_exit:
safe_print(f"\n[bold yellow] Run was stopped early due to tool failures.[/bold yellow]")
safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]")
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
def run(self, resume: bool = False):
"""
Run the batch processing pipeline (sync wrapper).
"""
asyncio.run(self._run_async(resume))
def main(
dataset_file: str = None,
batch_size: int = None,
run_name: str = None,
distribution: str = "default",
model: str = "claude-opus-4-20250514",
@ -1160,7 +888,6 @@ def main(
Args:
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
batch_size (int): Number of prompts per batch
run_name (str): Name for this run (used for output and checkpointing)
distribution (str): Toolset distribution to use (default: "default")
model (str): Model name to use (default: "claude-opus-4-20250514")
@ -1180,24 +907,13 @@ def main(
Examples:
# Basic usage
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
# Resume interrupted run
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
# Use specific distribution
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
# With ephemeral system prompt (not saved to dataset)
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
# With custom tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10
# List available distributions
python batch_runner.py --list_distributions
python batch_runner.py --dataset_file=data.jsonl --run_name=image_test --distribution=image_gen
"""
# Handle list distributions
if list_distributions:
@ -1209,10 +925,6 @@ def main(
all_dists = get_all_dists()
for dist_name in sorted(all_dists.keys()):
print_distribution_info(dist_name)
print("\n💡 Usage:")
print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\")
print(" --run_name=my_run --distribution=<name>")
return
# Validate required arguments
@ -1220,10 +932,6 @@ def main(
print("❌ Error: --dataset_file is required")
return
if not batch_size or batch_size < 1:
print("❌ Error: --batch_size must be a positive integer")
return
if not run_name:
print("❌ Error: --run_name is required")
return
@ -1232,7 +940,6 @@ def main(
try:
runner = BatchRunner(
dataset_file=dataset_file,
batch_size=batch_size,
run_name=run_name,
distribution=distribution,
max_iterations=max_turns,

View file

@ -23,7 +23,7 @@ Usage:
web_tools = get_tool_definitions(enabled_toolsets=['web_tools'])
# Handle function calls from model
result = handle_function_call("web_search", {"query": "Python"})
result = await handle_function_call("web_search", {"query": "Python"})
"""
import json
@ -439,7 +439,7 @@ def get_tool_definitions(
return filtered_tools
def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
async def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
"""
Handle function calls for web tools.
@ -454,25 +454,25 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
query = function_args.get("query", "")
# Always use fixed limit of 5
limit = 5
return web_search_tool(query, limit)
return await web_search_tool(query, limit)
elif function_name == "web_extract":
urls = function_args.get("urls", [])
# Limit URLs to prevent abuse
urls = urls[:5] if isinstance(urls, list) else []
# Run async function in event loop
return asyncio.run(web_extract_tool(urls, "markdown"))
# Run async function
return await web_extract_tool(urls, "markdown")
elif function_name == "web_crawl":
url = function_args.get("url", "")
instructions = function_args.get("instructions")
# Run async function in event loop
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
# Run async function
return await web_crawl_tool(url, instructions, "basic")
else:
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
async def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
"""
Handle function calls for terminal tools.
@ -489,13 +489,20 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
background = function_args.get("background", False)
timeout = function_args.get("timeout")
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
# Run sync terminal tool in a thread to avoid blocking
return await asyncio.to_thread(
simple_terminal_tool,
command=command,
background=background,
timeout=timeout,
task_id=task_id
)
else:
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
async def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
"""
Handle function calls for vision tools.
@ -512,14 +519,14 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
# Run async function in event loop
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
# Run async function
return await vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")
else:
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
async def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
"""
Handle function calls for Mixture-of-Agents tools.
@ -536,14 +543,14 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
if not user_prompt:
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
# Run async function in event loop
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
# Run async function
return await mixture_of_agents_tool(user_prompt=user_prompt)
else:
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
async def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
"""
Handle function calls for image generation tools.
@ -572,21 +579,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
allow_nsfw_images = True
seed = None
# Run async function in event loop with proper handling for multiprocessing
try:
# Try to get existing event loop
loop = asyncio.get_event_loop()
if loop.is_closed():
# If closed, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
except RuntimeError:
# No event loop in current thread, create one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the coroutine in the event loop
result = loop.run_until_complete(image_generate_tool(
# Run async function
return await image_generate_tool(
prompt=prompt,
image_size=image_size,
num_inference_steps=num_inference_steps,
@ -597,15 +591,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
acceleration=acceleration,
allow_nsfw_images=allow_nsfw_images,
seed=seed
))
return result
)
else:
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
async def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
"""
Main function call dispatcher that routes calls to appropriate toolsets.
@ -627,23 +619,23 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
try:
# Route web tools
if function_name in ["web_search", "web_extract", "web_crawl"]:
return handle_web_function_call(function_name, function_args)
return await handle_web_function_call(function_name, function_args)
# Route terminal tools
elif function_name in ["terminal"]:
return handle_terminal_function_call(function_name, function_args, task_id)
return await handle_terminal_function_call(function_name, function_args, task_id)
# Route vision tools
elif function_name in ["vision_analyze"]:
return handle_vision_function_call(function_name, function_args)
return await handle_vision_function_call(function_name, function_args)
# Route MoA tools
elif function_name in ["mixture_of_agents"]:
return handle_moa_function_call(function_name, function_args)
return await handle_moa_function_call(function_name, function_args)
# Route image generation tools
elif function_name in ["image_generate"]:
return handle_image_function_call(function_name, function_args)
return await handle_image_function_call(function_name, function_args)
else:
error_msg = f"Unknown function: {function_name}"
@ -773,4 +765,4 @@ if __name__ == "__main__":
if "terminal" in all_tool_names:
no_terminal = get_tool_definitions(disabled_tools=["terminal"])
print(f" All except terminal: {len(no_terminal)} tools")
print(f" All except terminal: {len(no_terminal)} tools")

View file

@ -24,121 +24,17 @@ import json
import logging
import os
import time
import asyncio
import sys
from typing import List, Dict, Any, Optional
from openai import OpenAI
from openai import AsyncOpenAI
import fire
from datetime import datetime
from pathlib import Path
from rich import print
from prokletor.formatters.hermes_formatter import HermesToolFormatterWithReasoning
class SyncCustomToolCompletions:
def __init__(self, completions, formatter):
self._completions = completions
self._formatter = formatter
def create(self, *args, messages, tools=None, **kwargs):
if not tools:
return self._completions.create(*args, messages=messages, **kwargs)
# 1. Format system message with tools
system_prompt = self._formatter.format_system_message(tools)
new_messages = list(messages)
if new_messages and new_messages[0]["role"] == "system":
# Append to existing system message
existing_content = new_messages[0]["content"]
if isinstance(existing_content, str):
new_messages[0] = {
"role": "system",
"content": existing_content + "\n\n" + system_prompt
}
else:
# Insert new system message
new_messages.insert(0, {
"role": "system",
"content": system_prompt
})
# 2. Call the API without the 'tools' parameter
kwargs.pop("tool_choice", None)
# Process messages (e.g. convert roles and add reasoning prompt)
# use_tool_role=False for API compatibility
new_messages = self._formatter.process_messages(new_messages, use_tool_role=False)
response = self._completions.create(
*args,
messages=new_messages,
# tools=tools, # Do NOT pass tools to the model
**kwargs
)
# 3. Parse the response
choice = response.choices[0]
if choice.message.content:
tool_calls = self._formatter.parse_response(choice.message.content, tools=tools)
if tool_calls:
choice.message.tool_calls = tool_calls
# Clean the content if the formatter supports it
if hasattr(self._formatter, "extract_text_from_content"):
cleaned_content = self._formatter.extract_text_from_content(choice.message.content)
choice.message.content = cleaned_content
if not choice.message.content:
choice.message.content = None
return response
def __getattr__(self, name):
return getattr(self._completions, name)
class SyncCustomChat:
def __init__(self, chat, formatter):
self._chat = chat
self.completions = SyncCustomToolCompletions(chat.completions, formatter)
def __getattr__(self, name):
return getattr(self._chat, name)
class SyncHermesToolClientWithReasoning:
def __init__(self, client):
self._client = client
self.formatter = HermesToolFormatterWithReasoning()
self.chat = SyncCustomChat(client.chat, self.formatter)
def format(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], use_tool_role: bool = True) -> List[Dict[str, Any]]:
"""
Format messages and tools into the Hermes XML format.
Useful for debugging or manual inspection of what will be sent to the model.
"""
# 1. Format system message with tools
system_prompt = self.formatter.format_system_message(tools)
new_messages = list(messages)
if new_messages and new_messages[0]["role"] == "system":
# Append to existing system message
existing_content = new_messages[0]["content"]
if isinstance(existing_content, str):
new_messages[0] = {
"role": "system",
"content": existing_content + "\n\n" + system_prompt
}
else:
# Insert new system message
new_messages.insert(0, {
"role": "system",
"content": system_prompt
})
# 2. Process messages (convert tool calls and results to XML)
return self.formatter.process_messages(new_messages, use_tool_role=use_tool_role)
def __getattr__(self, name):
return getattr(self._client, name)
from prokletor.clients.hermes import HermesToolClientWithReasoning
# Load environment variables from .env file
from dotenv import load_dotenv
@ -242,10 +138,10 @@ class AIAgent:
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
try:
oai_client = OpenAI(**client_kwargs)
oai_client = AsyncOpenAI(**client_kwargs)
# self.client = oai_client
self.client = SyncHermesToolClientWithReasoning(oai_client)
print(f"🧠 Wrapped OpenAI client with SyncHermesToolClientWithReasoning")
self.client = HermesToolClientWithReasoning(oai_client)
print(f"🧠 Wrapped OpenAI client with AsyncHermesToolClientWithReasoning")
print(f"🤖 AI Agent initialized with model: {self.model}")
if base_url:
@ -494,7 +390,7 @@ class AIAgent:
except Exception as e:
print(f"⚠️ Failed to save trajectory: {e}")
def run_conversation(
async def run_conversation(
self,
user_message: str,
system_message: str = None,
@ -567,7 +463,7 @@ class AIAgent:
# Make API call with tools
response = self.client.chat.completions.create(
response = await self.client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=self.tools if self.tools else None,
@ -603,7 +499,7 @@ class AIAgent:
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
print(f"⏳ Retrying in {wait_time}s...")
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
time.sleep(wait_time)
await asyncio.sleep(wait_time)
if response is None:
raise last_api_error if last_api_error else RuntimeError("OpenAI-compatible API call failed without a response")
@ -694,7 +590,7 @@ class AIAgent:
tool_start_time = time.time()
# Execute the tool with task_id to isolate VMs between concurrent tasks
function_result = handle_function_call(function_name, function_args, effective_task_id)
function_result = await handle_function_call(function_name, function_args, effective_task_id)
tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result
@ -720,7 +616,7 @@ class AIAgent:
# Delay between tool calls
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
time.sleep(self.tool_delay)
await asyncio.sleep(self.tool_delay)
# Continue loop for next response
continue
@ -838,7 +734,7 @@ class AIAgent:
# Clean up VM for this task after conversation completes
try:
cleanup_vm(effective_task_id)
await asyncio.to_thread(cleanup_vm, effective_task_id)
except Exception as e:
if self.verbose_logging:
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
@ -854,7 +750,7 @@ class AIAgent:
"profiling_stats": profiling_stats
}
def chat(self, message: str) -> str:
async def chat(self, message: str) -> str:
"""
Simple chat interface that returns just the final response.
@ -864,7 +760,7 @@ class AIAgent:
Returns:
str: Final assistant response
"""
result = self.run_conversation(message)
result = await self.run_conversation(message)
return result["final_response"]
@ -1037,7 +933,7 @@ def main(
print("\n" + "=" * 50)
# Run conversation
result = agent.run_conversation(user_query)
result = asyncio.run(agent.run_conversation(user_query))
print("\n" + "=" * 50)
print("📋 CONVERSATION SUMMARY")

View file

@ -48,11 +48,11 @@ import uuid
import datetime
from pathlib import Path
from typing import List, Dict, Any, Optional
from firecrawl import Firecrawl
from firecrawl import AsyncFirecrawl
from openai import AsyncOpenAI
# Initialize Firecrawl client once at module level
firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
firecrawl_client = AsyncFirecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
# Initialize Nous Research API client for LLM processing (async)
nous_client = AsyncOpenAI(
@ -261,7 +261,7 @@ def clean_base64_images(text: str) -> str:
return cleaned_text
def web_search_tool(query: str, limit: int = 5) -> str:
async def web_search_tool(query: str, limit: int = 5) -> str:
"""
Search the web for information using available search API backend.
@ -312,7 +312,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
# Use Firecrawl's v2 search functionality WITHOUT scraping
# We only want search result metadata, not scraped content
# Docs: https://docs.firecrawl.dev/features/search
response = firecrawl_client.search(
response = await firecrawl_client.search(
query=query,
limit=limit
)
@ -446,7 +446,7 @@ async def web_extract_tool(
for url in urls:
try:
print(f" 📄 Scraping: {url}")
scrape_result = firecrawl_client.scrape(
scrape_result = await firecrawl_client.scrape(
url=url,
formats=formats
)
@ -703,7 +703,7 @@ async def web_crawl_tool(
# Use the crawl method which waits for completion automatically
try:
crawl_result = firecrawl_client.crawl(
crawl_result = await firecrawl_client.crawl(
url=url,
**crawl_params
)