Add RL training configuration and tools

- Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training.
- Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process.
- Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB.
- Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks.
- Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management.
This commit is contained in:
teknium1 2026-02-04 09:36:51 -08:00
parent f018999da9
commit f6574978de
7 changed files with 169 additions and 65 deletions

View file

@ -168,20 +168,22 @@ async def rl_get_current_config() -> str:
"""
Get the current environment configuration.
Returns only the fields that are safe to modify. Other fields
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
Returns all configurable fields for the selected environment.
Each environment may have different configuration options.
Available fields:
- group_size: Rollouts per prompt (4-16 typical)
- max_token_length: Max generation tokens (2048-16384)
- total_steps: Training steps (50-2000)
- steps_per_eval: Steps between evaluations
- use_wandb: Enable WandB logging
Fields are divided into:
- configurable_fields: Can be changed with rl_edit_config()
- locked_fields: Infrastructure settings that cannot be changed
Common configurable fields include:
- group_size: Rollouts per prompt
- batch_size: Training batch size
- wandb_name: WandB run name prefix
- max_num_workers: Concurrent workers (-1 = auto)
- system_prompt: Model instructions
- And any environment-specific options
Returns:
JSON string with current config fields and their values
JSON string with configurable and locked fields
"""
result = await _make_request("GET", "/config")
return json.dumps(result, indent=2)
@ -191,21 +193,15 @@ async def rl_edit_config(field: str, value: Any) -> str:
"""
Update a configuration field.
Only exposed fields can be modified. Validates field name and type.
Use rl_get_current_config() first to see available fields for the
selected environment. Each environment has different options.
Locked fields (infrastructure settings) cannot be changed.
Args:
field: Name of the field to update (e.g., "group_size", "total_steps")
field: Name of the field to update (from rl_get_current_config)
value: New value for the field
Valid fields:
- group_size (int): Rollouts per prompt
- max_token_length (int): Max generation tokens
- total_steps (int): Training steps
- steps_per_eval (int): Eval frequency
- use_wandb (bool): Enable logging
- wandb_name (str): Run name prefix
- max_num_workers (int): Workers count
Returns:
JSON string with updated config or error message
"""
@ -217,37 +213,28 @@ async def rl_edit_config(field: str, value: Any) -> str:
# Training Management Tools
# ============================================================================
async def rl_start_training(
wandb_project: str = "rl-training",
lora_rank: int = 32,
learning_rate: float = 4e-5,
) -> str:
async def rl_start_training() -> str:
"""
Start a new RL training run with the current environment and config.
Requires an environment to be selected first using rl_select_environment().
Use rl_edit_config() to set group_size, batch_size, wandb_project before starting.
WARNING: Training runs can take hours to days. Use rl_check_status() to
monitor progress (recommended: check every 30 minutes at most).
Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.)
and cannot be changed.
Args:
wandb_project: WandB project name for logging
lora_rank: LoRA rank for training (default: 32)
learning_rate: Learning rate (default: 4e-5)
WARNING: Training runs take hours. Use rl_check_status() to monitor
progress (recommended: check every 30 minutes at most).
Returns:
JSON string with run_id and initial status
TIP: Before starting training:
1. Test with rl_test_inference() to verify the environment works
2. Start with fewer total_steps to validate the setup
2. Configure group_size and batch_size appropriately
3. Monitor WandB metrics for reward/mean and percent_correct
"""
result = await _make_request("POST", "/runs", {
"wandb_project": wandb_project,
"lora_rank": lora_rank,
"learning_rate": learning_rate,
})
result = await _make_request("POST", "/runs", {})
return json.dumps(result, indent=2)