diff --git a/batch_runner.py b/batch_runner.py index 9d21aebc3..c45d50423 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -9,15 +9,21 @@ across multiple prompts from a dataset. It includes: - Checkpointing for fault tolerance and resumption - Trajectory saving in the proper format (from/value pairs) - Tool usage statistics aggregation across all batches +- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors) +- Configurable failure thresholds with automatic data consolidation Usage: python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run - + # Resume an interrupted run python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume - + # Use a specific toolset distribution python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen + + # Configure tool failure thresholds + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --max_tool_failures=20 --max_tool_failure_rate=0.3 """ import json @@ -29,22 +35,94 @@ from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from multiprocessing import Pool, Manager, Lock import traceback +import re import fire from run_agent import AIAgent from toolset_distributions import ( - get_distribution, - list_distributions, + get_distribution, + list_distributions, sample_toolsets_from_distribution, validate_distribution ) +from safe_print import safe_print # Global configuration for worker processes _WORKER_CONFIG = {} +def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extract tool errors from message history with tool names. + + Args: + messages (List[Dict]): Message history + + Returns: + List[Dict]: List of tool errors with tool name, error message, and context + """ + tool_errors = [] + tool_calls_map = {} # Map tool_call_id to tool name + + for msg in messages: + # Track tool calls from assistant messages + if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + for tool_call in msg["tool_calls"]: + tool_name = tool_call["function"]["name"] + tool_call_id = tool_call["id"] + tool_calls_map[tool_call_id] = tool_name + + # Check tool responses for errors + elif msg["role"] == "tool": + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + # Determine if tool call had an error + has_error = False + error_msg = None + + try: + content_json = json.loads(content) if isinstance(content, str) else content + + if isinstance(content_json, dict): + # Check if error field exists AND has a non-null value + if "error" in content_json and content_json["error"] is not None: + has_error = True + error_msg = str(content_json["error"]) + + # Special handling for terminal tool responses + if "content" in content_json and isinstance(content_json["content"], dict): + inner_content = content_json["content"] + if inner_content.get("error") is not None or inner_content.get("exit_code", 0) != 0: + has_error = True + error_msg = inner_content.get("error") or f"Exit code: {inner_content.get('exit_code')}" + + # Check for "success": false pattern + if content_json.get("success") is False: + has_error = True + if not error_msg: + error_msg = str(content_json.get("message", content_json.get("error", "Unknown error"))) + + except: + # If not JSON, check if content explicitly states an error + if content.strip().lower().startswith("error:"): + has_error = True + error_msg = content.strip() + + # Record error if found + if has_error and tool_call_id in tool_calls_map: + tool_name = tool_calls_map[tool_call_id] + tool_errors.append({ + "tool_name": tool_name, + "error_message": error_msg or "Unknown error", + "full_content": content[:500] # Keep first 500 chars of full response + }) + + return tool_errors + + def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: """ Extract tool usage statistics from message history. @@ -170,22 +248,26 @@ def _process_single_prompt( # Run the agent with task_id to ensure each task gets its own isolated VM result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}") - + # Extract tool usage statistics tool_stats = _extract_tool_stats(result["messages"]) - + + # Extract tool errors from conversation + tool_errors = _extract_tool_errors_from_messages(result["messages"]) + # Convert to trajectory format (using existing method) trajectory = agent._convert_to_trajectory_format( result["messages"], prompt, result["completed"] ) - + return { "success": True, "prompt_index": prompt_index, "trajectory": trajectory, "tool_stats": tool_stats, + "tool_errors": tool_errors, "completed": result["completed"], "api_calls": result["api_calls"], "toolsets_used": selected_toolsets, @@ -197,14 +279,18 @@ def _process_single_prompt( } except Exception as e: - print(f"āŒ Error processing prompt {prompt_index}: {e}") + error_msg = str(e) + tb = traceback.format_exc() + safe_print(f"[bold red]āŒ Error processing prompt {prompt_index}:[/bold red] {error_msg}") if config.get("verbose"): - traceback.print_exc() - + safe_print(tb) + return { "success": False, "prompt_index": prompt_index, - "error": str(e), + "error": error_msg, + "traceback": tb, + "tool_errors": [], "trajectory": None, "tool_stats": {}, "toolsets_used": [], @@ -254,7 +340,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: # Initialize aggregated stats for this batch batch_tool_stats = {} completed_in_batch = [] - + all_tool_errors = [] # Track all tool errors in this batch + exception_errors = [] # Track top-level exceptions + # Process each prompt sequentially in this batch for prompt_index, prompt_data in prompts_to_process: # Process the prompt @@ -264,7 +352,26 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: batch_num, config ) - + + # Track tool errors from the conversation + if result.get("tool_errors"): + for tool_error in result["tool_errors"]: + all_tool_errors.append({ + "prompt_index": prompt_index, + "tool_name": tool_error["tool_name"], + "error_message": tool_error["error_message"], + "full_content": tool_error.get("full_content", "") + }) + + # Track top-level exceptions (not tool errors) + if not result["success"]: + exception_errors.append({ + "prompt_index": prompt_index, + "error": result.get("error", "Unknown error"), + "traceback": result.get("traceback", "") + }) + safe_print(f"[bold red]āŒ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}") + # Save trajectory if successful if result["success"] and result["trajectory"]: trajectory_entry = { @@ -275,7 +382,7 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "api_calls": result["api_calls"], "toolsets_used": result["toolsets_used"] } - + # Append to batch output file with open(batch_output_file, 'a', encoding='utf-8') as f: f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") @@ -303,7 +410,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "processed": len(prompts_to_process), "skipped": len(batch_data) - len(prompts_to_process), "tool_stats": batch_tool_stats, - "completed_prompts": completed_in_batch + "completed_prompts": completed_in_batch, + "tool_errors": all_tool_errors, + "exception_errors": exception_errors } @@ -326,6 +435,9 @@ class BatchRunner: verbose: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, + max_tool_failures: int = 10, + max_tool_failure_rate: float = 0.5, + keep_recent_errors: int = 5, ): """ Initialize the batch runner. @@ -343,6 +455,9 @@ class BatchRunner: verbose (bool): Enable verbose logging ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + max_tool_failures (int): Maximum number of tool failures before stopping (default: 10) + max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) + keep_recent_errors (int): Number of recent errors to keep per tool (default: 5) """ self.dataset_file = Path(dataset_file) self.batch_size = batch_size @@ -356,6 +471,9 @@ class BatchRunner: self.verbose = verbose self.ephemeral_system_prompt = ephemeral_system_prompt self.log_prefix_chars = log_prefix_chars + self.max_tool_failures = max_tool_failures + self.max_tool_failure_rate = max_tool_failure_rate + self.keep_recent_errors = keep_recent_errors # Validate distribution if not validate_distribution(distribution): @@ -377,17 +495,21 @@ class BatchRunner: # Create batches self.batches = self._create_batches() - print(f"šŸ“Š Batch Runner Initialized") - print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") - print(f" Batch size: {self.batch_size}") - print(f" Total batches: {len(self.batches)}") - print(f" Run name: {self.run_name}") - print(f" Distribution: {self.distribution}") - print(f" Output directory: {self.output_dir}") - print(f" Workers: {self.num_workers}") + safe_print("[bold cyan]šŸ“Š Batch Runner Initialized[/bold cyan]") + safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") + safe_print(f" Batch size: {self.batch_size}") + safe_print(f" Total batches: {len(self.batches)}") + safe_print(f" Run name: {self.run_name}") + safe_print(f" Distribution: {self.distribution}") + safe_print(f" Output directory: {self.output_dir}") + safe_print(f" Workers: {self.num_workers}") + safe_print(f" [yellow]Tool failure limits:[/yellow]") + safe_print(f" Max failures: {self.max_tool_failures}") + safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}") + safe_print(f" Keep recent errors: {self.keep_recent_errors}") if self.ephemeral_system_prompt: prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt - print(f" šŸ”’ Ephemeral system prompt: '{prompt_preview}'") + safe_print(f" šŸ”’ Ephemeral system prompt: '{prompt_preview}'") def _load_dataset(self) -> List[Dict[str, Any]]: """ @@ -465,13 +587,13 @@ class BatchRunner: def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): """ Save checkpoint data. - + Args: checkpoint_data (Dict): Checkpoint data to save lock (Lock): Optional lock for thread-safe access """ checkpoint_data["last_updated"] = datetime.now().isoformat() - + if lock: with lock: with open(self.checkpoint_file, 'w', encoding='utf-8') as f: @@ -479,6 +601,69 @@ class BatchRunner: else: with open(self.checkpoint_file, 'w', encoding='utf-8') as f: json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + + def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]], + start_time: float, tool_errors_by_tool: Dict[str, List[Dict]], + exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None): + """ + Consolidate batch data into trajectories.jsonl and save statistics. + + Args: + num_batches (int): Number of batches processed + tool_stats (Dict): Aggregated tool statistics + start_time (float): Start time of the run + tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent + exception_errors (List): Top-level exceptions + early_exit (bool): Whether this is an early exit + exit_reason (str): Reason for early exit + """ + # Combine all batch files into a single trajectories.jsonl file + combined_file = self.output_dir / "trajectories.jsonl" + safe_print(f"\n[cyan]šŸ“¦ Combining batch files into {combined_file.name}...[/cyan]") + + entries_written = 0 + with open(combined_file, 'w', encoding='utf-8') as outfile: + for batch_num in range(num_batches): + batch_file = self.output_dir / f"batch_{batch_num}.jsonl" + if batch_file.exists(): + with open(batch_file, 'r', encoding='utf-8') as infile: + for line in infile: + outfile.write(line) + entries_written += 1 + + safe_print(f"[green]āœ… Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]") + + # Calculate success rates for tool stats + for tool_name in tool_stats: + stats = tool_stats[tool_name] + total_calls = stats["success"] + stats["failure"] + if total_calls > 0: + stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) + stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) + else: + stats["success_rate"] = 0.0 + stats["failure_rate"] = 0.0 + + # Save final statistics + final_stats = { + "run_name": self.run_name, + "distribution": self.distribution, + "total_prompts": len(self.dataset), + "total_batches": len(self.batches), + "batches_processed": num_batches, + "batch_size": self.batch_size, + "model": self.model, + "completed_at": datetime.now().isoformat(), + "duration_seconds": round(time.time() - start_time, 2), + "early_exit": early_exit, + "exit_reason": exit_reason, + "tool_errors": tool_errors_by_tool, + "exception_errors": exception_errors[:self.keep_recent_errors], # Keep k most recent + "tool_statistics": tool_stats + } + + with open(self.stats_file, 'w', encoding='utf-8') as f: + json.dump(final_stats, f, indent=2, ensure_ascii=False) def run(self, resume: bool = False): @@ -520,9 +705,16 @@ class BatchRunner: # Aggregate statistics across all batches total_tool_stats = {} - + tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]} + all_exception_errors = [] + all_completed_prompts = list(completed_prompts_set) + total_processed = len(completed_prompts_set) + total_tool_errors = 0 + early_exit = False + exit_reason = None + start_time = time.time() - + # Process batches in parallel with Pool(processes=self.num_workers) as pool: # Create tasks for each batch @@ -536,84 +728,147 @@ class BatchRunner: ) for batch_num, batch_data in enumerate(self.batches) ] - - # Use map to process batches in parallel - results = pool.map(_process_batch_worker, tasks) - - # Aggregate all batch statistics and update checkpoint - all_completed_prompts = list(completed_prompts_set) - for batch_result in results: - # Add newly completed prompts - all_completed_prompts.extend(batch_result.get("completed_prompts", [])) - - # Aggregate tool stats - for tool_name, stats in batch_result.get("tool_stats", {}).items(): - if tool_name not in total_tool_stats: - total_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - - total_tool_stats[tool_name]["count"] += stats["count"] - total_tool_stats[tool_name]["success"] += stats["success"] - total_tool_stats[tool_name]["failure"] += stats["failure"] - + + # Process batches and check tool failure threshold after each batch + for batch_num, task in enumerate(tasks): + # Process single batch + result = pool.apply(_process_batch_worker, (task,)) + + # Update statistics + all_completed_prompts.extend(result.get("completed_prompts", [])) + total_processed += result.get("processed", 0) + + # Aggregate tool stats + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Aggregate tool errors (keep k most recent per tool) + for tool_error in result.get("tool_errors", []): + tool_name = tool_error["tool_name"] + if tool_name not in tool_errors_by_tool: + tool_errors_by_tool[tool_name] = [] + + # Add error and keep only k most recent + tool_errors_by_tool[tool_name].append(tool_error) + if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors: + tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:] + + total_tool_errors += 1 + + # Track exception errors + all_exception_errors.extend(result.get("exception_errors", [])) + + # Check tool failure thresholds + if total_processed > 0: + tool_failure_rate = total_tool_errors / total_processed + + # Check absolute count threshold + if total_tool_errors >= self.max_tool_failures: + early_exit = True + exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})" + safe_print(f"\n[bold red]šŸ›‘ STOPPING: {exit_reason}[/bold red]") + break + + # Check rate threshold + if tool_failure_rate >= self.max_tool_failure_rate: + early_exit = True + exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%})" + safe_print(f"\n[bold red]šŸ›‘ STOPPING: {exit_reason}[/bold red]") + break + + # Update checkpoint after each batch + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data) + # Save final checkpoint checkpoint_data["completed_prompts"] = all_completed_prompts self._save_checkpoint(checkpoint_data) - - # Calculate success rates - for tool_name in total_tool_stats: - stats = total_tool_stats[tool_name] - total_calls = stats["success"] + stats["failure"] - if total_calls > 0: - stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) - stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) - else: - stats["success_rate"] = 0.0 - stats["failure_rate"] = 0.0 - - # Combine all batch files into a single trajectories.jsonl file - combined_file = self.output_dir / "trajectories.jsonl" - print(f"\nšŸ“¦ Combining batch files into {combined_file.name}...") - - with open(combined_file, 'w', encoding='utf-8') as outfile: - for batch_num in range(len(self.batches)): - batch_file = self.output_dir / f"batch_{batch_num}.jsonl" - if batch_file.exists(): - with open(batch_file, 'r', encoding='utf-8') as infile: - for line in infile: - outfile.write(line) - - print(f"āœ… Combined {len(self.batches)} batch files into trajectories.jsonl") - - # Save final statistics - final_stats = { - "run_name": self.run_name, - "distribution": self.distribution, - "total_prompts": len(self.dataset), - "total_batches": len(self.batches), - "batch_size": self.batch_size, - "model": self.model, - "completed_at": datetime.now().isoformat(), - "duration_seconds": round(time.time() - start_time, 2), - "tool_statistics": total_tool_stats - } - - with open(self.stats_file, 'w', encoding='utf-8') as f: - json.dump(final_stats, f, indent=2, ensure_ascii=False) + + # Consolidate data and save statistics + num_batches_processed = batch_num + 1 if early_exit else len(self.batches) + self._consolidate_data( + num_batches_processed, + total_tool_stats, + start_time, + tool_errors_by_tool, + all_exception_errors, + early_exit, + exit_reason + ) # Print summary - print("\n" + "=" * 70) - print("šŸ“Š BATCH PROCESSING COMPLETE") - print("=" * 70) - print(f"āœ… Total prompts processed: {len(self.dataset)}") - print(f"āœ… Total batches: {len(self.batches)}") - print(f"ā±ļø Total duration: {round(time.time() - start_time, 2)}s") - print(f"\nšŸ“ˆ Tool Usage Statistics:") - print("-" * 70) - + safe_print("\n" + "=" * 70) + if early_exit: + safe_print("[bold yellow]āš ļø BATCH PROCESSING STOPPED EARLY[/bold yellow]") + safe_print(f"[yellow]Reason: {exit_reason}[/yellow]") + else: + safe_print("[bold green]šŸ“Š BATCH PROCESSING COMPLETE[/bold green]") + safe_print("=" * 70) + + safe_print(f"āœ… Total prompts processed: {total_processed}") + safe_print(f"āœ… Batches completed: {num_batches_processed}/{len(self.batches)}") + safe_print(f"ā±ļø Total duration: {round(time.time() - start_time, 2)}s") + + # Tool error summary + if tool_errors_by_tool: + total_errors = sum(len(errors) for errors in tool_errors_by_tool.values()) + safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]") + safe_print("[red]-[/red]" * 70) + + # Sort tools by error count + sorted_tools = sorted( + tool_errors_by_tool.items(), + key=lambda x: len(x[1]), + reverse=True + ) + + for tool_name, errors in sorted_tools: + # Count unique error messages + unique_errors = {} + for error in errors: + error_msg = error["error_message"][:100] # Truncate for grouping + if error_msg not in unique_errors: + unique_errors[error_msg] = [] + unique_errors[error_msg].append(error) + + safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)") + + # Show up to 3 most recent unique error types + for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]): + error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..." + safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})") + # Show one example with prompt index + example = instances[-1] # Most recent + safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]") + + if len(unique_errors) > 3: + safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]") + + tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0 + safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]") + + # Exception errors + if all_exception_errors: + safe_print(f"\n[bold red]šŸ’„ Top-level Exceptions: {len(all_exception_errors)}[/bold red]") + safe_print("[red]-[/red]" * 70) + for error in all_exception_errors[:self.keep_recent_errors]: + error_preview = error["error"][:100] + if len(error["error"]) > 100: + error_preview += "..." + safe_print(f" Prompt {error['prompt_index']}: [dim]{error_preview}[/dim]") + + safe_print(f"\n[cyan]šŸ“ˆ Tool Usage Statistics:[/cyan]") + safe_print("-" * 70) + if total_tool_stats: # Sort by count descending sorted_tools = sorted( @@ -621,25 +876,30 @@ class BatchRunner: key=lambda x: x[1]["count"], reverse=True ) - - print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") - print("-" * 70) + + safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") + safe_print("-" * 70) for tool_name, stats in sorted_tools: - print( + safe_print( f"{tool_name:<25} " f"{stats['count']:<10} " f"{stats['success']:<10} " f"{stats['failure']:<10} " - f"{stats['success_rate']:.1f}%" + f"{stats.get('success_rate', 0):.1f}%" ) else: - print("No tool calls were made during this run.") - - print(f"\nšŸ’¾ Results saved to: {self.output_dir}") - print(f" - Trajectories: trajectories.jsonl (combined)") - print(f" - Individual batches: batch_*.jsonl (for debugging)") - print(f" - Statistics: {self.stats_file.name}") - print(f" - Checkpoint: {self.checkpoint_file.name}") + safe_print("No tool calls were made during this run.") + + safe_print(f"\n[cyan]šŸ’¾ Results saved to:[/cyan] {self.output_dir}") + safe_print(f" - Trajectories: trajectories.jsonl (combined)") + safe_print(f" - Individual batches: batch_*.jsonl (for debugging)") + safe_print(f" - Statistics: {self.stats_file.name}") + safe_print(f" - Checkpoint: {self.checkpoint_file.name}") + + if early_exit: + safe_print(f"\n[bold yellow]ā„¹ļø Run was stopped early due to tool failures.[/bold yellow]") + safe_print(f"[yellow] Check {self.stats_file.name} for detailed error information including tracebacks.[/yellow]") + safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]") def main( @@ -657,6 +917,9 @@ def main( list_distributions: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, + max_tool_failures: int = 10, + max_tool_failure_rate: float = 0.5, + keep_recent_errors: int = 5, ): """ Run batch processing of agent prompts from a dataset. @@ -676,7 +939,10 @@ def main( list_distributions (bool): List available toolset distributions and exit ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) - + max_tool_failures (int): Maximum number of tool failures before stopping (default: 10) + max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) + keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5) + Examples: # Basic usage python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run @@ -690,7 +956,11 @@ def main( # With ephemeral system prompt (not saved to dataset) python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ --ephemeral_system_prompt="You are a helpful assistant focused on image generation." - + + # With custom tool failure thresholds + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --max_tool_failures=20 --max_tool_failure_rate=0.3 --keep_recent_errors=10 + # List available distributions python batch_runner.py --list_distributions """ @@ -737,7 +1007,10 @@ def main( num_workers=num_workers, verbose=verbose, ephemeral_system_prompt=ephemeral_system_prompt, - log_prefix_chars=log_prefix_chars + log_prefix_chars=log_prefix_chars, + max_tool_failures=max_tool_failures, + max_tool_failure_rate=max_tool_failure_rate, + keep_recent_errors=keep_recent_errors ) runner.run(resume=resume) diff --git a/safe_print.py b/safe_print.py new file mode 100644 index 000000000..4cddcc6ec --- /dev/null +++ b/safe_print.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +"""Simple safe print that tries rich, falls back to regular print.""" + +try: + from rich import print as rich_print + RICH_AVAILABLE = True +except ImportError: + RICH_AVAILABLE = False + + +def safe_print(*args, **kwargs): + """Try rich.print, fall back to regular print if it fails.""" + if RICH_AVAILABLE: + try: + rich_print(*args, **kwargs) + return + except Exception: + pass + # Fallback to regular print + print(*args, **kwargs)