mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-26 01:01:40 +00:00
Update RL tools and enhance configuration management
- Modified `model_tools.py` to update default model IDs and add new RL function `rl_test_inference`. - Enhanced `README.md` with installation instructions for submodules and updated API key usage. - Improved `rl_cli.py` to load configuration from `~/.hermes/config.yaml` and set terminal working directory for RL tools. - Updated `run_agent.py` to handle empty string arguments as empty objects for better JSON validation. - Refined installation scripts to ensure submodules are cloned and installed correctly, enhancing setup experience.
This commit is contained in:
parent
12bbca95ec
commit
3c0d0dba49
7 changed files with 274 additions and 56 deletions
|
|
@ -37,6 +37,7 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import yaml
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
|
@ -84,6 +85,7 @@ LOCKED_FIELDS = {
|
|||
"weight": 1.0,
|
||||
"num_requests_for_eval": 256,
|
||||
"timeout": 3600,
|
||||
"server_type": "sglang", # Tinker uses sglang for actual training
|
||||
}
|
||||
],
|
||||
"tinker": {
|
||||
|
|
@ -211,6 +213,9 @@ def _scan_environments() -> List[EnvironmentInfo]:
|
|||
def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Dynamically import an environment and extract its config fields.
|
||||
|
||||
Uses config_init() to get the actual config class, with fallback to
|
||||
directly importing BaseEnvConfig if config_init fails.
|
||||
"""
|
||||
try:
|
||||
# Load the environment module
|
||||
|
|
@ -230,15 +235,38 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
|
|||
if not env_class:
|
||||
return {}
|
||||
|
||||
# Call config_init to get the actual config
|
||||
env_config, server_configs = env_class.config_init()
|
||||
config_class = type(env_config)
|
||||
# Try calling config_init to get the actual config class
|
||||
config_class = None
|
||||
try:
|
||||
env_config, server_configs = env_class.config_init()
|
||||
config_class = type(env_config)
|
||||
except Exception as config_error:
|
||||
# Fallback: try to import BaseEnvConfig directly from atroposlib
|
||||
print(f"Note: config_init failed ({config_error}), using BaseEnvConfig defaults")
|
||||
try:
|
||||
from atroposlib.envs.base import BaseEnvConfig
|
||||
config_class = BaseEnvConfig
|
||||
except ImportError:
|
||||
return {}
|
||||
|
||||
if not config_class:
|
||||
return {}
|
||||
|
||||
# Helper to make values JSON-serializable (handle enums, etc.)
|
||||
def make_serializable(val):
|
||||
if val is None:
|
||||
return None
|
||||
if hasattr(val, 'value'): # Enum
|
||||
return val.value
|
||||
if hasattr(val, 'name') and hasattr(val, '__class__') and 'Enum' in str(type(val)):
|
||||
return val.name
|
||||
return val
|
||||
|
||||
# Extract fields from the Pydantic model
|
||||
fields = {}
|
||||
for field_name, field_info in config_class.model_fields.items():
|
||||
field_type = field_info.annotation
|
||||
default = field_info.default
|
||||
default = make_serializable(field_info.default)
|
||||
description = field_info.description or ""
|
||||
|
||||
is_locked = field_name in LOCKED_FIELD_NAMES
|
||||
|
|
@ -248,12 +276,15 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]:
|
|||
if hasattr(field_type, "__origin__"):
|
||||
type_name = str(field_type)
|
||||
|
||||
locked_value = LOCKED_FIELDS.get("env", {}).get(field_name, default)
|
||||
current_value = make_serializable(locked_value) if is_locked else default
|
||||
|
||||
fields[field_name] = {
|
||||
"type": type_name,
|
||||
"default": default if default is not None else None,
|
||||
"default": default,
|
||||
"description": description,
|
||||
"locked": is_locked,
|
||||
"current_value": LOCKED_FIELDS.get("env", {}).get(field_name, default) if is_locked else default,
|
||||
"current_value": current_value,
|
||||
}
|
||||
|
||||
return fields
|
||||
|
|
@ -315,7 +346,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
|||
|
||||
trainer_log_file = open(trainer_log, "w")
|
||||
run_state.trainer_process = subprocess.Popen(
|
||||
["python", "launch_training.py", "--config", str(config_path)],
|
||||
[sys.executable, "launch_training.py", "--config", str(config_path)],
|
||||
stdout=trainer_log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(TINKER_ATROPOS_ROOT),
|
||||
|
|
@ -355,7 +386,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path):
|
|||
|
||||
env_log_file = open(env_log, "w")
|
||||
run_state.env_process = subprocess.Popen(
|
||||
["python", str(env_info.file_path), "serve", "--config", str(config_path)],
|
||||
[sys.executable, str(env_info.file_path), "serve", "--config", str(config_path)],
|
||||
stdout=env_log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=str(TINKER_ATROPOS_ROOT),
|
||||
|
|
@ -543,17 +574,14 @@ async def rl_select_environment(name: str) -> str:
|
|||
if not field_info.get("locked", False):
|
||||
_current_config[field_name] = field_info.get("default")
|
||||
|
||||
configurable_count = sum(1 for f in config_fields.values() if not f.get("locked", False))
|
||||
locked_count = sum(1 for f in config_fields.values() if f.get("locked", False))
|
||||
# Auto-set wandb_name to "{env_name}-DATETIME" to avoid overlaps
|
||||
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
_current_config["wandb_name"] = f"{name}-{timestamp}"
|
||||
|
||||
return json.dumps({
|
||||
"message": f"Selected environment: {name}",
|
||||
"environment": name,
|
||||
"file_path": env_info.file_path,
|
||||
"configurable_fields": configurable_count,
|
||||
"locked_fields": locked_count,
|
||||
"config": _current_config,
|
||||
"tip": f"Use rl_get_current_config() to see all {configurable_count} configurable fields.",
|
||||
}, indent=2)
|
||||
|
||||
|
||||
|
|
@ -961,10 +989,11 @@ async def rl_list_runs() -> str:
|
|||
# ============================================================================
|
||||
|
||||
# Test models at different scales for robustness testing
|
||||
# These are cheap, capable models on OpenRouter for testing parsing/scoring
|
||||
TEST_MODELS = [
|
||||
{"id": "qwen/qwen3-8b", "name": "Qwen3 8B", "scale": "small"},
|
||||
{"id": "zhipu-ai/glm-4-flash", "name": "GLM-4 Flash", "scale": "medium"},
|
||||
{"id": "minimax/minimax-m1", "name": "MiniMax M1", "scale": "large"},
|
||||
{"id": "z-ai/glm-4.7-flash", "name": "GLM-4.7 Flash", "scale": "medium"},
|
||||
{"id": "minimax/minimax-m2.1", "name": "MiniMax M2.1", "scale": "large"},
|
||||
]
|
||||
|
||||
# Default test parameters - quick but representative
|
||||
|
|
@ -1066,18 +1095,35 @@ async def rl_test_inference(
|
|||
|
||||
# Build the process command using Atropos's built-in CLI
|
||||
# This runs the environment's actual code with OpenRouter as the inference backend
|
||||
# We pass our locked settings + test-specific overrides via CLI args
|
||||
cmd = [
|
||||
"python", env_info.file_path, "process",
|
||||
sys.executable, env_info.file_path, "process",
|
||||
# Test-specific overrides
|
||||
"--env.total_steps", str(num_steps),
|
||||
"--env.group_size", str(group_size),
|
||||
"--env.use_wandb", "false",
|
||||
"--env.use_wandb", "false", # No wandb for quick tests
|
||||
"--env.data_path_to_save_groups", str(output_file),
|
||||
# Use locked settings from our config
|
||||
"--env.tokenizer_name", LOCKED_FIELDS["env"]["tokenizer_name"],
|
||||
"--env.max_token_length", str(LOCKED_FIELDS["env"]["max_token_length"]),
|
||||
"--env.max_num_workers", str(LOCKED_FIELDS["env"]["max_num_workers"]),
|
||||
"--env.max_batches_offpolicy", str(LOCKED_FIELDS["env"]["max_batches_offpolicy"]),
|
||||
# OpenRouter config for inference testing
|
||||
# IMPORTANT: Use server_type=openai for OpenRouter (not sglang)
|
||||
# sglang is only for actual training with Tinker's inference server
|
||||
"--openai.base_url", "https://openrouter.ai/api/v1",
|
||||
"--openai.api_key", api_key,
|
||||
"--openai.model_name", model_id,
|
||||
"--openai.server_type", "openai", # OpenRouter is OpenAI-compatible
|
||||
"--openai.health_check", "false", # OpenRouter doesn't have health endpoint
|
||||
]
|
||||
|
||||
print(f"Running: python {Path(env_info.file_path).name} process ...")
|
||||
# Debug: Print the full command
|
||||
cmd_str = " ".join(str(c) for c in cmd)
|
||||
# Hide API key in printed output
|
||||
cmd_display = cmd_str.replace(api_key, "***API_KEY***")
|
||||
print(f"Command: {cmd_display}")
|
||||
print(f"Working dir: {TINKER_ATROPOS_ROOT}")
|
||||
print(f" {num_steps} steps × {group_size} completions = {total_rollouts_per_model} rollouts")
|
||||
|
||||
model_results = {
|
||||
|
|
@ -1105,12 +1151,44 @@ async def rl_test_inference(
|
|||
timeout=600, # 10 minute timeout per model
|
||||
)
|
||||
|
||||
# Decode output
|
||||
stdout_text = stdout.decode() if stdout else ""
|
||||
stderr_text = stderr.decode() if stderr else ""
|
||||
|
||||
# Write logs to files for inspection outside CLI
|
||||
log_file = test_output_dir / f"test_{_current_env}_{model_safe_name}.log"
|
||||
with open(log_file, "w") as f:
|
||||
f.write(f"Command: {cmd_display}\n")
|
||||
f.write(f"Working dir: {TINKER_ATROPOS_ROOT}\n")
|
||||
f.write(f"Return code: {process.returncode}\n")
|
||||
f.write(f"\n{'='*60}\n")
|
||||
f.write(f"STDOUT:\n{'='*60}\n")
|
||||
f.write(stdout_text or "(empty)\n")
|
||||
f.write(f"\n{'='*60}\n")
|
||||
f.write(f"STDERR:\n{'='*60}\n")
|
||||
f.write(stderr_text or "(empty)\n")
|
||||
|
||||
print(f" Log file: {log_file}")
|
||||
|
||||
# Print to console for immediate debugging
|
||||
if stdout_text.strip():
|
||||
print(f"\n--- STDOUT ---")
|
||||
print(stdout_text[-2000:]) # Last 2000 chars
|
||||
|
||||
if stderr_text.strip():
|
||||
print(f"\n--- STDERR ---")
|
||||
print(stderr_text[-2000:]) # Last 2000 chars
|
||||
|
||||
if process.returncode != 0:
|
||||
model_results["error"] = f"Process exited with code {process.returncode}"
|
||||
model_results["stderr"] = stderr.decode()[-1000:]
|
||||
print(f" Error: {model_results['error']}")
|
||||
model_results["stderr"] = stderr_text[-1000:]
|
||||
model_results["stdout"] = stdout_text[-1000:]
|
||||
model_results["log_file"] = str(log_file)
|
||||
print(f"\n ❌ Error: {model_results['error']}")
|
||||
else:
|
||||
print(f" Process completed successfully")
|
||||
print(f"\n ✅ Process completed successfully")
|
||||
print(f" Output file: {output_file}")
|
||||
print(f" File exists: {output_file.exists()}")
|
||||
|
||||
# Parse the output JSONL file
|
||||
if output_file.exists():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue