mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
initial RL training tools and loop
This commit is contained in:
parent
51a6b7d2b5
commit
f018999da9
5 changed files with 1199 additions and 3 deletions
360
model_tools.py
360
model_tools.py
|
|
@ -39,6 +39,21 @@ from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
|||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_test_inference,
|
||||
rl_list_runs,
|
||||
rl_health_check,
|
||||
check_rl_api_keys,
|
||||
)
|
||||
# Cronjob management tools (CLI-only)
|
||||
from tools.cronjob_tools import (
|
||||
schedule_cronjob,
|
||||
|
|
@ -128,6 +143,19 @@ TOOLSET_REQUIREMENTS = {
|
|||
"setup_url": None,
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training (Tinker-Atropos)",
|
||||
"env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"],
|
||||
"check_fn": check_rl_api_keys,
|
||||
"setup_url": "https://wandb.ai/authorize",
|
||||
"tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -471,6 +499,199 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]:
|
|||
]]
|
||||
|
||||
|
||||
def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for RL training tools in OpenAI's expected format.
|
||||
|
||||
These tools enable running RL training through Tinker-Atropos.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of RL tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_environments",
|
||||
"description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_select_environment",
|
||||
"description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the environment to select (from rl_list_environments)"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_current_config",
|
||||
"description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_edit_config",
|
||||
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "Name of the field to update"
|
||||
},
|
||||
"value": {
|
||||
"description": "New value for the field"
|
||||
}
|
||||
},
|
||||
"required": ["field", "value"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_start_training",
|
||||
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"wandb_project": {
|
||||
"type": "string",
|
||||
"description": "WandB project name for logging",
|
||||
"default": "rl-training"
|
||||
},
|
||||
"lora_rank": {
|
||||
"type": "integer",
|
||||
"description": "LoRA rank for training",
|
||||
"default": 32
|
||||
},
|
||||
"learning_rate": {
|
||||
"type": "number",
|
||||
"description": "Learning rate",
|
||||
"default": 4e-5
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_check_status",
|
||||
"description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID from rl_start_training()"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_stop_training",
|
||||
"description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to stop"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_results",
|
||||
"description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to get results for"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_test_inference",
|
||||
"description": "Test inference + verifier on sample prompts WITHOUT full training. Use to validate environments before committing to long training runs. Tests data loading, inference, and verifier logic.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"prompts": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of test prompts to run through the environment"
|
||||
},
|
||||
"max_tokens": {
|
||||
"type": "integer",
|
||||
"description": "Maximum tokens to generate per prompt",
|
||||
"default": 256
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"description": "Sampling temperature",
|
||||
"default": 1.0
|
||||
}
|
||||
},
|
||||
"required": ["prompts"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_runs",
|
||||
"description": "List all training runs (active and completed) with their status.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
|
|
@ -519,6 +740,16 @@ def get_all_tool_names() -> List[str]:
|
|||
"schedule_cronjob", "list_cronjobs", "remove_cronjob"
|
||||
])
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
tool_names.extend([
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
|
|
@ -557,7 +788,18 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
|||
# Cronjob management tools
|
||||
"schedule_cronjob": "cronjob_tools",
|
||||
"list_cronjobs": "cronjob_tools",
|
||||
"remove_cronjob": "cronjob_tools"
|
||||
"remove_cronjob": "cronjob_tools",
|
||||
# RL Training tools
|
||||
"rl_list_environments": "rl_tools",
|
||||
"rl_select_environment": "rl_tools",
|
||||
"rl_get_current_config": "rl_tools",
|
||||
"rl_edit_config": "rl_tools",
|
||||
"rl_start_training": "rl_tools",
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_test_inference": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
}
|
||||
|
||||
return toolset_mapping.get(tool_name, "unknown")
|
||||
|
|
@ -635,6 +877,11 @@ def get_tool_definitions(
|
|||
for tool in get_cronjob_tool_definitions_formatted():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
for tool in get_rl_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
|
|
@ -663,7 +910,14 @@ def get_tool_definitions(
|
|||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
|
|
@ -708,7 +962,14 @@ def get_tool_definitions(
|
|||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
|
|
@ -1018,6 +1279,89 @@ def handle_cronjob_function_call(
|
|||
return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_rl_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for RL training tools.
|
||||
|
||||
These tools communicate with the RL API server to manage training runs.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the RL function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
# Run async functions in event loop
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if function_name == "rl_list_environments":
|
||||
return loop.run_until_complete(rl_list_environments())
|
||||
|
||||
elif function_name == "rl_select_environment":
|
||||
return loop.run_until_complete(
|
||||
rl_select_environment(name=function_args.get("name", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_current_config":
|
||||
return loop.run_until_complete(rl_get_current_config())
|
||||
|
||||
elif function_name == "rl_edit_config":
|
||||
return loop.run_until_complete(
|
||||
rl_edit_config(
|
||||
field=function_args.get("field", ""),
|
||||
value=function_args.get("value")
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_start_training":
|
||||
return loop.run_until_complete(
|
||||
rl_start_training(
|
||||
wandb_project=function_args.get("wandb_project", "rl-training"),
|
||||
lora_rank=function_args.get("lora_rank", 32),
|
||||
learning_rate=function_args.get("learning_rate", 4e-5)
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_check_status":
|
||||
return loop.run_until_complete(
|
||||
rl_check_status(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_stop_training":
|
||||
return loop.run_until_complete(
|
||||
rl_stop_training(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_results":
|
||||
return loop.run_until_complete(
|
||||
rl_get_results(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_test_inference":
|
||||
return loop.run_until_complete(
|
||||
rl_test_inference(
|
||||
prompts=function_args.get("prompts", []),
|
||||
max_tokens=function_args.get("max_tokens", 256),
|
||||
temperature=function_args.get("temperature", 1.0)
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_list_runs":
|
||||
return loop.run_until_complete(rl_list_runs())
|
||||
|
||||
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
|
|
@ -1081,6 +1425,16 @@ def handle_function_call(
|
|||
elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]:
|
||||
return handle_cronjob_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route RL training tools
|
||||
elif function_name in [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_test_inference", "rl_list_runs"
|
||||
]:
|
||||
return handle_rl_function_call(function_name, function_args)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue