initial RL training tools and loop

This commit is contained in:
teknium1 2026-02-03 23:41:26 -08:00
parent 51a6b7d2b5
commit f018999da9
5 changed files with 1199 additions and 3 deletions

View file

@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
# RL Training tools (Tinker-Atropos)
from tools.rl_training_tool import (
rl_list_environments,
rl_select_environment,
rl_get_current_config,
rl_edit_config,
rl_start_training,
rl_check_status,
rl_stop_training,
rl_get_results,
rl_test_inference,
rl_list_runs,
rl_health_check,
check_rl_api_keys,
)
# Cronjob management tools (CLI-only)
from tools.cronjob_tools import (
schedule_cronjob,
@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = {
"setup_url": None,
"tools": ["skills_categories", "skills_list", "skill_view"],
},
"rl": {
"name": "RL Training (Tinker-Atropos)",
"env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"],
"check_fn": check_rl_api_keys,
"setup_url": "https://wandb.ai/authorize",
"tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs",
],
},
}
@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]:
]]
def get_rl_tool_definitions() -> List[Dict[str, Any]]:
"""
Get tool definitions for RL training tools in OpenAI's expected format.
These tools enable running RL training through Tinker-Atropos.
Returns:
List[Dict]: List of RL tool definitions compatible with OpenAI API
"""
return [
{
"type": "function",
"function": {
"name": "rl_list_environments",
"description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_select_environment",
"description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.",
"parameters": {
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "Name of the environment to select (from rl_list_environments)"
}
},
"required": ["name"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_get_current_config",
"description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_edit_config",
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
"parameters": {
"type": "object",
"properties": {
"field": {
"type": "string",
"description": "Name of the field to update"
},
"value": {
"description": "New value for the field"
}
},
"required": ["field", "value"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_start_training",
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
"parameters": {
"type": "object",
"properties": {
"wandb_project": {
"type": "string",
"description": "WandB project name for logging",
"default": "rl-training"
},
"lora_rank": {
"type": "integer",
"description": "LoRA rank for training",
"default": 32
},
"learning_rate": {
"type": "number",
"description": "Learning rate",
"default": 4e-5
}
},
"required": []
}
}
},
{
"type": "function",
"function": {
"name": "rl_check_status",
"description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID from rl_start_training()"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_stop_training",
"description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID to stop"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_get_results",
"description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.",
"parameters": {
"type": "object",
"properties": {
"run_id": {
"type": "string",
"description": "The run ID to get results for"
}
},
"required": ["run_id"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_test_inference",
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.",
"parameters": {
"type": "object",
"properties": {
"prompts": {
"type": "array",
"items": {"type": "string"},
"description": "List of test prompts to run through the environment"
},
"max_tokens": {
"type": "integer",
"description": "Maximum tokens to generate per prompt",
"default": 256
},
"temperature": {
"type": "number",
"description": "Sampling temperature",
"default": 1.0
}
},
"required": ["prompts"]
}
}
},
{
"type": "function",
"function": {
"name": "rl_list_runs",
"description": "List all training runs (active and completed) with their status.",
"parameters": {
"type": "object",
"properties": {},
"required": []
}
}
}
]
def get_all_tool_names() -> List[str]:
"""
Get the names of all available tools across all toolsets.
@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]:
"schedule_cronjob", "list_cronjobs", "remove_cronjob"
])
# RL Training tools
if check_rl_api_keys():
tool_names.extend([
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
])
return tool_names
@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str:
# Cronjob management tools
"schedule_cronjob": "cronjob_tools",
"list_cronjobs": "cronjob_tools",
"remove_cronjob": "cronjob_tools"
"remove_cronjob": "cronjob_tools",
# RL Training tools
"rl_list_environments": "rl_tools",
"rl_select_environment": "rl_tools",
"rl_get_current_config": "rl_tools",
"rl_edit_config": "rl_tools",
"rl_start_training": "rl_tools",
"rl_check_status": "rl_tools",
"rl_stop_training": "rl_tools",
"rl_get_results": "rl_tools",
"rl_test_inference": "rl_tools",
"rl_list_runs": "rl_tools",
}
return toolset_mapping.get(tool_name, "unknown")
@ -635,6 +877,11 @@ def get_tool_definitions(
for tool in get_cronjob_tool_definitions_formatted():
all_available_tools_map[tool["function"]["name"]] = tool
# RL Training tools
if check_rl_api_keys():
for tool in get_rl_tool_definitions():
all_available_tools_map[tool["function"]["name"]] = tool
# Determine which tools to include based on toolsets
tools_to_include = set()
@ -663,7 +910,14 @@ def get_tool_definitions(
"browser_press", "browser_close", "browser_get_images",
"browser_vision"
],
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
"rl_tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]
}
legacy_tools = legacy_map.get(toolset_name, [])
tools_to_include.update(legacy_tools)
@ -708,7 +962,14 @@ def get_tool_definitions(
"browser_press", "browser_close", "browser_get_images",
"browser_vision"
],
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
"rl_tools": [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]
}
legacy_tools = legacy_map.get(toolset_name, [])
tools_to_include.difference_update(legacy_tools)
@ -1018,6 +1279,89 @@ def handle_cronjob_function_call(
return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False)
def handle_rl_function_call(
function_name: str,
function_args: Dict[str, Any]
) -> str:
"""
Handle function calls for RL training tools.
These tools communicate with the RL API server to manage training runs.
Args:
function_name (str): Name of the RL function to call
function_args (Dict): Arguments for the function
Returns:
str: Function result as JSON string
"""
# Run async functions in event loop
import asyncio
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if function_name == "rl_list_environments":
return loop.run_until_complete(rl_list_environments())
elif function_name == "rl_select_environment":
return loop.run_until_complete(
rl_select_environment(name=function_args.get("name", ""))
)
elif function_name == "rl_get_current_config":
return loop.run_until_complete(rl_get_current_config())
elif function_name == "rl_edit_config":
return loop.run_until_complete(
rl_edit_config(
field=function_args.get("field", ""),
value=function_args.get("value")
)
)
elif function_name == "rl_start_training":
return loop.run_until_complete(
rl_start_training(
wandb_project=function_args.get("wandb_project", "rl-training"),
lora_rank=function_args.get("lora_rank", 32),
learning_rate=function_args.get("learning_rate", 4e-5)
)
)
elif function_name == "rl_check_status":
return loop.run_until_complete(
rl_check_status(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_stop_training":
return loop.run_until_complete(
rl_stop_training(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_get_results":
return loop.run_until_complete(
rl_get_results(run_id=function_args.get("run_id", ""))
)
elif function_name == "rl_test_inference":
return loop.run_until_complete(
rl_test_inference(
prompts=function_args.get("prompts", []),
max_tokens=function_args.get("max_tokens", 256),
temperature=function_args.get("temperature", 1.0)
)
)
elif function_name == "rl_list_runs":
return loop.run_until_complete(rl_list_runs())
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
def handle_function_call(
function_name: str,
function_args: Dict[str, Any],
@ -1081,6 +1425,16 @@ def handle_function_call(
elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]:
return handle_cronjob_function_call(function_name, function_args, task_id)
# Route RL training tools
elif function_name in [
"rl_list_environments", "rl_select_environment",
"rl_get_current_config", "rl_edit_config",
"rl_start_training", "rl_check_status",
"rl_stop_training", "rl_get_results",
"rl_test_inference", "rl_list_runs"
]:
return handle_rl_function_call(function_name, function_args)
else:
error_msg = f"Unknown function: {function_name}"
print(f"{error_msg}")