diff --git a/batch_runner.py b/batch_runner.py index 30c8cfad3c..49a91c6f69 100644 --- a/batch_runner.py +++ b/batch_runner.py @@ -327,7 +327,9 @@ async def _process_single_prompt( save_trajectories=False, # We handle saving ourselves verbose_logging=config.get("verbose", False), ephemeral_system_prompt=config.get("ephemeral_system_prompt"), - log_prefix_chars=config.get("log_prefix_chars", 100) + log_prefix_chars=config.get("log_prefix_chars", 100), + prokletor_client=config.get("prokletor_client"), + prokletor_formatter=config.get("prokletor_formatter") ) # Run the agent with task_id to ensure each task gets its own isolated VM @@ -439,6 +441,8 @@ class BatchRunner: max_tool_failure_rate: float = 0.5, keep_recent_errors: int = 5, min_tool_calls_for_rate: int = 10, + prokletor_client: str = None, + prokletor_formatter: str = None, ): """ Initialize the batch runner. @@ -459,6 +463,8 @@ class BatchRunner: max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) keep_recent_errors (int): Number of recent errors to keep per tool (default: 5) min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10) + prokletor_client (str): Name of the prokletor client to use + prokletor_formatter (str): Name of the prokletor formatter to use """ self.dataset_file = Path(dataset_file) self.run_name = run_name @@ -475,6 +481,8 @@ class BatchRunner: self.max_tool_failure_rate = max_tool_failure_rate self.keep_recent_errors = keep_recent_errors self.min_tool_calls_for_rate = min_tool_calls_for_rate + self.prokletor_client = prokletor_client + self.prokletor_formatter = prokletor_formatter # Validate distribution if not validate_distribution(distribution): @@ -719,7 +727,9 @@ class BatchRunner: "api_key": self.api_key, "verbose": self.verbose, "ephemeral_system_prompt": self.ephemeral_system_prompt, - "log_prefix_chars": self.log_prefix_chars + "log_prefix_chars": self.log_prefix_chars, + "prokletor_client": self.prokletor_client, + "prokletor_formatter": self.prokletor_formatter } # Start workers @@ -882,6 +892,8 @@ def main( max_tool_failure_rate: float = 0.5, keep_recent_errors: int = 5, min_tool_calls_for_rate: int = 10, + prokletor_client: str = None, + prokletor_formatter: str = None, ): """ Run batch processing of agent prompts from a dataset. @@ -904,6 +916,8 @@ def main( max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5) keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5) min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10) + prokletor_client (str): Name of the prokletor client to use + prokletor_formatter (str): Name of the prokletor formatter to use Examples: # Basic usage @@ -953,7 +967,9 @@ def main( max_tool_failures=max_tool_failures, max_tool_failure_rate=max_tool_failure_rate, keep_recent_errors=keep_recent_errors, - min_tool_calls_for_rate=min_tool_calls_for_rate + min_tool_calls_for_rate=min_tool_calls_for_rate, + prokletor_client=prokletor_client, + prokletor_formatter=prokletor_formatter ) runner.run(resume=resume) diff --git a/run_agent.py b/run_agent.py index 27e16a0c60..f7da65d88c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -33,8 +33,14 @@ from datetime import datetime from pathlib import Path from rich import print -from prokletor.formatters.hermes_formatter import HermesToolFormatterWithReasoning -from prokletor.clients.hermes import HermesToolClientWithReasoning +from prokletor.formatters.hermes import HermesToolFormatterWithReasoning +from prokletor.formatters.hermes import HermesToolFormatterWithReasoning +from prokletor.clients.hermes import HermesToolClientWithReasoning, HermesToolClient +from prokletor.clients.claude import AsyncClaudeClient +try: + from anthropic import AsyncAnthropic +except ImportError: + AsyncAnthropic = None # Load environment variables from .env file from dotenv import load_dotenv @@ -76,6 +82,8 @@ class AIAgent: verbose_logging: bool = False, ephemeral_system_prompt: str = None, log_prefix_chars: int = 100, + prokletor_client: str = None, + prokletor_formatter: str = None, ): """ Initialize the AI Agent. @@ -92,6 +100,8 @@ class AIAgent: verbose_logging (bool): Enable verbose logging for debugging (default: False) ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional) log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20) + prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient", "HermesToolClient") + prokletor_formatter (str): Name of the prokletor formatter to use (optional) """ self.model = model self.max_iterations = max_iterations @@ -100,6 +110,8 @@ class AIAgent: self.verbose_logging = verbose_logging self.ephemeral_system_prompt = ephemeral_system_prompt self.log_prefix_chars = log_prefix_chars + self.prokletor_client_name = prokletor_client + self.prokletor_formatter_name = prokletor_formatter # Store toolset filtering options self.enabled_toolsets = enabled_toolsets @@ -128,7 +140,7 @@ class AIAgent: logging.getLogger('openai').setLevel(logging.WARNING) logging.getLogger('httpx').setLevel(logging.WARNING) - # Initialize OpenAI client + # Initialize Client client_kwargs = {} if base_url: client_kwargs["base_url"] = base_url @@ -138,16 +150,45 @@ class AIAgent: client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key") try: - oai_client = AsyncOpenAI(**client_kwargs) - # self.client = oai_client - self.client = HermesToolClientWithReasoning(oai_client) - print(f"🧠 Wrapped OpenAI client with AsyncHermesToolClientWithReasoning") + if prokletor_client == "AsyncClaudeClient": + if AsyncAnthropic is None: + raise ImportError("anthropic package is required for AsyncClaudeClient") + + # AsyncAnthropic kwargs + anthropic_kwargs = {k: v for k, v in client_kwargs.items() if k in ["api_key", "base_url", "timeout", "max_retries", "default_headers"]} + + anthropic_client = AsyncAnthropic(**anthropic_kwargs) + self.client = AsyncClaudeClient(anthropic_client) + print(f"🧠 Wrapped Anthropic client with AsyncClaudeClient") + + elif prokletor_client == "HermesToolClient": + oai_client = AsyncOpenAI(**client_kwargs) + self.client = HermesToolClient(oai_client) + print(f"🧠 Wrapped OpenAI client with HermesToolClient") + + elif prokletor_client == "HermesToolClientWithReasoning": + oai_client = AsyncOpenAI(**client_kwargs) + self.client = HermesToolClientWithReasoning(oai_client) + print(f"🧠 Wrapped OpenAI client with HermesToolClientWithReasoning") + + elif prokletor_client: + # Fallback for unknown client names or if user provides a custom one (future proofing?) + # For now, raise error or default to OpenAI + print(f"⚠️ Unknown prokletor_client '{prokletor_client}'. Defaulting to HermesToolClientWithReasoning.") + oai_client = AsyncOpenAI(**client_kwargs) + self.client = HermesToolClientWithReasoning(oai_client) + + else: + # Default behavior + oai_client = AsyncOpenAI(**client_kwargs) + self.client = oai_client + print(f"🧠 Using raw OpenAI client (no prokletor wrapper)") print(f"🤖 AI Agent initialized with model: {self.model}") if base_url: print(f"🔗 Using custom base URL: {base_url}") except Exception as e: - raise RuntimeError(f"Failed to initialize OpenAI client: {e}") + raise RuntimeError(f"Failed to initialize client: {e}") # Get available tools with filtering self.tools = get_tool_definitions( @@ -223,7 +264,7 @@ class AIAgent: # Use the client wrapper's format method if available to get the exact Hermes format # This ensures batch runner also gets the correct formatting if hasattr(self, 'client') and hasattr(self.client, 'format'): - formatted_messages = self.client.format(messages, self.tools, use_tool_role=True) + formatted_messages = self.client.format(messages, self.tools, render_final=True) trajectory = [] for msg in formatted_messages: @@ -462,13 +503,23 @@ class AIAgent: api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages # Make API call with tools + api_kwargs = { + "model": self.model, + "messages": api_messages, + "tools": self.tools if self.tools else None, + "timeout": 300.0, # 5 minute timeout for long-running agent tasks + } - response = await self.client.chat.completions.create( - model=self.model, - messages=api_messages, - tools=self.tools if self.tools else None, - timeout=300.0 # 5 minute timeout for long-running agent tasks - ) + # Enable thinking by default for AsyncClaudeClient if using a supported model + if self.prokletor_client_name == "AsyncClaudeClient" and self.model.startswith("claude"): + api_kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": 8000 + } + # Ensure max_tokens is set higher than budget_tokens + api_kwargs["max_tokens"] = 16000 + + response = await self.client.chat.completions.create(**api_kwargs) api_duration = time.time() - api_start_time print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s") @@ -572,7 +623,8 @@ class AIAgent: "tool_calls": tool_calls_data }) - # Execute each tool call + # Execute tool calls concurrently + tool_tasks = [] for i, tool_call in enumerate(assistant_message.tool_calls, 1): function_name = tool_call.function.name @@ -587,35 +639,54 @@ class AIAgent: args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") + # Create coroutine for tool execution + task = handle_function_call(function_name, function_args, effective_task_id) + tool_tasks.append(task) + + if tool_tasks: tool_start_time = time.time() - - # Execute the tool with task_id to isolate VMs between concurrent tasks - function_result = await handle_function_call(function_name, function_args, effective_task_id) - - tool_duration = time.time() - tool_start_time - result_preview = function_result[:200] if len(function_result) > 200 else function_result - - # Record tool timing in profiler - get_profiler().record_tool_timing(function_name, tool_duration) - - if self.verbose_logging: - logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s") - logging.debug(f"Tool result preview: {result_preview}...") - - # Add tool result to conversation - # Note: thought_signature should NOT be in tool responses, only in assistant messages - messages.append({ - "role": "tool", - "content": function_result, - "tool_call_id": tool_call.id - }) - - # Preview tool response - response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result - print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}") - # Delay between tool calls - if self.tool_delay > 0 and i < len(assistant_message.tool_calls): + # Execute all tools concurrently + # We use return_exceptions=True to ensure one failure doesn't stop others + # Order of results corresponds to order of tasks + results = await asyncio.gather(*tool_tasks, return_exceptions=True) + + tool_duration = time.time() - tool_start_time + + # Process results + for i, (result, tool_call) in enumerate(zip(results, assistant_message.tool_calls), 1): + function_name = tool_call.function.name + + # Handle exceptions from asyncio.gather + if isinstance(result, Exception): + function_result = json.dumps({"error": str(result)}, ensure_ascii=False) + print(f"❌ Tool {i} ({function_name}) failed: {result}") + else: + function_result = result + + result_preview = function_result[:200] if len(function_result) > 200 else function_result + + # Record tool timing in profiler (approximate since they ran in parallel) + get_profiler().record_tool_timing(function_name, tool_duration) + + if self.verbose_logging: + logging.debug(f"Tool {function_name} completed in parallel batch") + logging.debug(f"Tool result preview: {result_preview}...") + + # Add tool result to conversation + # Note: thought_signature should NOT be in tool responses, only in assistant messages + messages.append({ + "role": "tool", + "content": function_result, + "tool_call_id": tool_call.id + }) + + # Preview tool response + response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result + print(f" ✅ Tool {i} completed - {response_preview}") + + # Optional delay after batch execution + if self.tool_delay > 0: await asyncio.sleep(self.tool_delay) # Continue loop for next response @@ -668,7 +739,7 @@ class AIAgent: # Use the client wrapper's format method if available to get the exact Hermes format if hasattr(self, 'client') and hasattr(self.client, 'format'): raise ValueError("reached this point") - formatted_messages = self.client.format(messages, self.tools, use_tool_role=True) + formatted_messages = self.client.format(messages, self.tools, render_final=True) # We need to adapt this formatted list to the trajectory format expected by _save_trajectory # Since _convert_to_trajectory_format expects raw OAI messages, we might need a different approach @@ -776,7 +847,9 @@ def main( save_trajectories: bool = False, verbose: bool = False, log_prefix_chars: int = 20, - show_profiling: bool = True + show_profiling: bool = True, + prokletor_client: str = None, + prokletor_formatter: str = None, ): """ Main function for running the agent directly. @@ -796,6 +869,8 @@ def main( verbose (bool): Enable verbose logging for debugging. Defaults to False. log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20. show_profiling (bool): Display profiling statistics after conversation. Defaults to True. + prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient") + prokletor_formatter (str): Name of the prokletor formatter to use Toolset Examples: - "research": Web search, extract, crawl + vision tools @@ -914,7 +989,9 @@ def main( disabled_toolsets=disabled_toolsets_list, save_trajectories=save_trajectories, verbose_logging=verbose, - log_prefix_chars=log_prefix_chars + log_prefix_chars=log_prefix_chars, + prokletor_client=prokletor_client, + prokletor_formatter=prokletor_formatter ) except RuntimeError as e: print(f"❌ Failed to initialize agent: {e}")