mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Add RL training configuration and tools
- Updated `.env.example` to include Tinker and WandB API keys for reinforcement learning training. - Enhanced `model_tools.py` to clarify configuration options and streamline the RL training process. - Expanded `README.md` with detailed instructions for setting up RL training using Tinker and WandB. - Modified `hermes_cli` files to integrate RL training tools and ensure proper configuration checks. - Improved `rl_training_tool.py` to reflect changes in training parameters and configuration management.
This commit is contained in:
parent
f018999da9
commit
f6574978de
7 changed files with 169 additions and 65 deletions
18
.env.example
18
.env.example
|
|
@ -165,3 +165,21 @@ IMAGE_TOOLS_DEBUG=false
|
||||||
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
||||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||||
# CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries
|
# CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# RL TRAINING (Tinker + Atropos)
|
||||||
|
# =============================================================================
|
||||||
|
# Run reinforcement learning training on language models using the Tinker API.
|
||||||
|
# Requires the rl-server to be running (from tinker-atropos package).
|
||||||
|
|
||||||
|
# Tinker API Key - RL training service
|
||||||
|
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||||
|
TINKER_API_KEY=
|
||||||
|
|
||||||
|
# Weights & Biases API Key - Experiment tracking and metrics
|
||||||
|
# Get at: https://wandb.ai/authorize
|
||||||
|
WANDB_API_KEY=
|
||||||
|
|
||||||
|
# RL API Server URL (default: http://localhost:8080)
|
||||||
|
# Change if running the rl-server on a different host/port
|
||||||
|
# RL_API_URL=http://localhost:8080
|
||||||
|
|
|
||||||
56
README.md
56
README.md
|
|
@ -74,6 +74,7 @@ You need at least one LLM provider:
|
||||||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||||
|
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
@ -270,6 +271,61 @@ When enabled, you'll see messages like:
|
||||||
|
|
||||||
See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
||||||
|
|
||||||
|
### 🤖 RL Training (Tinker + Atropos)
|
||||||
|
|
||||||
|
Train language models with reinforcement learning using the Tinker API and Atropos framework.
|
||||||
|
|
||||||
|
#### Requirements
|
||||||
|
|
||||||
|
1. **API Keys:** Add to `~/.hermes/.env`:
|
||||||
|
```bash
|
||||||
|
TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys
|
||||||
|
WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Install tinker-atropos:** (in a separate directory)
|
||||||
|
```bash
|
||||||
|
cd ~/tinker-atropos
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Start the RL API server:**
|
||||||
|
```bash
|
||||||
|
rl-server # Runs on port 8080 by default
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using RL Tools
|
||||||
|
|
||||||
|
The agent can now use RL training tools:
|
||||||
|
|
||||||
|
```
|
||||||
|
You: Start training on GSM8k with group_size=16
|
||||||
|
|
||||||
|
Agent: I'll set up an RL training run on the GSM8k environment...
|
||||||
|
[Uses rl_list_environments, rl_select_environment, rl_edit_config, rl_start_training]
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Available RL Tools
|
||||||
|
|
||||||
|
| Tool | Description |
|
||||||
|
|------|-------------|
|
||||||
|
| `rl_list_environments` | List available RL environments |
|
||||||
|
| `rl_select_environment` | Select an environment for training |
|
||||||
|
| `rl_get_current_config` | View all configurable options |
|
||||||
|
| `rl_edit_config` | Change a configuration value |
|
||||||
|
| `rl_start_training` | Start a training run |
|
||||||
|
| `rl_check_status` | Check training progress |
|
||||||
|
| `rl_stop_training` | Stop a running training |
|
||||||
|
| `rl_get_results` | Fetch WandB metrics |
|
||||||
|
|
||||||
|
#### Dedicated RL CLI
|
||||||
|
|
||||||
|
For extended RL workflows with longer timeouts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python rl_cli.py --model "anthropic/claude-sonnet-4-20250514"
|
||||||
|
```
|
||||||
|
|
||||||
### ⏰ Scheduled Tasks (Cron)
|
### ⏰ Scheduled Tasks (Cron)
|
||||||
|
|
||||||
Schedule tasks to run automatically:
|
Schedule tasks to run automatically:
|
||||||
|
|
|
||||||
|
|
@ -151,6 +151,20 @@ OPTIONAL_ENV_VARS = {
|
||||||
"tools": ["image_generate"],
|
"tools": ["image_generate"],
|
||||||
"password": True,
|
"password": True,
|
||||||
},
|
},
|
||||||
|
"TINKER_API_KEY": {
|
||||||
|
"description": "Tinker API key for RL training",
|
||||||
|
"prompt": "Tinker API key",
|
||||||
|
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||||
|
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
||||||
|
"password": True,
|
||||||
|
},
|
||||||
|
"WANDB_API_KEY": {
|
||||||
|
"description": "Weights & Biases API key for experiment tracking",
|
||||||
|
"prompt": "WandB API key",
|
||||||
|
"url": "https://wandb.ai/authorize",
|
||||||
|
"tools": ["rl_get_results", "rl_check_status"],
|
||||||
|
"password": True,
|
||||||
|
},
|
||||||
"OPENAI_BASE_URL": {
|
"OPENAI_BASE_URL": {
|
||||||
"description": "Custom OpenAI-compatible API endpoint URL",
|
"description": "Custom OpenAI-compatible API endpoint URL",
|
||||||
"prompt": "API base URL (e.g., https://api.example.com/v1)",
|
"prompt": "API base URL (e.g., https://api.example.com/v1)",
|
||||||
|
|
|
||||||
|
|
@ -186,6 +186,14 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||||
else:
|
else:
|
||||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||||
|
|
||||||
|
# Tinker + WandB (RL training)
|
||||||
|
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||||
|
tool_status.append(("RL Training (Tinker)", True, None))
|
||||||
|
elif get_env_value('TINKER_API_KEY'):
|
||||||
|
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
||||||
|
else:
|
||||||
|
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
||||||
|
|
||||||
# Terminal (always available if system deps met)
|
# Terminal (always available if system deps met)
|
||||||
tool_status.append(("Terminal/Commands", True, None))
|
tool_status.append(("Terminal/Commands", True, None))
|
||||||
|
|
||||||
|
|
@ -932,6 +940,47 @@ def run_setup_wizard(args):
|
||||||
if api_key:
|
if api_key:
|
||||||
save_env_value("FAL_KEY", api_key)
|
save_env_value("FAL_KEY", api_key)
|
||||||
print_success(" Configured ✓")
|
print_success(" Configured ✓")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Tinker + WandB - RL Training
|
||||||
|
print_info("─" * 50)
|
||||||
|
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||||
|
print_info(" Enables: rl_start_training, rl_check_status, rl_get_results tools")
|
||||||
|
print_info(" Use case: Run reinforcement learning training via Tinker API")
|
||||||
|
tinker_configured = get_env_value('TINKER_API_KEY')
|
||||||
|
wandb_configured = get_env_value('WANDB_API_KEY')
|
||||||
|
|
||||||
|
if tinker_configured and wandb_configured:
|
||||||
|
print_success(" Status: Configured ✓")
|
||||||
|
if prompt_yes_no(" Update RL training credentials?", False):
|
||||||
|
api_key = prompt(" Tinker API key", password=True)
|
||||||
|
if api_key:
|
||||||
|
save_env_value("TINKER_API_KEY", api_key)
|
||||||
|
wandb_key = prompt(" WandB API key", password=True)
|
||||||
|
if wandb_key:
|
||||||
|
save_env_value("WANDB_API_KEY", wandb_key)
|
||||||
|
print_success(" Updated")
|
||||||
|
else:
|
||||||
|
if tinker_configured:
|
||||||
|
print_warning(" Status: Tinker configured, WandB missing")
|
||||||
|
elif wandb_configured:
|
||||||
|
print_warning(" Status: WandB configured, Tinker missing")
|
||||||
|
else:
|
||||||
|
print_warning(" Status: Not configured (tools will be disabled)")
|
||||||
|
|
||||||
|
if prompt_yes_no(" Set up RL Training?", False):
|
||||||
|
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||||
|
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||||
|
api_key = prompt(" Tinker API key", password=True)
|
||||||
|
if api_key:
|
||||||
|
save_env_value("TINKER_API_KEY", api_key)
|
||||||
|
wandb_key = prompt(" WandB API key", password=True)
|
||||||
|
if wandb_key:
|
||||||
|
save_env_value("WANDB_API_KEY", wandb_key)
|
||||||
|
if api_key and wandb_key:
|
||||||
|
print_success(" Configured ✓")
|
||||||
|
else:
|
||||||
|
print_warning(" Partially configured (both keys required)")
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Save config and show summary
|
# Save config and show summary
|
||||||
|
|
|
||||||
|
|
@ -74,6 +74,8 @@ def show_status(args):
|
||||||
"Firecrawl": "FIRECRAWL_API_KEY",
|
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||||
"Browserbase": "BROWSERBASE_API_KEY",
|
"Browserbase": "BROWSERBASE_API_KEY",
|
||||||
"FAL": "FAL_KEY",
|
"FAL": "FAL_KEY",
|
||||||
|
"Tinker": "TINKER_API_KEY",
|
||||||
|
"WandB": "WANDB_API_KEY",
|
||||||
}
|
}
|
||||||
|
|
||||||
for name, env_var in keys.items():
|
for name, env_var in keys.items():
|
||||||
|
|
|
||||||
|
|
@ -554,13 +554,13 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "rl_edit_config",
|
"name": "rl_edit_config",
|
||||||
"description": "Update a configuration field. Valid fields: group_size (int), max_token_length (int), total_steps (int), steps_per_eval (int), use_wandb (bool), wandb_name (str), max_num_workers (int).",
|
"description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {
|
||||||
"field": {
|
"field": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Name of the field to update"
|
"description": "Name of the field to update (get available fields from rl_get_current_config)"
|
||||||
},
|
},
|
||||||
"value": {
|
"value": {
|
||||||
"description": "New value for the field"
|
"description": "New value for the field"
|
||||||
|
|
@ -574,26 +574,10 @@ def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": "rl_start_training",
|
"name": "rl_start_training",
|
||||||
"description": "Start a new RL training run. WARNING: Training can take hours. Use rl_check_status() to monitor (30-minute intervals recommended). Test with rl_test_inference() first!",
|
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours. Test with rl_test_inference() first!",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"properties": {
|
"properties": {},
|
||||||
"wandb_project": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "WandB project name for logging",
|
|
||||||
"default": "rl-training"
|
|
||||||
},
|
|
||||||
"lora_rank": {
|
|
||||||
"type": "integer",
|
|
||||||
"description": "LoRA rank for training",
|
|
||||||
"default": 32
|
|
||||||
},
|
|
||||||
"learning_rate": {
|
|
||||||
"type": "number",
|
|
||||||
"description": "Learning rate",
|
|
||||||
"default": 4e-5
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": []
|
"required": []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1324,13 +1308,7 @@ def handle_rl_function_call(
|
||||||
)
|
)
|
||||||
|
|
||||||
elif function_name == "rl_start_training":
|
elif function_name == "rl_start_training":
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(rl_start_training())
|
||||||
rl_start_training(
|
|
||||||
wandb_project=function_args.get("wandb_project", "rl-training"),
|
|
||||||
lora_rank=function_args.get("lora_rank", 32),
|
|
||||||
learning_rate=function_args.get("learning_rate", 4e-5)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif function_name == "rl_check_status":
|
elif function_name == "rl_check_status":
|
||||||
return loop.run_until_complete(
|
return loop.run_until_complete(
|
||||||
|
|
|
||||||
|
|
@ -168,20 +168,22 @@ async def rl_get_current_config() -> str:
|
||||||
"""
|
"""
|
||||||
Get the current environment configuration.
|
Get the current environment configuration.
|
||||||
|
|
||||||
Returns only the fields that are safe to modify. Other fields
|
Returns all configurable fields for the selected environment.
|
||||||
(tokenizer_name, rollout_server_url, etc.) are fixed by the system.
|
Each environment may have different configuration options.
|
||||||
|
|
||||||
Available fields:
|
Fields are divided into:
|
||||||
- group_size: Rollouts per prompt (4-16 typical)
|
- configurable_fields: Can be changed with rl_edit_config()
|
||||||
- max_token_length: Max generation tokens (2048-16384)
|
- locked_fields: Infrastructure settings that cannot be changed
|
||||||
- total_steps: Training steps (50-2000)
|
|
||||||
- steps_per_eval: Steps between evaluations
|
Common configurable fields include:
|
||||||
- use_wandb: Enable WandB logging
|
- group_size: Rollouts per prompt
|
||||||
|
- batch_size: Training batch size
|
||||||
- wandb_name: WandB run name prefix
|
- wandb_name: WandB run name prefix
|
||||||
- max_num_workers: Concurrent workers (-1 = auto)
|
- system_prompt: Model instructions
|
||||||
|
- And any environment-specific options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
JSON string with current config fields and their values
|
JSON string with configurable and locked fields
|
||||||
"""
|
"""
|
||||||
result = await _make_request("GET", "/config")
|
result = await _make_request("GET", "/config")
|
||||||
return json.dumps(result, indent=2)
|
return json.dumps(result, indent=2)
|
||||||
|
|
@ -191,21 +193,15 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Update a configuration field.
|
Update a configuration field.
|
||||||
|
|
||||||
Only exposed fields can be modified. Validates field name and type.
|
Use rl_get_current_config() first to see available fields for the
|
||||||
|
selected environment. Each environment has different options.
|
||||||
|
|
||||||
|
Locked fields (infrastructure settings) cannot be changed.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
field: Name of the field to update (e.g., "group_size", "total_steps")
|
field: Name of the field to update (from rl_get_current_config)
|
||||||
value: New value for the field
|
value: New value for the field
|
||||||
|
|
||||||
Valid fields:
|
|
||||||
- group_size (int): Rollouts per prompt
|
|
||||||
- max_token_length (int): Max generation tokens
|
|
||||||
- total_steps (int): Training steps
|
|
||||||
- steps_per_eval (int): Eval frequency
|
|
||||||
- use_wandb (bool): Enable logging
|
|
||||||
- wandb_name (str): Run name prefix
|
|
||||||
- max_num_workers (int): Workers count
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
JSON string with updated config or error message
|
JSON string with updated config or error message
|
||||||
"""
|
"""
|
||||||
|
|
@ -217,37 +213,28 @@ async def rl_edit_config(field: str, value: Any) -> str:
|
||||||
# Training Management Tools
|
# Training Management Tools
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
async def rl_start_training(
|
async def rl_start_training() -> str:
|
||||||
wandb_project: str = "rl-training",
|
|
||||||
lora_rank: int = 32,
|
|
||||||
learning_rate: float = 4e-5,
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Start a new RL training run with the current environment and config.
|
Start a new RL training run with the current environment and config.
|
||||||
|
|
||||||
Requires an environment to be selected first using rl_select_environment().
|
Requires an environment to be selected first using rl_select_environment().
|
||||||
|
Use rl_edit_config() to set group_size, batch_size, wandb_project before starting.
|
||||||
|
|
||||||
WARNING: Training runs can take hours to days. Use rl_check_status() to
|
Most training parameters are fixed (lora_rank=32, learning_rate=4e-5, etc.)
|
||||||
monitor progress (recommended: check every 30 minutes at most).
|
and cannot be changed.
|
||||||
|
|
||||||
Args:
|
WARNING: Training runs take hours. Use rl_check_status() to monitor
|
||||||
wandb_project: WandB project name for logging
|
progress (recommended: check every 30 minutes at most).
|
||||||
lora_rank: LoRA rank for training (default: 32)
|
|
||||||
learning_rate: Learning rate (default: 4e-5)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
JSON string with run_id and initial status
|
JSON string with run_id and initial status
|
||||||
|
|
||||||
TIP: Before starting training:
|
TIP: Before starting training:
|
||||||
1. Test with rl_test_inference() to verify the environment works
|
1. Test with rl_test_inference() to verify the environment works
|
||||||
2. Start with fewer total_steps to validate the setup
|
2. Configure group_size and batch_size appropriately
|
||||||
3. Monitor WandB metrics for reward/mean and percent_correct
|
3. Monitor WandB metrics for reward/mean and percent_correct
|
||||||
"""
|
"""
|
||||||
result = await _make_request("POST", "/runs", {
|
result = await _make_request("POST", "/runs", {})
|
||||||
"wandb_project": wandb_project,
|
|
||||||
"lora_rank": lora_rank,
|
|
||||||
"learning_rate": learning_rate,
|
|
||||||
})
|
|
||||||
return json.dumps(result, indent=2)
|
return json.dumps(result, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue