add prokletor formatter

This commit is contained in:
hjc-puro 2025-11-23 10:24:58 -05:00
parent e91d9e839a
commit 7d9a1e119d
2 changed files with 142 additions and 49 deletions

View file

@ -327,7 +327,9 @@ async def _process_single_prompt(
save_trajectories=False, # We handle saving ourselves
verbose_logging=config.get("verbose", False),
ephemeral_system_prompt=config.get("ephemeral_system_prompt"),
log_prefix_chars=config.get("log_prefix_chars", 100)
log_prefix_chars=config.get("log_prefix_chars", 100),
prokletor_client=config.get("prokletor_client"),
prokletor_formatter=config.get("prokletor_formatter")
)
# Run the agent with task_id to ensure each task gets its own isolated VM
@ -439,6 +441,8 @@ class BatchRunner:
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
min_tool_calls_for_rate: int = 10,
prokletor_client: str = None,
prokletor_formatter: str = None,
):
"""
Initialize the batch runner.
@ -459,6 +463,8 @@ class BatchRunner:
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool (default: 5)
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
prokletor_client (str): Name of the prokletor client to use
prokletor_formatter (str): Name of the prokletor formatter to use
"""
self.dataset_file = Path(dataset_file)
self.run_name = run_name
@ -475,6 +481,8 @@ class BatchRunner:
self.max_tool_failure_rate = max_tool_failure_rate
self.keep_recent_errors = keep_recent_errors
self.min_tool_calls_for_rate = min_tool_calls_for_rate
self.prokletor_client = prokletor_client
self.prokletor_formatter = prokletor_formatter
# Validate distribution
if not validate_distribution(distribution):
@ -719,7 +727,9 @@ class BatchRunner:
"api_key": self.api_key,
"verbose": self.verbose,
"ephemeral_system_prompt": self.ephemeral_system_prompt,
"log_prefix_chars": self.log_prefix_chars
"log_prefix_chars": self.log_prefix_chars,
"prokletor_client": self.prokletor_client,
"prokletor_formatter": self.prokletor_formatter
}
# Start workers
@ -882,6 +892,8 @@ def main(
max_tool_failure_rate: float = 0.5,
keep_recent_errors: int = 5,
min_tool_calls_for_rate: int = 10,
prokletor_client: str = None,
prokletor_formatter: str = None,
):
"""
Run batch processing of agent prompts from a dataset.
@ -904,6 +916,8 @@ def main(
max_tool_failure_rate (float): Maximum tool failure rate (0.0-1.0) before stopping (default: 0.5)
keep_recent_errors (int): Number of recent errors to keep per tool for reporting (default: 5)
min_tool_calls_for_rate (int): Minimum number of tool calls before checking failure rate (default: 10)
prokletor_client (str): Name of the prokletor client to use
prokletor_formatter (str): Name of the prokletor formatter to use
Examples:
# Basic usage
@ -953,7 +967,9 @@ def main(
max_tool_failures=max_tool_failures,
max_tool_failure_rate=max_tool_failure_rate,
keep_recent_errors=keep_recent_errors,
min_tool_calls_for_rate=min_tool_calls_for_rate
min_tool_calls_for_rate=min_tool_calls_for_rate,
prokletor_client=prokletor_client,
prokletor_formatter=prokletor_formatter
)
runner.run(resume=resume)

View file

@ -33,8 +33,14 @@ from datetime import datetime
from pathlib import Path
from rich import print
from prokletor.formatters.hermes_formatter import HermesToolFormatterWithReasoning
from prokletor.clients.hermes import HermesToolClientWithReasoning
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
from prokletor.formatters.hermes import HermesToolFormatterWithReasoning
from prokletor.clients.hermes import HermesToolClientWithReasoning, HermesToolClient
from prokletor.clients.claude import AsyncClaudeClient
try:
from anthropic import AsyncAnthropic
except ImportError:
AsyncAnthropic = None
# Load environment variables from .env file
from dotenv import load_dotenv
@ -76,6 +82,8 @@ class AIAgent:
verbose_logging: bool = False,
ephemeral_system_prompt: str = None,
log_prefix_chars: int = 100,
prokletor_client: str = None,
prokletor_formatter: str = None,
):
"""
Initialize the AI Agent.
@ -92,6 +100,8 @@ class AIAgent:
verbose_logging (bool): Enable verbose logging for debugging (default: False)
ephemeral_system_prompt (str): System prompt used during agent execution but NOT saved to trajectories (optional)
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses (default: 20)
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient", "HermesToolClient")
prokletor_formatter (str): Name of the prokletor formatter to use (optional)
"""
self.model = model
self.max_iterations = max_iterations
@ -100,6 +110,8 @@ class AIAgent:
self.verbose_logging = verbose_logging
self.ephemeral_system_prompt = ephemeral_system_prompt
self.log_prefix_chars = log_prefix_chars
self.prokletor_client_name = prokletor_client
self.prokletor_formatter_name = prokletor_formatter
# Store toolset filtering options
self.enabled_toolsets = enabled_toolsets
@ -128,7 +140,7 @@ class AIAgent:
logging.getLogger('openai').setLevel(logging.WARNING)
logging.getLogger('httpx').setLevel(logging.WARNING)
# Initialize OpenAI client
# Initialize Client
client_kwargs = {}
if base_url:
client_kwargs["base_url"] = base_url
@ -138,16 +150,45 @@ class AIAgent:
client_kwargs["api_key"] = os.getenv("ANTHROPIC_API_KEY", "dummy-key")
try:
oai_client = AsyncOpenAI(**client_kwargs)
# self.client = oai_client
self.client = HermesToolClientWithReasoning(oai_client)
print(f"🧠 Wrapped OpenAI client with AsyncHermesToolClientWithReasoning")
if prokletor_client == "AsyncClaudeClient":
if AsyncAnthropic is None:
raise ImportError("anthropic package is required for AsyncClaudeClient")
# AsyncAnthropic kwargs
anthropic_kwargs = {k: v for k, v in client_kwargs.items() if k in ["api_key", "base_url", "timeout", "max_retries", "default_headers"]}
anthropic_client = AsyncAnthropic(**anthropic_kwargs)
self.client = AsyncClaudeClient(anthropic_client)
print(f"🧠 Wrapped Anthropic client with AsyncClaudeClient")
elif prokletor_client == "HermesToolClient":
oai_client = AsyncOpenAI(**client_kwargs)
self.client = HermesToolClient(oai_client)
print(f"🧠 Wrapped OpenAI client with HermesToolClient")
elif prokletor_client == "HermesToolClientWithReasoning":
oai_client = AsyncOpenAI(**client_kwargs)
self.client = HermesToolClientWithReasoning(oai_client)
print(f"🧠 Wrapped OpenAI client with HermesToolClientWithReasoning")
elif prokletor_client:
# Fallback for unknown client names or if user provides a custom one (future proofing?)
# For now, raise error or default to OpenAI
print(f"⚠️ Unknown prokletor_client '{prokletor_client}'. Defaulting to HermesToolClientWithReasoning.")
oai_client = AsyncOpenAI(**client_kwargs)
self.client = HermesToolClientWithReasoning(oai_client)
else:
# Default behavior
oai_client = AsyncOpenAI(**client_kwargs)
self.client = oai_client
print(f"🧠 Using raw OpenAI client (no prokletor wrapper)")
print(f"🤖 AI Agent initialized with model: {self.model}")
if base_url:
print(f"🔗 Using custom base URL: {base_url}")
except Exception as e:
raise RuntimeError(f"Failed to initialize OpenAI client: {e}")
raise RuntimeError(f"Failed to initialize client: {e}")
# Get available tools with filtering
self.tools = get_tool_definitions(
@ -223,7 +264,7 @@ class AIAgent:
# Use the client wrapper's format method if available to get the exact Hermes format
# This ensures batch runner also gets the correct formatting
if hasattr(self, 'client') and hasattr(self.client, 'format'):
formatted_messages = self.client.format(messages, self.tools, use_tool_role=True)
formatted_messages = self.client.format(messages, self.tools, render_final=True)
trajectory = []
for msg in formatted_messages:
@ -462,13 +503,23 @@ class AIAgent:
api_messages = [{"role": "system", "content": active_system_prompt}] + api_messages
# Make API call with tools
api_kwargs = {
"model": self.model,
"messages": api_messages,
"tools": self.tools if self.tools else None,
"timeout": 300.0, # 5 minute timeout for long-running agent tasks
}
response = await self.client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=self.tools if self.tools else None,
timeout=300.0 # 5 minute timeout for long-running agent tasks
)
# Enable thinking by default for AsyncClaudeClient if using a supported model
if self.prokletor_client_name == "AsyncClaudeClient" and self.model.startswith("claude"):
api_kwargs["thinking"] = {
"type": "enabled",
"budget_tokens": 8000
}
# Ensure max_tokens is set higher than budget_tokens
api_kwargs["max_tokens"] = 16000
response = await self.client.chat.completions.create(**api_kwargs)
api_duration = time.time() - api_start_time
print(f"⏱️ OpenAI-compatible API call completed in {api_duration:.2f}s")
@ -572,7 +623,8 @@ class AIAgent:
"tool_calls": tool_calls_data
})
# Execute each tool call
# Execute tool calls concurrently
tool_tasks = []
for i, tool_call in enumerate(assistant_message.tool_calls, 1):
function_name = tool_call.function.name
@ -587,35 +639,54 @@ class AIAgent:
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
# Create coroutine for tool execution
task = handle_function_call(function_name, function_args, effective_task_id)
tool_tasks.append(task)
if tool_tasks:
tool_start_time = time.time()
# Execute the tool with task_id to isolate VMs between concurrent tasks
function_result = await handle_function_call(function_name, function_args, effective_task_id)
tool_duration = time.time() - tool_start_time
result_preview = function_result[:200] if len(function_result) > 200 else function_result
# Record tool timing in profiler
get_profiler().record_tool_timing(function_name, tool_duration)
if self.verbose_logging:
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
logging.debug(f"Tool result preview: {result_preview}...")
# Add tool result to conversation
# Note: thought_signature should NOT be in tool responses, only in assistant messages
messages.append({
"role": "tool",
"content": function_result,
"tool_call_id": tool_call.id
})
# Preview tool response
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
print(f" ✅ Tool {i} completed in {tool_duration:.2f}s - {response_preview}")
# Delay between tool calls
if self.tool_delay > 0 and i < len(assistant_message.tool_calls):
# Execute all tools concurrently
# We use return_exceptions=True to ensure one failure doesn't stop others
# Order of results corresponds to order of tasks
results = await asyncio.gather(*tool_tasks, return_exceptions=True)
tool_duration = time.time() - tool_start_time
# Process results
for i, (result, tool_call) in enumerate(zip(results, assistant_message.tool_calls), 1):
function_name = tool_call.function.name
# Handle exceptions from asyncio.gather
if isinstance(result, Exception):
function_result = json.dumps({"error": str(result)}, ensure_ascii=False)
print(f"❌ Tool {i} ({function_name}) failed: {result}")
else:
function_result = result
result_preview = function_result[:200] if len(function_result) > 200 else function_result
# Record tool timing in profiler (approximate since they ran in parallel)
get_profiler().record_tool_timing(function_name, tool_duration)
if self.verbose_logging:
logging.debug(f"Tool {function_name} completed in parallel batch")
logging.debug(f"Tool result preview: {result_preview}...")
# Add tool result to conversation
# Note: thought_signature should NOT be in tool responses, only in assistant messages
messages.append({
"role": "tool",
"content": function_result,
"tool_call_id": tool_call.id
})
# Preview tool response
response_preview = function_result[:self.log_prefix_chars] + "..." if len(function_result) > self.log_prefix_chars else function_result
print(f" ✅ Tool {i} completed - {response_preview}")
# Optional delay after batch execution
if self.tool_delay > 0:
await asyncio.sleep(self.tool_delay)
# Continue loop for next response
@ -668,7 +739,7 @@ class AIAgent:
# Use the client wrapper's format method if available to get the exact Hermes format
if hasattr(self, 'client') and hasattr(self.client, 'format'):
raise ValueError("reached this point")
formatted_messages = self.client.format(messages, self.tools, use_tool_role=True)
formatted_messages = self.client.format(messages, self.tools, render_final=True)
# We need to adapt this formatted list to the trajectory format expected by _save_trajectory
# Since _convert_to_trajectory_format expects raw OAI messages, we might need a different approach
@ -776,7 +847,9 @@ def main(
save_trajectories: bool = False,
verbose: bool = False,
log_prefix_chars: int = 20,
show_profiling: bool = True
show_profiling: bool = True,
prokletor_client: str = None,
prokletor_formatter: str = None,
):
"""
Main function for running the agent directly.
@ -796,6 +869,8 @@ def main(
verbose (bool): Enable verbose logging for debugging. Defaults to False.
log_prefix_chars (int): Number of characters to show in log previews for tool calls/responses. Defaults to 20.
show_profiling (bool): Display profiling statistics after conversation. Defaults to True.
prokletor_client (str): Name of the prokletor client to use (e.g., "AsyncClaudeClient")
prokletor_formatter (str): Name of the prokletor formatter to use
Toolset Examples:
- "research": Web search, extract, crawl + vision tools
@ -914,7 +989,9 @@ def main(
disabled_toolsets=disabled_toolsets_list,
save_trajectories=save_trajectories,
verbose_logging=verbose,
log_prefix_chars=log_prefix_chars
log_prefix_chars=log_prefix_chars,
prokletor_client=prokletor_client,
prokletor_formatter=prokletor_formatter
)
except RuntimeError as e:
print(f"❌ Failed to initialize agent: {e}")