mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Enhance batch processing and image generation tools
- Updated batch processing to include robust resume functionality by scanning completed prompts based on content rather than indices, improving recovery from failures. - Implemented retry logic for image downloads with exponential backoff to handle transient failures effectively. - Refined image generation tool to utilize the FLUX 2 Pro model, updating descriptions and parameters for clarity and consistency. - Added new configuration scripts for GLM 4.7 and Imagen tasks, enhancing usability and logging capabilities. - Removed outdated scripts and test files to streamline the codebase.
This commit is contained in:
parent
b32cc4b09d
commit
6eb76c7c1a
14 changed files with 293 additions and 233 deletions
192
batch_runner.py
192
batch_runner.py
|
|
@ -379,8 +379,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
|
|||
batch_tool_stats[tool_name]["success"] += stats["success"]
|
||||
batch_tool_stats[tool_name]["failure"] += stats["failure"]
|
||||
|
||||
completed_in_batch.append(prompt_index)
|
||||
print(f" ✅ Prompt {prompt_index} completed")
|
||||
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
|
||||
if result["success"] and result["trajectory"]:
|
||||
completed_in_batch.append(prompt_index)
|
||||
status = "⚠️ partial" if result.get("partial") else "✅"
|
||||
print(f" {status} Prompt {prompt_index} completed")
|
||||
else:
|
||||
print(f" ❌ Prompt {prompt_index} failed (will retry on resume)")
|
||||
|
||||
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
|
||||
|
||||
|
|
@ -578,6 +583,83 @@ class BatchRunner:
|
|||
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def _scan_completed_prompts_by_content(self) -> set:
|
||||
"""
|
||||
Scan all batch files and extract completed prompts by their actual content.
|
||||
|
||||
This provides a more robust resume mechanism that matches on prompt text
|
||||
rather than indices, allowing recovery even if indices don't match.
|
||||
|
||||
Returns:
|
||||
set: Set of prompt texts that have been successfully processed
|
||||
"""
|
||||
completed_prompts = set()
|
||||
batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
if not batch_files:
|
||||
return completed_prompts
|
||||
|
||||
print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...")
|
||||
|
||||
for batch_file in batch_files:
|
||||
try:
|
||||
with open(batch_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
try:
|
||||
entry = json.loads(line.strip())
|
||||
|
||||
# Skip failed entries - we want to retry these
|
||||
if entry.get("failed", False):
|
||||
continue
|
||||
|
||||
# Extract the human/user prompt from conversations
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
if msg.get("from") == "human":
|
||||
prompt_text = msg.get("value", "").strip()
|
||||
if prompt_text:
|
||||
completed_prompts.add(prompt_text)
|
||||
break # Only need the first human message
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}")
|
||||
|
||||
return completed_prompts
|
||||
|
||||
def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
|
||||
"""
|
||||
Filter the dataset to exclude prompts that have already been completed.
|
||||
|
||||
Args:
|
||||
completed_prompts: Set of prompt texts that have been completed
|
||||
|
||||
Returns:
|
||||
Tuple of (filtered_dataset, skipped_indices)
|
||||
"""
|
||||
filtered_dataset = []
|
||||
skipped_indices = []
|
||||
|
||||
for idx, entry in enumerate(self.dataset):
|
||||
# Extract prompt from the dataset entry
|
||||
prompt_text = entry.get("prompt", "").strip()
|
||||
|
||||
# Also check conversations format
|
||||
if not prompt_text:
|
||||
conversations = entry.get("conversations", [])
|
||||
for msg in conversations:
|
||||
role = msg.get("role") or msg.get("from")
|
||||
if role in ("user", "human"):
|
||||
prompt_text = (msg.get("content") or msg.get("value", "")).strip()
|
||||
break
|
||||
|
||||
if prompt_text in completed_prompts:
|
||||
skipped_indices.append(idx)
|
||||
else:
|
||||
# Keep original index for tracking
|
||||
filtered_dataset.append((idx, entry))
|
||||
|
||||
return filtered_dataset, skipped_indices
|
||||
|
||||
def run(self, resume: bool = False):
|
||||
"""
|
||||
|
|
@ -590,17 +672,48 @@ class BatchRunner:
|
|||
print("🚀 Starting Batch Processing")
|
||||
print("=" * 70)
|
||||
|
||||
# Load checkpoint
|
||||
checkpoint_data = self._load_checkpoint() if resume else {
|
||||
# Smart resume: scan batch files by content to find completed prompts
|
||||
completed_prompt_texts = set()
|
||||
if resume:
|
||||
completed_prompt_texts = self._scan_completed_prompts_by_content()
|
||||
if completed_prompt_texts:
|
||||
print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching")
|
||||
|
||||
# Filter dataset to only include unprocessed prompts
|
||||
if resume and completed_prompt_texts:
|
||||
filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
|
||||
|
||||
if not filtered_entries:
|
||||
print("\n✅ All prompts have already been processed!")
|
||||
return
|
||||
|
||||
# Recreate batches from filtered entries (keeping original indices for tracking)
|
||||
batches_to_process = []
|
||||
for i in range(0, len(filtered_entries), self.batch_size):
|
||||
batch = filtered_entries[i:i + self.batch_size]
|
||||
batches_to_process.append(batch)
|
||||
|
||||
self.batches = batches_to_process
|
||||
|
||||
# Print prominent resume summary
|
||||
print("\n" + "=" * 70)
|
||||
print("📊 RESUME SUMMARY")
|
||||
print("=" * 70)
|
||||
print(f" Original dataset size: {len(self.dataset):,} prompts")
|
||||
print(f" Already completed: {len(skipped_indices):,} prompts")
|
||||
print(f" ─────────────────────────────────────────")
|
||||
print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts")
|
||||
print(f" New batches created: {len(batches_to_process)}")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
# Initialize checkpoint data (needed for saving at the end)
|
||||
checkpoint_data = {
|
||||
"run_name": self.run_name,
|
||||
"completed_prompts": [],
|
||||
"batch_stats": {},
|
||||
"last_updated": None
|
||||
}
|
||||
|
||||
if resume and checkpoint_data.get("completed_prompts"):
|
||||
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
|
||||
|
||||
# Prepare configuration for workers
|
||||
config = {
|
||||
"distribution": self.distribution,
|
||||
|
|
@ -617,8 +730,8 @@ class BatchRunner:
|
|||
"provider_sort": self.provider_sort,
|
||||
}
|
||||
|
||||
# Get completed prompts set
|
||||
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
|
||||
# For backward compatibility, still track by index (but this is secondary to content matching)
|
||||
completed_prompts_set = set()
|
||||
|
||||
# Aggregate statistics across all batches
|
||||
total_tool_stats = {}
|
||||
|
|
@ -709,45 +822,51 @@ class BatchRunner:
|
|||
stats["success_rate"] = 0.0
|
||||
stats["failure_rate"] = 0.0
|
||||
|
||||
# Combine all batch files into a single trajectories.jsonl file
|
||||
# Combine ALL batch files in directory into a single trajectories.jsonl file
|
||||
# This includes both old batches (from previous runs) and new batches (from resume)
|
||||
# Also filter out corrupted entries (where model generated invalid tool names)
|
||||
combined_file = self.output_dir / "trajectories.jsonl"
|
||||
print(f"\n📦 Combining batch files into {combined_file.name}...")
|
||||
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
|
||||
|
||||
VALID_TOOLS = {'web_search', 'web_extract', 'web_crawl', 'terminal', 'vision_analyze',
|
||||
'image_generate', 'mixture_of_agents'}
|
||||
|
||||
total_entries = 0
|
||||
filtered_entries = 0
|
||||
batch_files_found = 0
|
||||
|
||||
# Find ALL batch files in the output directory (handles resume merging old + new)
|
||||
all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
|
||||
|
||||
with open(combined_file, 'w', encoding='utf-8') as outfile:
|
||||
for batch_num in range(len(self.batches)):
|
||||
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
|
||||
if batch_file.exists():
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
total_entries += 1
|
||||
try:
|
||||
data = json.loads(line)
|
||||
tool_stats = data.get('tool_stats', {})
|
||||
|
||||
# Check for invalid tool names (model hallucinations)
|
||||
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
|
||||
|
||||
if invalid_tools:
|
||||
filtered_entries += 1
|
||||
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
|
||||
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
|
||||
continue
|
||||
|
||||
outfile.write(line)
|
||||
except json.JSONDecodeError:
|
||||
for batch_file in all_batch_files:
|
||||
batch_files_found += 1
|
||||
batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging
|
||||
|
||||
with open(batch_file, 'r', encoding='utf-8') as infile:
|
||||
for line in infile:
|
||||
total_entries += 1
|
||||
try:
|
||||
data = json.loads(line)
|
||||
tool_stats = data.get('tool_stats', {})
|
||||
|
||||
# Check for invalid tool names (model hallucinations)
|
||||
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
|
||||
|
||||
if invalid_tools:
|
||||
filtered_entries += 1
|
||||
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
|
||||
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
|
||||
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
|
||||
continue
|
||||
|
||||
outfile.write(line)
|
||||
except json.JSONDecodeError:
|
||||
filtered_entries += 1
|
||||
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
|
||||
|
||||
if filtered_entries > 0:
|
||||
print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total")
|
||||
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
|
||||
print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
|
||||
|
||||
# Save final statistics
|
||||
final_stats = {
|
||||
|
|
@ -769,8 +888,9 @@ class BatchRunner:
|
|||
print("\n" + "=" * 70)
|
||||
print("📊 BATCH PROCESSING COMPLETE")
|
||||
print("=" * 70)
|
||||
print(f"✅ Total prompts processed: {len(self.dataset)}")
|
||||
print(f"✅ Total batches: {len(self.batches)}")
|
||||
print(f"✅ Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
|
||||
print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}")
|
||||
print(f"✅ Total batch files merged: {batch_files_found}")
|
||||
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
|
||||
print(f"\n📈 Tool Usage Statistics:")
|
||||
print("-" * 70)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue