Update RL tools and enhance configuration management

- Modified `model_tools.py` to update default model IDs and add new RL function `rl_test_inference`.
- Enhanced `README.md` with installation instructions for submodules and updated API key usage.
- Improved `rl_cli.py` to load configuration from `~/.hermes/config.yaml` and set terminal working directory for RL tools.
- Updated `run_agent.py` to handle empty string arguments as empty objects for better JSON validation.
- Refined installation scripts to ensure submodules are cloned and installed correctly, enhancing setup experience.
This commit is contained in:
teknium1 2026-02-04 13:57:59 -08:00
parent 12bbca95ec
commit 3c0d0dba49
7 changed files with 274 additions and 56 deletions

View file

@ -37,6 +37,7 @@ import subprocess
import sys
import time
import uuid
from datetime import datetime
import yaml
from dataclasses import dataclass, field
from pathlib import Path
@ -84,6 +85,7 @@ LOCKED_FIELDS = {
"weight": 1.0,
"num_requests_for_eval": 256,
"timeout": 3600,
"server_type": "sglang", # Tinker uses sglang for actual training
}
],
"tinker": {
@ -211,6 +213,9 @@ def _scan_environments() -> List[EnvironmentInfo]:
def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
"""
Dynamically import an environment and extract its config fields.
Uses config_init() to get the actual config class, with fallback to
directly importing BaseEnvConfig if config_init fails.
"""
try:
# Load the environment module
@ -230,15 +235,38 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
if not env_class:
return {}
# Call config_init to get the actual config
env_config, server_configs = env_class.config_init()
config_class = type(env_config)
# Try calling config_init to get the actual config class
config_class = None
try:
env_config, server_configs = env_class.config_init()
config_class = type(env_config)
except Exception as config_error:
# Fallback: try to import BaseEnvConfig directly from atroposlib
print(f"Note: config_init failed ({config_error}), using BaseEnvConfig defaults")
try:
from atroposlib.envs.base import BaseEnvConfig
config_class = BaseEnvConfig
except ImportError:
return {}
if not config_class:
return {}
# Helper to make values JSON-serializable (handle enums, etc.)
def make_serializable(val):
if val is None:
return None
if hasattr(val, 'value'): # Enum
return val.value
if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)):
return val.name
return val
# Extract fields from the Pydantic model
fields = {}
for field_name, field_info in config_class.model_fields.items():
field_type = field_info.annotation
default = field_info.default
default = make_serializable(field_info.default)
description = field_info.description or ""
is_locked = field_name in LOCKED_FIELD_NAMES
@ -248,12 +276,15 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
if hasattr(field_type, "__origin__"):
type_name = str(field_type)
locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default)
current_value = make_serializable(locked_value) if is_locked else default
fields[field_name] = {
"type": type_name,
"default": default if default is not None else None,
"default": default,
"description": description,
"locked": is_locked,
"current_value": LOCKED_FIELDS.get("env", {}).get(field_name, default) if is_locked else default,
"current_value": current_value,
}
return fields
@ -315,7 +346,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
trainer_log_file = open(trainer_log, "w")
run_state.trainer_process = subprocess.Popen(
["python", "launch_training.py", "--config", str(config_path)],
[sys.executable, "launch_training.py", "--config", str(config_path)],
stdout=trainer_log_file,
stderr=subprocess.STDOUT,
cwd=str(TINKER_ATROPOS_ROOT),
@ -355,7 +386,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
env_log_file = open(env_log, "w")
run_state.env_process = subprocess.Popen(
["python", str(env_info.file_path), "serve", "--config", str(config_path)],
[sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)],
stdout=env_log_file,
stderr=subprocess.STDOUT,
cwd=str(TINKER_ATROPOS_ROOT),
@ -543,17 +574,14 @@ async def rl_select_environment(name: str) -> str:
if not field_info.get("locked", False):
_current_config[field_name] = field_info.get("default")
configurable_count = sum(1 for f in config_fields.values() if not f.get("locked", False))
locked_count = sum(1 for f in config_fields.values() if f.get("locked", False))
# Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
_current_config["wandb_name"] = f"{name}-{timestamp}"
return json.dumps({
"message": f"Selected environment: {name}",
"environment": name,
"file_path": env_info.file_path,
"configurable_fields": configurable_count,
"locked_fields": locked_count,
"config": _current_config,
"tip": f"Use rl_get_current_config() to see all {configurable_count} configurable fields.",
}, indent=2)
@ -961,10 +989,11 @@ async def rl_list_runs() -> str:
# ============================================================================
# Test models at different scales for robustness testing
# These are cheap, capable models on OpenRouter for testing parsing/scoring
TEST_MODELS = [
{"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"},
{"id": "zhipu-ai/glm-4-flash", "name": "GLM-4 Flash", "scale": "medium"},
{"id": "minimax/minimax-m1", "name": "MiniMax M1", "scale": "large"},
{"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"},
{"id": "minimax/minimax-m2.1", "name": "MiniMax M2.1", "scale": "large"},
]
# Default test parameters - quick but representative
@ -1066,18 +1095,35 @@ async def rl_test_inference(
# Build the process command using Atropos's built-in CLI
# This runs the environment's actual code with OpenRouter as the inference backend
# We pass our locked settings + test-specific overrides via CLI args
cmd = [
"python", env_info.file_path, "process",
sys.executable, env_info.file_path, "process",
# Test-specific overrides
"--env.total_steps", str(num_steps),
"--env.group_size", str(group_size),
"--env.use_wandb", "false",
"--env.use_wandb", "false", # No wandb for quick tests
"--env.data_path_to_save_groups", str(output_file),
# Use locked settings from our config
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
"--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]),
"--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]),
"--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]),
# OpenRouter config for inference testing
# IMPORTANT: Use server_type=openai for OpenRouter (not sglang)
# sglang is only for actual training with Tinker's inference server
"--openai.base_url", "https://openrouter.ai/api/v1",
"--openai.api_key", api_key,
"--openai.model_name", model_id,
"--openai.server_type", "openai", # OpenRouter is OpenAI-compatible
"--openai.health_check", "false", # OpenRouter doesn't have health endpoint
]
print(f"Running: python {Path(env_info.file_path).name} process ...")
# Debug: Print the full command
cmd_str = " ".join(str(c) for c in cmd)
# Hide API key in printed output
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
print(f"Command: {cmd_display}")
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
model_results = {
@ -1105,12 +1151,44 @@ async def rl_test_inference(
timeout=600, # 10 minute timeout per model
)
# Decode output
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
# Write logs to files for inspection outside CLI
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
with open(log_file, "w") as f:
f.write(f"Command: {cmd_display}\n")
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
f.write(f"Return code: {process.returncode}\n")
f.write(f"\n{'='*60}\n")
f.write(f"STDOUT:\n{'='*60}\n")
f.write(stdout_text or "(empty)\n")
f.write(f"\n{'='*60}\n")
f.write(f"STDERR:\n{'='*60}\n")
f.write(stderr_text or "(empty)\n")
print(f" Log file: {log_file}")
# Print to console for immediate debugging
if stdout_text.strip():
print(f"\n--- STDOUT ---")
print(stdout_text[-2000:]) # Last 2000 chars
if stderr_text.strip():
print(f"\n--- STDERR ---")
print(stderr_text[-2000:]) # Last 2000 chars
if process.returncode != 0:
model_results["error"] = f"Process exited with code {process.returncode}"
model_results["stderr"] = stderr.decode()[-1000:]
print(f" Error: {model_results['error']}")
model_results["stderr"] = stderr_text[-1000:]
model_results["stdout"] = stdout_text[-1000:]
model_results["log_file"] = str(log_file)
print(f"\n ❌ Error: {model_results['error']}")
else:
print(f" Process completed successfully")
print(f"\n ✅ Process completed successfully")
print(f" Output file: {output_file}")
print(f" File exists: {output_file.exists()}")
# Parse the output JSONL file
if output_file.exists():