Refactor BatchRunner and AIAgent for enhanced reasoning and tool management, improved tool definitions for fileops

- Updated `ALL_POSSIBLE_TOOLS` to auto-derive from `TOOL_TO_TOOLSET_MAP` for consistent schema.
- Introduced `_extract_reasoning_stats` function to track reasoning coverage in assistant turns.
- Enhanced `_process_batch_worker` to discard prompts with no reasoning and aggregate reasoning statistics.
- Updated documentation and comments for clarity on new features and changes.
This commit is contained in:
teknium 2026-02-08 20:19:14 +00:00
parent f12ea1bc02
commit dd70d57b9b
4 changed files with 277 additions and 90 deletions

View file

@ -41,24 +41,17 @@ from toolset_distributions import (
sample_toolsets_from_distribution,
validate_distribution
)
from model_tools import TOOL_TO_TOOLSET_MAP
# Global configuration for worker processes
_WORKER_CONFIG = {}
# All possible tools - used to ensure consistent schema across all trajectory entries
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
ALL_POSSIBLE_TOOLS = {
'terminal', 'web_search', 'web_extract',
'vision_analyze', 'image_generate', 'mixture_of_agents',
# Skills tools
'skills_categories', 'skills_list', 'skill_view',
# Browser automation tools
'browser_navigate', 'browser_snapshot', 'browser_click',
'browser_type', 'browser_scroll', 'browser_back',
'browser_press', 'browser_close', 'browser_get_images',
'browser_vision'
}
# All possible tools - auto-derived from the master mapping in model_tools.py.
# This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP.
# Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for
# filtering corrupted entries during trajectory combination.
ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys())
# Default stats for tools that weren't used
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
@ -200,6 +193,42 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
return tool_stats
def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]:
"""
Count how many assistant turns have reasoning vs no reasoning.
Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field
(native thinking tokens). Returns counts for tracking reasoning coverage.
Args:
messages: Message history
Returns:
Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning'
"""
total = 0
with_reasoning = 0
for msg in messages:
if msg.get("role") != "assistant":
continue
total += 1
content = msg.get("content", "") or ""
has_scratchpad = "<REASONING_SCRATCHPAD>" in content
has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False
if has_scratchpad or has_native_reasoning:
with_reasoning += 1
return {
"total_assistant_turns": total,
"turns_with_reasoning": with_reasoning,
"turns_without_reasoning": total - with_reasoning,
"has_any_reasoning": with_reasoning > 0,
}
def _process_single_prompt(
prompt_index: int,
prompt_data: Dict[str, Any],
@ -255,6 +284,9 @@ def _process_single_prompt(
# Extract tool usage statistics
tool_stats = _extract_tool_stats(result["messages"])
# Extract reasoning coverage stats
reasoning_stats = _extract_reasoning_stats(result["messages"])
# Convert to trajectory format (using existing method)
trajectory = agent._convert_to_trajectory_format(
result["messages"],
@ -267,6 +299,7 @@ def _process_single_prompt(
"prompt_index": prompt_index,
"trajectory": trajectory,
"tool_stats": tool_stats,
"reasoning_stats": reasoning_stats,
"completed": result["completed"],
"partial": result.get("partial", False),
"api_calls": result["api_calls"],
@ -335,7 +368,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
# Initialize aggregated stats for this batch
batch_tool_stats = {}
batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
completed_in_batch = []
discarded_no_reasoning = 0
# Process each prompt sequentially in this batch
for prompt_index, prompt_data in prompts_to_process:
@ -349,6 +384,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
# Save trajectory if successful
if result["success"] and result["trajectory"]:
# Discard samples with zero reasoning across all turns
reasoning = result.get("reasoning_stats", {})
if not reasoning.get("has_any_reasoning", True):
print(f" 🚫 Prompt {prompt_index} discarded (no reasoning in any turn)")
discarded_no_reasoning += 1
continue
# Get and normalize tool stats for consistent schema across all entries
raw_tool_stats = result.get("tool_stats", {})
tool_stats = _normalize_tool_stats(raw_tool_stats)
@ -389,6 +431,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
batch_tool_stats[tool_name]["success"] += stats["success"]
batch_tool_stats[tool_name]["failure"] += stats["failure"]
# Aggregate reasoning stats
for key in batch_reasoning_stats:
batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0)
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
if result["success"] and result["trajectory"]:
completed_in_batch.append(prompt_index)
@ -404,6 +450,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
"processed": len(prompts_to_process),
"skipped": len(batch_data) - len(prompts_to_process),
"tool_stats": batch_tool_stats,
"reasoning_stats": batch_reasoning_stats,
"discarded_no_reasoning": discarded_no_reasoning,
"completed_prompts": completed_in_batch
}
@ -434,6 +482,7 @@ class BatchRunner:
max_tokens: int = None,
reasoning_config: Dict[str, Any] = None,
prefill_messages: List[Dict[str, Any]] = None,
max_samples: int = None,
):
"""
Initialize the batch runner.
@ -458,6 +507,7 @@ class BatchRunner:
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking)
prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming)
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
"""
self.dataset_file = Path(dataset_file)
self.batch_size = batch_size
@ -478,6 +528,7 @@ class BatchRunner:
self.max_tokens = max_tokens
self.reasoning_config = reasoning_config
self.prefill_messages = prefill_messages
self.max_samples = max_samples
# Validate distribution
if not validate_distribution(distribution):
@ -493,8 +544,12 @@ class BatchRunner:
# Statistics file
self.stats_file = self.output_dir / "statistics.json"
# Load dataset
# Load dataset (and optionally truncate to max_samples)
self.dataset = self._load_dataset()
if self.max_samples and self.max_samples < len(self.dataset):
full_count = len(self.dataset)
self.dataset = self.dataset[:self.max_samples]
print(f"✂️ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)")
# Create batches
self.batches = self._create_batches()
@ -812,6 +867,8 @@ class BatchRunner:
# Aggregate all batch statistics and update checkpoint
all_completed_prompts = list(completed_prompts_set)
total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
for batch_result in results:
# Add newly completed prompts
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
@ -828,6 +885,10 @@ class BatchRunner:
total_tool_stats[tool_name]["count"] += stats["count"]
total_tool_stats[tool_name]["success"] += stats["success"]
total_tool_stats[tool_name]["failure"] += stats["failure"]
# Aggregate reasoning stats
for key in total_reasoning_stats:
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
# Save final checkpoint
checkpoint_data["completed_prompts"] = all_completed_prompts
@ -850,15 +911,8 @@ class BatchRunner:
combined_file = self.output_dir / "trajectories.jsonl"
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
VALID_TOOLS = {'web_search', 'web_extract', 'terminal', 'vision_analyze',
'image_generate', 'mixture_of_agents',
# Skills tools
'skills_categories', 'skills_list', 'skill_view',
# Browser automation tools
'browser_navigate', 'browser_snapshot', 'browser_click',
'browser_type', 'browser_scroll', 'browser_back',
'browser_press', 'browser_close', 'browser_get_images',
'browser_vision'}
# Valid tools auto-derived from model_tools.py — no manual updates needed
VALID_TOOLS = ALL_POSSIBLE_TOOLS
total_entries = 0
filtered_entries = 0
@ -907,7 +961,8 @@ class BatchRunner:
"model": self.model,
"completed_at": datetime.now().isoformat(),
"duration_seconds": round(time.time() - start_time, 2),
"tool_statistics": total_tool_stats
"tool_statistics": total_tool_stats,
"reasoning_statistics": total_reasoning_stats,
}
with open(self.stats_file, 'w', encoding='utf-8') as f:
@ -945,6 +1000,25 @@ class BatchRunner:
else:
print("No tool calls were made during this run.")
# Print reasoning coverage stats
total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results)
print(f"\n🧠 Reasoning Coverage:")
print("-" * 70)
total_turns = total_reasoning_stats["total_assistant_turns"]
with_reasoning = total_reasoning_stats["turns_with_reasoning"]
without_reasoning = total_reasoning_stats["turns_without_reasoning"]
if total_turns > 0:
pct_with = round(with_reasoning / total_turns * 100, 1)
pct_without = round(without_reasoning / total_turns * 100, 1)
print(f" Total assistant turns: {total_turns:,}")
print(f" With reasoning: {with_reasoning:,} ({pct_with}%)")
print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)")
else:
print(" No assistant turns recorded.")
if total_discarded > 0:
print(f" 🚫 Samples discarded (zero reasoning): {total_discarded:,}")
print(f"\n💾 Results saved to: {self.output_dir}")
print(f" - Trajectories: trajectories.jsonl (combined)")
print(f" - Individual batches: batch_*.jsonl (for debugging)")
@ -975,6 +1049,7 @@ def main(
reasoning_effort: str = None,
reasoning_disabled: bool = False,
prefill_messages_file: str = None,
max_samples: int = None,
):
"""
Run batch processing of agent prompts from a dataset.
@ -1002,6 +1077,7 @@ def main(
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
Examples:
# Basic usage
@ -1110,6 +1186,7 @@ def main(
max_tokens=max_tokens,
reasoning_config=reasoning_config,
prefill_messages=prefill_messages,
max_samples=max_samples,
)
runner.run(resume=resume)