Refactor batch processing with rich progress tracking and update logging in AIAgent

- Replaced tqdm with rich for enhanced visual progress tracking in batch processing.
- Adjusted logging levels in AIAgent to suppress asyncio debug messages.
- Modified datagen script to reduce number of workers for improved performance.
This commit is contained in:
teknium 2026-01-14 14:02:59 +00:00
parent 6e3dbb8d8b
commit b32cc4b09d
4 changed files with 37 additions and 12 deletions

View file

@ -30,7 +30,8 @@ from datetime import datetime
from multiprocessing import Pool, Manager, Lock
import traceback
from tqdm import tqdm
from rich.progress import Progress, SpinnerColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn
from rich.console import Console
import fire
from run_agent import AIAgent
@ -643,14 +644,36 @@ class BatchRunner:
print(f"✅ Created {len(tasks)} batch tasks")
print(f"🚀 Starting parallel batch processing...\n")
# Use imap_unordered with tqdm for progress tracking
results = list(tqdm(
pool.imap_unordered(_process_batch_worker, tasks),
total=len(tasks),
desc="📦 Batches",
unit="batch",
ncols=80
))
# Use rich Progress for better visual tracking with persistent bottom bar
# redirect_stdout/stderr lets rich manage all output so progress bar stays clean
results = []
console = Console(force_terminal=True)
with Progress(
SpinnerColumn(),
TextColumn("[bold blue]📦 Batches"),
BarColumn(bar_width=40),
MofNCompleteColumn(),
TextColumn(""),
TimeRemainingColumn(),
console=console,
refresh_per_second=2,
transient=False,
redirect_stdout=False,
redirect_stderr=False,
) as progress:
task = progress.add_task("Processing", total=len(tasks))
# Temporarily suppress DEBUG logging to avoid bar interference
root_logger = logging.getLogger()
original_level = root_logger.level
root_logger.setLevel(logging.WARNING)
try:
for result in pool.imap_unordered(_process_batch_worker, tasks):
results.append(result)
progress.update(task, advance=1)
finally:
root_logger.setLevel(original_level)
# Aggregate all batch statistics and update checkpoint
all_completed_prompts = list(completed_prompts_set)

View file

@ -7,3 +7,4 @@ tenacity
python-dotenv
fire
httpx
rich

View file

@ -127,7 +127,8 @@ class AIAgent:
logging.getLogger('openai._base_client').setLevel(logging.WARNING)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('httpcore').setLevel(logging.WARNING)
print("🔍 Verbose logging enabled (OpenAI/httpx internal logs suppressed)")
logging.getLogger('asyncio').setLevel(logging.WARNING) # Suppress asyncio debug
print("🔍 Verbose logging enabled (OpenAI/httpx/asyncio internal logs suppressed)")
else:
# Set logging to INFO level for important messages only
logging.basicConfig(

View file

@ -17,10 +17,10 @@ python batch_runner.py \
--model="z-ai/glm-4.7" \
--base_url="https://openrouter.ai/api/v1" \
--providers_allowed="gmicloud,siliconflow,atlas-cloud,z-ai,novita" \
--num_workers=25 \
--num_workers=1 \
--max_turns=25 \
--verbose \
--ephemeral_system_prompt="When generating an image for the user view the image by using the vision_analyze tool to ensure it is what the user wanted. If it isn't feel free to retry a few times. If none are perfect, choose the best option that is the closest match, and explain its imperfections. If the image generation tool fails, try again a few times. If the vision analyze tool fails, provide the image to the user and explain it is your best effort attempt." \
2>&1 | tee "$LOG_FILE"
echo "✅ Log saved to: $LOG_FILE"
# --verbose \