mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Add RL training configuration and tools
- Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training. - Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process. - Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB. - Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks. - Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management.
This commit is contained in:
parent
f018999da9
commit
f6574978de
7 changed files with 169 additions and 65 deletions
|
|
@ -168,20 +168,22 @@ async def rl_get_current_config() -> str:
|
|||
"""
|
||||
Get the current environment configuration.
|
||||
|
||||
Returns only the fields that are safe to modify. Other fields
|
||||
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
|
||||
Returns all configurable fields for the selected environment.
|
||||
Each environment may have different configuration options.
|
||||
|
||||
Available fields:
|
||||
- group_size: Rollouts per prompt (4-16 typical)
|
||||
- max_token_length: Max generation tokens (2048-16384)
|
||||
- total_steps: Training steps (50-2000)
|
||||
- steps_per_eval: Steps between evaluations
|
||||
- use_wandb: Enable WandB logging
|
||||
Fields are divided into:
|
||||
- configurable_fields: Can be changed with rl_edit_config()
|
||||
- locked_fields: Infrastructure settings that cannot be changed
|
||||
|
||||
Common configurable fields include:
|
||||
- group_size: Rollouts per prompt
|
||||
- batch_size: Training batch size
|
||||
- wandb_name: WandB run name prefix
|
||||
- max_num_workers: Concurrent workers (-1 = auto)
|
||||
- system_prompt: Model instructions
|
||||
- And any environment-specific options
|
||||
|
||||
Returns:
|
||||
JSON string with current config fields and their values
|
||||
JSON string with configurable and locked fields
|
||||
"""
|
||||
result = await _make_request("GET", "/config")
|
||||
return json.dumps(result, indent=2)
|
||||
|
|
@ -191,21 +193,15 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
|||
"""
|
||||
Update a configuration field.
|
||||
|
||||
Only exposed fields can be modified. Validates field name and type.
|
||||
Use rl_get_current_config() first to see available fields for the
|
||||
selected environment. Each environment has different options.
|
||||
|
||||
Locked fields (infrastructure settings) cannot be changed.
|
||||
|
||||
Args:
|
||||
field: Name of the field to update (e.g., "group_size", "total_steps")
|
||||
field: Name of the field to update (from rl_get_current_config)
|
||||
value: New value for the field
|
||||
|
||||
Valid fields:
|
||||
- group_size (int): Rollouts per prompt
|
||||
- max_token_length (int): Max generation tokens
|
||||
- total_steps (int): Training steps
|
||||
- steps_per_eval (int): Eval frequency
|
||||
- use_wandb (bool): Enable logging
|
||||
- wandb_name (str): Run name prefix
|
||||
- max_num_workers (int): Workers count
|
||||
|
||||
Returns:
|
||||
JSON string with updated config or error message
|
||||
"""
|
||||
|
|
@ -217,37 +213,28 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
|||
# Training Management Tools
|
||||
# ============================================================================
|
||||
|
||||
async def rl_start_training(
|
||||
wandb_project: str = "rl-training",
|
||||
lora_rank: int = 32,
|
||||
learning_rate: float = 4e-5,
|
||||
) -> str:
|
||||
async def rl_start_training() -> str:
|
||||
"""
|
||||
Start a new RL training run with the current environment and config.
|
||||
|
||||
Requires an environment to be selected first using rl_select_environment().
|
||||
Use rl_edit_config() to set group_size, batch_size, wandb_project before starting.
|
||||
|
||||
WARNING: Training runs can take hours to days. Use rl_check_status() to
|
||||
monitor progress (recommended: check every 30 minutes at most).
|
||||
Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.)
|
||||
and cannot be changed.
|
||||
|
||||
Args:
|
||||
wandb_project: WandB project name for logging
|
||||
lora_rank: LoRA rank for training (default: 32)
|
||||
learning_rate: Learning rate (default: 4e-5)
|
||||
WARNING: Training runs take hours. Use rl_check_status() to monitor
|
||||
progress (recommended: check every 30 minutes at most).
|
||||
|
||||
Returns:
|
||||
JSON string with run_id and initial status
|
||||
|
||||
TIP: Before starting training:
|
||||
1. Test with rl_test_inference() to verify the environment works
|
||||
2. Start with fewer total_steps to validate the setup
|
||||
2. Configure group_size and batch_size appropriately
|
||||
3. Monitor WandB metrics for reward/mean and percent_correct
|
||||
"""
|
||||
result = await _make_request("POST", "/runs", {
|
||||
"wandb_project": wandb_project,
|
||||
"lora_rank": lora_rank,
|
||||
"learning_rate": learning_rate,
|
||||
})
|
||||
result = await _make_request("POST", "/runs", {})
|
||||
return json.dumps(result, indent=2)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue