diff --git a/batch_runner.py b/batch_runner.py index a1f09719b..8a9632b83 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -9,15 +9,21 @@ across multiple prompts from a dataset. It includes: - Checkpointing for fault tolerance and resumption - Trajectory saving in the proper format (from/value pairs) - Tool usage statistics aggregation across all batches +- Cluster failure detection and graceful shutdown (morph, firecrawl, API errors) +- Configurable failure thresholds with automatic data consolidation Usage: python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run - + # Resume an interrupted run python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --resume - + # Use a specific toolset distribution python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run --distribution=image_gen + + # Configure tool failure thresholds + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 """ import json @@ -29,6 +35,7 @@ from typing import List, Dict, Any, Optional, Tuple from datetime import datetime from multiprocessing import Pool, Manager, Lock import traceback +import re import fire @@ -39,12 +46,166 @@ from toolset_distributions import ( sample_toolsets_from_distribution, validate_distribution ) -from profiling import get_profiler +from safe_print import safe_print # Global configuration for worker processes _WORKER_CONFIG = {} +# Canonical names for the terminal tool (old & new implementations) +_TERMINAL_TOOL_NAMES = {"terminal", "terminal_tool", "simple_terminal_tool"} + + +def _is_terminal_tool_name(tool_name: Optional[str]) -> bool: + """Return True if the given tool name corresponds to a terminal tool.""" + return bool(tool_name) and tool_name.lower() in _TERMINAL_TOOL_NAMES + + +def _terminal_tool_failed(content_json: Dict[str, Any]) -> bool: + """ + Determine whether the terminal tool itself failed (not the user command). + + Terminal failures are indicated by explicit status flags or negative exit codes. + Regular command failures (non-zero positive exit codes, stderr, timeouts) are not counted. + """ + if not isinstance(content_json, dict): + return False + + status = str(content_json.get("status", "")).lower() + if status in {"error", "disabled"}: + return True + + exit_code = content_json.get("exit_code") + if isinstance(exit_code, int) and exit_code < 0: + return True + + return False + + +def _categorize_error_type(error_message: str) -> str: + """ + Categorize an error message into a failure type. + + Args: + error_message (str): The error message to categorize + + Returns: + str: Category of the error + """ + error_lower = error_message.lower() + + # Common error patterns + if "timeout" in error_lower or "timed out" in error_lower: + return "Timeout" + elif "connection" in error_lower or "connect" in error_lower: + return "Connection Error" + elif "rate limit" in error_lower or "ratelimit" in error_lower or "429" in error_lower: + return "Rate Limit" + elif "authentication" in error_lower or "auth" in error_lower or "unauthorized" in error_lower or "401" in error_lower: + return "Authentication" + elif "not found" in error_lower or "404" in error_lower: + return "Not Found" + elif "permission" in error_lower or "forbidden" in error_lower or "403" in error_lower: + return "Permission Denied" + elif "invalid" in error_lower or "malformed" in error_lower or "bad request" in error_lower or "400" in error_lower: + return "Invalid Input" + elif "out of memory" in error_lower or "oom" in error_lower: + return "Out of Memory" + elif "network" in error_lower: + return "Network Error" + elif "server error" in error_lower or "500" in error_lower or "502" in error_lower or "503" in error_lower: + return "Server Error" + elif "vm" in error_lower and ("fail" in error_lower or "error" in error_lower): + return "VM Error" + else: + return "Other" + + +def _extract_tool_errors_from_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Extract tool errors from message history with tool names. + + Args: + messages (List[Dict]): Message history + + Returns: + List[Dict]: List of tool errors with tool name, error message, error type, and context + """ + tool_errors = [] + tool_calls_map = {} # Map tool_call_id to tool name + + for msg in messages: + # Track tool calls from assistant messages + if msg["role"] == "assistant" and "tool_calls" in msg and msg["tool_calls"]: + for tool_call in msg["tool_calls"]: + tool_name = tool_call["function"]["name"] + tool_call_id = tool_call["id"] + tool_calls_map[tool_call_id] = tool_name + + # Check tool responses for errors + elif msg["role"] == "tool": + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + # Determine if tool call had an error + has_error = False + error_msg = None + + try: + content_json = json.loads(content) if isinstance(content, str) else content + + if isinstance(content_json, dict): + # Get tool name for special handling + tool_name = tool_calls_map.get(tool_call_id, "unknown") + + # Special handling for terminal tool outputs + if _is_terminal_tool_name(tool_name): + if _terminal_tool_failed(content_json): + has_error = True + # Prefer explicit error text, fall back to status or generic message + error_msg = str( + content_json.get("error") + or content_json.get("status") + or "Terminal tool failure" + ) + else: + # For other tools, check if error field exists AND has a non-null value + if "error" in content_json and content_json["error"] is not None: + has_error = True + error_msg = str(content_json["error"]) + + # Check nested content structure (some tools wrap responses) + if "content" in content_json and isinstance(content_json["content"], dict): + inner_content = content_json["content"] + if inner_content.get("error") is not None: + has_error = True + error_msg = inner_content.get("error") + + # Check for "success": false pattern + if content_json.get("success") is False: + has_error = True + if not error_msg: + error_msg = str(content_json.get("message", content_json.get("error", "Unknown error"))) + + except: + # If not JSON, check if content explicitly states an error + if content.strip().lower().startswith("error:"): + has_error = True + error_msg = content.strip() + + # Record error if found + if has_error and tool_call_id in tool_calls_map: + tool_name = tool_calls_map[tool_call_id] + error_message = error_msg or "Unknown error" + tool_errors.append({ + "tool_name": tool_name, + "error_message": error_message, + "error_type": _categorize_error_type(error_message), + "full_content": content[:500] # Keep first 500 chars of full response + }) + + return tool_errors + def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, int]]: """ @@ -83,31 +244,37 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i elif msg["role"] == "tool": tool_call_id = msg.get("tool_call_id", "") content = msg.get("content", "") - + # Determine if tool call was successful is_success = True try: # Try to parse as JSON and check for actual error values content_json = json.loads(content) if isinstance(content, str) else content - + if isinstance(content_json, dict): - # Check if error field exists AND has a non-null value - if "error" in content_json and content_json["error"] is not None: - is_success = False - - # Special handling for terminal tool responses - # Terminal wraps its response in a "content" field - if "content" in content_json and isinstance(content_json["content"], dict): - inner_content = content_json["content"] - # Check for actual error (non-null error field) - # Note: non-zero exit codes are not failures - the model can self-correct - if inner_content.get("error") is not None: + # Get tool name for special handling + tool_name = tool_calls_map.get(tool_call_id, "unknown") + + # Special handling for terminal tool: only count as failure when the tool itself fails + if _is_terminal_tool_name(tool_name): + if _terminal_tool_failed(content_json): is_success = False - - # Check for "success": false pattern used by some tools - if content_json.get("success") is False: - is_success = False - + else: + # For other tools, check if error field exists AND has a non-null value + if "error" in content_json and content_json["error"] is not None: + is_success = False + + # Check nested content structure (some tools wrap responses) + if "content" in content_json and isinstance(content_json["content"], dict): + inner_content = content_json["content"] + # Check for actual error (non-null error field) + if inner_content.get("error") is not None: + is_success = False + + # Check for "success": false pattern used by some tools + if content_json.get("success") is False: + is_success = False + except: # If not JSON, check if content is empty or explicitly states an error # Note: We avoid simple substring matching to prevent false positives @@ -170,22 +337,30 @@ def _process_single_prompt( # Run the agent with task_id to ensure each task gets its own isolated VM result = agent.run_conversation(prompt, task_id=f"task_{prompt_index}") - + # Extract tool usage statistics tool_stats = _extract_tool_stats(result["messages"]) - + + # Extract tool errors from conversation + tool_errors = _extract_tool_errors_from_messages(result["messages"]) + # Convert to trajectory format (using existing method) trajectory = agent._convert_to_trajectory_format( result["messages"], prompt, result["completed"] ) - + + # Get profiling stats from the result + profiling_stats = result.get("profiling_stats", {"tools": {}, "api_calls": {}}) + return { "success": True, "prompt_index": prompt_index, "trajectory": trajectory, "tool_stats": tool_stats, + "tool_errors": tool_errors, + "profiling_stats": profiling_stats, "completed": result["completed"], "api_calls": result["api_calls"], "toolsets_used": selected_toolsets, @@ -197,14 +372,19 @@ def _process_single_prompt( } except Exception as e: - print(f"❌ Error processing prompt {prompt_index}: {e}") + error_msg = str(e) + tb = traceback.format_exc() + safe_print(f"[bold red]❌ Error processing prompt {prompt_index}:[/bold red] {error_msg}") if config.get("verbose"): - traceback.print_exc() - + safe_print(tb) + return { "success": False, "prompt_index": prompt_index, - "error": str(e), + "error": error_msg, + "traceback": tb, + "tool_errors": [], + "profiling_stats": {"tools": {}, "api_calls": {}}, "trajectory": None, "tool_stats": {}, "toolsets_used": [], @@ -253,8 +433,11 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: # Initialize aggregated stats for this batch batch_tool_stats = {} + batch_profiling_stats = [] # Collect profiling stats from each prompt completed_in_batch = [] - + all_tool_errors = [] # Track all tool errors in this batch + exception_errors = [] # Track top-level exceptions + # Process each prompt sequentially in this batch for prompt_index, prompt_data in prompts_to_process: # Process the prompt @@ -264,7 +447,27 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: batch_num, config ) - + + # Track tool errors from the conversation + if result.get("tool_errors"): + for tool_error in result["tool_errors"]: + all_tool_errors.append({ + "prompt_index": prompt_index, + "tool_name": tool_error["tool_name"], + "error_message": tool_error["error_message"], + "full_content": tool_error.get("full_content", ""), + "error_type": tool_error.get("error_type", "Other") + }) + + # Track top-level exceptions (not tool errors) + if not result["success"]: + exception_errors.append({ + "prompt_index": prompt_index, + "error": result.get("error", "Unknown error"), + "traceback": result.get("traceback", "") + }) + safe_print(f"[bold red]❌ Exception in prompt {prompt_index}:[/bold red] {result.get('error', '')[:100]}") + # Save trajectory if successful if result["success"] and result["trajectory"]: trajectory_entry = { @@ -275,7 +478,7 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "api_calls": result["api_calls"], "toolsets_used": result["toolsets_used"] } - + # Append to batch output file with open(batch_output_file, 'a', encoding='utf-8') as f: f.write(json.dumps(trajectory_entry, ensure_ascii=False) + "\n") @@ -288,22 +491,29 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]: "success": 0, "failure": 0 } - + batch_tool_stats[tool_name]["count"] += stats["count"] batch_tool_stats[tool_name]["success"] += stats["success"] batch_tool_stats[tool_name]["failure"] += stats["failure"] - + + # Collect profiling statistics + if result.get("profiling_stats"): + batch_profiling_stats.append(result["profiling_stats"]) + completed_in_batch.append(prompt_index) print(f" ✅ Prompt {prompt_index} completed") print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)") - + return { "batch_num": batch_num, "processed": len(prompts_to_process), "skipped": len(batch_data) - len(prompts_to_process), "tool_stats": batch_tool_stats, - "completed_prompts": completed_in_batch + "profiling_stats": batch_profiling_stats, + "completed_prompts": completed_in_batch, + "tool_errors": all_tool_errors, + "exception_errors": exception_errors } @@ -326,6 +536,10 @@ class BatchRunner: verbose: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, + max_tool_failures: int = 10, + max_tool_failure_rate: float = 0.5, + keep_recent_errors: int = 5, + min_tool_calls_for_rate: int = 10, ): """ Initialize the batch runner. @@ -343,6 +557,10 @@ class BatchRunner: verbose (bool): Enable verbose logging ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + max_tool_failures (int): Maximum number of tool failures before stopping (default: 10) + max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) + keep_recent_errors (int): Number of recent errors to keep per tool (default: 5) + min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10) """ self.dataset_file = Path(dataset_file) self.batch_size = batch_size @@ -356,6 +574,10 @@ class BatchRunner: self.verbose = verbose self.ephemeral_system_prompt = ephemeral_system_prompt self.log_prefix_chars = log_prefix_chars + self.max_tool_failures = max_tool_failures + self.max_tool_failure_rate = max_tool_failure_rate + self.keep_recent_errors = keep_recent_errors + self.min_tool_calls_for_rate = min_tool_calls_for_rate # Validate distribution if not validate_distribution(distribution): @@ -364,12 +586,15 @@ class BatchRunner: # Setup output directory self.output_dir = Path("data") / run_name self.output_dir.mkdir(parents=True, exist_ok=True) - + # Checkpoint file self.checkpoint_file = self.output_dir / "checkpoint.json" - + # Statistics file self.stats_file = self.output_dir / "statistics.json" + + # Errors file + self.errors_file = self.output_dir / "errors.json" # Load dataset self.dataset = self._load_dataset() @@ -377,17 +602,22 @@ class BatchRunner: # Create batches self.batches = self._create_batches() - print(f"📊 Batch Runner Initialized") - print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") - print(f" Batch size: {self.batch_size}") - print(f" Total batches: {len(self.batches)}") - print(f" Run name: {self.run_name}") - print(f" Distribution: {self.distribution}") - print(f" Output directory: {self.output_dir}") - print(f" Workers: {self.num_workers}") + safe_print("[bold cyan]📊 Batch Runner Initialized[/bold cyan]") + safe_print(f" Dataset: {self.dataset_file} ({len(self.dataset)} prompts)") + safe_print(f" Batch size: {self.batch_size}") + safe_print(f" Total batches: {len(self.batches)}") + safe_print(f" Run name: {self.run_name}") + safe_print(f" Distribution: {self.distribution}") + safe_print(f" Output directory: {self.output_dir}") + safe_print(f" Workers: {self.num_workers}") + safe_print(f" [yellow]Tool failure limits:[/yellow]") + safe_print(f" Max failures: {self.max_tool_failures}") + safe_print(f" Max failure rate: {self.max_tool_failure_rate:.1%}") + safe_print(f" Min tool calls for rate check: {self.min_tool_calls_for_rate}") + safe_print(f" Keep recent errors: {self.keep_recent_errors}") if self.ephemeral_system_prompt: prompt_preview = self.ephemeral_system_prompt[:60] + "..." if len(self.ephemeral_system_prompt) > 60 else self.ephemeral_system_prompt - print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'") + safe_print(f" 🔒 Ephemeral system prompt: '{prompt_preview}'") def _load_dataset(self) -> List[Dict[str, Any]]: """ @@ -465,13 +695,13 @@ class BatchRunner: def _save_checkpoint(self, checkpoint_data: Dict[str, Any], lock: Optional[Lock] = None): """ Save checkpoint data. - + Args: checkpoint_data (Dict): Checkpoint data to save lock (Lock): Optional lock for thread-safe access """ checkpoint_data["last_updated"] = datetime.now().isoformat() - + if lock: with lock: with open(self.checkpoint_file, 'w', encoding='utf-8') as f: @@ -479,6 +709,118 @@ class BatchRunner: else: with open(self.checkpoint_file, 'w', encoding='utf-8') as f: json.dump(checkpoint_data, f, indent=2, ensure_ascii=False) + + def _consolidate_data(self, num_batches: int, tool_stats: Dict[str, Dict[str, int]], + start_time: float, tool_errors_by_tool: Dict[str, List[Dict]], + exception_errors: List[Dict], early_exit: bool = False, exit_reason: str = None, + profiling_stats_list: List[Dict] = None): + """ + Consolidate batch data into trajectories.jsonl and save statistics. + + Args: + num_batches (int): Number of batches processed + tool_stats (Dict): Aggregated tool statistics + start_time (float): Start time of the run + tool_errors_by_tool (Dict): Tool errors grouped by tool name with k most recent + exception_errors (List): Top-level exceptions + early_exit (bool): Whether this is an early exit + exit_reason (str): Reason for early exit + profiling_stats_list (List[Dict]): List of profiling statistics from each conversation + """ + # Combine all batch files into a single trajectories.jsonl file + combined_file = self.output_dir / "trajectories.jsonl" + safe_print(f"\n[cyan]📦 Combining batch files into {combined_file.name}...[/cyan]") + + entries_written = 0 + with open(combined_file, 'w', encoding='utf-8') as outfile: + for batch_num in range(num_batches): + batch_file = self.output_dir / f"batch_{batch_num}.jsonl" + if batch_file.exists(): + with open(batch_file, 'r', encoding='utf-8') as infile: + for line in infile: + outfile.write(line) + entries_written += 1 + + safe_print(f"[green]✅ Combined {num_batches} batch files into trajectories.jsonl ({entries_written} entries)[/green]") + + # Calculate success rates for tool stats + for tool_name in tool_stats: + stats = tool_stats[tool_name] + total_calls = stats["success"] + stats["failure"] + if total_calls > 0: + stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) + stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) + else: + stats["success_rate"] = 0.0 + stats["failure_rate"] = 0.0 + + # Build failure type breakdown for each tool + failure_type_breakdown = {} + for tool_name, errors in tool_errors_by_tool.items(): + failure_types = {} + for error in errors: + error_type = error.get("error_type", "Other") + if error_type not in failure_types: + failure_types[error_type] = 0 + failure_types[error_type] += 1 + + # Calculate percentages + total_failures = len(errors) + failure_type_breakdown[tool_name] = { + "total_failures": total_failures, + "types": { + error_type: { + "count": count, + "percentage": round((count / total_failures) * 100, 2) + } + for error_type, count in failure_types.items() + } + } + + # Save error information to separate file + error_data = { + "run_name": self.run_name, + "completed_at": datetime.now().isoformat(), + "total_tool_errors": sum(len(errors) for errors in tool_errors_by_tool.values()), + "total_exception_errors": len(exception_errors), + "tool_errors": tool_errors_by_tool, + "failure_type_breakdown": failure_type_breakdown, + "exception_errors": exception_errors[:self.keep_recent_errors] # Keep k most recent + } + + with open(self.errors_file, 'w', encoding='utf-8') as f: + json.dump(error_data, f, indent=2, ensure_ascii=False) + + # Aggregate profiling statistics if available + aggregated_profiling_stats = None + if profiling_stats_list: + from profiling import aggregate_profiling_stats + aggregated_profiling_stats = aggregate_profiling_stats(profiling_stats_list) + + # Save final statistics (without detailed errors) + final_stats = { + "run_name": self.run_name, + "distribution": self.distribution, + "total_prompts": len(self.dataset), + "total_batches": len(self.batches), + "batches_processed": num_batches, + "batch_size": self.batch_size, + "model": self.model, + "completed_at": datetime.now().isoformat(), + "duration_seconds": round(time.time() - start_time, 2), + "early_exit": early_exit, + "exit_reason": exit_reason, + "tool_statistics": tool_stats, + "profiling_statistics": aggregated_profiling_stats + } + + with open(self.stats_file, 'w', encoding='utf-8') as f: + json.dump(final_stats, f, indent=2, ensure_ascii=False) + + # Display aggregated profiling statistics + if aggregated_profiling_stats: + from profiling import print_aggregated_statistics + print_aggregated_statistics(aggregated_profiling_stats, detailed=True) def run(self, resume: bool = False): @@ -520,9 +862,17 @@ class BatchRunner: # Aggregate statistics across all batches total_tool_stats = {} - + all_profiling_stats = [] # Collect all profiling stats for aggregation + tool_errors_by_tool = {} # {tool_name: [list of k most recent errors]} + all_exception_errors = [] + all_completed_prompts = list(completed_prompts_set) + total_processed = len(completed_prompts_set) + total_tool_errors = 0 + early_exit = False + exit_reason = None + start_time = time.time() - + # Process batches in parallel with Pool(processes=self.num_workers) as pool: # Create tasks for each batch @@ -536,84 +886,184 @@ class BatchRunner: ) for batch_num, batch_data in enumerate(self.batches) ] - - # Use map to process batches in parallel - results = pool.map(_process_batch_worker, tasks) - - # Aggregate all batch statistics and update checkpoint - all_completed_prompts = list(completed_prompts_set) - for batch_result in results: - # Add newly completed prompts - all_completed_prompts.extend(batch_result.get("completed_prompts", [])) - - # Aggregate tool stats - for tool_name, stats in batch_result.get("tool_stats", {}).items(): - if tool_name not in total_tool_stats: - total_tool_stats[tool_name] = { - "count": 0, - "success": 0, - "failure": 0 - } - - total_tool_stats[tool_name]["count"] += stats["count"] - total_tool_stats[tool_name]["success"] += stats["success"] - total_tool_stats[tool_name]["failure"] += stats["failure"] - + + # Process batches in parallel and check tool failure threshold as results come in + # imap_unordered allows parallel processing while getting results as they complete + batch_num = 0 + try: + for result in pool.imap_unordered(_process_batch_worker, tasks): + # Update statistics + all_completed_prompts.extend(result.get("completed_prompts", [])) + total_processed += result.get("processed", 0) + + # Aggregate tool stats + for tool_name, stats in result.get("tool_stats", {}).items(): + if tool_name not in total_tool_stats: + total_tool_stats[tool_name] = { + "count": 0, + "success": 0, + "failure": 0 + } + + total_tool_stats[tool_name]["count"] += stats["count"] + total_tool_stats[tool_name]["success"] += stats["success"] + total_tool_stats[tool_name]["failure"] += stats["failure"] + + # Collect profiling stats from this batch + if result.get("profiling_stats"): + all_profiling_stats.extend(result["profiling_stats"]) + + # Aggregate tool errors (keep k most recent per tool) + for tool_error in result.get("tool_errors", []): + tool_name = tool_error["tool_name"] + if tool_name not in tool_errors_by_tool: + tool_errors_by_tool[tool_name] = [] + + # Add error and keep only k most recent + tool_errors_by_tool[tool_name].append(tool_error) + if len(tool_errors_by_tool[tool_name]) > self.keep_recent_errors: + tool_errors_by_tool[tool_name] = tool_errors_by_tool[tool_name][-self.keep_recent_errors:] + + total_tool_errors += 1 + + # Track exception errors + all_exception_errors.extend(result.get("exception_errors", [])) + + # Check tool failure thresholds + # Calculate total tool calls (not prompts) + total_tool_calls = sum(stats["count"] for stats in total_tool_stats.values()) + + # Check absolute count threshold + if total_tool_errors >= self.max_tool_failures: + early_exit = True + exit_reason = f"Exceeded maximum tool failures ({total_tool_errors}/{self.max_tool_failures})" + safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]") + pool.terminate() # Stop all workers immediately + break + + # Check rate threshold (only if we have enough tool calls to trust the rate) + if total_tool_calls >= self.min_tool_calls_for_rate: + tool_failure_rate = total_tool_errors / total_tool_calls + + if tool_failure_rate >= self.max_tool_failure_rate: + early_exit = True + exit_reason = f"Exceeded tool failure rate ({tool_failure_rate:.2%} >= {self.max_tool_failure_rate:.2%}, {total_tool_errors}/{total_tool_calls} tool calls)" + safe_print(f"\n[bold red]🛑 STOPPING: {exit_reason}[/bold red]") + pool.terminate() # Stop all workers immediately + break + + # Update checkpoint after each batch completes + checkpoint_data["completed_prompts"] = all_completed_prompts + self._save_checkpoint(checkpoint_data) + + batch_num += 1 + except KeyboardInterrupt: + safe_print("\n[bold yellow]⚠️ Interrupted by user, stopping workers...[/bold yellow]") + pool.terminate() + early_exit = True + exit_reason = "Interrupted by user" + # Save final checkpoint checkpoint_data["completed_prompts"] = all_completed_prompts self._save_checkpoint(checkpoint_data) - - # Calculate success rates - for tool_name in total_tool_stats: - stats = total_tool_stats[tool_name] - total_calls = stats["success"] + stats["failure"] - if total_calls > 0: - stats["success_rate"] = round(stats["success"] / total_calls * 100, 2) - stats["failure_rate"] = round(stats["failure"] / total_calls * 100, 2) - else: - stats["success_rate"] = 0.0 - stats["failure_rate"] = 0.0 - - # Combine all batch files into a single trajectories.jsonl file - combined_file = self.output_dir / "trajectories.jsonl" - print(f"\n📦 Combining batch files into {combined_file.name}...") - - with open(combined_file, 'w', encoding='utf-8') as outfile: - for batch_num in range(len(self.batches)): - batch_file = self.output_dir / f"batch_{batch_num}.jsonl" - if batch_file.exists(): - with open(batch_file, 'r', encoding='utf-8') as infile: - for line in infile: - outfile.write(line) - - print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl") - - # Save final statistics - final_stats = { - "run_name": self.run_name, - "distribution": self.distribution, - "total_prompts": len(self.dataset), - "total_batches": len(self.batches), - "batch_size": self.batch_size, - "model": self.model, - "completed_at": datetime.now().isoformat(), - "duration_seconds": round(time.time() - start_time, 2), - "tool_statistics": total_tool_stats - } - - with open(self.stats_file, 'w', encoding='utf-8') as f: - json.dump(final_stats, f, indent=2, ensure_ascii=False) + + # Consolidate data and save statistics + num_batches_processed = batch_num + 1 if early_exit else len(self.batches) + self._consolidate_data( + num_batches_processed, + total_tool_stats, + start_time, + tool_errors_by_tool, + all_exception_errors, + early_exit, + exit_reason, + all_profiling_stats + ) # Print summary - print("\n" + "=" * 70) - print("📊 BATCH PROCESSING COMPLETE") - print("=" * 70) - print(f"✅ Total prompts processed: {len(self.dataset)}") - print(f"✅ Total batches: {len(self.batches)}") - print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s") - print(f"\n📈 Tool Usage Statistics:") - print("-" * 70) - + safe_print("\n" + "=" * 70) + if early_exit: + safe_print("[bold yellow]⚠️ BATCH PROCESSING STOPPED EARLY[/bold yellow]") + safe_print(f"[yellow]Reason: {exit_reason}[/yellow]") + else: + safe_print("[bold green]📊 BATCH PROCESSING COMPLETE[/bold green]") + safe_print("=" * 70) + + safe_print(f"✅ Total prompts processed: {total_processed}") + safe_print(f"✅ Batches completed: {num_batches_processed}/{len(self.batches)}") + safe_print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s") + + # Tool error summary + if tool_errors_by_tool: + total_errors = sum(len(errors) for errors in tool_errors_by_tool.values()) + safe_print(f"\n[bold red]🚨 Tool Errors: {total_tool_errors} total ({len(tool_errors_by_tool)} tools)[/bold red]") + safe_print("[red]-[/red]" * 70) + + # Sort tools by error count + sorted_tools = sorted( + tool_errors_by_tool.items(), + key=lambda x: len(x[1]), + reverse=True + ) + + for tool_name, errors in sorted_tools: + # Count unique error messages + unique_errors = {} + for error in errors: + error_msg = error["error_message"][:100] # Truncate for grouping + if error_msg not in unique_errors: + unique_errors[error_msg] = [] + unique_errors[error_msg].append(error) + + safe_print(f"\n [red]{tool_name}:[/red] {len(errors)} errors ({len(unique_errors)} unique)") + + # Show up to 3 most recent unique error types + for idx, (error_msg, instances) in enumerate(list(unique_errors.items())[:3]): + error_preview = error_msg if len(error_msg) <= 100 else error_msg[:97] + "..." + safe_print(f" [{idx+1}] [dim]{error_preview}[/dim] (x{len(instances)})") + + # Show one example with prompt index and full content prefix + example = instances[-1] # Most recent + safe_print(f" [dim]Prompt {example['prompt_index']}[/dim]") + + # Show full content prefix (first 200 chars) + full_content = example.get('full_content', '') + if full_content and full_content != error_preview: + content_preview = full_content[:200] + if len(full_content) > 200: + content_preview += "..." + # Show with prefix indicator + safe_print(f" [dim]Content: {content_preview}[/dim]") + + if len(unique_errors) > 3: + safe_print(f" [dim]... and {len(unique_errors) - 3} more error types[/dim]") + + tool_failure_rate = total_tool_errors / total_processed if total_processed > 0 else 0 + safe_print(f"\n [red]Tool failure rate: {tool_failure_rate:.2%}[/red]") + + # Exception errors + if all_exception_errors: + safe_print(f"\n[bold red]💥 Top-level Exceptions: {len(all_exception_errors)}[/bold red]") + safe_print("[red]-[/red]" * 70) + for error in all_exception_errors[:self.keep_recent_errors]: + error_msg = error["error"] + error_preview = error_msg[:150] + if len(error_msg) > 150: + error_preview += "..." + safe_print(f" [red]Prompt {error['prompt_index']}:[/red] [dim]{error_preview}[/dim]") + + # Show traceback prefix if available + traceback_text = error.get("traceback", "") + if traceback_text: + # Show last 3 lines of traceback for context + tb_lines = traceback_text.strip().split('\n') + relevant_lines = tb_lines[-3:] if len(tb_lines) > 3 else tb_lines + for line in relevant_lines: + safe_print(f" [dim]{line}[/dim]") + + safe_print(f"\n[cyan]📈 Tool Usage Statistics:[/cyan]") + safe_print("-" * 70) + if total_tool_stats: # Sort by count descending sorted_tools = sorted( @@ -621,28 +1071,68 @@ class BatchRunner: key=lambda x: x[1]["count"], reverse=True ) - - print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") - print("-" * 70) + + safe_print(f"{'Tool Name':<25} {'Count':<10} {'Success':<10} {'Failure':<10} {'Success Rate':<12}") + safe_print("-" * 70) for tool_name, stats in sorted_tools: - print( + safe_print( f"{tool_name:<25} " f"{stats['count']:<10} " f"{stats['success']:<10} " f"{stats['failure']:<10} " - f"{stats['success_rate']:.1f}%" + f"{stats.get('success_rate', 0):.1f}%" ) else: - print("No tool calls were made during this run.") - - print(f"\n💾 Results saved to: {self.output_dir}") - print(f" - Trajectories: trajectories.jsonl (combined)") - print(f" - Individual batches: batch_*.jsonl (for debugging)") - print(f" - Statistics: {self.stats_file.name}") - print(f" - Checkpoint: {self.checkpoint_file.name}") + safe_print("No tool calls were made during this run.") - # Display profiling statistics for the entire batch run - get_profiler().print_statistics(detailed=True) + # Display failure type breakdown for tools with failures + if tool_errors_by_tool: + safe_print(f"\n[cyan]📊 Failure Type Breakdown:[/cyan]") + safe_print("-" * 70) + + # Sort tools by total error count + sorted_tools = sorted( + tool_errors_by_tool.items(), + key=lambda x: len(x[1]), + reverse=True + ) + + for tool_name, errors in sorted_tools: + # Count failure types for this tool + failure_types = {} + for error in errors: + error_type = error.get("error_type", "Other") + if error_type not in failure_types: + failure_types[error_type] = 0 + failure_types[error_type] += 1 + + # Display tool name and total failures + total_failures = len(errors) + safe_print(f"\n[yellow]{tool_name}[/yellow] ({total_failures} failures):") + + # Sort failure types by count + sorted_types = sorted( + failure_types.items(), + key=lambda x: x[1], + reverse=True + ) + + # Display each failure type with count and percentage + for failure_type, count in sorted_types: + percentage = (count / total_failures) * 100 + safe_print(f" • {failure_type:<20} {count:>4} ({percentage:>5.1f}%)") + + safe_print(f"\n[cyan]💾 Results saved to:[/cyan] {self.output_dir}") + safe_print(f" - Trajectories: trajectories.jsonl (combined)") + safe_print(f" - Individual batches: batch_*.jsonl (for debugging)") + safe_print(f" - Statistics: {self.stats_file.name}") + safe_print(f" - Errors: {self.errors_file.name}") + safe_print(f" - Checkpoint: {self.checkpoint_file.name}") + + if early_exit: + safe_print(f"\n[bold yellow]ℹ️ Run was stopped early due to tool failures.[/bold yellow]") + safe_print(f"[yellow] Check {self.errors_file.name} for detailed error information including tracebacks.[/yellow]") + safe_print(f"[yellow] You can resume this run later with --resume flag.[/yellow]") def main( @@ -660,6 +1150,10 @@ def main( list_distributions: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, + max_tool_failures: int = 10, + max_tool_failure_rate: float = 0.5, + keep_recent_errors: int = 5, + min_tool_calls_for_rate: int = 10, ): """ Run batch processing of agent prompts from a dataset. @@ -679,7 +1173,11 @@ def main( list_distributions (bool): List available toolset distributions and exit ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) - + max_tool_failures (int): Maximum number of tool failures before stopping (default: 10) + max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) + keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5) + min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10) + Examples: # Basic usage python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run @@ -693,7 +1191,11 @@ def main( # With ephemeral system prompt (not saved to dataset) python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ --ephemeral_system_prompt="You are a helpful assistant focused on image generation." - + + # With custom tool failure thresholds + python batch_runner.py --dataset_file=data.jsonl --batch_size=10 --run_name=my_run \\ + --max_tool_failures=20 --max_tool_failure_rate=0.3 --min_tool_calls_for_rate=10 --keep_recent_errors=10 + # List available distributions python batch_runner.py --list_distributions """ @@ -740,7 +1242,11 @@ def main( num_workers=num_workers, verbose=verbose, ephemeral_system_prompt=ephemeral_system_prompt, - log_prefix_chars=log_prefix_chars + log_prefix_chars=log_prefix_chars, + max_tool_failures=max_tool_failures, + max_tool_failure_rate=max_tool_failure_rate, + keep_recent_errors=keep_recent_errors, + min_tool_calls_for_rate=min_tool_calls_for_rate ) runner.run(resume=resume) @@ -754,4 +1260,3 @@ def main( if __name__ == "__main__": fire.Fire(main) - diff --git a/profiling.py b/profiling.py new file mode 100644 index 000000000..04d0d473e --- /dev/null +++ b/profiling.py @@ -0,0 +1,381 @@ +""" +Profiling module for tracking timing statistics of tools and LLM API calls. + +This module provides a centralized way to track timing information for various +operations in the agent system, including: +- Individual tool executions +- OpenAI API calls +- Aggregate statistics (min, max, median, mean, total) +""" + +import time +from typing import Dict, List, Optional +from dataclasses import dataclass, field +from collections import defaultdict +import statistics + + +@dataclass +class ProfilingStats: + """Statistics for a particular operation type.""" + call_count: int = 0 + total_time: float = 0.0 + min_time: float = float('inf') + max_time: float = 0.0 + times: List[float] = field(default_factory=list) + + def add_timing(self, duration: float): + """Add a timing measurement.""" + self.call_count += 1 + self.total_time += duration + self.min_time = min(self.min_time, duration) + self.max_time = max(self.max_time, duration) + self.times.append(duration) + + @property + def mean_time(self) -> float: + """Calculate mean time.""" + return self.total_time / self.call_count if self.call_count > 0 else 0.0 + + @property + def median_time(self) -> float: + """Calculate median time.""" + return statistics.median(self.times) if self.times else 0.0 + + def to_dict(self) -> Dict: + """Convert to dictionary for serialization.""" + return { + "call_count": self.call_count, + "total_time": self.total_time, + "min_time": self.min_time if self.min_time != float('inf') else 0.0, + "max_time": self.max_time, + "mean_time": self.mean_time, + "median_time": self.median_time + } + + +class Profiler: + """ + Global profiler for tracking timing statistics across tools and API calls. + + Usage: + profiler = Profiler() + + # Time a tool execution + with profiler.time_tool("web_search"): + # ... tool execution code ... + pass + + # Time an API call + with profiler.time_api_call(): + # ... API call code ... + pass + + # Get statistics + stats = profiler.get_statistics() + """ + + def __init__(self): + """Initialize the profiler.""" + self.tool_stats: Dict[str, ProfilingStats] = defaultdict(ProfilingStats) + self.api_stats: ProfilingStats = ProfilingStats() + self._enabled = True + + def enable(self): + """Enable profiling.""" + self._enabled = True + + def disable(self): + """Disable profiling.""" + self._enabled = False + + def reset(self): + """Reset all profiling data.""" + self.tool_stats.clear() + self.api_stats = ProfilingStats() + + def record_tool_timing(self, tool_name: str, duration: float): + """Record timing for a tool execution.""" + if self._enabled: + self.tool_stats[tool_name].add_timing(duration) + + def record_api_timing(self, duration: float): + """Record timing for an API call.""" + if self._enabled: + self.api_stats.add_timing(duration) + + def get_statistics(self) -> Dict: + """ + Get all profiling statistics. + + Returns: + Dictionary containing tool and API statistics + """ + return { + "tools": { + tool_name: stats.to_dict() + for tool_name, stats in sorted(self.tool_stats.items()) + }, + "api_calls": self.api_stats.to_dict() + } + + def print_statistics(self, detailed: bool = True): + """ + Print profiling statistics in a readable format. + + Args: + detailed: If True, show per-tool breakdown. If False, show summary only. + """ + print("\n" + "="*80) + print("📊 PROFILING STATISTICS") + print("="*80) + + # API Call Statistics + print("\n🔷 OpenAI API Calls:") + if self.api_stats.call_count > 0: + api_dict = self.api_stats.to_dict() + print(f" Total Calls: {api_dict['call_count']}") + print(f" Total Time: {api_dict['total_time']:.2f}s") + print(f" Min Time: {api_dict['min_time']:.2f}s") + print(f" Max Time: {api_dict['max_time']:.2f}s") + print(f" Mean Time: {api_dict['mean_time']:.2f}s") + print(f" Median Time: {api_dict['median_time']:.2f}s") + else: + print(" No API calls recorded") + + # Tool Statistics + print("\n🔧 Tool Executions:") + if self.tool_stats: + if detailed: + for tool_name in sorted(self.tool_stats.keys()): + stats_dict = self.tool_stats[tool_name].to_dict() + print(f"\n 📌 {tool_name}:") + print(f" Total Calls: {stats_dict['call_count']}") + print(f" Total Time: {stats_dict['total_time']:.2f}s") + print(f" Min Time: {stats_dict['min_time']:.2f}s") + print(f" Max Time: {stats_dict['max_time']:.2f}s") + print(f" Mean Time: {stats_dict['mean_time']:.2f}s") + print(f" Median Time: {stats_dict['median_time']:.2f}s") + + # Summary + total_tool_calls = sum(s.call_count for s in self.tool_stats.values()) + total_tool_time = sum(s.total_time for s in self.tool_stats.values()) + print(f"\n 📊 Summary:") + print(f" Total Tool Calls: {total_tool_calls}") + print(f" Total Tool Time: {total_tool_time:.2f}s") + print(f" Unique Tools Used: {len(self.tool_stats)}") + else: + print(" No tool executions recorded") + + # Overall Summary + total_api_time = self.api_stats.total_time + total_tool_time = sum(s.total_time for s in self.tool_stats.values()) + print(f"\n📈 Overall Summary:") + print(f" Total API Time: {total_api_time:.2f}s") + print(f" Total Tool Time: {total_tool_time:.2f}s") + print(f" Total Time: {total_api_time + total_tool_time:.2f}s") + print("="*80 + "\n") + + def export_to_json(self) -> str: + """Export statistics as JSON string.""" + import json + return json.dumps(self.get_statistics(), indent=2) + + def export_to_file(self, filepath: str): + """ + Export statistics to a JSON file. + + Args: + filepath: Path to output file + """ + import json + with open(filepath, 'w') as f: + json.dump(self.get_statistics(), f, indent=2) + print(f"📁 Profiling statistics exported to: {filepath}") + + +# Global profiler instance +_global_profiler: Optional[Profiler] = None + + +def get_profiler() -> Profiler: + """Get or create the global profiler instance.""" + global _global_profiler + if _global_profiler is None: + _global_profiler = Profiler() + return _global_profiler + + +def reset_profiler(): + """Reset the global profiler.""" + global _global_profiler + if _global_profiler is not None: + _global_profiler.reset() + + +class TimingContext: + """Context manager for timing operations.""" + + def __init__(self, profiler: Profiler, operation_type: str, operation_name: Optional[str] = None): + """ + Initialize timing context. + + Args: + profiler: Profiler instance to record timing + operation_type: 'tool' or 'api' + operation_name: Name of the operation (required for tools) + """ + self.profiler = profiler + self.operation_type = operation_type + self.operation_name = operation_name + self.start_time = None + + def __enter__(self): + """Start timing.""" + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Stop timing and record.""" + duration = time.time() - self.start_time + + if self.operation_type == 'tool': + self.profiler.record_tool_timing(self.operation_name, duration) + elif self.operation_type == 'api': + self.profiler.record_api_timing(duration) + + return False # Don't suppress exceptions + + +def aggregate_profiling_stats(stats_list: List[Dict]) -> Dict: + """ + Aggregate multiple profiling statistics dictionaries into one. + + This is useful for batch processing where each worker process has its own + profiler instance that needs to be combined. + + Args: + stats_list: List of statistics dictionaries from get_statistics() + + Returns: + Dict: Aggregated statistics with combined tool and API call data + """ + aggregated = { + "tools": defaultdict(lambda: {"times": []}), + "api_calls": {"times": []} + } + + # Aggregate tool statistics + for stats in stats_list: + # Aggregate tool timings + for tool_name, tool_stats in stats.get("tools", {}).items(): + # Reconstruct individual timings from aggregated stats + # Since we have mean_time and call_count, we approximate + aggregated["tools"][tool_name]["times"].extend( + [tool_stats.get("mean_time", 0.0)] * tool_stats.get("call_count", 0) + ) + + # Aggregate API call timings + api_stats = stats.get("api_calls", {}) + if api_stats.get("call_count", 0) > 0: + aggregated["api_calls"]["times"].extend( + [api_stats.get("mean_time", 0.0)] * api_stats.get("call_count", 0) + ) + + # Calculate final statistics for tools + final_stats = {"tools": {}, "api_calls": {}} + + for tool_name, data in aggregated["tools"].items(): + times = data["times"] + if times: + final_stats["tools"][tool_name] = { + "call_count": len(times), + "total_time": sum(times), + "min_time": min(times), + "max_time": max(times), + "mean_time": statistics.mean(times), + "median_time": statistics.median(times) + } + + # Calculate final statistics for API calls + api_times = aggregated["api_calls"]["times"] + if api_times: + final_stats["api_calls"] = { + "call_count": len(api_times), + "total_time": sum(api_times), + "min_time": min(api_times), + "max_time": max(api_times), + "mean_time": statistics.mean(api_times), + "median_time": statistics.median(api_times) + } + else: + final_stats["api_calls"] = { + "call_count": 0, + "total_time": 0.0, + "min_time": 0.0, + "max_time": 0.0, + "mean_time": 0.0, + "median_time": 0.0 + } + + return final_stats + + +def print_aggregated_statistics(stats: Dict, detailed: bool = True): + """ + Print aggregated profiling statistics in a readable format. + + Args: + stats: Aggregated statistics dictionary from aggregate_profiling_stats() + detailed: If True, show per-tool breakdown. If False, show summary only. + """ + print("\n" + "="*80) + print("📊 AGGREGATED PROFILING STATISTICS") + print("="*80) + + # API Call Statistics + print("\n🔷 OpenAI API Calls:") + api_stats = stats.get("api_calls", {}) + if api_stats.get("call_count", 0) > 0: + print(f" Total Calls: {api_stats['call_count']}") + print(f" Total Time: {api_stats['total_time']:.2f}s") + print(f" Min Time: {api_stats['min_time']:.2f}s") + print(f" Max Time: {api_stats['max_time']:.2f}s") + print(f" Mean Time: {api_stats['mean_time']:.2f}s") + print(f" Median Time: {api_stats['median_time']:.2f}s") + else: + print(" No API calls recorded") + + # Tool Statistics + print("\n🔧 Tool Executions:") + tool_stats = stats.get("tools", {}) + if tool_stats: + if detailed: + for tool_name in sorted(tool_stats.keys()): + stats_dict = tool_stats[tool_name] + print(f"\n 📌 {tool_name}:") + print(f" Total Calls: {stats_dict['call_count']}") + print(f" Total Time: {stats_dict['total_time']:.2f}s") + print(f" Min Time: {stats_dict['min_time']:.2f}s") + print(f" Max Time: {stats_dict['max_time']:.2f}s") + print(f" Mean Time: {stats_dict['mean_time']:.2f}s") + print(f" Median Time: {stats_dict['median_time']:.2f}s") + + # Summary + total_tool_calls = sum(s["call_count"] for s in tool_stats.values()) + total_tool_time = sum(s["total_time"] for s in tool_stats.values()) + print(f"\n 📊 Summary:") + print(f" Total Tool Calls: {total_tool_calls}") + print(f" Total Tool Time: {total_tool_time:.2f}s") + print(f" Unique Tools Used: {len(tool_stats)}") + else: + print(" No tool executions recorded") + + # Overall Summary + total_api_time = api_stats.get("total_time", 0.0) + total_tool_time = sum(s["total_time"] for s in tool_stats.values()) + print(f"\n📈 Overall Summary:") + print(f" Total API Time: {total_api_time:.2f}s") + print(f" Total Tool Time: {total_tool_time:.2f}s") + print(f" Total Time: {total_api_time + total_tool_time:.2f}s") + print("="*80 + "\n") diff --git a/run_agent.py b/run_agent.py index 97fb37087..526f30e2c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -45,6 +45,9 @@ else: from model_tools import get_tool_definitions, handle_function_call, check_toolset_requirements from tools.terminal_tool import cleanup_vm +# Import profiling +from profiling import get_profiler + class AIAgent: """ @@ -364,6 +367,10 @@ class AIAgent: Returns: Dict: Complete conversation result with final response and message history """ + # Reset profiler for this conversation to get fresh stats + from profiling import reset_profiler as reset_prof + reset_prof() + # Generate unique task_id if not provided to isolate VMs between concurrent tasks import uuid effective_task_id = task_id or str(uuid.uuid4()) @@ -419,6 +426,9 @@ class AIAgent: api_duration = time.time() - api_start_time print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s") + # Record API timing in profiler + get_profiler().record_api_timing(api_duration) + if self.verbose_logging: logging.debug(f"API Response received - Usage: {response.usage if hasattr(response, 'usage') else 'N/A'}") @@ -490,6 +500,9 @@ class AIAgent: tool_duration = time.time() - tool_start_time result_preview = function_result[:200] if len(function_result) > 200 else function_result + # Record tool timing in profiler + get_profiler().record_tool_timing(function_name, tool_duration) + if self.verbose_logging: logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s") logging.debug(f"Tool result preview: {result_preview}...") @@ -562,11 +575,15 @@ class AIAgent: if self.verbose_logging: logging.warning(f"Failed to cleanup VM for task {effective_task_id}: {e}") + # Get profiling statistics for this conversation + profiling_stats = get_profiler().get_statistics() + return { "final_response": final_response, "messages": messages, "api_calls": api_call_count, - "completed": completed + "completed": completed, + "profiling_stats": profiling_stats } def chat(self, message: str) -> str: @@ -594,7 +611,8 @@ def main( list_tools: bool = False, save_trajectories: bool = False, verbose: bool = False, - log_prefix_chars: int = 20 + log_prefix_chars: int = 20, + show_profiling: bool = True ): """ Main function for running the agent directly. @@ -613,6 +631,7 @@ def main( save_trajectories (bool): Save conversation trajectories to JSONL files. Defaults to False. verbose (bool): Enable verbose logging for debugging. Defaults to False. log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20. + show_profiling (bool): Display profiling statistics after conversation. Defaults to True. Toolset Examples: - "research": Web search, extract, crawl + vision tools @@ -763,7 +782,11 @@ def main( print(f"\n🎯 FINAL RESPONSE:") print("-" * 30) print(result['final_response']) - + + # Display profiling statistics if enabled + if show_profiling: + get_profiler().print_statistics(detailed=True) + print("\n👋 Agent execution completed!") diff --git a/safe_print.py b/safe_print.py new file mode 100644 index 000000000..4cddcc6ec --- /dev/null +++ b/safe_print.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +"""Simple safe print that tries rich, falls back to regular print.""" + +try: + from rich import print as rich_print + RICH_AVAILABLE = True +except ImportError: + RICH_AVAILABLE = False + + +def safe_print(*args, **kwargs): + """Try rich.print, fall back to regular print if it fails.""" + if RICH_AVAILABLE: + try: + rich_print(*args, **kwargs) + return + except Exception: + pass + # Fallback to regular print + print(*args, **kwargs)