add tracking for cluster failurse

This commit is contained in:
hjc-puro 2025-11-15 00:01:19 -05:00
parent 0c618482c4
commit 31c733383b
2 changed files with 409 additions and 116 deletions

View file

@ -9,15 +9,21 @@ across multiple prompts from a dataset. It includes:
- Checkpointing for fault tolerance and resumption
- Trajectory saving in the proper format (from/value pairs)
- Tool usage statistics aggregation across all batches
- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors)
- Configurable failure thresholds with automatic data consolidation
Usage:
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
# Resume an interrupted run
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume
# Use a specific toolset distribution
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen
# Configure tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3
"""
import json
@ -29,22 +35,94 @@ from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from multiprocessing import Pool, Manager, Lock
import traceback
import re
import fire
from run_agent import AIAgent
from toolset_distributions import (
get_distribution,
list_distributions,
get_distribution,
list_distributions,
sample_toolsets_from_distribution,
validate_distribution
)
from safe_print import safe_print
# Global configuration for worker processes
_WORKER_CONFIG = {}
def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Extract tool errors from message history with tool names.
Args:
messages (List[Dict]): Message history
Returns:
List[Dict]: List of tool errors with tool name, error message, and context
"""
tool_errors = []
tool_calls_map = {} # Map tool_call_id to tool name
for msg in messages:
# Track tool calls from assistant messages
if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]:
for tool_call in msg["tool_calls"]:
tool_name = tool_call["function"]["name"]
tool_call_id = tool_call["id"]
tool_calls_map[tool_call_id] = tool_name
# Check tool responses for errors
elif msg["role"] == "tool":
tool_call_id = msg.get("tool_call_id", "")
content = msg.get("content", "")
# Determine if tool call had an error
has_error = False
error_msg = None
try:
content_json = json.loads(content) if isinstance(content, str) else content
if isinstance(content_json, dict):
# Check if error field exists AND has a non-null value
if "error" in content_json and content_json["error"] is not None:
has_error = True
error_msg = str(content_json["error"])
# Special handling for terminal tool responses
if "content" in content_json and isinstance(content_json["content"], dict):
inner_content = content_json["content"]
if inner_content.get("error") is not None or inner_content.get("exit_code", 0) != 0:
has_error = True
error_msg = inner_content.get("error") or f"Exit code: {inner_content.get('exit_code')}"
# Check for "success": false pattern
if content_json.get("success") is False:
has_error = True
if not error_msg:
error_msg = str(content_json.get("message", content_json.get("error", "Unknown error")))
except:
# If not JSON, check if content explicitly states an error
if content.strip().lower().startswith("error:"):
has_error = True
error_msg = content.strip()
# Record error if found
if has_error and tool_call_id in tool_calls_map:
tool_name = tool_calls_map[tool_call_id]
tool_errors.append({
"tool_name": tool_name,
"error_message": error_msg or "Unknown error",
"full_content": content[:500] # Keep first 500 chars of full response
})
return tool_errors
def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]:
"""
Extract tool usage statistics from message history.
@ -170,22 +248,26 @@ def _process_single_prompt(
# Run the agent with task_id to ensure each task gets its own isolated VM
result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}")
# Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"])
# Extract tool errors from conversation
tool_errors = _extract_tool_errors_from_messages(result["messages"])
# Convert to trajectory format (using existing method)
trajectory = agent._convert_to_trajectory_format(
result["messages"],
prompt,
result["completed"]
)
return {
"success": True,
"prompt_index": prompt_index,
"trajectory": trajectory,
"tool_stats": tool_stats,
"tool_errors": tool_errors,
"completed": result["completed"],
"api_calls": result["api_calls"],
"toolsets_used": selected_toolsets,
@ -197,14 +279,18 @@ def _process_single_prompt(
}
except Exception as e:
print(f"❌ Error processing prompt {prompt_index}: {e}")
error_msg = str(e)
tb = traceback.format_exc()
safe_print(f"[bold red]❌ Error processing prompt {prompt_index}:[/bold red] {error_msg}")
if config.get("verbose"):
traceback.print_exc()
safe_print(tb)
return {
"success": False,
"prompt_index": prompt_index,
"error": str(e),
"error": error_msg,
"traceback": tb,
"tool_errors": [],
"trajectory": None,
"tool_stats": {},
"toolsets_used": [],
@ -254,7 +340,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
# Initialize aggregated stats for this batch
batch_tool_stats = {}
completed_in_batch = []
all_tool_errors = [] # Track all tool errors in this batch
exception_errors = [] # Track top-level exceptions
# Process each prompt sequentially in this batch
for prompt_index, prompt_data in prompts_to_process:
# Process the prompt
@ -264,7 +352,26 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
batch_num,
config
)
# Track tool errors from the conversation
if result.get("tool_errors"):
for tool_error in result["tool_errors"]:
all_tool_errors.append({
"prompt_index": prompt_index,
"tool_name": tool_error["tool_name"],
"error_message": tool_error["error_message"],
"full_content": tool_error.get("full_content", "")
})
# Track top-level exceptions (not tool errors)
if not result["success"]:
exception_errors.append({
"prompt_index": prompt_index,
"error": result.get("error", "Unknown error"),
"traceback": result.get("traceback", "")
})
safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}")
# Save trajectory if successful
if result["success"] and result["trajectory"]:
trajectory_entry = {
@ -275,7 +382,7 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
"api_calls": result["api_calls"],
"toolsets_used": result["toolsets_used"]
}
# Append to batch output file
with open(batch_output_file, 'a', encoding='utf-8') as f:
f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n")
@ -303,7 +410,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
"processed": len(prompts_to_process),
"skipped": len(batch_data) - len(prompts_to_process),
"tool_stats": batch_tool_stats,
"completed_prompts": completed_in_batch
"completed_prompts": completed_in_batch,
"tool_errors": all_tool_errors,
"exception_errors": exception_errors
}
@ -326,6 +435,9 @@ class BatchRunner:
verbose: bool = False,
ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100,
max_tool_failures: int = 10,
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
):
"""
Initialize the batch runner.
@ -343,6 +455,9 @@ class BatchRunner:
verbose (bool): Enable verbose logging
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
"""
self.dataset_file = Path(dataset_file)
self.batch_size = batch_size
@ -356,6 +471,9 @@ class BatchRunner:
self.verbose = verbose
self.ephemeral_system_prompt = ephemeral_system_prompt
self.log_prefix_chars = log_prefix_chars
self.max_tool_failures = max_tool_failures
self.max_tool_failure_rate = max_tool_failure_rate
self.keep_recent_errors = keep_recent_errors
# Validate distribution
if not validate_distribution(distribution):
@ -377,17 +495,21 @@ class BatchRunner:
# Create batches
self.batches = self._create_batches()
print(f"📊 Batch Runner Initialized")
print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
print(f" Batch size: {self.batch_size}")
print(f" Total batches: {len(self.batches)}")
print(f" Run name: {self.run_name}")
print(f" Distribution: {self.distribution}")
print(f" Output directory: {self.output_dir}")
print(f" Workers: {self.num_workers}")
safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]")
safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)")
safe_print(f" Batch size: {self.batch_size}")
safe_print(f" Total batches: {len(self.batches)}")
safe_print(f" Run name: {self.run_name}")
safe_print(f" Distribution: {self.distribution}")
safe_print(f" Output directory: {self.output_dir}")
safe_print(f" Workers: {self.num_workers}")
safe_print(f" [yellow]Tool failure limits:[/yellow]")
safe_print(f" Max failures: {self.max_tool_failures}")
safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}")
safe_print(f" Keep recent errors: {self.keep_recent_errors}")
if self.ephemeral_system_prompt:
prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt
print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
safe_print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'")
def _load_dataset(self) -> List[Dict[str, Any]]:
"""
@ -465,13 +587,13 @@ class BatchRunner:
def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None):
"""
Save checkpoint data.
Args:
checkpoint_data (Dict): Checkpoint data to save
lock (Lock): Optional lock for thread-safe access
"""
checkpoint_data["last_updated"] = datetime.now().isoformat()
if lock:
with lock:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
@ -479,6 +601,69 @@ class BatchRunner:
else:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]],
start_time: float, tool_errors_by_tool: Dict[str, List[Dict]],
exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None):
"""
Consolidate batch data into trajectories.jsonl and save statistics.
Args:
num_batches (int): Number of batches processed
tool_stats (Dict): Aggregated tool statistics
start_time (float): Start time of the run
tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent
exception_errors (List): Top-level exceptions
early_exit (bool): Whether this is an early exit
exit_reason (str): Reason for early exit
"""
# Combine all batch files into a single trajectories.jsonl file
combined_file = self.output_dir / "trajectories.jsonl"
safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]")
entries_written = 0
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(num_batches):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
outfile.write(line)
entries_written += 1
safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]")
# Calculate success rates for tool stats
for tool_name in tool_stats:
stats = tool_stats[tool_name]
total_calls = stats["success"] + stats["failure"]
if total_calls > 0:
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
else:
stats["success_rate"] = 0.0
stats["failure_rate"] = 0.0
# Save final statistics
final_stats = {
"run_name": self.run_name,
"distribution": self.distribution,
"total_prompts": len(self.dataset),
"total_batches": len(self.batches),
"batches_processed": num_batches,
"batch_size": self.batch_size,
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
"early_exit": early_exit,
"exit_reason": exit_reason,
"tool_errors": tool_errors_by_tool,
"exception_errors": exception_errors[:self.keep_recent_errors], # Keep k most recent
"tool_statistics": tool_stats
}
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
def run(self, resume: bool = False):
@ -520,9 +705,16 @@ class BatchRunner:
# Aggregate statistics across all batches
total_tool_stats = {}
tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]}
all_exception_errors = []
all_completed_prompts = list(completed_prompts_set)
total_processed = len(completed_prompts_set)
total_tool_errors = 0
early_exit = False
exit_reason = None
start_time = time.time()
# Process batches in parallel
with Pool(processes=self.num_workers) as pool:
# Create tasks for each batch
@ -536,84 +728,147 @@ class BatchRunner:
)
for batch_num, batch_data in enumerate(self.batches)
]
# Use map to process batches in parallel
results = pool.map(_process_batch_worker, tasks)
# Aggregate all batch statistics and update checkpoint
all_completed_prompts = list(completed_prompts_set)
for batch_result in results:
# Add newly completed prompts
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
# Aggregate tool stats
for tool_name, stats in batch_result.get("tool_stats", {}).items():
if tool_name not in total_tool_stats:
total_tool_stats[tool_name] = {
"count": 0,
"success": 0,
"failure": 0
}
total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"]
# Process batches and check tool failure threshold after each batch
for batch_num, task in enumerate(tasks):
# Process single batch
result = pool.apply(_process_batch_worker, (task,))
# Update statistics
all_completed_prompts.extend(result.get("completed_prompts", []))
total_processed += result.get("processed", 0)
# Aggregate tool stats
for tool_name, stats in result.get("tool_stats", {}).items():
if tool_name not in total_tool_stats:
total_tool_stats[tool_name] = {
"count": 0,
"success": 0,
"failure": 0
}
total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"]
# Aggregate tool errors (keep k most recent per tool)
for tool_error in result.get("tool_errors", []):
tool_name = tool_error["tool_name"]
if tool_name not in tool_errors_by_tool:
tool_errors_by_tool[tool_name] = []
# Add error and keep only k most recent
tool_errors_by_tool[tool_name].append(tool_error)
if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors:
tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:]
total_tool_errors += 1
# Track exception errors
all_exception_errors.extend(result.get("exception_errors", []))
# Check tool failure thresholds
if total_processed > 0:
tool_failure_rate = total_tool_errors / total_processed
# Check absolute count threshold
if total_tool_errors >= self.max_tool_failures:
early_exit = True
exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
break
# Check rate threshold
if tool_failure_rate >= self.max_tool_failure_rate:
early_exit = True
exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%})"
safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]")
break
# Update checkpoint after each batch
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
self._save_checkpoint(checkpoint_data)
# Calculate success rates
for tool_name in total_tool_stats:
stats = total_tool_stats[tool_name]
total_calls = stats["success"] + stats["failure"]
if total_calls > 0:
stats["success_rate"] = round(stats["success"] / total_calls * 100, 2)
stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2)
else:
stats["success_rate"] = 0.0
stats["failure_rate"] = 0.0
# Combine all batch files into a single trajectories.jsonl file
combined_file = self.output_dir / "trajectories.jsonl"
print(f"\n📦 Combining batch files into {combined_file.name}...")
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(len(self.batches)):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
outfile.write(line)
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl")
# Save final statistics
final_stats = {
"run_name": self.run_name,
"distribution": self.distribution,
"total_prompts": len(self.dataset),
"total_batches": len(self.batches),
"batch_size": self.batch_size,
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
"tool_statistics": total_tool_stats
}
with open(self.stats_file, 'w', encoding='utf-8') as f:
json.dump(final_stats, f, indent=2, ensure_ascii=False)
# Consolidate data and save statistics
num_batches_processed = batch_num + 1 if early_exit else len(self.batches)
self._consolidate_data(
num_batches_processed,
total_tool_stats,
start_time,
tool_errors_by_tool,
all_exception_errors,
early_exit,
exit_reason
)
# Print summary
print("\n" + "=" * 70)
print("📊 BATCH PROCESSING COMPLETE")
print("=" * 70)
print(f"✅ Total prompts processed: {len(self.dataset)}")
print(f"✅ Total batches: {len(self.batches)}")
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
print(f"\n📈 Tool Usage Statistics:")
print("-" * 70)
safe_print("\n" + "=" * 70)
if early_exit:
safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]")
safe_print(f"[yellow]Reason: {exit_reason}[/yellow]")
else:
safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]")
safe_print("=" * 70)
safe_print(f"✅ Total prompts processed: {total_processed}")
safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}")
safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
# Tool error summary
if tool_errors_by_tool:
total_errors = sum(len(errors) for errors in tool_errors_by_tool.values())
safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]")
safe_print("[red]-[/red]" * 70)
# Sort tools by error count
sorted_tools = sorted(
tool_errors_by_tool.items(),
key=lambda x: len(x[1]),
reverse=True
)
for tool_name, errors in sorted_tools:
# Count unique error messages
unique_errors = {}
for error in errors:
error_msg = error["error_message"][:100] # Truncate for grouping
if error_msg not in unique_errors:
unique_errors[error_msg] = []
unique_errors[error_msg].append(error)
safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)")
# Show up to 3 most recent unique error types
for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]):
error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..."
safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})")
# Show one example with prompt index
example = instances[-1] # Most recent
safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]")
if len(unique_errors) > 3:
safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]")
tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0
safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]")
# Exception errors
if all_exception_errors:
safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]")
safe_print("[red]-[/red]" * 70)
for error in all_exception_errors[:self.keep_recent_errors]:
error_preview = error["error"][:100]
if len(error["error"]) > 100:
error_preview += "..."
safe_print(f" Prompt {error['prompt_index']}: [dim]{error_preview}[/dim]")
safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]")
safe_print("-" * 70)
if total_tool_stats:
# Sort by count descending
sorted_tools = sorted(
@ -621,25 +876,30 @@ class BatchRunner:
key=lambda x: x[1]["count"],
reverse=True
)
print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
print("-" * 70)
safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}")
safe_print("-" * 70)
for tool_name, stats in sorted_tools:
print(
safe_print(
f"{tool_name:<25} "
f"{stats['count']:<10} "
f"{stats['success']:<10} "
f"{stats['failure']:<10} "
f"{stats['success_rate']:.1f}%"
f"{stats.get('success_rate', 0):.1f}%"
)
else:
print("No tool calls were made during this run.")
print(f"\n💾 Results saved to: {self.output_dir}")
print(f" - Trajectories: trajectories.jsonl (combined)")
print(f" - Individual batches: batch_*.jsonl (for debugging)")
print(f" - Statistics: {self.stats_file.name}")
print(f" - Checkpoint: {self.checkpoint_file.name}")
safe_print("No tool calls were made during this run.")
safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}")
safe_print(f" - Trajectories: trajectories.jsonl (combined)")
safe_print(f" - Individual batches: batch_*.jsonl (for debugging)")
safe_print(f" - Statistics: {self.stats_file.name}")
safe_print(f" - Checkpoint: {self.checkpoint_file.name}")
if early_exit:
safe_print(f"\n[bold yellow] Run was stopped early due to tool failures.[/bold yellow]")
safe_print(f"[yellow] Check {self.stats_file.name} for detailed error information including tracebacks.[/yellow]")
safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]")
def main(
@ -657,6 +917,9 @@ def main(
list_distributions: bool = False,
ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100,
max_tool_failures: int = 10,
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
):
"""
Run batch processing of agent prompts from a dataset.
@ -676,7 +939,10 @@ def main(
list_distributions (bool): List available toolset distributions and exit
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
max_tool_failures (int): Maximum number of tool failures before stopping (default: 10)
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
Examples:
# Basic usage
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run
@ -690,7 +956,11 @@ def main(
# With ephemeral system prompt (not saved to dataset)
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--ephemeral_system_prompt="You are a helpful assistant focused on image generation."
# With custom tool failure thresholds
python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\
--max_tool_failures=20 --max_tool_failure_rate=0.3 --keep_recent_errors=10
# List available distributions
python batch_runner.py --list_distributions
"""
@ -737,7 +1007,10 @@ def main(
num_workers=num_workers,
verbose=verbose,
ephemeral_system_prompt=ephemeral_system_prompt,
log_prefix_chars=log_prefix_chars
log_prefix_chars=log_prefix_chars,
max_tool_failures=max_tool_failures,
max_tool_failure_rate=max_tool_failure_rate,
keep_recent_errors=keep_recent_errors
)
runner.run(resume=resume)

20
safe_print.py Normal file
View file

@ -0,0 +1,20 @@
#!/usr/bin/env python3
"""Simple safe print that tries rich, falls back to regular print."""
try:
from rich import print as rich_print
RICH_AVAILABLE = True
except ImportError:
RICH_AVAILABLE = False
def safe_print(*args, **kwargs):
"""Try rich.print, fall back to regular print if it fails."""
if RICH_AVAILABLE:
try:
rich_print(*args, **kwargs)
return
except Exception:
pass
# Fallback to regular print
print(*args, **kwargs)