mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Refactor BatchRunner and AIAgent for enhanced reasoning and tool management, improved tool definitions for fileops
- Updated `ALL_POSSIBLE_TOOLS` to auto-derive from `TOOL_TO_TOOLSET_MAP` for consistent schema. - Introduced `_extract_reasoning_stats` function to track reasoning coverage in assistant turns. - Enhanced `_process_batch_worker` to discard prompts with no reasoning and aggregate reasoning statistics. - Updated documentation and comments for clarity on new features and changes.
This commit is contained in:
parent
f12ea1bc02
commit
dd70d57b9b
4 changed files with 277 additions and 90 deletions
125
batch_runner.py
125
batch_runner.py
|
|
@ -41,24 +41,17 @@ from toolset_distributions import (
|
|||
sample_toolsets_from_distribution,
|
||||
validate_distribution
|
||||
)
|
||||
from model_tools import TOOL_TO_TOOLSET_MAP
|
||||
|
||||
|
||||
# Global configuration for worker processes
|
||||
_WORKER_CONFIG = {}
|
||||
|
||||
# All possible tools - used to ensure consistent schema across all trajectory entries
|
||||
# This is required because Arrow/Parquet (used by HuggingFace datasets) needs identical schemas
|
||||
ALL_POSSIBLE_TOOLS = {
|
||||
'terminal', 'web_search', 'web_extract',
|
||||
'vision_analyze', 'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'
|
||||
}
|
||||
# All possible tools - auto-derived from the master mapping in model_tools.py.
|
||||
# This stays in sync automatically when new tools are added to TOOL_TO_TOOLSET_MAP.
|
||||
# Used for consistent schema in Arrow/Parquet (HuggingFace datasets) and for
|
||||
# filtering corrupted entries during trajectory combination.
|
||||
ALL_POSSIBLE_TOOLS = set(TOOL_TO_TOOLSET_MAP.keys())
|
||||
|
||||
# Default stats for tools that weren't used
|
||||
DEFAULT_TOOL_STATS = {'count': 0, 'success': 0, 'failure': 0}
|
||||
|
|
@ -200,6 +193,42 @@ def _extract_tool_stats(messages: List[Dict[str, Any]]) -> Dict[str, Dict[str, i
|
|||
return tool_stats
|
||||
|
||||
|
||||
def _extract_reasoning_stats(messages: List[Dict[str, Any]]) -> Dict[str, int]:
|
||||
"""
|
||||
Count how many assistant turns have reasoning vs no reasoning.
|
||||
|
||||
Checks for <REASONING_SCRATCHPAD> in content or a non-empty 'reasoning' field
|
||||
(native thinking tokens). Returns counts for tracking reasoning coverage.
|
||||
|
||||
Args:
|
||||
messages: Message history
|
||||
|
||||
Returns:
|
||||
Dict with 'total_assistant_turns', 'turns_with_reasoning', 'turns_without_reasoning'
|
||||
"""
|
||||
total = 0
|
||||
with_reasoning = 0
|
||||
|
||||
for msg in messages:
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
total += 1
|
||||
|
||||
content = msg.get("content", "") or ""
|
||||
has_scratchpad = "<REASONING_SCRATCHPAD>" in content
|
||||
has_native_reasoning = bool(msg.get("reasoning", "").strip()) if msg.get("reasoning") else False
|
||||
|
||||
if has_scratchpad or has_native_reasoning:
|
||||
with_reasoning += 1
|
||||
|
||||
return {
|
||||
"total_assistant_turns": total,
|
||||
"turns_with_reasoning": with_reasoning,
|
||||
"turns_without_reasoning": total - with_reasoning,
|
||||
"has_any_reasoning": with_reasoning > 0,
|
||||
}
|
||||
|
||||
|
||||
def _process_single_prompt(
|
||||
prompt_index: int,
|
||||
prompt_data: Dict[str, Any],
|
||||
|
|
@ -255,6 +284,9 @@ def _process_single_prompt(
|
|||
# Extract tool usage statistics
|
||||
tool_stats = _extract_tool_stats(result["messages"])
|
||||
|
||||
# Extract reasoning coverage stats
|
||||
reasoning_stats = _extract_reasoning_stats(result["messages"])
|
||||
|
||||
# Convert to trajectory format (using existing method)
|
||||
trajectory = agent._convert_to_trajectory_format(
|
||||
result["messages"],
|
||||
|
|
@ -267,6 +299,7 @@ def _process_single_prompt(
|
|||
"prompt_index": prompt_index,
|
||||
"trajectory": trajectory,
|
||||
"tool_stats": tool_stats,
|
||||
"reasoning_stats": reasoning_stats,
|
||||
"completed": result["completed"],
|
||||
"partial": result.get("partial", False),
|
||||
"api_calls": result["api_calls"],
|
||||
|
|
@ -335,7 +368,9 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||
|
||||
# Initialize aggregated stats for this batch
|
||||
batch_tool_stats = {}
|
||||
batch_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
|
||||
completed_in_batch = []
|
||||
discarded_no_reasoning = 0
|
||||
|
||||
# Process each prompt sequentially in this batch
|
||||
for prompt_index, prompt_data in prompts_to_process:
|
||||
|
|
@ -349,6 +384,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||
|
||||
# Save trajectory if successful
|
||||
if result["success"] and result["trajectory"]:
|
||||
# Discard samples with zero reasoning across all turns
|
||||
reasoning = result.get("reasoning_stats", {})
|
||||
if not reasoning.get("has_any_reasoning", True):
|
||||
print(f" 🚫 Prompt {prompt_index} discarded (no reasoning in any turn)")
|
||||
discarded_no_reasoning += 1
|
||||
continue
|
||||
|
||||
# Get and normalize tool stats for consistent schema across all entries
|
||||
raw_tool_stats = result.get("tool_stats", {})
|
||||
tool_stats = _normalize_tool_stats(raw_tool_stats)
|
||||
|
|
@ -389,6 +431,10 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Aggregate reasoning stats
|
||||
for key in batch_reasoning_stats:
|
||||
batch_reasoning_stats[key] += result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
|
||||
if result["success"] and result["trajectory"]:
|
||||
completed_in_batch.append(prompt_index)
|
||||
|
|
@ -404,6 +450,8 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||
"processed": len(prompts_to_process),
|
||||
"skipped": len(batch_data) - len(prompts_to_process),
|
||||
"tool_stats": batch_tool_stats,
|
||||
"reasoning_stats": batch_reasoning_stats,
|
||||
"discarded_no_reasoning": discarded_no_reasoning,
|
||||
"completed_prompts": completed_in_batch
|
||||
}
|
||||
|
||||
|
|
@ -434,6 +482,7 @@ class BatchRunner:
|
|||
max_tokens: int = None,
|
||||
reasoning_config: Dict[str, Any] = None,
|
||||
prefill_messages: List[Dict[str, Any]] = None,
|
||||
max_samples: int = None,
|
||||
):
|
||||
"""
|
||||
Initialize the batch runner.
|
||||
|
|
@ -458,6 +507,7 @@ class BatchRunner:
|
|||
max_tokens (int): Maximum tokens for model responses (optional, uses model default if not set)
|
||||
reasoning_config (Dict): OpenRouter reasoning config override (e.g. {"effort": "none"} to disable thinking)
|
||||
prefill_messages (List[Dict]): Messages to prepend as prefilled conversation context (few-shot priming)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
"""
|
||||
self.dataset_file = Path(dataset_file)
|
||||
self.batch_size = batch_size
|
||||
|
|
@ -478,6 +528,7 @@ class BatchRunner:
|
|||
self.max_tokens = max_tokens
|
||||
self.reasoning_config = reasoning_config
|
||||
self.prefill_messages = prefill_messages
|
||||
self.max_samples = max_samples
|
||||
|
||||
# Validate distribution
|
||||
if not validate_distribution(distribution):
|
||||
|
|
@ -493,8 +544,12 @@ class BatchRunner:
|
|||
# Statistics file
|
||||
self.stats_file = self.output_dir / "statistics.json"
|
||||
|
||||
# Load dataset
|
||||
# Load dataset (and optionally truncate to max_samples)
|
||||
self.dataset = self._load_dataset()
|
||||
if self.max_samples and self.max_samples < len(self.dataset):
|
||||
full_count = len(self.dataset)
|
||||
self.dataset = self.dataset[:self.max_samples]
|
||||
print(f"✂️ Truncated dataset from {full_count} to {self.max_samples} samples (--max_samples)")
|
||||
|
||||
# Create batches
|
||||
self.batches = self._create_batches()
|
||||
|
|
@ -812,6 +867,8 @@ class BatchRunner:
|
|||
|
||||
# Aggregate all batch statistics and update checkpoint
|
||||
all_completed_prompts = list(completed_prompts_set)
|
||||
total_reasoning_stats = {"total_assistant_turns": 0, "turns_with_reasoning": 0, "turns_without_reasoning": 0}
|
||||
|
||||
for batch_result in results:
|
||||
# Add newly completed prompts
|
||||
all_completed_prompts.extend(batch_result.get("completed_prompts", []))
|
||||
|
|
@ -828,6 +885,10 @@ class BatchRunner:
|
|||
total_tool_stats[tool_name]["count"] += stats["count"]
|
||||
total_tool_stats[tool_name]["success"] += stats["success"]
|
||||
total_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
# Aggregate reasoning stats
|
||||
for key in total_reasoning_stats:
|
||||
total_reasoning_stats[key] += batch_result.get("reasoning_stats", {}).get(key, 0)
|
||||
|
||||
# Save final checkpoint
|
||||
checkpoint_data["completed_prompts"] = all_completed_prompts
|
||||
|
|
@ -850,15 +911,8 @@ class BatchRunner:
|
|||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
|
||||
|
||||
VALID_TOOLS = {'web_search', 'web_extract', 'terminal', 'vision_analyze',
|
||||
'image_generate', 'mixture_of_agents',
|
||||
# Skills tools
|
||||
'skills_categories', 'skills_list', 'skill_view',
|
||||
# Browser automation tools
|
||||
'browser_navigate', 'browser_snapshot', 'browser_click',
|
||||
'browser_type', 'browser_scroll', 'browser_back',
|
||||
'browser_press', 'browser_close', 'browser_get_images',
|
||||
'browser_vision'}
|
||||
# Valid tools auto-derived from model_tools.py — no manual updates needed
|
||||
VALID_TOOLS = ALL_POSSIBLE_TOOLS
|
||||
|
||||
total_entries = 0
|
||||
filtered_entries = 0
|
||||
|
|
@ -907,7 +961,8 @@ class BatchRunner:
|
|||
"model": self.model,
|
||||
"completed_at": datetime.now().isoformat(),
|
||||
"duration_seconds": round(time.time() - start_time, 2),
|
||||
"tool_statistics": total_tool_stats
|
||||
"tool_statistics": total_tool_stats,
|
||||
"reasoning_statistics": total_reasoning_stats,
|
||||
}
|
||||
|
||||
with open(self.stats_file, 'w', encoding='utf-8') as f:
|
||||
|
|
@ -945,6 +1000,25 @@ class BatchRunner:
|
|||
else:
|
||||
print("No tool calls were made during this run.")
|
||||
|
||||
# Print reasoning coverage stats
|
||||
total_discarded = sum(r.get("discarded_no_reasoning", 0) for r in results)
|
||||
|
||||
print(f"\n🧠 Reasoning Coverage:")
|
||||
print("-" * 70)
|
||||
total_turns = total_reasoning_stats["total_assistant_turns"]
|
||||
with_reasoning = total_reasoning_stats["turns_with_reasoning"]
|
||||
without_reasoning = total_reasoning_stats["turns_without_reasoning"]
|
||||
if total_turns > 0:
|
||||
pct_with = round(with_reasoning / total_turns * 100, 1)
|
||||
pct_without = round(without_reasoning / total_turns * 100, 1)
|
||||
print(f" Total assistant turns: {total_turns:,}")
|
||||
print(f" With reasoning: {with_reasoning:,} ({pct_with}%)")
|
||||
print(f" Without reasoning: {without_reasoning:,} ({pct_without}%)")
|
||||
else:
|
||||
print(" No assistant turns recorded.")
|
||||
if total_discarded > 0:
|
||||
print(f" 🚫 Samples discarded (zero reasoning): {total_discarded:,}")
|
||||
|
||||
print(f"\n💾 Results saved to: {self.output_dir}")
|
||||
print(f" - Trajectories: trajectories.jsonl (combined)")
|
||||
print(f" - Individual batches: batch_*.jsonl (for debugging)")
|
||||
|
|
@ -975,6 +1049,7 @@ def main(
|
|||
reasoning_effort: str = None,
|
||||
reasoning_disabled: bool = False,
|
||||
prefill_messages_file: str = None,
|
||||
max_samples: int = None,
|
||||
):
|
||||
"""
|
||||
Run batch processing of agent prompts from a dataset.
|
||||
|
|
@ -1002,6 +1077,7 @@ def main(
|
|||
reasoning_effort (str): OpenRouter reasoning effort level: "xhigh", "high", "medium", "low", "minimal", "none" (default: "xhigh")
|
||||
reasoning_disabled (bool): Completely disable reasoning/thinking tokens (default: False)
|
||||
prefill_messages_file (str): Path to JSON file containing prefill messages (list of {role, content} dicts)
|
||||
max_samples (int): Only process the first N samples from the dataset (optional, processes all if not set)
|
||||
|
||||
Examples:
|
||||
# Basic usage
|
||||
|
|
@ -1110,6 +1186,7 @@ def main(
|
|||
max_tokens=max_tokens,
|
||||
reasoning_config=reasoning_config,
|
||||
prefill_messages=prefill_messages,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
|
||||
runner.run(resume=resume)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue