mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
switch to asyncio
This commit is contained in:
parent
98321be8b0
commit
e91d9e839a
4 changed files with 261 additions and 666 deletions
707
batch_runner.py
707
batch_runner.py
|
|
@ -4,25 +4,25 @@ Batch Agent Runner
|
|||
|
||||
This module provides parallel batch processing capabilities for running the agent
|
||||
across multiple prompts from a dataset. It includes:
|
||||
- Dataset loading and batching
|
||||
- Parallel batch processing with multiprocessing
|
||||
- Dataset loading
|
||||
- Concurrent processing with asyncio (Producer-Consumer pattern)
|
||||
- Checkpointing for fault tolerance and resumption
|
||||
- Trajectory saving in the proper format (from/value pairs)
|
||||
- Tool usage statistics aggregation across all batches
|
||||
- Tool usage statistics aggregation across all prompts
|
||||
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
|
||||
- Configurable failure thresholds with automatic data consolidation
|
||||
|
||||
Usage:
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
|
||||
|
||||
# Resume an interrupted run
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
|
||||
|
||||
# Use a specific toolset distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --distribution=image_gen
|
||||
|
||||
# Configure tool failure thresholds
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run \\
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10
|
||||
"""
|
||||
|
||||
|
|
@ -30,10 +30,10 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from typing import List, Dict, Any, Optional, Tuple, Set
|
||||
from datetime import datetime
|
||||
from multiprocessing import Pool, Manager, Lock
|
||||
import traceback
|
||||
import re
|
||||
|
||||
|
|
@ -49,9 +49,6 @@ from toolset_distributions import (
|
|||
from safe_print import safe_print
|
||||
|
||||
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# Canonical names for the terminal tool (old & new implementations)
|
||||
_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"}
|
||||
|
||||
|
|
@ -295,10 +292,9 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
|||
return tool_stats
|
||||
|
||||
|
||||
def _process_single_prompt(
|
||||
async def _process_single_prompt(
|
||||
prompt_index: int,
|
||||
prompt_data: Dict[str, Any],
|
||||
batch_num: int,
|
||||
config: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -307,7 +303,6 @@ def _process_single_prompt(
|
|||
Args:
|
||||
prompt_index (int): Index of prompt in dataset
|
||||
prompt_data (Dict): Prompt data containing 'prompt' field
|
||||
batch_num (int): Batch number
|
||||
config (Dict): Configuration dict with agent parameters
|
||||
|
||||
Returns:
|
||||
|
|
@ -336,7 +331,7 @@ def _process_single_prompt(
|
|||
)
|
||||
|
||||
# Run the agent with task_id to ensure each task gets its own isolated VM
|
||||
result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
|
||||
result = await agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
|
||||
|
||||
# Extract tool usage statistics
|
||||
tool_stats = _extract_tool_stats(result["messages"])
|
||||
|
|
@ -365,7 +360,6 @@ def _process_single_prompt(
|
|||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": selected_toolsets,
|
||||
"metadata": {
|
||||
"batch_num": batch_num,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": config["model"]
|
||||
}
|
||||
|
|
@ -389,132 +383,38 @@ def _process_single_prompt(
|
|||
"tool_stats": {},
|
||||
"toolsets_used": [],
|
||||
"metadata": {
|
||||
"batch_num": batch_num,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
||||
async def worker(
|
||||
work_queue: asyncio.Queue,
|
||||
result_queue: asyncio.Queue,
|
||||
config: Dict[str, Any]
|
||||
):
|
||||
"""
|
||||
Worker function to process a single batch of prompts.
|
||||
|
||||
Args:
|
||||
args (Tuple): (batch_num, batch_data, output_dir, completed_prompts, config)
|
||||
|
||||
Returns:
|
||||
Dict: Batch results with statistics
|
||||
Consumer worker that processes prompts from the work queue.
|
||||
"""
|
||||
batch_num, batch_data, output_dir, completed_prompts_set, config = args
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
print(f"\n🔄 Batch {batch_num}: Starting ({len(batch_data)} prompts)")
|
||||
|
||||
# Output file for this batch
|
||||
batch_output_file = output_dir / f"batch_{batch_num}.jsonl"
|
||||
|
||||
# Filter out already completed prompts
|
||||
prompts_to_process = [
|
||||
(idx, data) for idx, data in batch_data
|
||||
if idx not in completed_prompts_set
|
||||
]
|
||||
|
||||
if not prompts_to_process:
|
||||
print(f"✅ Batch {batch_num}: Already completed (skipping)")
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"processed": 0,
|
||||
"skipped": len(batch_data),
|
||||
"tool_stats": {},
|
||||
"completed_prompts": []
|
||||
}
|
||||
|
||||
print(f" Processing {len(prompts_to_process)} prompts (skipping {len(batch_data) - len(prompts_to_process)} already completed)")
|
||||
|
||||
# Initialize aggregated stats for this batch
|
||||
batch_tool_stats = {}
|
||||
batch_profiling_stats = [] # Collect profiling stats from each prompt
|
||||
completed_in_batch = []
|
||||
all_tool_errors = [] # Track all tool errors in this batch
|
||||
exception_errors = [] # Track top-level exceptions
|
||||
|
||||
# Process each prompt sequentially in this batch
|
||||
for prompt_index, prompt_data in prompts_to_process:
|
||||
# Process the prompt
|
||||
result = _process_single_prompt(
|
||||
prompt_index,
|
||||
prompt_data,
|
||||
batch_num,
|
||||
config
|
||||
)
|
||||
|
||||
# Track tool errors from the conversation
|
||||
if result.get("tool_errors"):
|
||||
for tool_error in result["tool_errors"]:
|
||||
all_tool_errors.append({
|
||||
"prompt_index": prompt_index,
|
||||
"tool_name": tool_error["tool_name"],
|
||||
"error_message": tool_error["error_message"],
|
||||
"full_content": tool_error.get("full_content", ""),
|
||||
"error_type": tool_error.get("error_type", "Other")
|
||||
})
|
||||
|
||||
# Track top-level exceptions (not tool errors)
|
||||
if not result["success"]:
|
||||
exception_errors.append({
|
||||
"prompt_index": prompt_index,
|
||||
"error": result.get("error", "Unknown error"),
|
||||
"traceback": result.get("traceback", "")
|
||||
})
|
||||
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
|
||||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
"conversations": result["trajectory"],
|
||||
"metadata": result["metadata"],
|
||||
"completed": result["completed"],
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"]
|
||||
}
|
||||
|
||||
# Append to batch output file
|
||||
with open(batch_output_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
# Aggregate tool statistics
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in batch_tool_stats:
|
||||
batch_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
batch_tool_stats[tool_name]["count"] += stats["count"]
|
||||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Collect profiling statistics
|
||||
if result.get("profiling_stats"):
|
||||
batch_profiling_stats.append(result["profiling_stats"])
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
return {
|
||||
"batch_num": batch_num,
|
||||
"processed": len(prompts_to_process),
|
||||
"skipped": len(batch_data) - len(prompts_to_process),
|
||||
"tool_stats": batch_tool_stats,
|
||||
"profiling_stats": batch_profiling_stats,
|
||||
"completed_prompts": completed_in_batch,
|
||||
"tool_errors": all_tool_errors,
|
||||
"exception_errors": exception_errors
|
||||
}
|
||||
while True:
|
||||
try:
|
||||
task = await work_queue.get()
|
||||
if task is None:
|
||||
# Sentinel to stop worker
|
||||
work_queue.task_done()
|
||||
break
|
||||
|
||||
prompt_index, prompt_data = task
|
||||
|
||||
result = await _process_single_prompt(prompt_index, prompt_data, config)
|
||||
|
||||
await result_queue.put(result)
|
||||
work_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in worker: {e}")
|
||||
if 'task' in locals() and task is not None:
|
||||
work_queue.task_done()
|
||||
|
||||
|
||||
class BatchRunner:
|
||||
|
|
@ -525,7 +425,6 @@ class BatchRunner:
|
|||
def __init__(
|
||||
self,
|
||||
dataset_file: str,
|
||||
batch_size: int,
|
||||
run_name: str,
|
||||
distribution: str = "default",
|
||||
max_iterations: int = 10,
|
||||
|
|
@ -546,14 +445,13 @@ class BatchRunner:
|
|||
|
||||
Args:
|
||||
dataset_file (str): Path to the dataset JSONL file with 'prompt' field
|
||||
batch_size (int): Number of prompts per batch
|
||||
run_name (str): Name for this run (used for checkpointing and output)
|
||||
distribution (str): Toolset distribution to use (default: "default")
|
||||
max_iterations (int): Max iterations per agent run
|
||||
base_url (str): Base URL for model API
|
||||
api_key (str): API key for model
|
||||
model (str): Model name to use
|
||||
num_workers (int): Number of parallel workers
|
||||
num_workers (int): Number of parallel workers (default: 4)
|
||||
verbose (bool): Enable verbose logging
|
||||
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
|
||||
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
|
||||
|
|
@ -563,7 +461,6 @@ class BatchRunner:
|
|||
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
self.run_name = run_name
|
||||
self.distribution = distribution
|
||||
self.max_iterations = max_iterations
|
||||
|
|
@ -596,16 +493,14 @@ class BatchRunner:
|
|||
# Errors file
|
||||
self.errors_file = self.output_dir / "errors.json"
|
||||
|
||||
# Trajectories file
|
||||
self.trajectories_file = self.output_dir / "trajectories.jsonl"
|
||||
|
||||
# Load dataset
|
||||
self.dataset = self._load_dataset()
|
||||
|
||||
# Create batches
|
||||
self.batches = self._create_batches()
|
||||
|
||||
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
|
||||
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
|
||||
safe_print(f" Batch size: {self.batch_size}")
|
||||
safe_print(f" Total batches: {len(self.batches)}")
|
||||
safe_print(f" Run name: {self.run_name}")
|
||||
safe_print(f" Distribution: {self.distribution}")
|
||||
safe_print(f" Output directory: {self.output_dir}")
|
||||
|
|
@ -651,20 +546,6 @@ class BatchRunner:
|
|||
|
||||
return dataset
|
||||
|
||||
def _create_batches(self) -> List[List[Tuple[int, Dict[str, Any]]]]:
|
||||
"""
|
||||
Split dataset into batches with indices.
|
||||
|
||||
Returns:
|
||||
List of batches, where each batch is a list of (index, entry) tuples
|
||||
"""
|
||||
batches = []
|
||||
for i in range(0, len(self.dataset), self.batch_size):
|
||||
batch = [(idx, entry) for idx, entry in enumerate(self.dataset[i:i + self.batch_size], start=i)]
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
||||
|
||||
def _load_checkpoint(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Load checkpoint data if it exists.
|
||||
|
|
@ -676,7 +557,6 @@ class BatchRunner:
|
|||
return {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
|
|
@ -688,61 +568,34 @@ class BatchRunner:
|
|||
return {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
|
||||
def _save_checkpoint(self, checkpoint_data: Dict[str, Any]):
|
||||
"""
|
||||
Save checkpoint data.
|
||||
|
||||
Args:
|
||||
checkpoint_data (Dict): Checkpoint data to save
|
||||
lock (Lock): Optional lock for thread-safe access
|
||||
"""
|
||||
checkpoint_data["last_updated"] = datetime.now().isoformat()
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if lock:
|
||||
with lock:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
|
||||
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
|
||||
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None,
|
||||
profiling_stats_list: List[Dict] = None):
|
||||
def _save_final_stats(
|
||||
self,
|
||||
num_processed: int,
|
||||
tool_stats: Dict[str, Dict[str, int]],
|
||||
start_time: float,
|
||||
tool_errors_by_tool: Dict[str, List[Dict]],
|
||||
exception_errors: List[Dict],
|
||||
early_exit: bool = False,
|
||||
exit_reason: str = None,
|
||||
profiling_stats_list: List[Dict] = None
|
||||
):
|
||||
"""
|
||||
Consolidate batch data into trajectories.jsonl and save statistics.
|
||||
|
||||
Args:
|
||||
num_batches (int): Number of batches processed
|
||||
tool_stats (Dict): Aggregated tool statistics
|
||||
start_time (float): Start time of the run
|
||||
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
|
||||
exception_errors (List): Top-level exceptions
|
||||
early_exit (bool): Whether this is an early exit
|
||||
exit_reason (str): Reason for early exit
|
||||
profiling_stats_list (List[Dict]): List of profiling statistics from each conversation
|
||||
Save final statistics and errors.
|
||||
"""
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
|
||||
|
||||
entries_written = 0
|
||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||
for batch_num in range(num_batches):
|
||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||
if batch_file.exists():
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
outfile.write(line)
|
||||
entries_written += 1
|
||||
|
||||
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
|
||||
|
||||
# Calculate success rates for tool stats
|
||||
for tool_name in tool_stats:
|
||||
stats = tool_stats[tool_name]
|
||||
|
|
@ -794,17 +647,21 @@ class BatchRunner:
|
|||
# Aggregate profiling statistics if available
|
||||
aggregated_profiling_stats = None
|
||||
if profiling_stats_list:
|
||||
from profiling import aggregate_profiling_stats
|
||||
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
||||
try:
|
||||
from profiling import aggregate_profiling_stats, print_aggregated_statistics
|
||||
aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list)
|
||||
|
||||
# Display aggregated profiling statistics
|
||||
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Save final statistics (without detailed errors)
|
||||
# Save final statistics
|
||||
final_stats = {
|
||||
"run_name": self.run_name,
|
||||
"distribution": self.distribution,
|
||||
"total_prompts": len(self.dataset),
|
||||
"total_batches": len(self.batches),
|
||||
"batches_processed": num_batches,
|
||||
"batch_size": self.batch_size,
|
||||
"processed": num_processed,
|
||||
"model": self.model,
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"duration_seconds": round(time.time() - start_time, 2),
|
||||
|
|
@ -817,18 +674,9 @@ class BatchRunner:
|
|||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(final_stats, f, indent=2, ensure_ascii=False)
|
||||
|
||||
# Display aggregated profiling statistics
|
||||
if aggregated_profiling_stats:
|
||||
from profiling import print_aggregated_statistics
|
||||
print_aggregated_statistics(aggregated_profiling_stats, detailed=True)
|
||||
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
async def _run_async(self, resume: bool = False):
|
||||
"""
|
||||
Run the batch processing pipeline.
|
||||
|
||||
Args:
|
||||
resume (bool): Whether to resume from checkpoint
|
||||
Async implementation of the batch runner pipeline.
|
||||
"""
|
||||
print("\n" + "=" * 70)
|
||||
print("🚀 Starting Batch Processing")
|
||||
|
|
@ -838,15 +686,32 @@ class BatchRunner:
|
|||
checkpoint_data = self._load_checkpoint() if resume else {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
if resume and checkpoint_data.get("completed_prompts"):
|
||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
|
||||
# Prepare queues
|
||||
work_queue = asyncio.Queue()
|
||||
result_queue = asyncio.Queue()
|
||||
|
||||
# Enqueue prompts to process
|
||||
prompts_to_process = []
|
||||
for idx, entry in enumerate(self.dataset):
|
||||
if idx not in completed_prompts_set:
|
||||
prompts_to_process.append((idx, entry))
|
||||
work_queue.put_nowait((idx, entry))
|
||||
|
||||
total_to_process = len(prompts_to_process)
|
||||
if total_to_process == 0:
|
||||
print("✅ All prompts already completed.")
|
||||
return
|
||||
|
||||
# Worker configuration
|
||||
worker_config = {
|
||||
"distribution": self.distribution,
|
||||
"model": self.model,
|
||||
"max_iterations": self.max_iterations,
|
||||
|
|
@ -857,120 +722,119 @@ class BatchRunner:
|
|||
"log_prefix_chars": self.log_prefix_chars
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
# Start workers
|
||||
workers = []
|
||||
for _ in range(min(self.num_workers, total_to_process)):
|
||||
w = asyncio.create_task(worker(work_queue, result_queue, worker_config))
|
||||
workers.append(w)
|
||||
|
||||
print(f" Processing {total_to_process} prompts with {len(workers)} workers...")
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
# Aggregate statistics
|
||||
total_tool_stats = {}
|
||||
all_profiling_stats = [] # Collect all profiling stats for aggregation
|
||||
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
|
||||
all_profiling_stats = []
|
||||
tool_errors_by_tool = {}
|
||||
all_exception_errors = []
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
total_processed = len(completed_prompts_set)
|
||||
total_tool_errors = 0
|
||||
early_exit = False
|
||||
exit_reason = None
|
||||
|
||||
processed_count = 0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Process batches in parallel
|
||||
with Pool(processes=self.num_workers) as pool:
|
||||
# Create tasks for each batch
|
||||
tasks = [
|
||||
(
|
||||
batch_num,
|
||||
batch_data,
|
||||
str(self.output_dir), # Convert Path to string for pickling
|
||||
completed_prompts_set,
|
||||
config
|
||||
)
|
||||
for batch_num, batch_data in enumerate(self.batches)
|
||||
]
|
||||
|
||||
# Process batches in parallel and check tool failure threshold as results come in
|
||||
# imap_unordered allows parallel processing while getting results as they complete
|
||||
batch_num = 0
|
||||
try:
|
||||
for result in pool.imap_unordered(_process_batch_worker, tasks):
|
||||
# Update statistics
|
||||
all_completed_prompts.extend(result.get("completed_prompts", []))
|
||||
total_processed += result.get("processed", 0)
|
||||
|
||||
# Aggregate tool stats
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in total_tool_stats:
|
||||
total_tool_stats[tool_name] = {
|
||||
"count": 0,
|
||||
"success": 0,
|
||||
"failure": 0
|
||||
}
|
||||
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Collect profiling stats from this batch
|
||||
if result.get("profiling_stats"):
|
||||
all_profiling_stats.extend(result["profiling_stats"])
|
||||
|
||||
# Aggregate tool errors (keep k most recent per tool)
|
||||
for tool_error in result.get("tool_errors", []):
|
||||
tool_name = tool_error["tool_name"]
|
||||
if tool_name not in tool_errors_by_tool:
|
||||
tool_errors_by_tool[tool_name] = []
|
||||
|
||||
# Add error and keep only k most recent
|
||||
tool_errors_by_tool[tool_name].append(tool_error)
|
||||
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||
|
||||
total_tool_errors += 1
|
||||
|
||||
# Track exception errors
|
||||
all_exception_errors.extend(result.get("exception_errors", []))
|
||||
|
||||
# Check tool failure thresholds
|
||||
# Calculate total tool calls (not prompts)
|
||||
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
||||
|
||||
# Check absolute count threshold
|
||||
if total_tool_errors >= self.max_tool_failures:
|
||||
|
||||
# Process results as they arrive
|
||||
try:
|
||||
while processed_count < total_to_process:
|
||||
result = await result_queue.get()
|
||||
processed_count += 1
|
||||
|
||||
prompt_index = result["prompt_index"]
|
||||
|
||||
# Track exceptions
|
||||
if not result["success"]:
|
||||
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
|
||||
all_exception_errors.append({
|
||||
"prompt_index": prompt_index,
|
||||
"error": result.get("error", "Unknown error"),
|
||||
"traceback": result.get("traceback", "")
|
||||
})
|
||||
else:
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
|
||||
# Save trajectory immediately
|
||||
if result.get("trajectory"):
|
||||
trajectory_entry = {
|
||||
"prompt_index": prompt_index,
|
||||
"conversations": result["trajectory"],
|
||||
"metadata": result["metadata"],
|
||||
"completed": result["completed"],
|
||||
"api_calls": result["api_calls"],
|
||||
"toolsets_used": result["toolsets_used"]
|
||||
}
|
||||
with open(self.trajectories_file, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
|
||||
|
||||
# Aggregate tool stats
|
||||
for tool_name, stats in result.get("tool_stats", {}).items():
|
||||
if tool_name not in total_tool_stats:
|
||||
total_tool_stats[tool_name] = {"count": 0, "success": 0, "failure": 0}
|
||||
|
||||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Collect profiling stats
|
||||
if result.get("profiling_stats"):
|
||||
all_profiling_stats.append(result["profiling_stats"])
|
||||
|
||||
# Aggregate tool errors
|
||||
for tool_error in result.get("tool_errors", []):
|
||||
tool_name = tool_error["tool_name"]
|
||||
if tool_name not in tool_errors_by_tool:
|
||||
tool_errors_by_tool[tool_name] = []
|
||||
|
||||
tool_errors_by_tool[tool_name].append(tool_error)
|
||||
# Keep only k most recent
|
||||
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
|
||||
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
|
||||
|
||||
total_tool_errors += 1
|
||||
|
||||
# Update checkpoint
|
||||
completed_prompts_set.add(prompt_index)
|
||||
checkpoint_data["completed_prompts"] = list(completed_prompts_set)
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
|
||||
# Check failure thresholds
|
||||
total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values())
|
||||
|
||||
if total_tool_errors >= self.max_tool_failures:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
||||
break
|
||||
|
||||
if total_tool_calls >= self.min_tool_calls_for_rate:
|
||||
tool_failure_rate = total_tool_errors / total_tool_calls
|
||||
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
pool.terminate() # Stop all workers immediately
|
||||
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%})"
|
||||
break
|
||||
|
||||
# Check rate threshold (only if we have enough tool calls to trust the rate)
|
||||
if total_tool_calls >= self.min_tool_calls_for_rate:
|
||||
tool_failure_rate = total_tool_errors / total_tool_calls
|
||||
|
||||
if tool_failure_rate >= self.max_tool_failure_rate:
|
||||
early_exit = True
|
||||
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)"
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
pool.terminate() # Stop all workers immediately
|
||||
break
|
||||
|
||||
# Update checkpoint after each batch completes
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
|
||||
batch_num += 1
|
||||
except KeyboardInterrupt:
|
||||
safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]")
|
||||
pool.terminate()
|
||||
early_exit = True
|
||||
exit_reason = "Interrupted by user"
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
self._save_checkpoint(checkpoint_data)
|
||||
|
||||
# Consolidate data and save statistics
|
||||
num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
|
||||
self._consolidate_data(
|
||||
num_batches_processed,
|
||||
|
||||
except asyncio.CancelledError:
|
||||
early_exit = True
|
||||
exit_reason = "Run cancelled"
|
||||
finally:
|
||||
# Stop all workers
|
||||
for _ in range(len(workers)):
|
||||
work_queue.put_nowait(None)
|
||||
await asyncio.gather(*workers, return_exceptions=True)
|
||||
|
||||
if early_exit:
|
||||
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
|
||||
|
||||
# Save final statistics
|
||||
self._save_final_stats(
|
||||
processed_count,
|
||||
total_tool_stats,
|
||||
start_time,
|
||||
tool_errors_by_tool,
|
||||
|
|
@ -980,164 +844,28 @@ class BatchRunner:
|
|||
all_profiling_stats
|
||||
)
|
||||
|
||||
# Print summary
|
||||
# Summary output
|
||||
safe_print("\n" + "=" * 70)
|
||||
if early_exit:
|
||||
safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
|
||||
safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
|
||||
else:
|
||||
safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
|
||||
safe_print("=" * 70)
|
||||
|
||||
safe_print(f"✅ Total prompts processed: {total_processed}")
|
||||
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
|
||||
safe_print(f"✅ Total prompts processed: {processed_count}/{total_to_process}")
|
||||
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||
|
||||
# Tool error summary
|
||||
|
||||
if tool_errors_by_tool:
|
||||
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
|
||||
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
|
||||
safe_print("[red]-[/red]" * 70)
|
||||
|
||||
# Sort tools by error count
|
||||
sorted_tools = sorted(
|
||||
tool_errors_by_tool.items(),
|
||||
key=lambda x: len(x[1]),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for tool_name, errors in sorted_tools:
|
||||
# Count unique error messages
|
||||
unique_errors = {}
|
||||
for error in errors:
|
||||
error_msg = error["error_message"][:100] # Truncate for grouping
|
||||
if error_msg not in unique_errors:
|
||||
unique_errors[error_msg] = []
|
||||
unique_errors[error_msg].append(error)
|
||||
|
||||
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
|
||||
|
||||
# Show up to 3 most recent unique error types
|
||||
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
|
||||
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
|
||||
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
|
||||
|
||||
# Show one example with prompt index and full content prefix
|
||||
example = instances[-1] # Most recent
|
||||
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
|
||||
|
||||
# Show full content prefix (first 200 chars)
|
||||
full_content = example.get('full_content', '')
|
||||
if full_content and full_content != error_preview:
|
||||
content_preview = full_content[:200]
|
||||
if len(full_content) > 200:
|
||||
content_preview += "..."
|
||||
# Show with prefix indicator
|
||||
safe_print(f" [dim]Content: {content_preview}[/dim]")
|
||||
|
||||
if len(unique_errors) > 3:
|
||||
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
|
||||
|
||||
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
|
||||
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
|
||||
|
||||
# Exception errors
|
||||
if all_exception_errors:
|
||||
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
|
||||
safe_print("[red]-[/red]" * 70)
|
||||
for error in all_exception_errors[:self.keep_recent_errors]:
|
||||
error_msg = error["error"]
|
||||
error_preview = error_msg[:150]
|
||||
if len(error_msg) > 150:
|
||||
error_preview += "..."
|
||||
safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]")
|
||||
|
||||
# Show traceback prefix if available
|
||||
traceback_text = error.get("traceback", "")
|
||||
if traceback_text:
|
||||
# Show last 3 lines of traceback for context
|
||||
tb_lines = traceback_text.strip().split('\n')
|
||||
relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines
|
||||
for line in relevant_lines:
|
||||
safe_print(f" [dim]{line}[/dim]")
|
||||
|
||||
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
|
||||
safe_print("-" * 70)
|
||||
|
||||
if total_tool_stats:
|
||||
# Sort by count descending
|
||||
sorted_tools = sorted(
|
||||
total_tool_stats.items(),
|
||||
key=lambda x: x[1]["count"],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
|
||||
safe_print("-" * 70)
|
||||
for tool_name, stats in sorted_tools:
|
||||
safe_print(
|
||||
f"{tool_name:<25} "
|
||||
f"{stats['count']:<10} "
|
||||
f"{stats['success']:<10} "
|
||||
f"{stats['failure']:<10} "
|
||||
f"{stats.get('success_rate', 0):.1f}%"
|
||||
)
|
||||
else:
|
||||
safe_print("No tool calls were made during this run.")
|
||||
|
||||
# Display failure type breakdown for tools with failures
|
||||
if tool_errors_by_tool:
|
||||
safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]")
|
||||
safe_print("-" * 70)
|
||||
|
||||
# Sort tools by total error count
|
||||
sorted_tools = sorted(
|
||||
tool_errors_by_tool.items(),
|
||||
key=lambda x: len(x[1]),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
for tool_name, errors in sorted_tools:
|
||||
# Count failure types for this tool
|
||||
failure_types = {}
|
||||
for error in errors:
|
||||
error_type = error.get("error_type", "Other")
|
||||
if error_type not in failure_types:
|
||||
failure_types[error_type] = 0
|
||||
failure_types[error_type] += 1
|
||||
|
||||
# Display tool name and total failures
|
||||
total_failures = len(errors)
|
||||
safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):")
|
||||
|
||||
# Sort failure types by count
|
||||
sorted_types = sorted(
|
||||
failure_types.items(),
|
||||
key=lambda x: x[1],
|
||||
reverse=True
|
||||
)
|
||||
|
||||
# Display each failure type with count and percentage
|
||||
for failure_type, count in sorted_types:
|
||||
percentage = (count / total_failures) * 100
|
||||
safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)")
|
||||
|
||||
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total[/bold red]")
|
||||
# Simplified error printing here, full detail is in json
|
||||
for tool_name, errors in tool_errors_by_tool.items():
|
||||
safe_print(f" {tool_name}: {len(errors)} errors")
|
||||
|
||||
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
|
||||
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||
safe_print(f" - Statistics: {self.stats_file.name}")
|
||||
safe_print(f" - Errors: {self.errors_file.name}")
|
||||
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
|
||||
|
||||
if early_exit:
|
||||
safe_print(f"\n[bold yellow]ℹ️ Run was stopped early due to tool failures.[/bold yellow]")
|
||||
safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]")
|
||||
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
|
||||
def run(self, resume: bool = False):
|
||||
"""
|
||||
Run the batch processing pipeline (sync wrapper).
|
||||
"""
|
||||
asyncio.run(self._run_async(resume))
|
||||
|
||||
|
||||
def main(
|
||||
dataset_file: str = None,
|
||||
batch_size: int = None,
|
||||
run_name: str = None,
|
||||
distribution: str = "default",
|
||||
model: str = "claude-opus-4-20250514",
|
||||
|
|
@ -1160,7 +888,6 @@ def main(
|
|||
|
||||
Args:
|
||||
dataset_file (str): Path to JSONL file with 'prompt' field in each entry
|
||||
batch_size (int): Number of prompts per batch
|
||||
run_name (str): Name for this run (used for output and checkpointing)
|
||||
distribution (str): Toolset distribution to use (default: "default")
|
||||
model (str): Model name to use (default: "claude-opus-4-20250514")
|
||||
|
|
@ -1180,24 +907,13 @@ def main(
|
|||
|
||||
Examples:
|
||||
# Basic usage
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run
|
||||
|
||||
# Resume interrupted run
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=my_run --resume
|
||||
|
||||
# Use specific distribution
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=image_test --distribution=image_gen
|
||||
|
||||
# With ephemeral system prompt (not saved to dataset)
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
|
||||
|
||||
# With custom tool failure thresholds
|
||||
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
|
||||
--max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10
|
||||
|
||||
# List available distributions
|
||||
python batch_runner.py --list_distributions
|
||||
python batch_runner.py --dataset_file=data.jsonl --run_name=image_test --distribution=image_gen
|
||||
"""
|
||||
# Handle list distributions
|
||||
if list_distributions:
|
||||
|
|
@ -1209,10 +925,6 @@ def main(
|
|||
all_dists = get_all_dists()
|
||||
for dist_name in sorted(all_dists.keys()):
|
||||
print_distribution_info(dist_name)
|
||||
|
||||
print("\n💡 Usage:")
|
||||
print(" python batch_runner.py --dataset_file=data.jsonl --batch_size=10 \\")
|
||||
print(" --run_name=my_run --distribution=<name>")
|
||||
return
|
||||
|
||||
# Validate required arguments
|
||||
|
|
@ -1220,10 +932,6 @@ def main(
|
|||
print("❌ Error: --dataset_file is required")
|
||||
return
|
||||
|
||||
if not batch_size or batch_size < 1:
|
||||
print("❌ Error: --batch_size must be a positive integer")
|
||||
return
|
||||
|
||||
if not run_name:
|
||||
print("❌ Error: --run_name is required")
|
||||
return
|
||||
|
|
@ -1232,7 +940,6 @@ def main(
|
|||
try:
|
||||
runner = BatchRunner(
|
||||
dataset_file=dataset_file,
|
||||
batch_size=batch_size,
|
||||
run_name=run_name,
|
||||
distribution=distribution,
|
||||
max_iterations=max_turns,
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ Usage:
|
|||
web_tools = get_tool_definitions(enabled_toolsets=['web_tools'])
|
||||
|
||||
# Handle function calls from model
|
||||
result = handle_function_call("web_search", {"query": "Python"})
|
||||
result = await handle_function_call("web_search", {"query": "Python"})
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -439,7 +439,7 @@ def get_tool_definitions(
|
|||
|
||||
return filtered_tools
|
||||
|
||||
def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_web_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for web tools.
|
||||
|
||||
|
|
@ -454,25 +454,25 @@ def handle_web_function_call(function_name: str, function_args: Dict[str, Any])
|
|||
query = function_args.get("query", "")
|
||||
# Always use fixed limit of 5
|
||||
limit = 5
|
||||
return web_search_tool(query, limit)
|
||||
return await web_search_tool(query, limit)
|
||||
|
||||
elif function_name == "web_extract":
|
||||
urls = function_args.get("urls", [])
|
||||
# Limit URLs to prevent abuse
|
||||
urls = urls[:5] if isinstance(urls, list) else []
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_extract_tool(urls, "markdown"))
|
||||
# Run async function
|
||||
return await web_extract_tool(urls, "markdown")
|
||||
|
||||
elif function_name == "web_crawl":
|
||||
url = function_args.get("url", "")
|
||||
instructions = function_args.get("instructions")
|
||||
# Run async function in event loop
|
||||
return asyncio.run(web_crawl_tool(url, instructions, "basic"))
|
||||
# Run async function
|
||||
return await web_crawl_tool(url, instructions, "basic")
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown web function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
async def handle_terminal_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Handle function calls for terminal tools.
|
||||
|
||||
|
|
@ -489,13 +489,20 @@ def handle_terminal_function_call(function_name: str, function_args: Dict[str, A
|
|||
background = function_args.get("background", False)
|
||||
timeout = function_args.get("timeout")
|
||||
|
||||
return simple_terminal_tool(command=command, background=background, timeout=timeout, task_id=task_id)
|
||||
# Run sync terminal tool in a thread to avoid blocking
|
||||
return await asyncio.to_thread(
|
||||
simple_terminal_tool,
|
||||
command=command,
|
||||
background=background,
|
||||
timeout=timeout,
|
||||
task_id=task_id
|
||||
)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown terminal function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_vision_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for vision tools.
|
||||
|
||||
|
|
@ -512,14 +519,14 @@ def handle_vision_function_call(function_name: str, function_args: Dict[str, Any
|
|||
|
||||
full_prompt = f"Fully describe and explain everything about this image, then answer the following question:\n\n{question}"
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash"))
|
||||
# Run async function
|
||||
return await vision_analyze_tool(image_url, full_prompt, "gemini-2.5-flash")
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown vision function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_moa_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for Mixture-of-Agents tools.
|
||||
|
||||
|
|
@ -536,14 +543,14 @@ def handle_moa_function_call(function_name: str, function_args: Dict[str, Any])
|
|||
if not user_prompt:
|
||||
return json.dumps({"error": "user_prompt is required for MoA processing"}, ensure_ascii=False)
|
||||
|
||||
# Run async function in event loop
|
||||
return asyncio.run(mixture_of_agents_tool(user_prompt=user_prompt))
|
||||
# Run async function
|
||||
return await mixture_of_agents_tool(user_prompt=user_prompt)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown MoA function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
async def handle_image_function_call(function_name: str, function_args: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Handle function calls for image generation tools.
|
||||
|
||||
|
|
@ -572,21 +579,8 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
|||
allow_nsfw_images = True
|
||||
seed = None
|
||||
|
||||
# Run async function in event loop with proper handling for multiprocessing
|
||||
try:
|
||||
# Try to get existing event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
# If closed, create a new one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
# No event loop in current thread, create one
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Run the coroutine in the event loop
|
||||
result = loop.run_until_complete(image_generate_tool(
|
||||
# Run async function
|
||||
return await image_generate_tool(
|
||||
prompt=prompt,
|
||||
image_size=image_size,
|
||||
num_inference_steps=num_inference_steps,
|
||||
|
|
@ -597,15 +591,13 @@ def handle_image_function_call(function_name: str, function_args: Dict[str, Any]
|
|||
acceleration=acceleration,
|
||||
allow_nsfw_images=allow_nsfw_images,
|
||||
seed=seed
|
||||
))
|
||||
|
||||
return result
|
||||
)
|
||||
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown image generation function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
async def handle_function_call(function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Main function call dispatcher that routes calls to appropriate toolsets.
|
||||
|
||||
|
|
@ -627,23 +619,23 @@ def handle_function_call(function_name: str, function_args: Dict[str, Any], task
|
|||
try:
|
||||
# Route web tools
|
||||
if function_name in ["web_search", "web_extract", "web_crawl"]:
|
||||
return handle_web_function_call(function_name, function_args)
|
||||
return await handle_web_function_call(function_name, function_args)
|
||||
|
||||
# Route terminal tools
|
||||
elif function_name in ["terminal"]:
|
||||
return handle_terminal_function_call(function_name, function_args, task_id)
|
||||
return await handle_terminal_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route vision tools
|
||||
elif function_name in ["vision_analyze"]:
|
||||
return handle_vision_function_call(function_name, function_args)
|
||||
return await handle_vision_function_call(function_name, function_args)
|
||||
|
||||
# Route MoA tools
|
||||
elif function_name in ["mixture_of_agents"]:
|
||||
return handle_moa_function_call(function_name, function_args)
|
||||
return await handle_moa_function_call(function_name, function_args)
|
||||
|
||||
# Route image generation tools
|
||||
elif function_name in ["image_generate"]:
|
||||
return handle_image_function_call(function_name, function_args)
|
||||
return await handle_image_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
|
|
@ -773,4 +765,4 @@ if __name__ == "__main__":
|
|||
|
||||
if "terminal" in all_tool_names:
|
||||
no_terminal = get_tool_definitions(disabled_tools=["terminal"])
|
||||
print(f" All except terminal: {len(no_terminal)} tools")
|
||||
print(f" All except terminal: {len(no_terminal)} tools")
|
||||
134
run_agent.py
134
run_agent.py
|
|
@ -24,121 +24,17 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import List, Dict, Any, Optional
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
import fire
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from rich import print
|
||||
|
||||
from prokletor.formatters.hermes_formatter import HermesToolFormatterWithReasoning
|
||||
|
||||
class SyncCustomToolCompletions:
|
||||
def __init__(self, completions, formatter):
|
||||
self._completions = completions
|
||||
self._formatter = formatter
|
||||
|
||||
def create(self, *args, messages, tools=None, **kwargs):
|
||||
if not tools:
|
||||
return self._completions.create(*args, messages=messages, **kwargs)
|
||||
|
||||
# 1. Format system message with tools
|
||||
system_prompt = self._formatter.format_system_message(tools)
|
||||
|
||||
new_messages = list(messages)
|
||||
if new_messages and new_messages[0]["role"] == "system":
|
||||
# Append to existing system message
|
||||
existing_content = new_messages[0]["content"]
|
||||
if isinstance(existing_content, str):
|
||||
new_messages[0] = {
|
||||
"role": "system",
|
||||
"content": existing_content + "\n\n" + system_prompt
|
||||
}
|
||||
else:
|
||||
# Insert new system message
|
||||
new_messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# 2. Call the API without the 'tools' parameter
|
||||
kwargs.pop("tool_choice", None)
|
||||
|
||||
# Process messages (e.g. convert roles and add reasoning prompt)
|
||||
# use_tool_role=False for API compatibility
|
||||
new_messages = self._formatter.process_messages(new_messages, use_tool_role=False)
|
||||
|
||||
response = self._completions.create(
|
||||
*args,
|
||||
messages=new_messages,
|
||||
# tools=tools, # Do NOT pass tools to the model
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# 3. Parse the response
|
||||
choice = response.choices[0]
|
||||
if choice.message.content:
|
||||
tool_calls = self._formatter.parse_response(choice.message.content, tools=tools)
|
||||
if tool_calls:
|
||||
choice.message.tool_calls = tool_calls
|
||||
|
||||
# Clean the content if the formatter supports it
|
||||
if hasattr(self._formatter, "extract_text_from_content"):
|
||||
cleaned_content = self._formatter.extract_text_from_content(choice.message.content)
|
||||
choice.message.content = cleaned_content
|
||||
|
||||
if not choice.message.content:
|
||||
choice.message.content = None
|
||||
|
||||
return response
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._completions, name)
|
||||
|
||||
class SyncCustomChat:
|
||||
def __init__(self, chat, formatter):
|
||||
self._chat = chat
|
||||
self.completions = SyncCustomToolCompletions(chat.completions, formatter)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._chat, name)
|
||||
|
||||
class SyncHermesToolClientWithReasoning:
|
||||
def __init__(self, client):
|
||||
self._client = client
|
||||
self.formatter = HermesToolFormatterWithReasoning()
|
||||
self.chat = SyncCustomChat(client.chat, self.formatter)
|
||||
|
||||
def format(self, messages: List[Dict[str, Any]], tools: List[Dict[str, Any]], use_tool_role: bool = True) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Format messages and tools into the Hermes XML format.
|
||||
Useful for debugging or manual inspection of what will be sent to the model.
|
||||
"""
|
||||
# 1. Format system message with tools
|
||||
system_prompt = self.formatter.format_system_message(tools)
|
||||
|
||||
new_messages = list(messages)
|
||||
if new_messages and new_messages[0]["role"] == "system":
|
||||
# Append to existing system message
|
||||
existing_content = new_messages[0]["content"]
|
||||
if isinstance(existing_content, str):
|
||||
new_messages[0] = {
|
||||
"role": "system",
|
||||
"content": existing_content + "\n\n" + system_prompt
|
||||
}
|
||||
else:
|
||||
# Insert new system message
|
||||
new_messages.insert(0, {
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# 2. Process messages (convert tool calls and results to XML)
|
||||
return self.formatter.process_messages(new_messages, use_tool_role=use_tool_role)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._client, name)
|
||||
from prokletor.clients.hermes import HermesToolClientWithReasoning
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
|
|
@ -242,10 +138,10 @@ class AIAgent:
|
|||
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
|
||||
|
||||
try:
|
||||
oai_client = OpenAI(**client_kwargs)
|
||||
oai_client = AsyncOpenAI(**client_kwargs)
|
||||
# self.client = oai_client
|
||||
self.client = SyncHermesToolClientWithReasoning(oai_client)
|
||||
print(f"🧠 Wrapped OpenAI client with SyncHermesToolClientWithReasoning")
|
||||
self.client = HermesToolClientWithReasoning(oai_client)
|
||||
print(f"🧠 Wrapped OpenAI client with AsyncHermesToolClientWithReasoning")
|
||||
|
||||
print(f"🤖 AI Agent initialized with model: {self.model}")
|
||||
if base_url:
|
||||
|
|
@ -494,7 +390,7 @@ class AIAgent:
|
|||
except Exception as e:
|
||||
print(f"⚠️ Failed to save trajectory: {e}")
|
||||
|
||||
def run_conversation(
|
||||
async def run_conversation(
|
||||
self,
|
||||
user_message: str,
|
||||
system_message: str = None,
|
||||
|
|
@ -567,7 +463,7 @@ class AIAgent:
|
|||
|
||||
# Make API call with tools
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
response = await self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=api_messages,
|
||||
tools=self.tools if self.tools else None,
|
||||
|
|
@ -603,7 +499,7 @@ class AIAgent:
|
|||
print(f"⚠️ OpenAI-compatible API call failed (attempt {retry_count}/{max_retries}): {str(api_error)[:100]}")
|
||||
print(f"⏳ Retrying in {wait_time}s...")
|
||||
logging.warning(f"API retry {retry_count}/{max_retries} after error: {api_error}")
|
||||
time.sleep(wait_time)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
if response is None:
|
||||
raise last_api_error if last_api_error else RuntimeError("OpenAI-compatible API call failed without a response")
|
||||
|
|
@ -694,7 +590,7 @@ class AIAgent:
|
|||
tool_start_time = time.time()
|
||||
|
||||
# Execute the tool with task_id to isolate VMs between concurrent tasks
|
||||
function_result = handle_function_call(function_name, function_args, effective_task_id)
|
||||
function_result = await handle_function_call(function_name, function_args, effective_task_id)
|
||||
|
||||
tool_duration = time.time() - tool_start_time
|
||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||
|
|
@ -720,7 +616,7 @@ class AIAgent:
|
|||
|
||||
# Delay between tool calls
|
||||
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
|
||||
time.sleep(self.tool_delay)
|
||||
await asyncio.sleep(self.tool_delay)
|
||||
|
||||
# Continue loop for next response
|
||||
continue
|
||||
|
|
@ -838,7 +734,7 @@ class AIAgent:
|
|||
|
||||
# Clean up VM for this task after conversation completes
|
||||
try:
|
||||
cleanup_vm(effective_task_id)
|
||||
await asyncio.to_thread(cleanup_vm, effective_task_id)
|
||||
except Exception as e:
|
||||
if self.verbose_logging:
|
||||
logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}")
|
||||
|
|
@ -854,7 +750,7 @@ class AIAgent:
|
|||
"profiling_stats": profiling_stats
|
||||
}
|
||||
|
||||
def chat(self, message: str) -> str:
|
||||
async def chat(self, message: str) -> str:
|
||||
"""
|
||||
Simple chat interface that returns just the final response.
|
||||
|
||||
|
|
@ -864,7 +760,7 @@ class AIAgent:
|
|||
Returns:
|
||||
str: Final assistant response
|
||||
"""
|
||||
result = self.run_conversation(message)
|
||||
result = await self.run_conversation(message)
|
||||
return result["final_response"]
|
||||
|
||||
|
||||
|
|
@ -1037,7 +933,7 @@ def main(
|
|||
print("\n" + "=" * 50)
|
||||
|
||||
# Run conversation
|
||||
result = agent.run_conversation(user_query)
|
||||
result = asyncio.run(agent.run_conversation(user_query))
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("📋 CONVERSATION SUMMARY")
|
||||
|
|
|
|||
|
|
@ -48,11 +48,11 @@ import uuid
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional
|
||||
from firecrawl import Firecrawl
|
||||
from firecrawl import AsyncFirecrawl
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# Initialize Firecrawl client once at module level
|
||||
firecrawl_client = Firecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
||||
firecrawl_client = AsyncFirecrawl(api_key=os.getenv("FIRECRAWL_API_KEY"))
|
||||
|
||||
# Initialize Nous Research API client for LLM processing (async)
|
||||
nous_client = AsyncOpenAI(
|
||||
|
|
@ -261,7 +261,7 @@ def clean_base64_images(text: str) -> str:
|
|||
return cleaned_text
|
||||
|
||||
|
||||
def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
async def web_search_tool(query: str, limit: int = 5) -> str:
|
||||
"""
|
||||
Search the web for information using available search API backend.
|
||||
|
||||
|
|
@ -312,7 +312,7 @@ def web_search_tool(query: str, limit: int = 5) -> str:
|
|||
# Use Firecrawl's v2 search functionality WITHOUT scraping
|
||||
# We only want search result metadata, not scraped content
|
||||
# Docs: https://docs.firecrawl.dev/features/search
|
||||
response = firecrawl_client.search(
|
||||
response = await firecrawl_client.search(
|
||||
query=query,
|
||||
limit=limit
|
||||
)
|
||||
|
|
@ -446,7 +446,7 @@ async def web_extract_tool(
|
|||
for url in urls:
|
||||
try:
|
||||
print(f" 📄 Scraping: {url}")
|
||||
scrape_result = firecrawl_client.scrape(
|
||||
scrape_result = await firecrawl_client.scrape(
|
||||
url=url,
|
||||
formats=formats
|
||||
)
|
||||
|
|
@ -703,7 +703,7 @@ async def web_crawl_tool(
|
|||
|
||||
# Use the crawl method which waits for completion automatically
|
||||
try:
|
||||
crawl_result = firecrawl_client.crawl(
|
||||
crawl_result = await firecrawl_client.crawl(
|
||||
url=url,
|
||||
**crawl_params
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue