diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 8c18bee670..770c542c7b 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -1093,6 +1093,10 @@ async def rl_test_inference( # Output file for this test run output_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.jsonl" + # Generate unique run ID for wandb + test_run_id = str(uuid.uuid4())[:8] + wandb_run_name = f"test_inference_RSIAgent_{_current_env}_{test_run_id}" + # Build the process command using Atropos's built-in CLI # This runs the environment's actual code with OpenRouter as the inference backend # We pass our locked settings + test-specific overrides via CLI args @@ -1101,7 +1105,8 @@ async def rl_test_inference( # Test-specific overrides "--env.total_steps", str(num_steps), "--env.group_size", str(group_size), - "--env.use_wandb", "false", # No wandb for quick tests + "--env.use_wandb", "true", # Enable wandb for test tracking + "--env.wandb_name", wandb_run_name, "--env.data_path_to_save_groups", str(output_file), # Use locked settings from our config "--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"], @@ -1124,12 +1129,14 @@ async def rl_test_inference( cmd_display = cmd_str.replace(api_key, "***API_KEY***") print(f"Command: {cmd_display}") print(f"Working dir: {TINKER_ATROPOS_ROOT}") + print(f"WandB run: {wandb_run_name}") print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts") model_results = { "model": model_id, "name": model_info["name"], "scale": model_info["scale"], + "wandb_run": wandb_run_name, "output_file": str(output_file), "steps": [], "steps_tested": 0, @@ -1138,7 +1145,7 @@ async def rl_test_inference( } try: - # Run the process command + # Run the process command with real-time output streaming process = await asyncio.create_subprocess_exec( *cmd, stdout=asyncio.subprocess.PIPE, @@ -1146,17 +1153,43 @@ async def rl_test_inference( cwd=str(TINKER_ATROPOS_ROOT), ) - stdout, stderr = await asyncio.wait_for( - process.communicate(), - timeout=600, # 10 minute timeout per model - ) + # Stream output in real-time while collecting for logs + stdout_lines = [] + stderr_lines = [] + log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" - # Decode output - stdout_text = stdout.decode() if stdout else "" - stderr_text = stderr.decode() if stderr else "" + async def read_stream(stream, lines_list, prefix=""): + """Read stream line by line and print in real-time.""" + while True: + line = await stream.readline() + if not line: + break + decoded = line.decode().rstrip() + lines_list.append(decoded) + # Print progress-related lines in real-time + if any(kw in decoded.lower() for kw in ['processing', 'group', 'step', 'progress', '%', 'completed']): + print(f" {prefix}{decoded}") + + # Read both streams concurrently with timeout + try: + await asyncio.wait_for( + asyncio.gather( + read_stream(process.stdout, stdout_lines, "📊 "), + read_stream(process.stderr, stderr_lines, "⚠️ "), + ), + timeout=600, # 10 minute timeout per model + ) + except asyncio.TimeoutError: + process.kill() + raise + + await process.wait() + + # Combine output for logging + stdout_text = "\n".join(stdout_lines) + stderr_text = "\n".join(stderr_lines) # Write logs to files for inspection outside CLI - log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log" with open(log_file, "w") as f: f.write(f"Command: {cmd_display}\n") f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n") @@ -1170,21 +1203,17 @@ async def rl_test_inference( print(f" Log file: {log_file}") - # Print to console for immediate debugging - if stdout_text.strip(): - print(f"\n--- STDOUT ---") - print(stdout_text[-2000:]) # Last 2000 chars - - if stderr_text.strip(): - print(f"\n--- STDERR ---") - print(stderr_text[-2000:]) # Last 2000 chars - if process.returncode != 0: model_results["error"] = f"Process exited with code {process.returncode}" model_results["stderr"] = stderr_text[-1000:] model_results["stdout"] = stdout_text[-1000:] model_results["log_file"] = str(log_file) print(f"\n ❌ Error: {model_results['error']}") + # Print last few lines of stderr for debugging + if stderr_lines: + print(f" Last errors:") + for line in stderr_lines[-5:]: + print(f" {line}") else: print(f"\n ✅ Process completed successfully") print(f" Output file: {output_file}")