Enhance batch processing and image generation tools

- Updated batch processing to include robust resume functionality by scanning completed prompts based on content rather than indices, improving recovery from failures.
- Implemented retry logic for image downloads with exponential backoff to handle transient failures effectively.
- Refined image generation tool to utilize the FLUX 2 Pro model, updating descriptions and parameters for clarity and consistency.
- Added new configuration scripts for GLM 4.7 and Imagen tasks, enhancing usability and logging capabilities.
- Removed outdated scripts and test files to streamline the codebase.
This commit is contained in:
teknium 2026-01-18 10:11:59 +00:00
parent b32cc4b09d
commit 6eb76c7c1a
14 changed files with 293 additions and 233 deletions

View file

@ -379,8 +379,13 @@ def _process_batch_worker(args: Tuple) -> Dict[str, Any]:
batch_tool_stats[tool_name]["success"] += stats["success"]
batch_tool_stats[tool_name]["failure"] += stats["failure"]
completed_in_batch.append(prompt_index)
print(f" ✅ Prompt {prompt_index} completed")
# Only mark as completed if successfully saved (failed prompts can be retried on resume)
if result["success"] and result["trajectory"]:
completed_in_batch.append(prompt_index)
status = "⚠️ partial" if result.get("partial") else ""
print(f" {status} Prompt {prompt_index} completed")
else:
print(f" ❌ Prompt {prompt_index} failed (will retry on resume)")
print(f"✅ Batch {batch_num}: Completed ({len(prompts_to_process)} prompts processed)")
@ -578,6 +583,83 @@ class BatchRunner:
with open(self.checkpoint_file, 'w', encoding='utf-8') as f:
json.dump(checkpoint_data, f, indent=2, ensure_ascii=False)
def _scan_completed_prompts_by_content(self) -> set:
"""
Scan all batch files and extract completed prompts by their actual content.
This provides a more robust resume mechanism that matches on prompt text
rather than indices, allowing recovery even if indices don't match.
Returns:
set: Set of prompt texts that have been successfully processed
"""
completed_prompts = set()
batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
if not batch_files:
return completed_prompts
print(f"📂 Scanning {len(batch_files)} batch files for completed prompts...")
for batch_file in batch_files:
try:
with open(batch_file, 'r', encoding='utf-8') as f:
for line in f:
try:
entry = json.loads(line.strip())
# Skip failed entries - we want to retry these
if entry.get("failed", False):
continue
# Extract the human/user prompt from conversations
conversations = entry.get("conversations", [])
for msg in conversations:
if msg.get("from") == "human":
prompt_text = msg.get("value", "").strip()
if prompt_text:
completed_prompts.add(prompt_text)
break # Only need the first human message
except json.JSONDecodeError:
continue
except Exception as e:
print(f" ⚠️ Warning: Error reading {batch_file.name}: {e}")
return completed_prompts
def _filter_dataset_by_completed(self, completed_prompts: set) -> Tuple[List[Dict], List[int]]:
"""
Filter the dataset to exclude prompts that have already been completed.
Args:
completed_prompts: Set of prompt texts that have been completed
Returns:
Tuple of (filtered_dataset, skipped_indices)
"""
filtered_dataset = []
skipped_indices = []
for idx, entry in enumerate(self.dataset):
# Extract prompt from the dataset entry
prompt_text = entry.get("prompt", "").strip()
# Also check conversations format
if not prompt_text:
conversations = entry.get("conversations", [])
for msg in conversations:
role = msg.get("role") or msg.get("from")
if role in ("user", "human"):
prompt_text = (msg.get("content") or msg.get("value", "")).strip()
break
if prompt_text in completed_prompts:
skipped_indices.append(idx)
else:
# Keep original index for tracking
filtered_dataset.append((idx, entry))
return filtered_dataset, skipped_indices
def run(self, resume: bool = False):
"""
@ -590,17 +672,48 @@ class BatchRunner:
print("🚀 Starting Batch Processing")
print("=" * 70)
# Load checkpoint
checkpoint_data = self._load_checkpoint() if resume else {
# Smart resume: scan batch files by content to find completed prompts
completed_prompt_texts = set()
if resume:
completed_prompt_texts = self._scan_completed_prompts_by_content()
if completed_prompt_texts:
print(f" Found {len(completed_prompt_texts)} already-completed prompts by content matching")
# Filter dataset to only include unprocessed prompts
if resume and completed_prompt_texts:
filtered_entries, skipped_indices = self._filter_dataset_by_completed(completed_prompt_texts)
if not filtered_entries:
print("\n✅ All prompts have already been processed!")
return
# Recreate batches from filtered entries (keeping original indices for tracking)
batches_to_process = []
for i in range(0, len(filtered_entries), self.batch_size):
batch = filtered_entries[i:i + self.batch_size]
batches_to_process.append(batch)
self.batches = batches_to_process
# Print prominent resume summary
print("\n" + "=" * 70)
print("📊 RESUME SUMMARY")
print("=" * 70)
print(f" Original dataset size: {len(self.dataset):,} prompts")
print(f" Already completed: {len(skipped_indices):,} prompts")
print(f" ─────────────────────────────────────────")
print(f" 🎯 RESUMING WITH: {len(filtered_entries):,} prompts")
print(f" New batches created: {len(batches_to_process)}")
print("=" * 70 + "\n")
# Initialize checkpoint data (needed for saving at the end)
checkpoint_data = {
"run_name": self.run_name,
"completed_prompts": [],
"batch_stats": {},
"last_updated": None
}
if resume and checkpoint_data.get("completed_prompts"):
print(f"📂 Resuming from checkpoint ({len(checkpoint_data['completed_prompts'])} prompts already completed)")
# Prepare configuration for workers
config = {
"distribution": self.distribution,
@ -617,8 +730,8 @@ class BatchRunner:
"provider_sort": self.provider_sort,
}
# Get completed prompts set
completed_prompts_set = set(checkpoint_data.get("completed_prompts", []))
# For backward compatibility, still track by index (but this is secondary to content matching)
completed_prompts_set = set()
# Aggregate statistics across all batches
total_tool_stats = {}
@ -709,45 +822,51 @@ class BatchRunner:
stats["success_rate"] = 0.0
stats["failure_rate"] = 0.0
# Combine all batch files into a single trajectories.jsonl file
# Combine ALL batch files in directory into a single trajectories.jsonl file
# This includes both old batches (from previous runs) and new batches (from resume)
# Also filter out corrupted entries (where model generated invalid tool names)
combined_file = self.output_dir / "trajectories.jsonl"
print(f"\n📦 Combining batch files into {combined_file.name}...")
print(f"\n📦 Combining ALL batch files into {combined_file.name}...")
VALID_TOOLS = {'web_search', 'web_extract', 'web_crawl', 'terminal', 'vision_analyze',
'image_generate', 'mixture_of_agents'}
total_entries = 0
filtered_entries = 0
batch_files_found = 0
# Find ALL batch files in the output directory (handles resume merging old + new)
all_batch_files = sorted(self.output_dir.glob("batch_*.jsonl"))
with open(combined_file, 'w', encoding='utf-8') as outfile:
for batch_num in range(len(self.batches)):
batch_file = self.output_dir / f"batch_{batch_num}.jsonl"
if batch_file.exists():
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
total_entries += 1
try:
data = json.loads(line)
tool_stats = data.get('tool_stats', {})
# Check for invalid tool names (model hallucinations)
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
if invalid_tools:
filtered_entries += 1
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
continue
outfile.write(line)
except json.JSONDecodeError:
for batch_file in all_batch_files:
batch_files_found += 1
batch_num = batch_file.stem.split("_")[1] # Extract batch number for logging
with open(batch_file, 'r', encoding='utf-8') as infile:
for line in infile:
total_entries += 1
try:
data = json.loads(line)
tool_stats = data.get('tool_stats', {})
# Check for invalid tool names (model hallucinations)
invalid_tools = [k for k in tool_stats.keys() if k not in VALID_TOOLS]
if invalid_tools:
filtered_entries += 1
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
invalid_preview = invalid_tools[0][:50] + "..." if len(invalid_tools[0]) > 50 else invalid_tools[0]
print(f" ⚠️ Filtering corrupted entry (batch {batch_num}): invalid tool '{invalid_preview}'")
continue
outfile.write(line)
except json.JSONDecodeError:
filtered_entries += 1
print(f" ⚠️ Filtering invalid JSON entry (batch {batch_num})")
if filtered_entries > 0:
print(f"⚠️ Filtered {filtered_entries} corrupted entries out of {total_entries} total")
print(f"✅ Combined {len(self.batches)} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
print(f"✅ Combined {batch_files_found} batch files into trajectories.jsonl ({total_entries - filtered_entries} entries)")
# Save final statistics
final_stats = {
@ -769,8 +888,9 @@ class BatchRunner:
print("\n" + "=" * 70)
print("📊 BATCH PROCESSING COMPLETE")
print("=" * 70)
print(f"✅ Total prompts processed: {len(self.dataset)}")
print(f"✅ Total batches: {len(self.batches)}")
print(f"✅ Prompts processed this run: {sum(r.get('processed', 0) for r in results)}")
print(f"✅ Total trajectories in merged file: {total_entries - filtered_entries}")
print(f"✅ Total batch files merged: {batch_files_found}")
print(f"⏱️ Total duration: {round(time.time() - start_time, 2)}s")
print(f"\n📈 Tool Usage Statistics:")
print("-" * 70)