mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge pull request #15 from NousResearch/rl-capabilities
Rl capabilities && File Operator Tools
This commit is contained in:
commit
8dd38318fc
22 changed files with 4910 additions and 24 deletions
18
.env.example
18
.env.example
|
|
@ -165,3 +165,21 @@ IMAGE_TOOLS_DEBUG=false
|
|||
# CONTEXT_COMPRESSION_ENABLED=true # Enable auto-compression (default: true)
|
||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||
# CONTEXT_COMPRESSION_MODEL=google/gemini-2.0-flash-001 # Fast model for summaries
|
||||
|
||||
# =============================================================================
|
||||
# RL TRAINING (Tinker + Atropos)
|
||||
# =============================================================================
|
||||
# Run reinforcement learning training on language models using the Tinker API.
|
||||
# Requires the rl-server to be running (from tinker-atropos package).
|
||||
|
||||
# Tinker API Key - RL training service
|
||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
||||
TINKER_API_KEY=
|
||||
|
||||
# Weights & Biases API Key - Experiment tracking and metrics
|
||||
# Get at: https://wandb.ai/authorize
|
||||
WANDB_API_KEY=
|
||||
|
||||
# RL API Server URL (default: http://localhost:8080)
|
||||
# Change if running the rl-server on a different host/port
|
||||
# RL_API_URL=http://localhost:8080
|
||||
|
|
|
|||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -1,3 +1,6 @@
|
|||
[submodule "mini-swe-agent"]
|
||||
path = mini-swe-agent
|
||||
url = https://github.com/SWE-agent/mini-swe-agent
|
||||
[submodule "tinker-atropos"]
|
||||
path = tinker-atropos
|
||||
url = https://github.com/nousresearch/tinker-atropos
|
||||
|
|
|
|||
59
README.md
59
README.md
|
|
@ -15,7 +15,7 @@ irm https://raw.githubusercontent.com/NousResearch/hermes-agent/main/scripts/ins
|
|||
```
|
||||
|
||||
The installer will:
|
||||
- Clone to `~/.hermes-agent`
|
||||
- Clone to `~/.hermes-agent` (with submodules: mini-swe-agent, tinker-atropos)
|
||||
- Create a virtual environment
|
||||
- Install all dependencies
|
||||
- Run the interactive setup wizard
|
||||
|
|
@ -74,6 +74,7 @@ You need at least one LLM provider:
|
|||
| Web scraping | [Firecrawl](https://firecrawl.dev/) | `FIRECRAWL_API_KEY` |
|
||||
| Browser automation | [Browserbase](https://browserbase.com/) | `BROWSERBASE_API_KEY`, `BROWSERBASE_PROJECT_ID` |
|
||||
| Image generation | [FAL](https://fal.ai/) | `FAL_KEY` |
|
||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
||||
| Messaging | Telegram, Discord | `TELEGRAM_BOT_TOKEN`, `DISCORD_BOT_TOKEN` |
|
||||
|
||||
---
|
||||
|
|
@ -270,6 +271,55 @@ When enabled, you'll see messages like:
|
|||
|
||||
See [docs/messaging.md](docs/messaging.md) for WhatsApp and advanced setup.
|
||||
|
||||
### 🤖 RL Training (Tinker + Atropos)
|
||||
|
||||
Train language models with reinforcement learning using the Tinker API and Atropos framework.
|
||||
|
||||
#### Requirements
|
||||
|
||||
1. **API Keys:** Add to `~/.hermes/.env`:
|
||||
```bash
|
||||
TINKER_API_KEY=your-tinker-key # Get from https://tinker-console.thinkingmachines.ai/keys
|
||||
WANDB_API_KEY=your-wandb-key # Get from https://wandb.ai/authorize
|
||||
OPENROUTER_API_KEY=your-key # Optional: for rl_test_inference
|
||||
```
|
||||
|
||||
2. **That's it!** tinker-atropos is included as a submodule - no separate installation needed.
|
||||
|
||||
#### Using RL Tools
|
||||
|
||||
The agent can now use RL training tools:
|
||||
|
||||
```
|
||||
You: Start training on GSM8k with group_size=16
|
||||
|
||||
Agent: I'll set up an RL training run on the GSM8k environment...
|
||||
[Uses rl_list_environments, rl_select_environment, rl_edit_config, rl_start_training]
|
||||
```
|
||||
|
||||
#### Available RL Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `rl_list_environments` | List available RL environments |
|
||||
| `rl_select_environment` | Select an environment for training |
|
||||
| `rl_get_current_config` | View all configurable options |
|
||||
| `rl_edit_config` | Change a configuration value |
|
||||
| `rl_test_inference` | Test environment with OpenRouter (pre-training validation) |
|
||||
| `rl_start_training` | Start a training run |
|
||||
| `rl_check_status` | Check training progress |
|
||||
| `rl_stop_training` | Stop a running training |
|
||||
| `rl_get_results` | Fetch WandB metrics |
|
||||
| `rl_list_runs` | List active training runs |
|
||||
|
||||
#### Dedicated RL CLI
|
||||
|
||||
For extended RL workflows with longer timeouts:
|
||||
|
||||
```bash
|
||||
python rl_cli.py --model "anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### ⏰ Scheduled Tasks (Cron)
|
||||
|
||||
Schedule tasks to run automatically:
|
||||
|
|
@ -378,7 +428,7 @@ skills/
|
|||
If you prefer not to use the installer:
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
# Clone the repository (with submodules)
|
||||
git clone --recurse-submodules https://github.com/NousResearch/hermes-agent.git
|
||||
cd hermes-agent
|
||||
|
||||
|
|
@ -389,6 +439,11 @@ cd hermes-agent
|
|||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
pip install -e ".[all]"
|
||||
|
||||
# Install submodules (required for terminal and RL tools)
|
||||
pip install -e "./mini-swe-agent" # Terminal tool backend
|
||||
pip install -e "./tinker-atropos" # RL training backend
|
||||
|
||||
hermes setup
|
||||
```
|
||||
|
||||
|
|
|
|||
|
|
@ -151,6 +151,20 @@ OPTIONAL_ENV_VARS = {
|
|||
"tools": ["image_generate"],
|
||||
"password": True,
|
||||
},
|
||||
"TINKER_API_KEY": {
|
||||
"description": "Tinker API key for RL training",
|
||||
"prompt": "Tinker API key",
|
||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
||||
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
||||
"password": True,
|
||||
},
|
||||
"WANDB_API_KEY": {
|
||||
"description": "Weights & Biases API key for experiment tracking",
|
||||
"prompt": "WandB API key",
|
||||
"url": "https://wandb.ai/authorize",
|
||||
"tools": ["rl_get_results", "rl_check_status"],
|
||||
"password": True,
|
||||
},
|
||||
"OPENAI_BASE_URL": {
|
||||
"description": "Custom OpenAI-compatible API endpoint URL",
|
||||
"prompt": "API base URL (e.g., https://api.example.com/v1)",
|
||||
|
|
|
|||
|
|
@ -167,6 +167,13 @@ def run_doctor(args):
|
|||
else:
|
||||
check_warn("git not found", "(optional)")
|
||||
|
||||
# ripgrep (optional, for faster file search)
|
||||
if shutil.which("rg"):
|
||||
check_ok("ripgrep (rg)", "(faster file search)")
|
||||
else:
|
||||
check_warn("ripgrep (rg) not found", "(file search uses grep fallback)")
|
||||
check_info("Install for faster search: sudo apt install ripgrep")
|
||||
|
||||
# Docker (optional)
|
||||
terminal_env = os.getenv("TERMINAL_ENV", "local")
|
||||
if terminal_env == "docker":
|
||||
|
|
|
|||
|
|
@ -186,6 +186,14 @@ def _print_setup_summary(config: dict, hermes_home):
|
|||
else:
|
||||
tool_status.append(("Image Generation", False, "FAL_KEY"))
|
||||
|
||||
# Tinker + WandB (RL training)
|
||||
if get_env_value('TINKER_API_KEY') and get_env_value('WANDB_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", True, None))
|
||||
elif get_env_value('TINKER_API_KEY'):
|
||||
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
||||
else:
|
||||
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
||||
|
||||
# Terminal (always available if system deps met)
|
||||
tool_status.append(("Terminal/Commands", True, None))
|
||||
|
||||
|
|
@ -932,6 +940,47 @@ def run_setup_wizard(args):
|
|||
if api_key:
|
||||
save_env_value("FAL_KEY", api_key)
|
||||
print_success(" Configured ✓")
|
||||
print()
|
||||
|
||||
# Tinker + WandB - RL Training
|
||||
print_info("─" * 50)
|
||||
print(color(" RL Training (Tinker + WandB)", Colors.CYAN))
|
||||
print_info(" Enables: rl_start_training, rl_check_status, rl_get_results tools")
|
||||
print_info(" Use case: Run reinforcement learning training via Tinker API")
|
||||
tinker_configured = get_env_value('TINKER_API_KEY')
|
||||
wandb_configured = get_env_value('WANDB_API_KEY')
|
||||
|
||||
if tinker_configured and wandb_configured:
|
||||
print_success(" Status: Configured ✓")
|
||||
if prompt_yes_no(" Update RL training credentials?", False):
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
print_success(" Updated")
|
||||
else:
|
||||
if tinker_configured:
|
||||
print_warning(" Status: Tinker configured, WandB missing")
|
||||
elif wandb_configured:
|
||||
print_warning(" Status: WandB configured, Tinker missing")
|
||||
else:
|
||||
print_warning(" Status: Not configured (tools will be disabled)")
|
||||
|
||||
if prompt_yes_no(" Set up RL Training?", False):
|
||||
print_info(" Get Tinker key at: https://tinker-console.thinkingmachines.ai/keys")
|
||||
print_info(" Get WandB key at: https://wandb.ai/authorize")
|
||||
api_key = prompt(" Tinker API key", password=True)
|
||||
if api_key:
|
||||
save_env_value("TINKER_API_KEY", api_key)
|
||||
wandb_key = prompt(" WandB API key", password=True)
|
||||
if wandb_key:
|
||||
save_env_value("WANDB_API_KEY", wandb_key)
|
||||
if api_key and wandb_key:
|
||||
print_success(" Configured ✓")
|
||||
else:
|
||||
print_warning(" Partially configured (both keys required)")
|
||||
|
||||
# =========================================================================
|
||||
# Save config and show summary
|
||||
|
|
|
|||
|
|
@ -74,6 +74,8 @@ def show_status(args):
|
|||
"Firecrawl": "FIRECRAWL_API_KEY",
|
||||
"Browserbase": "BROWSERBASE_API_KEY",
|
||||
"FAL": "FAL_KEY",
|
||||
"Tinker": "TINKER_API_KEY",
|
||||
"WandB": "WANDB_API_KEY",
|
||||
}
|
||||
|
||||
for name, env_var in keys.items():
|
||||
|
|
|
|||
598
model_tools.py
598
model_tools.py
|
|
@ -33,12 +33,29 @@ from typing import Dict, Any, List, Optional, Tuple
|
|||
|
||||
from tools.web_tools import web_search_tool, web_extract_tool, web_crawl_tool, check_firecrawl_api_key
|
||||
from tools.terminal_tool import terminal_tool, check_terminal_requirements, TERMINAL_TOOL_DESCRIPTION, cleanup_vm
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from tools.file_tools import read_file_tool, write_file_tool, patch_tool, search_tool
|
||||
from tools import check_file_requirements
|
||||
# Hecate/MorphCloud terminal tool (cloud VMs) - available as alternative backend
|
||||
from tools.terminal_hecate import terminal_hecate_tool, check_hecate_requirements, TERMINAL_HECATE_DESCRIPTION
|
||||
from tools.vision_tools import vision_analyze_tool, check_vision_requirements
|
||||
from tools.mixture_of_agents_tool import mixture_of_agents_tool, check_moa_requirements
|
||||
from tools.image_generation_tool import image_generate_tool, check_image_generation_requirements
|
||||
from tools.skills_tool import skills_categories, skills_list, skill_view, check_skills_requirements, SKILLS_TOOL_DESCRIPTION
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from tools.rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
)
|
||||
# Cronjob management tools (CLI-only)
|
||||
from tools.cronjob_tools import (
|
||||
schedule_cronjob,
|
||||
|
|
@ -128,6 +145,26 @@ TOOLSET_REQUIREMENTS = {
|
|||
"setup_url": None,
|
||||
"tools": ["skills_categories", "skills_list", "skill_view"],
|
||||
},
|
||||
"rl": {
|
||||
"name": "RL Training (Tinker-Atropos)",
|
||||
"env_vars": ["TINKER_API_KEY", "WANDB_API_KEY"],
|
||||
"check_fn": check_rl_api_keys,
|
||||
"setup_url": "https://wandb.ai/authorize",
|
||||
"tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference",
|
||||
],
|
||||
},
|
||||
"file": {
|
||||
"name": "File Operations (read, write, patch, search)",
|
||||
"env_vars": [], # Uses terminal backend, no additional requirements
|
||||
"check_fn": check_file_requirements,
|
||||
"setup_url": None,
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -471,6 +508,340 @@ def get_cronjob_tool_definitions_formatted() -> List[Dict[str, Any]]:
|
|||
]]
|
||||
|
||||
|
||||
def get_rl_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for RL training tools in OpenAI's expected format.
|
||||
|
||||
These tools enable running RL training through Tinker-Atropos.
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of RL tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_environments",
|
||||
"description": "List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_select_environment",
|
||||
"description": "Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Name of the environment to select (from rl_list_environments)"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_current_config",
|
||||
"description": "Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_edit_config",
|
||||
"description": "Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_rate) are locked.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {
|
||||
"type": "string",
|
||||
"description": "Name of the field to update (get available fields from rl_get_current_config)"
|
||||
},
|
||||
"value": {
|
||||
"description": "New value for the field"
|
||||
}
|
||||
},
|
||||
"required": ["field", "value"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_start_training",
|
||||
"description": "Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training takes hours.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_check_status",
|
||||
"description": "Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID from rl_start_training()"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_stop_training",
|
||||
"description": "Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to stop"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_get_results",
|
||||
"description": "Get final results and metrics for a completed training run. Returns final metrics and path to trained weights.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"run_id": {
|
||||
"type": "string",
|
||||
"description": "The run ID to get results for"
|
||||
}
|
||||
},
|
||||
"required": ["run_id"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_list_runs",
|
||||
"description": "List all training runs (active and completed) with their status.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "rl_test_inference",
|
||||
"description": "Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps × 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, inference parsing, and verifier logic. Use BEFORE training to catch issues.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"num_steps": {
|
||||
"type": "integer",
|
||||
"description": "Number of steps to run (default: 3, recommended max for testing)",
|
||||
"default": 3
|
||||
},
|
||||
"group_size": {
|
||||
"type": "integer",
|
||||
"description": "Completions per step (default: 16, like training)",
|
||||
"default": 16
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Optional list of OpenRouter model IDs. Default: qwen/qwen3-8b, z-ai/glm-4.7-flash, minimax/minimax-m2.1"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_file_tool_definitions() -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get tool definitions for file manipulation tools in OpenAI's expected format.
|
||||
|
||||
File tools operate via the terminal backend and support any environment
|
||||
(local, docker, singularity, ssh, modal).
|
||||
|
||||
Returns:
|
||||
List[Dict]: List of file tool definitions compatible with OpenAI API
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file with pagination support. Returns content with line numbers in 'LINE_NUM|CONTENT' format. For binary files (images), returns base64-encoded data. If file not found, suggests similar filenames.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to read (absolute or relative)"
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Line number to start reading from (1-indexed, default: 1)",
|
||||
"default": 1,
|
||||
"minimum": 1
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of lines to read (default: 500, max: 2000)",
|
||||
"default": 500,
|
||||
"maximum": 2000
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "write_file",
|
||||
"description": "Write content to a file. Creates parent directories automatically. Returns bytes written and lint check results for supported languages.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Path to the file to write (will be created if doesn't exist)"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Content to write to the file"
|
||||
}
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "patch",
|
||||
"description": "Modify files using either simple string replacement or V4A patch format. Mode 'replace' does find-and-replace with fuzzy matching. Mode 'patch' applies multi-file changes using V4A format (*** Begin/End Patch). Auto-runs syntax checks on modified files.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"mode": {
|
||||
"type": "string",
|
||||
"enum": ["replace", "patch"],
|
||||
"description": "Edit mode: 'replace' for string replacement, 'patch' for V4A patch format",
|
||||
"default": "replace"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path (required for 'replace' mode)"
|
||||
},
|
||||
"old_string": {
|
||||
"type": "string",
|
||||
"description": "Text to find and replace (required for 'replace' mode). Must be unique in file unless replace_all=true"
|
||||
},
|
||||
"new_string": {
|
||||
"type": "string",
|
||||
"description": "Replacement text (required for 'replace' mode)"
|
||||
},
|
||||
"replace_all": {
|
||||
"type": "boolean",
|
||||
"description": "Replace all occurrences instead of requiring unique match (default: false)",
|
||||
"default": False
|
||||
},
|
||||
"patch": {
|
||||
"type": "string",
|
||||
"description": "V4A format patch content (required for 'patch' mode). Format: *** Begin Patch / *** Update File: path / @@ context @@ / -removed / +added / *** End Patch"
|
||||
}
|
||||
},
|
||||
"required": ["mode"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"description": "Search for content in files or search for files by name. Use target='content' to search inside files (like grep), or target='files' to find files by name pattern (like glob/find). Results sorted by modification time (newest first).",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "For target='content': regex pattern to search for. For target='files': glob pattern (e.g., '*.py', '*config*')"
|
||||
},
|
||||
"target": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files"],
|
||||
"description": "Search mode: 'content' searches inside files, 'files' searches for files by name",
|
||||
"default": "content"
|
||||
},
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "Directory or file to search in (default: current directory)",
|
||||
"default": "."
|
||||
},
|
||||
"file_glob": {
|
||||
"type": "string",
|
||||
"description": "Filter files by pattern when target='content' (e.g., '*.py' to only search Python files)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Maximum number of results (default: 50)",
|
||||
"default": 50
|
||||
},
|
||||
"offset": {
|
||||
"type": "integer",
|
||||
"description": "Skip first N results for pagination (default: 0)",
|
||||
"default": 0
|
||||
},
|
||||
"output_mode": {
|
||||
"type": "string",
|
||||
"enum": ["content", "files_only", "count"],
|
||||
"description": "For target='content': 'content' shows matches, 'files_only' shows file paths, 'count' shows match counts per file",
|
||||
"default": "content"
|
||||
},
|
||||
"context": {
|
||||
"type": "integer",
|
||||
"description": "Lines of context around matches (only for target='content', output_mode='content')",
|
||||
"default": 0
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def get_all_tool_names() -> List[str]:
|
||||
"""
|
||||
Get the names of all available tools across all toolsets.
|
||||
|
|
@ -519,6 +890,22 @@ def get_all_tool_names() -> List[str]:
|
|||
"schedule_cronjob", "list_cronjobs", "remove_cronjob"
|
||||
])
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
tool_names.extend([
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
])
|
||||
|
||||
# File manipulation tools (use terminal backend)
|
||||
if check_file_requirements():
|
||||
tool_names.extend([
|
||||
"read_file", "write_file", "patch", "search"
|
||||
])
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
|
|
@ -557,7 +944,22 @@ def get_toolset_for_tool(tool_name: str) -> str:
|
|||
# Cronjob management tools
|
||||
"schedule_cronjob": "cronjob_tools",
|
||||
"list_cronjobs": "cronjob_tools",
|
||||
"remove_cronjob": "cronjob_tools"
|
||||
"remove_cronjob": "cronjob_tools",
|
||||
# RL Training tools
|
||||
"rl_list_environments": "rl_tools",
|
||||
"rl_select_environment": "rl_tools",
|
||||
"rl_get_current_config": "rl_tools",
|
||||
"rl_edit_config": "rl_tools",
|
||||
"rl_start_training": "rl_tools",
|
||||
"rl_check_status": "rl_tools",
|
||||
"rl_stop_training": "rl_tools",
|
||||
"rl_get_results": "rl_tools",
|
||||
"rl_list_runs": "rl_tools",
|
||||
# File manipulation tools
|
||||
"read_file": "file_tools",
|
||||
"write_file": "file_tools",
|
||||
"patch": "file_tools",
|
||||
"search": "file_tools",
|
||||
}
|
||||
|
||||
return toolset_mapping.get(tool_name, "unknown")
|
||||
|
|
@ -635,6 +1037,16 @@ def get_tool_definitions(
|
|||
for tool in get_cronjob_tool_definitions_formatted():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# RL Training tools
|
||||
if check_rl_api_keys():
|
||||
for tool in get_rl_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# File manipulation tools (use terminal backend)
|
||||
if check_file_requirements():
|
||||
for tool in get_file_tool_definitions():
|
||||
all_available_tools_map[tool["function"]["name"]] = tool
|
||||
|
||||
# Determine which tools to include based on toolsets
|
||||
tools_to_include = set()
|
||||
|
||||
|
|
@ -663,7 +1075,15 @@ def get_tool_definitions(
|
|||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.update(legacy_tools)
|
||||
|
|
@ -708,7 +1128,15 @@ def get_tool_definitions(
|
|||
"browser_press", "browser_close", "browser_get_images",
|
||||
"browser_vision"
|
||||
],
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]
|
||||
"cronjob_tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"rl_tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"file_tools": ["read_file", "write_file", "patch", "search"]
|
||||
}
|
||||
legacy_tools = legacy_map.get(toolset_name, [])
|
||||
tools_to_include.difference_update(legacy_tools)
|
||||
|
|
@ -1018,6 +1446,147 @@ def handle_cronjob_function_call(
|
|||
return json.dumps({"error": f"Unknown cronjob function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_rl_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for RL training tools.
|
||||
|
||||
These tools communicate with the RL API server to manage training runs.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the RL function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
# Run async functions in event loop
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
if function_name == "rl_list_environments":
|
||||
return loop.run_until_complete(rl_list_environments())
|
||||
|
||||
elif function_name == "rl_select_environment":
|
||||
return loop.run_until_complete(
|
||||
rl_select_environment(name=function_args.get("name", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_current_config":
|
||||
return loop.run_until_complete(rl_get_current_config())
|
||||
|
||||
elif function_name == "rl_edit_config":
|
||||
return loop.run_until_complete(
|
||||
rl_edit_config(
|
||||
field=function_args.get("field", ""),
|
||||
value=function_args.get("value")
|
||||
)
|
||||
)
|
||||
|
||||
elif function_name == "rl_start_training":
|
||||
return loop.run_until_complete(rl_start_training())
|
||||
|
||||
elif function_name == "rl_check_status":
|
||||
return loop.run_until_complete(
|
||||
rl_check_status(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_stop_training":
|
||||
return loop.run_until_complete(
|
||||
rl_stop_training(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_get_results":
|
||||
return loop.run_until_complete(
|
||||
rl_get_results(run_id=function_args.get("run_id", ""))
|
||||
)
|
||||
|
||||
elif function_name == "rl_list_runs":
|
||||
return loop.run_until_complete(rl_list_runs())
|
||||
|
||||
elif function_name == "rl_test_inference":
|
||||
return loop.run_until_complete(
|
||||
rl_test_inference(
|
||||
num_steps=function_args.get("num_steps", 3),
|
||||
group_size=function_args.get("group_size", 16),
|
||||
models=function_args.get("models"),
|
||||
)
|
||||
)
|
||||
|
||||
return json.dumps({"error": f"Unknown RL function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_file_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Handle function calls for file manipulation tools.
|
||||
|
||||
These tools use the terminal backend for all operations, supporting
|
||||
local, docker, singularity, ssh, and modal environments.
|
||||
|
||||
Args:
|
||||
function_name (str): Name of the file function to call
|
||||
function_args (Dict): Arguments for the function
|
||||
task_id (str): Task identifier for environment isolation
|
||||
|
||||
Returns:
|
||||
str: Function result as JSON string
|
||||
"""
|
||||
# Determine task_id to use
|
||||
tid = task_id or "default"
|
||||
|
||||
if function_name == "read_file":
|
||||
return read_file_tool(
|
||||
path=function_args.get("path", ""),
|
||||
offset=function_args.get("offset", 1),
|
||||
limit=function_args.get("limit", 500),
|
||||
task_id=tid
|
||||
)
|
||||
|
||||
elif function_name == "write_file":
|
||||
return write_file_tool(
|
||||
path=function_args.get("path", ""),
|
||||
content=function_args.get("content", ""),
|
||||
task_id=tid
|
||||
)
|
||||
|
||||
elif function_name == "patch":
|
||||
return patch_tool(
|
||||
mode=function_args.get("mode", "replace"),
|
||||
path=function_args.get("path"),
|
||||
old_string=function_args.get("old_string"),
|
||||
new_string=function_args.get("new_string"),
|
||||
replace_all=function_args.get("replace_all", False),
|
||||
patch=function_args.get("patch"),
|
||||
task_id=tid
|
||||
)
|
||||
|
||||
elif function_name == "search":
|
||||
return search_tool(
|
||||
pattern=function_args.get("pattern", ""),
|
||||
target=function_args.get("target", "content"),
|
||||
path=function_args.get("path", "."),
|
||||
file_glob=function_args.get("file_glob"),
|
||||
limit=function_args.get("limit", 50),
|
||||
offset=function_args.get("offset", 0),
|
||||
output_mode=function_args.get("output_mode", "content"),
|
||||
context=function_args.get("context", 0),
|
||||
task_id=tid
|
||||
)
|
||||
|
||||
return json.dumps({"error": f"Unknown file function: {function_name}"}, ensure_ascii=False)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
|
|
@ -1081,6 +1650,20 @@ def handle_function_call(
|
|||
elif function_name in ["schedule_cronjob", "list_cronjobs", "remove_cronjob"]:
|
||||
return handle_cronjob_function_call(function_name, function_args, task_id)
|
||||
|
||||
# Route RL training tools
|
||||
elif function_name in [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
]:
|
||||
return handle_rl_function_call(function_name, function_args)
|
||||
|
||||
# Route file manipulation tools
|
||||
elif function_name in ["read_file", "write_file", "patch", "search"]:
|
||||
return handle_file_function_call(function_name, function_args, task_id)
|
||||
|
||||
else:
|
||||
error_msg = f"Unknown function: {function_name}"
|
||||
print(f"❌ {error_msg}")
|
||||
|
|
@ -1152,6 +1735,12 @@ def get_available_toolsets() -> Dict[str, Dict[str, Any]]:
|
|||
"tools": ["schedule_cronjob", "list_cronjobs", "remove_cronjob"],
|
||||
"description": "Schedule and manage automated tasks (cronjobs) - only available in interactive CLI mode",
|
||||
"requirements": ["HERMES_INTERACTIVE=1 (set automatically by cli.py)"]
|
||||
},
|
||||
"file_tools": {
|
||||
"available": check_file_requirements(),
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
"description": "File manipulation tools: read/write files, search content/files, patch with fuzzy matching",
|
||||
"requirements": ["Terminal backend available (local/docker/ssh/singularity/modal)"]
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1172,7 +1761,8 @@ def check_toolset_requirements() -> Dict[str, bool]:
|
|||
"image_tools": check_image_generation_requirements(),
|
||||
"skills_tools": check_skills_requirements(),
|
||||
"browser_tools": check_browser_requirements(),
|
||||
"cronjob_tools": check_cronjob_requirements()
|
||||
"cronjob_tools": check_cronjob_requirements(),
|
||||
"file_tools": check_file_requirements()
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
448
rl_cli.py
Normal file
448
rl_cli.py
Normal file
|
|
@ -0,0 +1,448 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
RL Training CLI Runner
|
||||
|
||||
Dedicated CLI runner for RL training workflows with:
|
||||
- Extended timeouts for long-running training
|
||||
- RL-focused system prompts
|
||||
- Full toolset including RL training tools
|
||||
- Special handling for 30-minute check intervals
|
||||
|
||||
Usage:
|
||||
python rl_cli.py "Train a model on GSM8k for math reasoning"
|
||||
python rl_cli.py --interactive
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
Environment Variables:
|
||||
TINKER_API_KEY: API key for Tinker service (required)
|
||||
WANDB_API_KEY: API key for WandB metrics (required)
|
||||
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import yaml
|
||||
|
||||
# Load environment variables from .env file
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load from ~/.hermes/.env first, then local .env
|
||||
hermes_env_path = Path.home() / '.hermes' / '.env'
|
||||
local_env_path = Path(__file__).parent / '.env'
|
||||
|
||||
if hermes_env_path.exists():
|
||||
load_dotenv(dotenv_path=hermes_env_path)
|
||||
print(f"✅ Loaded environment variables from {hermes_env_path}")
|
||||
elif local_env_path.exists():
|
||||
load_dotenv(dotenv_path=local_env_path)
|
||||
print(f"✅ Loaded environment variables from {local_env_path}")
|
||||
|
||||
# Set terminal working directory to tinker-atropos submodule
|
||||
# This ensures terminal commands run in the right context for RL work
|
||||
tinker_atropos_dir = Path(__file__).parent / 'tinker-atropos'
|
||||
if tinker_atropos_dir.exists():
|
||||
os.environ['TERMINAL_CWD'] = str(tinker_atropos_dir)
|
||||
os.environ['HERMES_QUIET'] = '1' # Disable temp subdirectory creation
|
||||
print(f"📂 Terminal working directory: {tinker_atropos_dir}")
|
||||
else:
|
||||
# Fall back to hermes-agent directory if submodule not found
|
||||
os.environ['TERMINAL_CWD'] = str(Path(__file__).parent)
|
||||
os.environ['HERMES_QUIET'] = '1'
|
||||
print(f"⚠️ tinker-atropos submodule not found, using: {Path(__file__).parent}")
|
||||
|
||||
# Import agent and tools
|
||||
from run_agent import AIAgent
|
||||
from model_tools import get_tool_definitions, check_toolset_requirements
|
||||
from tools.rl_training_tool import check_rl_api_keys, get_missing_keys
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Loading
|
||||
# ============================================================================
|
||||
|
||||
DEFAULT_MODEL = "anthropic/claude-opus-4.5"
|
||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||
|
||||
|
||||
def load_hermes_config() -> dict:
|
||||
"""
|
||||
Load configuration from ~/.hermes/config.yaml.
|
||||
|
||||
Returns:
|
||||
dict: Configuration with model, base_url, etc.
|
||||
"""
|
||||
config_path = Path.home() / '.hermes' / 'config.yaml'
|
||||
|
||||
config = {
|
||||
"model": DEFAULT_MODEL,
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
}
|
||||
|
||||
if config_path.exists():
|
||||
try:
|
||||
with open(config_path, "r") as f:
|
||||
file_config = yaml.safe_load(f) or {}
|
||||
|
||||
# Get model from config
|
||||
if "model" in file_config:
|
||||
if isinstance(file_config["model"], str):
|
||||
config["model"] = file_config["model"]
|
||||
elif isinstance(file_config["model"], dict):
|
||||
config["model"] = file_config["model"].get("default", DEFAULT_MODEL)
|
||||
|
||||
# Get base_url if specified
|
||||
if "base_url" in file_config:
|
||||
config["base_url"] = file_config["base_url"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Warning: Failed to load config.yaml: {e}")
|
||||
|
||||
return config
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RL-Specific Configuration
|
||||
# ============================================================================
|
||||
|
||||
# Extended timeouts for long-running RL operations
|
||||
RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows
|
||||
|
||||
# RL-focused system prompt
|
||||
RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models.
|
||||
|
||||
## Your Capabilities
|
||||
|
||||
You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos:
|
||||
|
||||
1. **DISCOVER**: Use `rl_list_environments` to see available RL environments
|
||||
2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards)
|
||||
3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format
|
||||
4. **CREATE**: Copy existing environments as templates, modify for your needs
|
||||
5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training
|
||||
6. **TEST**: Always use `rl_test_inference` before full training to validate your setup
|
||||
7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor
|
||||
8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance
|
||||
|
||||
## Environment Files
|
||||
|
||||
Environment files are located in: `tinker-atropos/tinker_atropos/environments/`
|
||||
|
||||
Study existing environments to learn patterns. Look for:
|
||||
- `load_dataset()` calls - how data is loaded
|
||||
- `score_answer()` / `score()` - verification logic
|
||||
- `get_next_item()` - prompt formatting
|
||||
- `system_prompt` - instruction format
|
||||
- `config_init()` - default configuration
|
||||
|
||||
## Creating New Environments
|
||||
|
||||
To create a new environment:
|
||||
1. Read an existing environment file (e.g., gsm8k_tinker.py)
|
||||
2. Use terminal to explore the target dataset format
|
||||
3. Copy the environment file as a template
|
||||
4. Modify the dataset loading, prompt formatting, and verifier logic
|
||||
5. Test with `rl_test_inference` before training
|
||||
|
||||
## Important Guidelines
|
||||
|
||||
- **Always test before training**: Training runs take hours - verify everything works first
|
||||
- **Monitor metrics**: Check WandB for reward/mean and percent_correct
|
||||
- **Status check intervals**: Wait at least 30 minutes between status checks
|
||||
- **Early stopping**: Stop training early if metrics look bad or stagnant
|
||||
- **Iterate quickly**: Start with small total_steps to validate, then scale up
|
||||
|
||||
## Available Toolsets
|
||||
|
||||
You have access to:
|
||||
- **RL tools**: Environment discovery, config management, training, testing
|
||||
- **Terminal**: Run commands, inspect files, explore datasets
|
||||
- **Web**: Search for information, documentation, papers
|
||||
- **File tools**: Read and modify code files
|
||||
|
||||
When asked to train a model, follow this workflow:
|
||||
1. List available environments
|
||||
2. Select and configure the appropriate environment
|
||||
3. Test with sample prompts
|
||||
4. Start training with conservative settings
|
||||
5. Monitor progress and adjust as needed
|
||||
"""
|
||||
|
||||
# Toolsets to enable for RL workflows
|
||||
RL_TOOLSETS = ["terminal", "web", "rl"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Functions
|
||||
# ============================================================================
|
||||
|
||||
def check_requirements():
|
||||
"""Check that all required environment variables and services are available."""
|
||||
errors = []
|
||||
|
||||
# Check API keys
|
||||
if not os.getenv("OPENROUTER_API_KEY"):
|
||||
errors.append("OPENROUTER_API_KEY not set - required for agent")
|
||||
|
||||
missing_rl_keys = get_missing_keys()
|
||||
if missing_rl_keys:
|
||||
errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}")
|
||||
|
||||
if errors:
|
||||
print("❌ Missing requirements:")
|
||||
for error in errors:
|
||||
print(f" - {error}")
|
||||
print("\nPlease set these environment variables in your .env file or shell.")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def check_tinker_atropos():
|
||||
"""Check if tinker-atropos submodule is properly set up."""
|
||||
tinker_path = Path(__file__).parent / "tinker-atropos"
|
||||
|
||||
if not tinker_path.exists():
|
||||
return False, "tinker-atropos submodule not found. Run: git submodule update --init"
|
||||
|
||||
envs_path = tinker_path / "tinker_atropos" / "environments"
|
||||
if not envs_path.exists():
|
||||
return False, f"environments directory not found at {envs_path}"
|
||||
|
||||
env_files = list(envs_path.glob("*.py"))
|
||||
env_files = [f for f in env_files if not f.name.startswith("_")]
|
||||
|
||||
return True, {"path": str(tinker_path), "environments_count": len(env_files)}
|
||||
|
||||
|
||||
def list_environments_sync():
|
||||
"""List available environments (synchronous wrapper)."""
|
||||
from tools.rl_training_tool import rl_list_environments
|
||||
import json
|
||||
|
||||
async def _list():
|
||||
result = await rl_list_environments()
|
||||
return json.loads(result)
|
||||
|
||||
return asyncio.run(_list())
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main CLI
|
||||
# ============================================================================
|
||||
|
||||
def main(
|
||||
task: str = None,
|
||||
model: str = None,
|
||||
api_key: str = None,
|
||||
base_url: str = None,
|
||||
max_iterations: int = RL_MAX_ITERATIONS,
|
||||
interactive: bool = False,
|
||||
list_environments: bool = False,
|
||||
check_server: bool = False,
|
||||
verbose: bool = False,
|
||||
save_trajectories: bool = True,
|
||||
):
|
||||
"""
|
||||
RL Training CLI - Dedicated runner for RL training workflows.
|
||||
|
||||
Args:
|
||||
task: The training task/goal (e.g., "Train a model on GSM8k for math")
|
||||
model: Model to use for the agent (reads from ~/.hermes/config.yaml if not provided)
|
||||
api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided)
|
||||
base_url: API base URL (reads from config or defaults to OpenRouter)
|
||||
max_iterations: Maximum agent iterations (default: 200 for long workflows)
|
||||
interactive: Run in interactive mode (multiple conversations)
|
||||
list_environments: Just list available RL environments and exit
|
||||
check_server: Check if RL API server is running and exit
|
||||
verbose: Enable verbose logging
|
||||
save_trajectories: Save conversation trajectories (default: True for RL)
|
||||
|
||||
Examples:
|
||||
# Train on a specific environment
|
||||
python rl_cli.py "Train a model on GSM8k math problems"
|
||||
|
||||
# Interactive mode
|
||||
python rl_cli.py --interactive
|
||||
|
||||
# List available environments
|
||||
python rl_cli.py --list-environments
|
||||
|
||||
# Check server status
|
||||
python rl_cli.py --check-server
|
||||
"""
|
||||
# Load config from ~/.hermes/config.yaml
|
||||
config = load_hermes_config()
|
||||
|
||||
# Use config values if not explicitly provided
|
||||
if model is None:
|
||||
model = config["model"]
|
||||
if base_url is None:
|
||||
base_url = config["base_url"]
|
||||
|
||||
print("🎯 RL Training Agent")
|
||||
print("=" * 60)
|
||||
|
||||
# Handle setup check
|
||||
if check_server:
|
||||
print("\n🔍 Checking tinker-atropos setup...")
|
||||
ok, result = check_tinker_atropos()
|
||||
if ok:
|
||||
print("✅ tinker-atropos submodule found")
|
||||
print(f" Path: {result.get('path')}")
|
||||
print(f" Environments found: {result.get('environments_count', 0)}")
|
||||
|
||||
# Also check API keys
|
||||
missing = get_missing_keys()
|
||||
if missing:
|
||||
print(f"\n⚠️ Missing API keys: {', '.join(missing)}")
|
||||
print(" Add them to ~/.hermes/.env")
|
||||
else:
|
||||
print("✅ API keys configured")
|
||||
else:
|
||||
print(f"❌ tinker-atropos not set up: {result}")
|
||||
print("\nTo set up:")
|
||||
print(" git submodule update --init")
|
||||
print(" pip install -e ./tinker-atropos")
|
||||
return
|
||||
|
||||
# Handle environment listing
|
||||
if list_environments:
|
||||
print("\n📋 Available RL Environments:")
|
||||
print("-" * 40)
|
||||
try:
|
||||
data = list_environments_sync()
|
||||
if "error" in data:
|
||||
print(f"❌ Error: {data['error']}")
|
||||
return
|
||||
|
||||
envs = data.get("environments", [])
|
||||
if not envs:
|
||||
print("No environments found.")
|
||||
print("\nMake sure tinker-atropos is set up:")
|
||||
print(" git submodule update --init")
|
||||
return
|
||||
|
||||
for env in envs:
|
||||
print(f"\n 📦 {env['name']}")
|
||||
print(f" Class: {env['class_name']}")
|
||||
print(f" Path: {env['file_path']}")
|
||||
if env.get('description'):
|
||||
desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '')
|
||||
print(f" Description: {desc}")
|
||||
|
||||
print(f"\n📊 Total: {len(envs)} environments")
|
||||
print("\nUse `rl_select_environment(name)` to select an environment for training.")
|
||||
except Exception as e:
|
||||
print(f"❌ Error listing environments: {e}")
|
||||
print("\nMake sure tinker-atropos is set up:")
|
||||
print(" git submodule update --init")
|
||||
print(" pip install -e ./tinker-atropos")
|
||||
return
|
||||
|
||||
# Check requirements
|
||||
if not check_requirements():
|
||||
sys.exit(1)
|
||||
|
||||
# Set default task if none provided
|
||||
if not task and not interactive:
|
||||
print("\n⚠️ No task provided. Use --interactive for interactive mode or provide a task.")
|
||||
print("\nExamples:")
|
||||
print(' python rl_cli.py "Train a model on GSM8k math problems"')
|
||||
print(' python rl_cli.py "Create an RL environment for code generation"')
|
||||
print(' python rl_cli.py --interactive')
|
||||
return
|
||||
|
||||
# Get API key
|
||||
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
||||
if not api_key:
|
||||
print("❌ No API key provided. Set OPENROUTER_API_KEY or pass --api-key")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"\n🤖 Model: {model}")
|
||||
print(f"🔧 Max iterations: {max_iterations}")
|
||||
print(f"📁 Toolsets: {', '.join(RL_TOOLSETS)}")
|
||||
print("=" * 60)
|
||||
|
||||
# Create agent with RL configuration
|
||||
agent = AIAgent(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
max_iterations=max_iterations,
|
||||
enabled_toolsets=RL_TOOLSETS,
|
||||
save_trajectories=save_trajectories,
|
||||
verbose_logging=verbose,
|
||||
quiet_mode=False,
|
||||
ephemeral_system_prompt=RL_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
if interactive:
|
||||
# Interactive mode - multiple conversations
|
||||
print("\n🔄 Interactive RL Training Mode")
|
||||
print("Type 'quit' or 'exit' to end the session.")
|
||||
print("Type 'status' to check active training runs.")
|
||||
print("-" * 40)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n🎯 RL Task> ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ('quit', 'exit', 'q'):
|
||||
print("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'status':
|
||||
# Quick status check
|
||||
from tools.rl_training_tool import rl_list_runs
|
||||
import json
|
||||
result = asyncio.run(rl_list_runs())
|
||||
runs = json.loads(result)
|
||||
if isinstance(runs, list) and runs:
|
||||
print("\n📊 Active Runs:")
|
||||
for run in runs:
|
||||
print(f" - {run['run_id']}: {run['environment']} ({run['status']})")
|
||||
else:
|
||||
print("\nNo active runs.")
|
||||
continue
|
||||
|
||||
# Run the agent
|
||||
print("\n" + "=" * 60)
|
||||
response = agent.run_conversation(user_input)
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
else:
|
||||
# Single task mode
|
||||
print(f"\n📝 Task: {task}")
|
||||
print("-" * 40)
|
||||
|
||||
try:
|
||||
response = agent.run_conversation(task)
|
||||
print("\n" + "=" * 60)
|
||||
print("✅ Task completed")
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ Interrupted by user")
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
if verbose:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
||||
|
|
@ -1764,10 +1764,16 @@ class AIAgent:
|
|||
self._invalid_tool_retries = 0
|
||||
|
||||
# Validate tool call arguments are valid JSON
|
||||
# Handle empty strings as empty objects (common model quirk)
|
||||
invalid_json_args = []
|
||||
for tc in assistant_message.tool_calls:
|
||||
args = tc.function.arguments
|
||||
# Treat empty/whitespace strings as empty object
|
||||
if not args or not args.strip():
|
||||
tc.function.arguments = "{}"
|
||||
continue
|
||||
try:
|
||||
json.loads(tc.function.arguments)
|
||||
json.loads(args)
|
||||
except json.JSONDecodeError as e:
|
||||
invalid_json_args.append((tc.function.name, str(e)))
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,78 @@ function Test-Node {
|
|||
return $true # Don't fail - Node is optional
|
||||
}
|
||||
|
||||
function Test-Ripgrep {
|
||||
Write-Info "Checking ripgrep (optional, for faster file search)..."
|
||||
|
||||
if (Get-Command rg -ErrorAction SilentlyContinue) {
|
||||
$version = rg --version | Select-Object -First 1
|
||||
Write-Success "$version found"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
|
||||
Write-Warning "ripgrep not found (file search will use findstr fallback)"
|
||||
|
||||
# Check what package managers are available
|
||||
$hasWinget = Get-Command winget -ErrorAction SilentlyContinue
|
||||
$hasChoco = Get-Command choco -ErrorAction SilentlyContinue
|
||||
$hasScoop = Get-Command scoop -ErrorAction SilentlyContinue
|
||||
|
||||
# Offer to install
|
||||
Write-Host ""
|
||||
$response = Read-Host "Would you like to install ripgrep? (faster search, recommended) [Y/n]"
|
||||
|
||||
if ($response -eq "" -or $response -match "^[Yy]") {
|
||||
Write-Info "Installing ripgrep..."
|
||||
|
||||
if ($hasWinget) {
|
||||
try {
|
||||
winget install BurntSushi.ripgrep.MSVC --silent 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via winget"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
if ($hasChoco) {
|
||||
try {
|
||||
choco install ripgrep -y 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via chocolatey"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
if ($hasScoop) {
|
||||
try {
|
||||
scoop install ripgrep 2>&1 | Out-Null
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "ripgrep installed via scoop"
|
||||
$script:HasRipgrep = $true
|
||||
return $true
|
||||
}
|
||||
} catch { }
|
||||
}
|
||||
|
||||
Write-Warning "Auto-install failed. You can install manually:"
|
||||
} else {
|
||||
Write-Info "Skipping ripgrep installation. To install manually:"
|
||||
}
|
||||
|
||||
# Show manual install instructions
|
||||
Write-Info " winget install BurntSushi.ripgrep.MSVC"
|
||||
Write-Info " Or: choco install ripgrep"
|
||||
Write-Info " Or: scoop install ripgrep"
|
||||
Write-Info " Or download from: https://github.com/BurntSushi/ripgrep/releases"
|
||||
|
||||
$script:HasRipgrep = $false
|
||||
return $true # Don't fail - ripgrep is optional
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
|
|
@ -150,14 +222,15 @@ function Install-Repository {
|
|||
}
|
||||
} else {
|
||||
# Try SSH first (for private repo access), fall back to HTTPS
|
||||
# Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos
|
||||
Write-Info "Trying SSH clone..."
|
||||
$sshResult = git clone --branch $Branch $RepoUrlSsh $InstallDir 2>&1
|
||||
$sshResult = git clone --branch $Branch --recurse-submodules $RepoUrlSsh $InstallDir 2>&1
|
||||
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "Cloned via SSH"
|
||||
} else {
|
||||
Write-Info "SSH failed, trying HTTPS..."
|
||||
$httpsResult = git clone --branch $Branch $RepoUrlHttps $InstallDir 2>&1
|
||||
$httpsResult = git clone --branch $Branch --recurse-submodules $RepoUrlHttps $InstallDir 2>&1
|
||||
|
||||
if ($LASTEXITCODE -eq 0) {
|
||||
Write-Success "Cloned via HTTPS"
|
||||
|
|
@ -171,6 +244,13 @@ function Install-Repository {
|
|||
}
|
||||
}
|
||||
|
||||
# Ensure submodules are initialized and updated (for existing installs or if --recurse failed)
|
||||
Write-Info "Initializing submodules (mini-swe-agent, tinker-atropos)..."
|
||||
Push-Location $InstallDir
|
||||
git submodule update --init --recursive
|
||||
Pop-Location
|
||||
Write-Success "Submodules ready"
|
||||
|
||||
Write-Success "Repository ready"
|
||||
}
|
||||
|
||||
|
|
@ -208,15 +288,43 @@ function Install-Dependencies {
|
|||
& .\venv\Scripts\Activate.ps1
|
||||
}
|
||||
|
||||
# Install main package
|
||||
try {
|
||||
pip install -e ".[all]" 2>&1 | Out-Null
|
||||
} catch {
|
||||
pip install -e "." | Out-Null
|
||||
}
|
||||
|
||||
Write-Success "Main package installed"
|
||||
|
||||
# Install submodules
|
||||
Write-Info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if (Test-Path "mini-swe-agent\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\mini-swe-agent" 2>&1 | Out-Null
|
||||
Write-Success "mini-swe-agent installed"
|
||||
} catch {
|
||||
Write-Warning "mini-swe-agent install failed (terminal tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "mini-swe-agent not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Write-Info "Installing tinker-atropos (RL training backend)..."
|
||||
if (Test-Path "tinker-atropos\pyproject.toml") {
|
||||
try {
|
||||
pip install -e ".\tinker-atropos" 2>&1 | Out-Null
|
||||
Write-Success "tinker-atropos installed"
|
||||
} catch {
|
||||
Write-Warning "tinker-atropos install failed (RL tools may not work)"
|
||||
}
|
||||
} else {
|
||||
Write-Warning "tinker-atropos not found (run: git submodule update --init)"
|
||||
}
|
||||
|
||||
Pop-Location
|
||||
|
||||
Write-Success "Dependencies installed"
|
||||
Write-Success "All dependencies installed"
|
||||
}
|
||||
|
||||
function Set-PathVariable {
|
||||
|
|
@ -369,6 +477,20 @@ function Write-Completion {
|
|||
Write-Host ""
|
||||
Write-Host "⚡ Restart your terminal for PATH changes to take effect" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
|
||||
# Show notes about optional tools
|
||||
if (-not $HasNode) {
|
||||
Write-Host "Note: Node.js was not found. Browser automation tools" -ForegroundColor Yellow
|
||||
Write-Host "will have limited functionality." -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
}
|
||||
|
||||
if (-not $HasRipgrep) {
|
||||
Write-Host "Note: ripgrep (rg) was not found. File search will use" -ForegroundColor Yellow
|
||||
Write-Host "findstr as a fallback. For faster search:" -ForegroundColor Yellow
|
||||
Write-Host " winget install BurntSushi.ripgrep.MSVC" -ForegroundColor Yellow
|
||||
Write-Host ""
|
||||
}
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -380,7 +502,8 @@ function Main {
|
|||
|
||||
if (-not (Test-Python)) { exit 1 }
|
||||
if (-not (Test-Git)) { exit 1 }
|
||||
Test-Node # Optional, doesn't fail
|
||||
Test-Node # Optional, doesn't fail
|
||||
Test-Ripgrep # Optional, doesn't fail
|
||||
|
||||
Install-Repository
|
||||
Install-Venv
|
||||
|
|
|
|||
|
|
@ -271,6 +271,120 @@ check_node() {
|
|||
# Don't exit - Node is optional
|
||||
}
|
||||
|
||||
check_ripgrep() {
|
||||
log_info "Checking ripgrep (optional, for faster file search)..."
|
||||
|
||||
if command -v rg &> /dev/null; then
|
||||
RG_VERSION=$(rg --version | head -1)
|
||||
log_success "$RG_VERSION found"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
|
||||
log_warn "ripgrep not found (file search will use grep fallback)"
|
||||
|
||||
# Offer to install
|
||||
echo ""
|
||||
read -p "Would you like to install ripgrep? (faster search, recommended) [Y/n] " -n 1 -r
|
||||
echo
|
||||
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
log_info "Installing ripgrep..."
|
||||
|
||||
# Check if we can use sudo
|
||||
CAN_SUDO=false
|
||||
if command -v sudo &> /dev/null; then
|
||||
# Check if user has sudo access (without actually running sudo)
|
||||
if sudo -n true 2>/dev/null || sudo -v 2>/dev/null; then
|
||||
CAN_SUDO=true
|
||||
fi
|
||||
fi
|
||||
|
||||
case "$OS" in
|
||||
linux)
|
||||
if [ "$CAN_SUDO" = true ]; then
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
if sudo apt install -y ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
fedora)
|
||||
if sudo dnf install -y ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
arch)
|
||||
if sudo pacman -S --noconfirm ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
else
|
||||
log_warn "sudo not available - cannot auto-install system packages"
|
||||
# Try cargo as fallback if available
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info "Trying cargo install (no sudo required)..."
|
||||
if cargo install ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed via cargo"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
macos)
|
||||
if command -v brew &> /dev/null; then
|
||||
if brew install ripgrep 2>/dev/null; then
|
||||
log_success "ripgrep installed"
|
||||
HAS_RIPGREP=true
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
log_warn "Auto-install failed. You can install manually later:"
|
||||
else
|
||||
log_info "Skipping ripgrep installation. To install manually:"
|
||||
fi
|
||||
|
||||
# Show manual install instructions
|
||||
case "$OS" in
|
||||
linux)
|
||||
case "$DISTRO" in
|
||||
ubuntu|debian)
|
||||
log_info " sudo apt install ripgrep"
|
||||
;;
|
||||
fedora)
|
||||
log_info " sudo dnf install ripgrep"
|
||||
;;
|
||||
arch)
|
||||
log_info " sudo pacman -S ripgrep"
|
||||
;;
|
||||
*)
|
||||
log_info " https://github.com/BurntSushi/ripgrep#installation"
|
||||
;;
|
||||
esac
|
||||
# Show cargo alternative for users without sudo
|
||||
if command -v cargo &> /dev/null; then
|
||||
log_info " Or without sudo: cargo install ripgrep"
|
||||
fi
|
||||
;;
|
||||
macos)
|
||||
log_info " brew install ripgrep"
|
||||
;;
|
||||
esac
|
||||
|
||||
HAS_RIPGREP=false
|
||||
# Don't exit - ripgrep is optional (grep fallback exists)
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# Installation
|
||||
# ============================================================================
|
||||
|
|
@ -292,12 +406,13 @@ clone_repo() {
|
|||
fi
|
||||
else
|
||||
# Try SSH first (for private repo access), fall back to HTTPS
|
||||
# Use --recurse-submodules to also clone mini-swe-agent and tinker-atropos
|
||||
log_info "Trying SSH clone..."
|
||||
if git clone --branch "$BRANCH" "$REPO_URL_SSH" "$INSTALL_DIR" 2>/dev/null; then
|
||||
if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_SSH" "$INSTALL_DIR" 2>/dev/null; then
|
||||
log_success "Cloned via SSH"
|
||||
else
|
||||
log_info "SSH failed, trying HTTPS..."
|
||||
if git clone --branch "$BRANCH" "$REPO_URL_HTTPS" "$INSTALL_DIR"; then
|
||||
if git clone --branch "$BRANCH" --recurse-submodules "$REPO_URL_HTTPS" "$INSTALL_DIR"; then
|
||||
log_success "Cloned via HTTPS"
|
||||
else
|
||||
log_error "Failed to clone repository"
|
||||
|
|
@ -310,6 +425,12 @@ clone_repo() {
|
|||
fi
|
||||
|
||||
cd "$INSTALL_DIR"
|
||||
|
||||
# Ensure submodules are initialized and updated (for existing installs or if --recurse failed)
|
||||
log_info "Initializing submodules (mini-swe-agent, tinker-atropos)..."
|
||||
git submodule update --init --recursive
|
||||
log_success "Submodules ready"
|
||||
|
||||
log_success "Repository ready"
|
||||
}
|
||||
|
||||
|
|
@ -343,10 +464,29 @@ install_deps() {
|
|||
source venv/bin/activate
|
||||
fi
|
||||
|
||||
# Install the package in editable mode with all extras
|
||||
# Install the main package in editable mode with all extras
|
||||
pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
||||
|
||||
log_success "Dependencies installed"
|
||||
log_success "Main package installed"
|
||||
|
||||
# Install submodules
|
||||
log_info "Installing mini-swe-agent (terminal tool backend)..."
|
||||
if [ -d "mini-swe-agent" ] && [ -f "mini-swe-agent/pyproject.toml" ]; then
|
||||
pip install -e "./mini-swe-agent" > /dev/null 2>&1 || log_warn "mini-swe-agent install failed (terminal tools may not work)"
|
||||
log_success "mini-swe-agent installed"
|
||||
else
|
||||
log_warn "mini-swe-agent not found (run: git submodule update --init)"
|
||||
fi
|
||||
|
||||
log_info "Installing tinker-atropos (RL training backend)..."
|
||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
||||
pip install -e "./tinker-atropos" > /dev/null 2>&1 || log_warn "tinker-atropos install failed (RL tools may not work)"
|
||||
log_success "tinker-atropos installed"
|
||||
else
|
||||
log_warn "tinker-atropos not found (run: git submodule update --init)"
|
||||
fi
|
||||
|
||||
log_success "All dependencies installed"
|
||||
}
|
||||
|
||||
setup_path() {
|
||||
|
|
@ -514,6 +654,15 @@ print_success() {
|
|||
echo "if you need full browser support."
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
|
||||
# Show ripgrep note if not installed
|
||||
if [ "$HAS_RIPGREP" = false ]; then
|
||||
echo -e "${YELLOW}"
|
||||
echo "Note: ripgrep (rg) was not found. File search will use"
|
||||
echo "grep as a fallback. For faster search in large codebases,"
|
||||
echo "install ripgrep: sudo apt install ripgrep (or brew install ripgrep)"
|
||||
echo -e "${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -527,6 +676,7 @@ main() {
|
|||
check_python
|
||||
check_git
|
||||
check_node
|
||||
check_ripgrep
|
||||
|
||||
clone_repo
|
||||
setup_venv
|
||||
|
|
|
|||
|
|
@ -80,6 +80,53 @@ pip install -e ".[all]" > /dev/null 2>&1 || pip install -e "." > /dev/null
|
|||
|
||||
echo -e "${GREEN}✓${NC} Dependencies installed"
|
||||
|
||||
# ============================================================================
|
||||
# Optional: ripgrep (for faster file search)
|
||||
# ============================================================================
|
||||
|
||||
echo -e "${CYAN}→${NC} Checking ripgrep (optional, for faster search)..."
|
||||
|
||||
if command -v rg &> /dev/null; then
|
||||
echo -e "${GREEN}✓${NC} ripgrep found"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} ripgrep not found (file search will use grep fallback)"
|
||||
read -p "Install ripgrep for faster search? [Y/n] " -n 1 -r
|
||||
echo
|
||||
if [[ $REPLY =~ ^[Yy]$ ]] || [[ -z $REPLY ]]; then
|
||||
INSTALLED=false
|
||||
|
||||
# Check if sudo is available
|
||||
if command -v sudo &> /dev/null && sudo -n true 2>/dev/null; then
|
||||
if command -v apt &> /dev/null; then
|
||||
sudo apt install -y ripgrep && INSTALLED=true
|
||||
elif command -v dnf &> /dev/null; then
|
||||
sudo dnf install -y ripgrep && INSTALLED=true
|
||||
fi
|
||||
fi
|
||||
|
||||
# Try brew (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v brew &> /dev/null; then
|
||||
brew install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
# Try cargo (no sudo needed)
|
||||
if [ "$INSTALLED" = false ] && command -v cargo &> /dev/null; then
|
||||
echo -e "${CYAN}→${NC} Trying cargo install (no sudo required)..."
|
||||
cargo install ripgrep && INSTALLED=true
|
||||
fi
|
||||
|
||||
if [ "$INSTALLED" = true ]; then
|
||||
echo -e "${GREEN}✓${NC} ripgrep installed"
|
||||
else
|
||||
echo -e "${YELLOW}⚠${NC} Auto-install failed. Install options:"
|
||||
echo " sudo apt install ripgrep # Debian/Ubuntu"
|
||||
echo " brew install ripgrep # macOS"
|
||||
echo " cargo install ripgrep # With Rust (no sudo)"
|
||||
echo " https://github.com/BurntSushi/ripgrep#installation"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# ============================================================================
|
||||
# Environment file
|
||||
# ============================================================================
|
||||
|
|
|
|||
1
tinker-atropos
Submodule
1
tinker-atropos
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 65f084ee8054a5d02aeac76e24ed60388511c82b
|
||||
|
|
@ -95,6 +95,38 @@ from .cronjob_tools import (
|
|||
REMOVE_CRONJOB_SCHEMA
|
||||
)
|
||||
|
||||
# RL Training tools (Tinker-Atropos)
|
||||
from .rl_training_tool import (
|
||||
rl_list_environments,
|
||||
rl_select_environment,
|
||||
rl_get_current_config,
|
||||
rl_edit_config,
|
||||
rl_start_training,
|
||||
rl_check_status,
|
||||
rl_stop_training,
|
||||
rl_get_results,
|
||||
rl_list_runs,
|
||||
rl_test_inference,
|
||||
check_rl_api_keys,
|
||||
get_missing_keys,
|
||||
)
|
||||
|
||||
# File manipulation tools (read, write, patch, search)
|
||||
from .file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
patch_tool,
|
||||
search_tool,
|
||||
get_file_tools,
|
||||
clear_file_ops_cache,
|
||||
)
|
||||
|
||||
# File tools have no external requirements - they use the terminal backend
|
||||
def check_file_requirements():
|
||||
"""File tools only require terminal backend to be available."""
|
||||
from .terminal_tool import check_terminal_requirements
|
||||
return check_terminal_requirements()
|
||||
|
||||
__all__ = [
|
||||
# Web tools
|
||||
'web_search_tool',
|
||||
|
|
@ -152,5 +184,26 @@ __all__ = [
|
|||
'SCHEDULE_CRONJOB_SCHEMA',
|
||||
'LIST_CRONJOBS_SCHEMA',
|
||||
'REMOVE_CRONJOB_SCHEMA',
|
||||
# RL Training tools
|
||||
'rl_list_environments',
|
||||
'rl_select_environment',
|
||||
'rl_get_current_config',
|
||||
'rl_edit_config',
|
||||
'rl_start_training',
|
||||
'rl_check_status',
|
||||
'rl_stop_training',
|
||||
'rl_get_results',
|
||||
'rl_list_runs',
|
||||
'rl_test_inference',
|
||||
'check_rl_api_keys',
|
||||
'get_missing_keys',
|
||||
# File manipulation tools
|
||||
'read_file_tool',
|
||||
'write_file_tool',
|
||||
'patch_tool',
|
||||
'search_tool',
|
||||
'get_file_tools',
|
||||
'clear_file_ops_cache',
|
||||
'check_file_requirements',
|
||||
]
|
||||
|
||||
|
|
|
|||
937
tools/file_operations.py
Normal file
937
tools/file_operations.py
Normal file
|
|
@ -0,0 +1,937 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
File Operations Module
|
||||
|
||||
Provides file manipulation capabilities (read, write, patch, search) that work
|
||||
across all terminal backends (local, docker, singularity, ssh, modal).
|
||||
|
||||
The key insight is that all file operations can be expressed as shell commands,
|
||||
so we wrap the terminal backend's execute() interface to provide a unified file API.
|
||||
|
||||
Usage:
|
||||
from tools.file_operations import ShellFileOperations
|
||||
from tools.terminal_tool import _active_environments
|
||||
|
||||
# Get file operations for a terminal environment
|
||||
file_ops = ShellFileOperations(terminal_env)
|
||||
|
||||
# Read a file
|
||||
result = file_ops.read_file("/path/to/file.py")
|
||||
|
||||
# Write a file
|
||||
result = file_ops.write_file("/path/to/new.py", "print('hello')")
|
||||
|
||||
# Search for content
|
||||
result = file_ops.search("TODO", path=".", file_glob="*.py")
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import uuid
|
||||
import difflib
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Result Data Classes
|
||||
# =============================================================================
|
||||
|
||||
@dataclass
|
||||
class ReadResult:
|
||||
"""Result from reading a file."""
|
||||
content: str = ""
|
||||
total_lines: int = 0
|
||||
file_size: int = 0
|
||||
truncated: bool = False
|
||||
hint: Optional[str] = None
|
||||
is_binary: bool = False
|
||||
is_image: bool = False
|
||||
base64_content: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
dimensions: Optional[str] = None # For images: "WIDTHxHEIGHT"
|
||||
error: Optional[str] = None
|
||||
similar_files: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None and v != [] and v != ""}
|
||||
|
||||
|
||||
@dataclass
|
||||
class WriteResult:
|
||||
"""Result from writing a file."""
|
||||
bytes_written: int = 0
|
||||
dirs_created: bool = False
|
||||
error: Optional[str] = None
|
||||
warning: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchResult:
|
||||
"""Result from patching a file."""
|
||||
success: bool = False
|
||||
diff: str = ""
|
||||
files_modified: List[str] = field(default_factory=list)
|
||||
files_created: List[str] = field(default_factory=list)
|
||||
files_deleted: List[str] = field(default_factory=list)
|
||||
lint: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
result = {"success": self.success}
|
||||
if self.diff:
|
||||
result["diff"] = self.diff
|
||||
if self.files_modified:
|
||||
result["files_modified"] = self.files_modified
|
||||
if self.files_created:
|
||||
result["files_created"] = self.files_created
|
||||
if self.files_deleted:
|
||||
result["files_deleted"] = self.files_deleted
|
||||
if self.lint:
|
||||
result["lint"] = self.lint
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchMatch:
|
||||
"""A single search match."""
|
||||
path: str
|
||||
line_number: int
|
||||
content: str
|
||||
mtime: float = 0.0 # Modification time for sorting
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result from searching."""
|
||||
matches: List[SearchMatch] = field(default_factory=list)
|
||||
files: List[str] = field(default_factory=list)
|
||||
counts: Dict[str, int] = field(default_factory=dict)
|
||||
total_count: int = 0
|
||||
truncated: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
result = {"total_count": self.total_count}
|
||||
if self.matches:
|
||||
result["matches"] = [
|
||||
{"path": m.path, "line": m.line_number, "content": m.content}
|
||||
for m in self.matches
|
||||
]
|
||||
if self.files:
|
||||
result["files"] = self.files
|
||||
if self.counts:
|
||||
result["counts"] = self.counts
|
||||
if self.truncated:
|
||||
result["truncated"] = True
|
||||
if self.error:
|
||||
result["error"] = self.error
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
class LintResult:
|
||||
"""Result from linting a file."""
|
||||
success: bool = True
|
||||
skipped: bool = False
|
||||
output: str = ""
|
||||
message: str = ""
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
if self.skipped:
|
||||
return {"status": "skipped", "message": self.message}
|
||||
return {
|
||||
"status": "ok" if self.success else "error",
|
||||
"output": self.output
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteResult:
|
||||
"""Result from executing a shell command."""
|
||||
stdout: str = ""
|
||||
exit_code: int = 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Abstract Interface
|
||||
# =============================================================================
|
||||
|
||||
class FileOperations(ABC):
|
||||
"""Abstract interface for file operations across terminal backends."""
|
||||
|
||||
@abstractmethod
|
||||
def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult:
|
||||
"""Read a file with pagination support."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def write_file(self, path: str, content: str) -> WriteResult:
|
||||
"""Write content to a file, creating directories as needed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def patch_replace(self, path: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> PatchResult:
|
||||
"""Replace text in a file using fuzzy matching."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def patch_v4a(self, patch_content: str) -> PatchResult:
|
||||
"""Apply a V4A format patch."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def search(self, pattern: str, path: str = ".", target: str = "content",
|
||||
file_glob: Optional[str] = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0) -> SearchResult:
|
||||
"""Search for content or files."""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Shell-based Implementation
|
||||
# =============================================================================
|
||||
|
||||
# Binary file extensions (fast path check)
|
||||
BINARY_EXTENSIONS = {
|
||||
# Images
|
||||
'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico', '.tiff', '.tif',
|
||||
'.svg', # SVG is text but often treated as binary
|
||||
# Audio/Video
|
||||
'.mp3', '.mp4', '.wav', '.avi', '.mov', '.mkv', '.flac', '.ogg', '.webm',
|
||||
# Archives
|
||||
'.zip', '.tar', '.gz', '.bz2', '.xz', '.7z', '.rar',
|
||||
# Documents
|
||||
'.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx',
|
||||
# Compiled/Binary
|
||||
'.exe', '.dll', '.so', '.dylib', '.o', '.a', '.pyc', '.pyo', '.class',
|
||||
'.wasm', '.bin',
|
||||
# Fonts
|
||||
'.ttf', '.otf', '.woff', '.woff2', '.eot',
|
||||
# Other
|
||||
'.db', '.sqlite', '.sqlite3',
|
||||
}
|
||||
|
||||
# Image extensions (subset of binary that we can return as base64)
|
||||
IMAGE_EXTENSIONS = {'.png', '.jpg', '.jpeg', '.gif', '.webp', '.bmp', '.ico'}
|
||||
|
||||
# Linters by file extension
|
||||
LINTERS = {
|
||||
'.py': 'python -m py_compile {file} 2>&1',
|
||||
'.js': 'node --check {file} 2>&1',
|
||||
'.ts': 'npx tsc --noEmit {file} 2>&1',
|
||||
'.go': 'go vet {file} 2>&1',
|
||||
'.rs': 'rustfmt --check {file} 2>&1',
|
||||
}
|
||||
|
||||
# Max limits for read operations
|
||||
MAX_LINES = 2000
|
||||
MAX_LINE_LENGTH = 2000
|
||||
MAX_FILE_SIZE = 50 * 1024 # 50KB
|
||||
|
||||
|
||||
class ShellFileOperations(FileOperations):
|
||||
"""
|
||||
File operations implemented via shell commands.
|
||||
|
||||
Works with ANY terminal backend that has execute(command, cwd) method.
|
||||
This includes local, docker, singularity, ssh, and modal environments.
|
||||
"""
|
||||
|
||||
def __init__(self, terminal_env, cwd: str = None):
|
||||
"""
|
||||
Initialize file operations with a terminal environment.
|
||||
|
||||
Args:
|
||||
terminal_env: Any object with execute(command, cwd) method.
|
||||
Returns {"output": str, "returncode": int}
|
||||
cwd: Working directory (defaults to env's cwd or /tmp)
|
||||
"""
|
||||
self.env = terminal_env
|
||||
# Determine cwd from various possible sources
|
||||
self.cwd = cwd or getattr(terminal_env, 'cwd', None) or \
|
||||
getattr(getattr(terminal_env, 'config', None), 'cwd', None) or '/tmp'
|
||||
|
||||
# Cache for command availability checks
|
||||
self._command_cache: Dict[str, bool] = {}
|
||||
|
||||
def _exec(self, command: str, cwd: str = None, timeout: int = None) -> ExecuteResult:
|
||||
"""Execute command via terminal backend."""
|
||||
kwargs = {}
|
||||
if timeout:
|
||||
kwargs['timeout'] = timeout
|
||||
|
||||
result = self.env.execute(command, cwd=cwd or self.cwd, **kwargs)
|
||||
return ExecuteResult(
|
||||
stdout=result.get("output", ""),
|
||||
exit_code=result.get("returncode", 0)
|
||||
)
|
||||
|
||||
def _has_command(self, cmd: str) -> bool:
|
||||
"""Check if a command exists in the environment (cached)."""
|
||||
if cmd not in self._command_cache:
|
||||
result = self._exec(f"command -v {cmd} >/dev/null 2>&1 && echo 'yes'")
|
||||
self._command_cache[cmd] = result.stdout.strip() == 'yes'
|
||||
return self._command_cache[cmd]
|
||||
|
||||
def _is_likely_binary(self, path: str, content_sample: str = None) -> bool:
|
||||
"""
|
||||
Check if a file is likely binary.
|
||||
|
||||
Uses extension check (fast) + content analysis (fallback).
|
||||
"""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
if ext in BINARY_EXTENSIONS:
|
||||
return True
|
||||
|
||||
# Content analysis: >30% non-printable chars = binary
|
||||
if content_sample:
|
||||
if not content_sample:
|
||||
return False
|
||||
non_printable = sum(1 for c in content_sample[:1000]
|
||||
if ord(c) < 32 and c not in '\n\r\t')
|
||||
return non_printable / min(len(content_sample), 1000) > 0.30
|
||||
|
||||
return False
|
||||
|
||||
def _is_image(self, path: str) -> bool:
|
||||
"""Check if file is an image we can return as base64."""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
return ext in IMAGE_EXTENSIONS
|
||||
|
||||
def _add_line_numbers(self, content: str, start_line: int = 1) -> str:
|
||||
"""Add line numbers to content in LINE_NUM|CONTENT format."""
|
||||
lines = content.split('\n')
|
||||
numbered = []
|
||||
for i, line in enumerate(lines, start=start_line):
|
||||
# Truncate long lines
|
||||
if len(line) > MAX_LINE_LENGTH:
|
||||
line = line[:MAX_LINE_LENGTH] + "... [truncated]"
|
||||
numbered.append(f"{i:6d}|{line}")
|
||||
return '\n'.join(numbered)
|
||||
|
||||
def _expand_path(self, path: str) -> str:
|
||||
"""
|
||||
Expand shell-style paths like ~ and ~user to absolute paths.
|
||||
|
||||
This must be done BEFORE shell escaping, since ~ doesn't expand
|
||||
inside single quotes.
|
||||
"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
# Handle ~ and ~user
|
||||
if path.startswith('~'):
|
||||
# Get home directory via the terminal environment
|
||||
result = self._exec("echo $HOME")
|
||||
if result.exit_code == 0 and result.stdout.strip():
|
||||
home = result.stdout.strip()
|
||||
if path == '~':
|
||||
return home
|
||||
elif path.startswith('~/'):
|
||||
return home + path[1:] # Replace ~ with home
|
||||
# ~username format - let shell expand it
|
||||
expand_result = self._exec(f"echo {path}")
|
||||
if expand_result.exit_code == 0:
|
||||
return expand_result.stdout.strip()
|
||||
|
||||
return path
|
||||
|
||||
def _escape_shell_arg(self, arg: str) -> str:
|
||||
"""Escape a string for safe use in shell commands."""
|
||||
# Use single quotes and escape any single quotes in the string
|
||||
return "'" + arg.replace("'", "'\"'\"'") + "'"
|
||||
|
||||
def _unified_diff(self, old_content: str, new_content: str, filename: str) -> str:
|
||||
"""Generate unified diff between old and new content."""
|
||||
old_lines = old_content.splitlines(keepends=True)
|
||||
new_lines = new_content.splitlines(keepends=True)
|
||||
diff = difflib.unified_diff(
|
||||
old_lines, new_lines,
|
||||
fromfile=f"a/{filename}",
|
||||
tofile=f"b/{filename}"
|
||||
)
|
||||
return ''.join(diff)
|
||||
|
||||
# =========================================================================
|
||||
# READ Implementation
|
||||
# =========================================================================
|
||||
|
||||
def read_file(self, path: str, offset: int = 1, limit: int = 500) -> ReadResult:
|
||||
"""
|
||||
Read a file with pagination, binary detection, and line numbers.
|
||||
|
||||
Args:
|
||||
path: File path (absolute or relative to cwd)
|
||||
offset: Line number to start from (1-indexed, default 1)
|
||||
limit: Maximum lines to return (default 500, max 2000)
|
||||
|
||||
Returns:
|
||||
ReadResult with content, metadata, or error info
|
||||
"""
|
||||
# Expand ~ and other shell paths
|
||||
path = self._expand_path(path)
|
||||
|
||||
# Clamp limit
|
||||
limit = min(limit, MAX_LINES)
|
||||
|
||||
# Check if file exists and get metadata
|
||||
stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
stat_result = self._exec(stat_cmd)
|
||||
|
||||
if stat_result.exit_code != 0:
|
||||
# File not found - try to suggest similar files
|
||||
return self._suggest_similar_files(path)
|
||||
|
||||
try:
|
||||
file_size = int(stat_result.stdout.strip())
|
||||
except ValueError:
|
||||
file_size = 0
|
||||
|
||||
# Check if file is too large
|
||||
if file_size > MAX_FILE_SIZE:
|
||||
# Still try to read, but warn
|
||||
pass
|
||||
|
||||
# Check if it's an image - return base64
|
||||
if self._is_image(path):
|
||||
return self._read_image(path)
|
||||
|
||||
# Read a sample to check for binary content
|
||||
sample_cmd = f"head -c 1000 {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
sample_result = self._exec(sample_cmd)
|
||||
|
||||
if self._is_likely_binary(path, sample_result.stdout):
|
||||
return ReadResult(
|
||||
is_binary=True,
|
||||
file_size=file_size,
|
||||
error="Binary file - cannot display as text. Use appropriate tools to handle this file type."
|
||||
)
|
||||
|
||||
# Read with pagination using sed
|
||||
end_line = offset + limit - 1
|
||||
read_cmd = f"sed -n '{offset},{end_line}p' {self._escape_shell_arg(path)}"
|
||||
read_result = self._exec(read_cmd)
|
||||
|
||||
if read_result.exit_code != 0:
|
||||
return ReadResult(error=f"Failed to read file: {read_result.stdout}")
|
||||
|
||||
# Get total line count
|
||||
wc_cmd = f"wc -l < {self._escape_shell_arg(path)}"
|
||||
wc_result = self._exec(wc_cmd)
|
||||
try:
|
||||
total_lines = int(wc_result.stdout.strip())
|
||||
except ValueError:
|
||||
total_lines = 0
|
||||
|
||||
# Check if truncated
|
||||
truncated = total_lines > end_line
|
||||
hint = None
|
||||
if truncated:
|
||||
hint = f"Use offset={end_line + 1} to continue reading (showing {offset}-{end_line} of {total_lines} lines)"
|
||||
|
||||
return ReadResult(
|
||||
content=self._add_line_numbers(read_result.stdout, offset),
|
||||
total_lines=total_lines,
|
||||
file_size=file_size,
|
||||
truncated=truncated,
|
||||
hint=hint
|
||||
)
|
||||
|
||||
def _read_image(self, path: str) -> ReadResult:
|
||||
"""Read an image file, returning base64 content."""
|
||||
# Get file size
|
||||
stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
stat_result = self._exec(stat_cmd)
|
||||
try:
|
||||
file_size = int(stat_result.stdout.strip())
|
||||
except ValueError:
|
||||
file_size = 0
|
||||
|
||||
# Get base64 content
|
||||
b64_cmd = f"base64 -w 0 {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
b64_result = self._exec(b64_cmd, timeout=30)
|
||||
|
||||
if b64_result.exit_code != 0:
|
||||
return ReadResult(
|
||||
is_image=True,
|
||||
is_binary=True,
|
||||
file_size=file_size,
|
||||
error=f"Failed to read image: {b64_result.stdout}"
|
||||
)
|
||||
|
||||
# Try to get dimensions (requires ImageMagick)
|
||||
dimensions = None
|
||||
if self._has_command('identify'):
|
||||
dim_cmd = f"identify -format '%wx%h' {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
dim_result = self._exec(dim_cmd)
|
||||
if dim_result.exit_code == 0:
|
||||
dimensions = dim_result.stdout.strip()
|
||||
|
||||
# Determine MIME type from extension
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
mime_types = {
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
'.webp': 'image/webp',
|
||||
'.bmp': 'image/bmp',
|
||||
'.ico': 'image/x-icon',
|
||||
}
|
||||
mime_type = mime_types.get(ext, 'application/octet-stream')
|
||||
|
||||
return ReadResult(
|
||||
is_image=True,
|
||||
is_binary=True,
|
||||
file_size=file_size,
|
||||
base64_content=b64_result.stdout,
|
||||
mime_type=mime_type,
|
||||
dimensions=dimensions
|
||||
)
|
||||
|
||||
def _suggest_similar_files(self, path: str) -> ReadResult:
|
||||
"""Suggest similar files when the requested file is not found."""
|
||||
# Get directory and filename
|
||||
dir_path = os.path.dirname(path) or "."
|
||||
filename = os.path.basename(path)
|
||||
|
||||
# List files in directory
|
||||
ls_cmd = f"ls -1 {self._escape_shell_arg(dir_path)} 2>/dev/null | head -20"
|
||||
ls_result = self._exec(ls_cmd)
|
||||
|
||||
similar = []
|
||||
if ls_result.exit_code == 0 and ls_result.stdout.strip():
|
||||
files = ls_result.stdout.strip().split('\n')
|
||||
# Simple similarity: files that share some characters with the target
|
||||
for f in files:
|
||||
# Check if filenames share significant overlap
|
||||
common = set(filename.lower()) & set(f.lower())
|
||||
if len(common) >= len(filename) * 0.5: # 50% character overlap
|
||||
similar.append(os.path.join(dir_path, f))
|
||||
|
||||
return ReadResult(
|
||||
error=f"File not found: {path}",
|
||||
similar_files=similar[:5] # Limit to 5 suggestions
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# WRITE Implementation
|
||||
# =========================================================================
|
||||
|
||||
def write_file(self, path: str, content: str) -> WriteResult:
|
||||
"""
|
||||
Write content to a file, creating parent directories as needed.
|
||||
|
||||
Uses heredoc with unique marker for safe shell execution.
|
||||
|
||||
Args:
|
||||
path: File path to write
|
||||
content: Content to write
|
||||
|
||||
Returns:
|
||||
WriteResult with bytes written or error
|
||||
"""
|
||||
# Expand ~ and other shell paths
|
||||
path = self._expand_path(path)
|
||||
|
||||
# Create parent directories
|
||||
parent = os.path.dirname(path)
|
||||
dirs_created = False
|
||||
|
||||
if parent:
|
||||
mkdir_cmd = f"mkdir -p {self._escape_shell_arg(parent)}"
|
||||
mkdir_result = self._exec(mkdir_cmd)
|
||||
if mkdir_result.exit_code == 0:
|
||||
dirs_created = True
|
||||
|
||||
# Generate unique marker for heredoc that won't appear in content
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
while marker in content:
|
||||
marker = f"HERMES_EOF_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Write using heredoc with single-quoted marker (prevents all expansion)
|
||||
# The single quotes around the marker prevent variable expansion
|
||||
write_cmd = f"cat > {self._escape_shell_arg(path)} << '{marker}'\n{content}\n{marker}"
|
||||
write_result = self._exec(write_cmd)
|
||||
|
||||
if write_result.exit_code != 0:
|
||||
return WriteResult(error=f"Failed to write file: {write_result.stdout}")
|
||||
|
||||
# Get bytes written
|
||||
stat_cmd = f"stat -c '%s' {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
stat_result = self._exec(stat_cmd)
|
||||
|
||||
try:
|
||||
bytes_written = int(stat_result.stdout.strip())
|
||||
except ValueError:
|
||||
bytes_written = len(content.encode('utf-8'))
|
||||
|
||||
return WriteResult(
|
||||
bytes_written=bytes_written,
|
||||
dirs_created=dirs_created
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# PATCH Implementation (Replace Mode)
|
||||
# =========================================================================
|
||||
|
||||
def patch_replace(self, path: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> PatchResult:
|
||||
"""
|
||||
Replace text in a file using fuzzy matching.
|
||||
|
||||
Args:
|
||||
path: File path to modify
|
||||
old_string: Text to find (must be unique unless replace_all=True)
|
||||
new_string: Replacement text
|
||||
replace_all: If True, replace all occurrences
|
||||
|
||||
Returns:
|
||||
PatchResult with diff and lint results
|
||||
"""
|
||||
# Expand ~ and other shell paths
|
||||
path = self._expand_path(path)
|
||||
|
||||
# Read current content
|
||||
read_cmd = f"cat {self._escape_shell_arg(path)} 2>/dev/null"
|
||||
read_result = self._exec(read_cmd)
|
||||
|
||||
if read_result.exit_code != 0:
|
||||
return PatchResult(error=f"Failed to read file: {path}")
|
||||
|
||||
content = read_result.stdout
|
||||
|
||||
# Import and use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content, old_string, new_string, replace_all
|
||||
)
|
||||
|
||||
if error:
|
||||
return PatchResult(error=error)
|
||||
|
||||
if match_count == 0:
|
||||
return PatchResult(error=f"Could not find match for old_string in {path}")
|
||||
|
||||
# Write back
|
||||
write_result = self.write_file(path, new_content)
|
||||
if write_result.error:
|
||||
return PatchResult(error=f"Failed to write changes: {write_result.error}")
|
||||
|
||||
# Generate diff
|
||||
diff = self._unified_diff(content, new_content, path)
|
||||
|
||||
# Auto-lint
|
||||
lint_result = self._check_lint(path)
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=diff,
|
||||
files_modified=[path],
|
||||
lint=lint_result.to_dict() if lint_result else None
|
||||
)
|
||||
|
||||
def patch_v4a(self, patch_content: str) -> PatchResult:
|
||||
"""
|
||||
Apply a V4A format patch.
|
||||
|
||||
V4A format:
|
||||
*** Begin Patch
|
||||
*** Update File: path/to/file.py
|
||||
@@ context hint @@
|
||||
context line
|
||||
-removed line
|
||||
+added line
|
||||
*** End Patch
|
||||
|
||||
Args:
|
||||
patch_content: V4A format patch string
|
||||
|
||||
Returns:
|
||||
PatchResult with changes made
|
||||
"""
|
||||
# Import patch parser
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
operations, parse_error = parse_v4a_patch(patch_content)
|
||||
if parse_error:
|
||||
return PatchResult(error=f"Failed to parse patch: {parse_error}")
|
||||
|
||||
# Apply operations
|
||||
result = apply_v4a_operations(operations, self)
|
||||
return result
|
||||
|
||||
def _check_lint(self, path: str) -> LintResult:
|
||||
"""
|
||||
Run syntax check on a file after editing.
|
||||
|
||||
Args:
|
||||
path: File path to lint
|
||||
|
||||
Returns:
|
||||
LintResult with status and any errors
|
||||
"""
|
||||
ext = os.path.splitext(path)[1].lower()
|
||||
|
||||
if ext not in LINTERS:
|
||||
return LintResult(skipped=True, message=f"No linter for {ext} files")
|
||||
|
||||
# Check if linter command is available
|
||||
linter_cmd = LINTERS[ext]
|
||||
# Extract the base command (first word)
|
||||
base_cmd = linter_cmd.split()[0]
|
||||
|
||||
if not self._has_command(base_cmd):
|
||||
return LintResult(skipped=True, message=f"{base_cmd} not available")
|
||||
|
||||
# Run linter
|
||||
cmd = linter_cmd.format(file=self._escape_shell_arg(path))
|
||||
result = self._exec(cmd, timeout=30)
|
||||
|
||||
return LintResult(
|
||||
success=result.exit_code == 0,
|
||||
output=result.stdout.strip() if result.stdout.strip() else ""
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# SEARCH Implementation
|
||||
# =========================================================================
|
||||
|
||||
def search(self, pattern: str, path: str = ".", target: str = "content",
|
||||
file_glob: Optional[str] = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0) -> SearchResult:
|
||||
"""
|
||||
Search for content or files.
|
||||
|
||||
Args:
|
||||
pattern: Regex (for content) or glob pattern (for files)
|
||||
path: Directory/file to search (default: cwd)
|
||||
target: "content" (grep) or "files" (glob)
|
||||
file_glob: File pattern filter for content search (e.g., "*.py")
|
||||
limit: Max results (default 50)
|
||||
offset: Skip first N results
|
||||
output_mode: "content", "files_only", or "count"
|
||||
context: Lines of context around matches
|
||||
|
||||
Returns:
|
||||
SearchResult with matches or file list
|
||||
"""
|
||||
# Expand ~ and other shell paths
|
||||
path = self._expand_path(path)
|
||||
|
||||
if target == "files":
|
||||
return self._search_files(pattern, path, limit, offset)
|
||||
else:
|
||||
return self._search_content(pattern, path, file_glob, limit, offset,
|
||||
output_mode, context)
|
||||
|
||||
def _search_files(self, pattern: str, path: str, limit: int, offset: int) -> SearchResult:
|
||||
"""Search for files by name pattern (glob-like)."""
|
||||
# Check if find is available (not on Windows without Git Bash/WSL)
|
||||
if not self._has_command('find'):
|
||||
return SearchResult(
|
||||
error="File search requires 'find' command. "
|
||||
"On Windows, use Git Bash, WSL, or install Unix tools."
|
||||
)
|
||||
|
||||
# Auto-prepend **/ for recursive search if not already present
|
||||
if not pattern.startswith('**/') and '/' not in pattern:
|
||||
search_pattern = pattern
|
||||
else:
|
||||
search_pattern = pattern.split('/')[-1]
|
||||
|
||||
# Use find with modification time sorting
|
||||
# -printf '%T@ %p\n' outputs: timestamp path
|
||||
# sort -rn sorts by timestamp descending (newest first)
|
||||
cmd = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \
|
||||
f"-printf '%T@ %p\\n' 2>/dev/null | sort -rn | tail -n +{offset + 1} | head -n {limit}"
|
||||
|
||||
result = self._exec(cmd, timeout=60)
|
||||
|
||||
if result.exit_code != 0 and not result.stdout.strip():
|
||||
# Try without -printf (BSD find compatibility)
|
||||
cmd_simple = f"find {self._escape_shell_arg(path)} -type f -name {self._escape_shell_arg(search_pattern)} " \
|
||||
f"2>/dev/null | head -n {limit + offset} | tail -n +{offset + 1}"
|
||||
result = self._exec(cmd_simple, timeout=60)
|
||||
|
||||
files = []
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if not line:
|
||||
continue
|
||||
# Parse "timestamp path" format
|
||||
parts = line.split(' ', 1)
|
||||
if len(parts) == 2 and parts[0].replace('.', '').isdigit():
|
||||
files.append(parts[1])
|
||||
else:
|
||||
files.append(line)
|
||||
|
||||
return SearchResult(
|
||||
files=files,
|
||||
total_count=len(files)
|
||||
)
|
||||
|
||||
def _search_content(self, pattern: str, path: str, file_glob: Optional[str],
|
||||
limit: int, offset: int, output_mode: str, context: int) -> SearchResult:
|
||||
"""Search for content inside files (grep-like)."""
|
||||
# Try ripgrep first (fast), fallback to grep (slower but works)
|
||||
if self._has_command('rg'):
|
||||
return self._search_with_rg(pattern, path, file_glob, limit, offset,
|
||||
output_mode, context)
|
||||
elif self._has_command('grep'):
|
||||
return self._search_with_grep(pattern, path, file_glob, limit, offset,
|
||||
output_mode, context)
|
||||
else:
|
||||
# Neither rg nor grep available (Windows without Git Bash, etc.)
|
||||
return SearchResult(
|
||||
error="Content search requires ripgrep (rg) or grep. "
|
||||
"Install ripgrep: https://github.com/BurntSushi/ripgrep#installation"
|
||||
)
|
||||
|
||||
def _search_with_rg(self, pattern: str, path: str, file_glob: Optional[str],
|
||||
limit: int, offset: int, output_mode: str, context: int) -> SearchResult:
|
||||
"""Search using ripgrep."""
|
||||
cmd_parts = ["rg", "--line-number", "--no-heading"]
|
||||
|
||||
# Add context if requested
|
||||
if context > 0:
|
||||
cmd_parts.extend(["-C", str(context)])
|
||||
|
||||
# Add file glob filter
|
||||
if file_glob:
|
||||
cmd_parts.extend(["--glob", file_glob])
|
||||
|
||||
# Output mode handling
|
||||
if output_mode == "files_only":
|
||||
cmd_parts.append("-l") # Files only
|
||||
elif output_mode == "count":
|
||||
cmd_parts.append("-c") # Count per file
|
||||
|
||||
# Add pattern and path
|
||||
cmd_parts.append(self._escape_shell_arg(pattern))
|
||||
cmd_parts.append(self._escape_shell_arg(path))
|
||||
|
||||
# Limit results
|
||||
cmd_parts.extend(["|", "head", "-n", str(limit + offset)])
|
||||
|
||||
cmd = " ".join(cmd_parts)
|
||||
result = self._exec(cmd, timeout=60)
|
||||
|
||||
# Parse results based on output mode
|
||||
if output_mode == "files_only":
|
||||
files = [f for f in result.stdout.strip().split('\n') if f][offset:]
|
||||
return SearchResult(files=files[:limit], total_count=len(files))
|
||||
|
||||
elif output_mode == "count":
|
||||
counts = {}
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if ':' in line:
|
||||
parts = line.rsplit(':', 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
counts[parts[0]] = int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
return SearchResult(counts=counts, total_count=sum(counts.values()))
|
||||
|
||||
else:
|
||||
# Parse content matches
|
||||
matches = []
|
||||
for line in result.stdout.strip().split('\n')[offset:]:
|
||||
if not line:
|
||||
continue
|
||||
# Format: file:line:content
|
||||
parts = line.split(':', 2)
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
matches.append(SearchMatch(
|
||||
path=parts[0],
|
||||
line_number=int(parts[1]),
|
||||
content=parts[2][:500] # Truncate long lines
|
||||
))
|
||||
except ValueError:
|
||||
# Line number not an int, skip
|
||||
pass
|
||||
|
||||
return SearchResult(
|
||||
matches=matches[:limit],
|
||||
total_count=len(matches),
|
||||
truncated=len(matches) > limit
|
||||
)
|
||||
|
||||
def _search_with_grep(self, pattern: str, path: str, file_glob: Optional[str],
|
||||
limit: int, offset: int, output_mode: str, context: int) -> SearchResult:
|
||||
"""Fallback search using grep."""
|
||||
cmd_parts = ["grep", "-rn"]
|
||||
|
||||
# Add context if requested
|
||||
if context > 0:
|
||||
cmd_parts.extend(["-C", str(context)])
|
||||
|
||||
# Add file pattern filter
|
||||
if file_glob:
|
||||
cmd_parts.extend(["--include", file_glob])
|
||||
|
||||
# Output mode handling
|
||||
if output_mode == "files_only":
|
||||
cmd_parts.append("-l")
|
||||
elif output_mode == "count":
|
||||
cmd_parts.append("-c")
|
||||
|
||||
# Add pattern and path
|
||||
cmd_parts.append(self._escape_shell_arg(pattern))
|
||||
cmd_parts.append(self._escape_shell_arg(path))
|
||||
|
||||
# Limit and offset
|
||||
cmd_parts.extend(["|", "tail", "-n", f"+{offset + 1}", "|", "head", "-n", str(limit)])
|
||||
|
||||
cmd = " ".join(cmd_parts)
|
||||
result = self._exec(cmd, timeout=60)
|
||||
|
||||
# Parse results (same format as rg)
|
||||
if output_mode == "files_only":
|
||||
files = [f for f in result.stdout.strip().split('\n') if f]
|
||||
return SearchResult(files=files, total_count=len(files))
|
||||
|
||||
elif output_mode == "count":
|
||||
counts = {}
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if ':' in line:
|
||||
parts = line.rsplit(':', 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
counts[parts[0]] = int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
return SearchResult(counts=counts, total_count=sum(counts.values()))
|
||||
|
||||
else:
|
||||
matches = []
|
||||
for line in result.stdout.strip().split('\n'):
|
||||
if not line:
|
||||
continue
|
||||
parts = line.split(':', 2)
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
matches.append(SearchMatch(
|
||||
path=parts[0],
|
||||
line_number=int(parts[1]),
|
||||
content=parts[2][:500]
|
||||
))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return SearchResult(
|
||||
matches=matches,
|
||||
total_count=len(matches)
|
||||
)
|
||||
113
tools/file_tools.py
Normal file
113
tools/file_tools.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
#!/usr/bin/env python3
|
||||
"""File Tools Module - LLM agent file manipulation tools."""
|
||||
|
||||
import json
|
||||
import threading
|
||||
from typing import Optional
|
||||
from tools.file_operations import ShellFileOperations
|
||||
|
||||
_file_ops_lock = threading.Lock()
|
||||
_file_ops_cache: dict = {}
|
||||
|
||||
|
||||
def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
||||
"""Get or create ShellFileOperations for a terminal environment."""
|
||||
from tools.terminal_tool import _active_environments, _env_lock, _LocalEnvironment
|
||||
|
||||
with _file_ops_lock:
|
||||
if task_id in _file_ops_cache:
|
||||
return _file_ops_cache[task_id]
|
||||
|
||||
with _env_lock:
|
||||
if task_id not in _active_environments:
|
||||
import os
|
||||
env = _LocalEnvironment(cwd=os.getcwd(), timeout=60)
|
||||
_active_environments[task_id] = env
|
||||
terminal_env = _active_environments[task_id]
|
||||
|
||||
file_ops = ShellFileOperations(terminal_env)
|
||||
_file_ops_cache[task_id] = file_ops
|
||||
return file_ops
|
||||
|
||||
|
||||
def clear_file_ops_cache(task_id: str = None):
|
||||
"""Clear the file operations cache."""
|
||||
with _file_ops_lock:
|
||||
if task_id:
|
||||
_file_ops_cache.pop(task_id, None)
|
||||
else:
|
||||
_file_ops_cache.clear()
|
||||
|
||||
|
||||
def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = "default") -> str:
|
||||
"""Read a file with pagination and line numbers."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.read_file(path, offset, limit)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
"""Write content to a file."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.write_file(path, content)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||
new_string: str = None, replace_all: bool = False, patch: str = None,
|
||||
task_id: str = "default") -> str:
|
||||
"""Patch a file using replace mode or V4A patch format."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
|
||||
if mode == "replace":
|
||||
if not path:
|
||||
return json.dumps({"error": "path required"})
|
||||
if old_string is None or new_string is None:
|
||||
return json.dumps({"error": "old_string and new_string required"})
|
||||
result = file_ops.patch_replace(path, old_string, new_string, replace_all)
|
||||
elif mode == "patch":
|
||||
if not patch:
|
||||
return json.dumps({"error": "patch content required"})
|
||||
result = file_ops.patch_v4a(patch)
|
||||
else:
|
||||
return json.dumps({"error": f"Unknown mode: {mode}"})
|
||||
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
def search_tool(pattern: str, target: str = "content", path: str = ".",
|
||||
file_glob: str = None, limit: int = 50, offset: int = 0,
|
||||
output_mode: str = "content", context: int = 0,
|
||||
task_id: str = "default") -> str:
|
||||
"""Search for content or files."""
|
||||
try:
|
||||
file_ops = _get_file_ops(task_id)
|
||||
result = file_ops.search(
|
||||
pattern=pattern, path=path, target=target, file_glob=file_glob,
|
||||
limit=limit, offset=offset, output_mode=output_mode, context=context
|
||||
)
|
||||
return json.dumps(result.to_dict(), ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return json.dumps({"error": str(e)}, ensure_ascii=False)
|
||||
|
||||
|
||||
FILE_TOOLS = [
|
||||
{"name": "read_file", "function": read_file_tool},
|
||||
{"name": "write_file", "function": write_file_tool},
|
||||
{"name": "patch", "function": patch_tool},
|
||||
{"name": "search", "function": search_tool}
|
||||
]
|
||||
|
||||
|
||||
def get_file_tools():
|
||||
"""Get the list of file tool definitions."""
|
||||
return FILE_TOOLS
|
||||
478
tools/fuzzy_match.py
Normal file
478
tools/fuzzy_match.py
Normal file
|
|
@ -0,0 +1,478 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Fuzzy Matching Module for File Operations
|
||||
|
||||
Implements a multi-strategy matching chain to robustly find and replace text,
|
||||
accommodating variations in whitespace, indentation, and escaping common
|
||||
in LLM-generated code.
|
||||
|
||||
The 9-strategy chain (inspired by OpenCode):
|
||||
1. Exact match - Direct string comparison
|
||||
2. Line-trimmed - Strip leading/trailing whitespace per line
|
||||
3. Block anchor - Match first+last lines, use similarity for middle
|
||||
4. Whitespace normalized - Collapse multiple spaces/tabs to single space
|
||||
5. Indentation flexible - Ignore indentation differences entirely
|
||||
6. Escape normalized - Convert \\n literals to actual newlines
|
||||
7. Trimmed boundary - Trim first/last line whitespace only
|
||||
8. Context-aware - 50% line similarity threshold
|
||||
9. Multi-occurrence - For replace_all flag
|
||||
|
||||
Usage:
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
|
||||
new_content, match_count, error = fuzzy_find_and_replace(
|
||||
content="def foo():\\n pass",
|
||||
old_string="def foo():",
|
||||
new_string="def bar():",
|
||||
replace_all=False
|
||||
)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple, Optional, List, Callable
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
|
||||
def fuzzy_find_and_replace(content: str, old_string: str, new_string: str,
|
||||
replace_all: bool = False) -> Tuple[str, int, Optional[str]]:
|
||||
"""
|
||||
Find and replace text using a chain of increasingly fuzzy matching strategies.
|
||||
|
||||
Args:
|
||||
content: The file content to search in
|
||||
old_string: The text to find
|
||||
new_string: The replacement text
|
||||
replace_all: If True, replace all occurrences; if False, require uniqueness
|
||||
|
||||
Returns:
|
||||
Tuple of (new_content, match_count, error_message)
|
||||
- If successful: (modified_content, number_of_replacements, None)
|
||||
- If failed: (original_content, 0, error_description)
|
||||
"""
|
||||
if not old_string:
|
||||
return content, 0, "old_string cannot be empty"
|
||||
|
||||
if old_string == new_string:
|
||||
return content, 0, "old_string and new_string are identical"
|
||||
|
||||
# Try each matching strategy in order
|
||||
strategies: List[Tuple[str, Callable]] = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
("indentation_flexible", _strategy_indentation_flexible),
|
||||
("escape_normalized", _strategy_escape_normalized),
|
||||
("trimmed_boundary", _strategy_trimmed_boundary),
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, old_string)
|
||||
|
||||
if matches:
|
||||
# Found matches with this strategy
|
||||
if len(matches) > 1 and not replace_all:
|
||||
return content, 0, (
|
||||
f"Found {len(matches)} matches for old_string. "
|
||||
f"Provide more context to make it unique, or use replace_all=True."
|
||||
)
|
||||
|
||||
# Perform replacement
|
||||
new_content = _apply_replacements(content, matches, new_string)
|
||||
return new_content, len(matches), None
|
||||
|
||||
# No strategy found a match
|
||||
return content, 0, "Could not find a match for old_string in the file"
|
||||
|
||||
|
||||
def _apply_replacements(content: str, matches: List[Tuple[int, int]], new_string: str) -> str:
|
||||
"""
|
||||
Apply replacements at the given positions.
|
||||
|
||||
Args:
|
||||
content: Original content
|
||||
matches: List of (start, end) positions to replace
|
||||
new_string: Replacement text
|
||||
|
||||
Returns:
|
||||
Content with replacements applied
|
||||
"""
|
||||
# Sort matches by position (descending) to replace from end to start
|
||||
# This preserves positions of earlier matches
|
||||
sorted_matches = sorted(matches, key=lambda x: x[0], reverse=True)
|
||||
|
||||
result = content
|
||||
for start, end in sorted_matches:
|
||||
result = result[:start] + new_string + result[end:]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Matching Strategies
|
||||
# =============================================================================
|
||||
|
||||
def _strategy_exact(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""Strategy 1: Exact string match."""
|
||||
matches = []
|
||||
start = 0
|
||||
while True:
|
||||
pos = content.find(pattern, start)
|
||||
if pos == -1:
|
||||
break
|
||||
matches.append((pos, pos + len(pattern)))
|
||||
start = pos + 1
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_line_trimmed(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 2: Match with line-by-line whitespace trimming.
|
||||
|
||||
Strips leading/trailing whitespace from each line before matching.
|
||||
"""
|
||||
# Normalize pattern and content by trimming each line
|
||||
pattern_lines = [line.strip() for line in pattern.split('\n')]
|
||||
pattern_normalized = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_normalized_lines = [line.strip() for line in content_lines]
|
||||
|
||||
# Build mapping from normalized positions back to original positions
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_normalized_lines,
|
||||
pattern, pattern_normalized
|
||||
)
|
||||
|
||||
|
||||
def _strategy_whitespace_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 3: Collapse multiple whitespace to single space.
|
||||
"""
|
||||
def normalize(s):
|
||||
# Collapse multiple spaces/tabs to single space, preserve newlines
|
||||
return re.sub(r'[ \t]+', ' ', s)
|
||||
|
||||
pattern_normalized = normalize(pattern)
|
||||
content_normalized = normalize(content)
|
||||
|
||||
# Find in normalized, map back to original
|
||||
matches_in_normalized = _strategy_exact(content_normalized, pattern_normalized)
|
||||
|
||||
if not matches_in_normalized:
|
||||
return []
|
||||
|
||||
# Map positions back to original content
|
||||
return _map_normalized_positions(content, content_normalized, matches_in_normalized)
|
||||
|
||||
|
||||
def _strategy_indentation_flexible(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 4: Ignore indentation differences entirely.
|
||||
|
||||
Strips all leading whitespace from lines before matching.
|
||||
"""
|
||||
def strip_indent(s):
|
||||
return '\n'.join(line.lstrip() for line in s.split('\n'))
|
||||
|
||||
pattern_stripped = strip_indent(pattern)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
content_stripped_lines = [line.lstrip() for line in content_lines]
|
||||
pattern_lines = [line.lstrip() for line in pattern.split('\n')]
|
||||
|
||||
return _find_normalized_matches(
|
||||
content, content_lines, content_stripped_lines,
|
||||
pattern, '\n'.join(pattern_lines)
|
||||
)
|
||||
|
||||
|
||||
def _strategy_escape_normalized(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 5: Convert escape sequences to actual characters.
|
||||
|
||||
Handles \\n -> newline, \\t -> tab, etc.
|
||||
"""
|
||||
def unescape(s):
|
||||
# Convert common escape sequences
|
||||
return s.replace('\\n', '\n').replace('\\t', '\t').replace('\\r', '\r')
|
||||
|
||||
pattern_unescaped = unescape(pattern)
|
||||
|
||||
if pattern_unescaped == pattern:
|
||||
# No escapes to convert, skip this strategy
|
||||
return []
|
||||
|
||||
return _strategy_exact(content, pattern_unescaped)
|
||||
|
||||
|
||||
def _strategy_trimmed_boundary(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 6: Trim whitespace from first and last lines only.
|
||||
|
||||
Useful when the pattern boundaries have whitespace differences.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
# Trim only first and last lines
|
||||
pattern_lines[0] = pattern_lines[0].strip()
|
||||
if len(pattern_lines) > 1:
|
||||
pattern_lines[-1] = pattern_lines[-1].strip()
|
||||
|
||||
modified_pattern = '\n'.join(pattern_lines)
|
||||
|
||||
content_lines = content.split('\n')
|
||||
|
||||
# Search through content for matching block
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Trim first and last of this block
|
||||
check_lines = block_lines.copy()
|
||||
check_lines[0] = check_lines[0].strip()
|
||||
if len(check_lines) > 1:
|
||||
check_lines[-1] = check_lines[-1].strip()
|
||||
|
||||
if '\n'.join(check_lines) == modified_pattern:
|
||||
# Found match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_block_anchor(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 7: Match by anchoring on first and last lines.
|
||||
|
||||
If first and last lines match exactly, accept middle with 70% similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
if len(pattern_lines) < 2:
|
||||
return [] # Need at least 2 lines for anchoring
|
||||
|
||||
first_line = pattern_lines[0].strip()
|
||||
last_line = pattern_lines[-1].strip()
|
||||
|
||||
content_lines = content.split('\n')
|
||||
matches = []
|
||||
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
# Check if first and last lines match
|
||||
if (content_lines[i].strip() == first_line and
|
||||
content_lines[i + pattern_line_count - 1].strip() == last_line):
|
||||
|
||||
# Check middle similarity
|
||||
if pattern_line_count <= 2:
|
||||
# Only first and last, they match
|
||||
similarity = 1.0
|
||||
else:
|
||||
content_middle = '\n'.join(content_lines[i+1:i+pattern_line_count-1])
|
||||
pattern_middle = '\n'.join(pattern_lines[1:-1])
|
||||
similarity = SequenceMatcher(None, content_middle, pattern_middle).ratio()
|
||||
|
||||
if similarity >= 0.70:
|
||||
# Calculate positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _strategy_context_aware(content: str, pattern: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Strategy 8: Line-by-line similarity with 50% threshold.
|
||||
|
||||
Finds blocks where at least 50% of lines have high similarity.
|
||||
"""
|
||||
pattern_lines = pattern.split('\n')
|
||||
content_lines = content.split('\n')
|
||||
|
||||
if not pattern_lines:
|
||||
return []
|
||||
|
||||
matches = []
|
||||
pattern_line_count = len(pattern_lines)
|
||||
|
||||
for i in range(len(content_lines) - pattern_line_count + 1):
|
||||
block_lines = content_lines[i:i + pattern_line_count]
|
||||
|
||||
# Calculate line-by-line similarity
|
||||
high_similarity_count = 0
|
||||
for p_line, c_line in zip(pattern_lines, block_lines):
|
||||
sim = SequenceMatcher(None, p_line.strip(), c_line.strip()).ratio()
|
||||
if sim >= 0.80:
|
||||
high_similarity_count += 1
|
||||
|
||||
# Need at least 50% of lines to have high similarity
|
||||
if high_similarity_count >= len(pattern_lines) * 0.5:
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + pattern_line_count]) - 1
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Helper Functions
|
||||
# =============================================================================
|
||||
|
||||
def _find_normalized_matches(content: str, content_lines: List[str],
|
||||
content_normalized_lines: List[str],
|
||||
pattern: str, pattern_normalized: str) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Find matches in normalized content and map back to original positions.
|
||||
|
||||
Args:
|
||||
content: Original content string
|
||||
content_lines: Original content split by lines
|
||||
content_normalized_lines: Normalized content lines
|
||||
pattern: Original pattern
|
||||
pattern_normalized: Normalized pattern
|
||||
|
||||
Returns:
|
||||
List of (start, end) positions in the original content
|
||||
"""
|
||||
pattern_norm_lines = pattern_normalized.split('\n')
|
||||
num_pattern_lines = len(pattern_norm_lines)
|
||||
|
||||
matches = []
|
||||
|
||||
for i in range(len(content_normalized_lines) - num_pattern_lines + 1):
|
||||
# Check if this block matches
|
||||
block = '\n'.join(content_normalized_lines[i:i + num_pattern_lines])
|
||||
|
||||
if block == pattern_normalized:
|
||||
# Found a match - calculate original positions
|
||||
start_pos = sum(len(line) + 1 for line in content_lines[:i])
|
||||
end_pos = sum(len(line) + 1 for line in content_lines[:i + num_pattern_lines]) - 1
|
||||
|
||||
# Handle case where end is past content
|
||||
if end_pos >= len(content):
|
||||
end_pos = len(content)
|
||||
|
||||
matches.append((start_pos, end_pos))
|
||||
|
||||
return matches
|
||||
|
||||
|
||||
def _map_normalized_positions(original: str, normalized: str,
|
||||
normalized_matches: List[Tuple[int, int]]) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Map positions from normalized string back to original.
|
||||
|
||||
This is a best-effort mapping that works for whitespace normalization.
|
||||
"""
|
||||
if not normalized_matches:
|
||||
return []
|
||||
|
||||
# Build character mapping from normalized to original
|
||||
orig_to_norm = [] # orig_to_norm[i] = position in normalized
|
||||
|
||||
orig_idx = 0
|
||||
norm_idx = 0
|
||||
|
||||
while orig_idx < len(original) and norm_idx < len(normalized):
|
||||
if original[orig_idx] == normalized[norm_idx]:
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t' and normalized[norm_idx] == ' ':
|
||||
# Original has space/tab, normalized collapsed to space
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
# Don't advance norm_idx yet - wait until all whitespace consumed
|
||||
if orig_idx < len(original) and original[orig_idx] not in ' \t':
|
||||
norm_idx += 1
|
||||
elif original[orig_idx] in ' \t':
|
||||
# Extra whitespace in original
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
else:
|
||||
# Mismatch - shouldn't happen with our normalization
|
||||
orig_to_norm.append(norm_idx)
|
||||
orig_idx += 1
|
||||
|
||||
# Fill remaining
|
||||
while orig_idx < len(original):
|
||||
orig_to_norm.append(len(normalized))
|
||||
orig_idx += 1
|
||||
|
||||
# Reverse mapping: for each normalized position, find original range
|
||||
norm_to_orig_start = {}
|
||||
norm_to_orig_end = {}
|
||||
|
||||
for orig_pos, norm_pos in enumerate(orig_to_norm):
|
||||
if norm_pos not in norm_to_orig_start:
|
||||
norm_to_orig_start[norm_pos] = orig_pos
|
||||
norm_to_orig_end[norm_pos] = orig_pos
|
||||
|
||||
# Map matches
|
||||
original_matches = []
|
||||
for norm_start, norm_end in normalized_matches:
|
||||
# Find original start
|
||||
if norm_start in norm_to_orig_start:
|
||||
orig_start = norm_to_orig_start[norm_start]
|
||||
else:
|
||||
# Find nearest
|
||||
orig_start = min(i for i, n in enumerate(orig_to_norm) if n >= norm_start)
|
||||
|
||||
# Find original end
|
||||
if norm_end - 1 in norm_to_orig_end:
|
||||
orig_end = norm_to_orig_end[norm_end - 1] + 1
|
||||
else:
|
||||
orig_end = orig_start + (norm_end - norm_start)
|
||||
|
||||
# Expand to include trailing whitespace that was normalized
|
||||
while orig_end < len(original) and original[orig_end] in ' \t':
|
||||
orig_end += 1
|
||||
|
||||
original_matches.append((orig_start, min(orig_end, len(original))))
|
||||
|
||||
return original_matches
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utility Functions
|
||||
# =============================================================================
|
||||
|
||||
def find_best_match(content: str, pattern: str) -> Optional[Tuple[int, int, str]]:
|
||||
"""
|
||||
Find the best match for a pattern and return the strategy name.
|
||||
|
||||
Returns:
|
||||
Tuple of (start, end, strategy_name) or None if no match
|
||||
"""
|
||||
strategies = [
|
||||
("exact", _strategy_exact),
|
||||
("line_trimmed", _strategy_line_trimmed),
|
||||
("whitespace_normalized", _strategy_whitespace_normalized),
|
||||
("indentation_flexible", _strategy_indentation_flexible),
|
||||
("escape_normalized", _strategy_escape_normalized),
|
||||
("trimmed_boundary", _strategy_trimmed_boundary),
|
||||
("block_anchor", _strategy_block_anchor),
|
||||
("context_aware", _strategy_context_aware),
|
||||
]
|
||||
|
||||
for strategy_name, strategy_fn in strategies:
|
||||
matches = strategy_fn(content, pattern)
|
||||
if matches:
|
||||
return (matches[0][0], matches[0][1], strategy_name)
|
||||
|
||||
return None
|
||||
439
tools/patch_parser.py
Normal file
439
tools/patch_parser.py
Normal file
|
|
@ -0,0 +1,439 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
V4A Patch Format Parser
|
||||
|
||||
Parses the V4A patch format used by codex, cline, and other coding agents.
|
||||
|
||||
V4A Format:
|
||||
*** Begin Patch
|
||||
*** Update File: path/to/file.py
|
||||
@@ optional context hint @@
|
||||
context line (space prefix)
|
||||
-removed line (minus prefix)
|
||||
+added line (plus prefix)
|
||||
*** Add File: path/to/new.py
|
||||
+new file content
|
||||
+line 2
|
||||
*** Delete File: path/to/old.py
|
||||
*** Move File: old/path.py -> new/path.py
|
||||
*** End Patch
|
||||
|
||||
Usage:
|
||||
from tools.patch_parser import parse_v4a_patch, apply_v4a_operations
|
||||
|
||||
operations, error = parse_v4a_patch(patch_content)
|
||||
if error:
|
||||
print(f"Parse error: {error}")
|
||||
else:
|
||||
result = apply_v4a_operations(operations, file_ops)
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Any
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class OperationType(Enum):
|
||||
ADD = "add"
|
||||
UPDATE = "update"
|
||||
DELETE = "delete"
|
||||
MOVE = "move"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunkLine:
|
||||
"""A single line in a patch hunk."""
|
||||
prefix: str # ' ', '-', or '+'
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Hunk:
|
||||
"""A group of changes within a file."""
|
||||
context_hint: Optional[str] = None
|
||||
lines: List[HunkLine] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatchOperation:
|
||||
"""A single operation in a V4A patch."""
|
||||
operation: OperationType
|
||||
file_path: str
|
||||
new_path: Optional[str] = None # For move operations
|
||||
hunks: List[Hunk] = field(default_factory=list)
|
||||
content: Optional[str] = None # For add file operations
|
||||
|
||||
|
||||
def parse_v4a_patch(patch_content: str) -> Tuple[List[PatchOperation], Optional[str]]:
|
||||
"""
|
||||
Parse a V4A format patch.
|
||||
|
||||
Args:
|
||||
patch_content: The patch text in V4A format
|
||||
|
||||
Returns:
|
||||
Tuple of (operations, error_message)
|
||||
- If successful: (list_of_operations, None)
|
||||
- If failed: ([], error_description)
|
||||
"""
|
||||
lines = patch_content.split('\n')
|
||||
operations: List[PatchOperation] = []
|
||||
|
||||
# Find patch boundaries
|
||||
start_idx = None
|
||||
end_idx = None
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
if '*** Begin Patch' in line or '***Begin Patch' in line:
|
||||
start_idx = i
|
||||
elif '*** End Patch' in line or '***End Patch' in line:
|
||||
end_idx = i
|
||||
break
|
||||
|
||||
if start_idx is None:
|
||||
# Try to parse without explicit begin marker
|
||||
start_idx = -1
|
||||
|
||||
if end_idx is None:
|
||||
end_idx = len(lines)
|
||||
|
||||
# Parse operations between boundaries
|
||||
i = start_idx + 1
|
||||
current_op: Optional[PatchOperation] = None
|
||||
current_hunk: Optional[Hunk] = None
|
||||
|
||||
while i < end_idx:
|
||||
line = lines[i]
|
||||
|
||||
# Check for file operation markers
|
||||
update_match = re.match(r'\*\*\*\s*Update\s+File:\s*(.+)', line)
|
||||
add_match = re.match(r'\*\*\*\s*Add\s+File:\s*(.+)', line)
|
||||
delete_match = re.match(r'\*\*\*\s*Delete\s+File:\s*(.+)', line)
|
||||
move_match = re.match(r'\*\*\*\s*Move\s+File:\s*(.+?)\s*->\s*(.+)', line)
|
||||
|
||||
if update_match:
|
||||
# Save previous operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.UPDATE,
|
||||
file_path=update_match.group(1).strip()
|
||||
)
|
||||
current_hunk = None
|
||||
|
||||
elif add_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.ADD,
|
||||
file_path=add_match.group(1).strip()
|
||||
)
|
||||
current_hunk = Hunk()
|
||||
|
||||
elif delete_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.DELETE,
|
||||
file_path=delete_match.group(1).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif move_match:
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
current_op = PatchOperation(
|
||||
operation=OperationType.MOVE,
|
||||
file_path=move_match.group(1).strip(),
|
||||
new_path=move_match.group(2).strip()
|
||||
)
|
||||
operations.append(current_op)
|
||||
current_op = None
|
||||
current_hunk = None
|
||||
|
||||
elif line.startswith('@@'):
|
||||
# Context hint / hunk marker
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
|
||||
# Extract context hint
|
||||
hint_match = re.match(r'@@\s*(.+?)\s*@@', line)
|
||||
hint = hint_match.group(1) if hint_match else None
|
||||
current_hunk = Hunk(context_hint=hint)
|
||||
|
||||
elif current_op and line:
|
||||
# Parse hunk line
|
||||
if current_hunk is None:
|
||||
current_hunk = Hunk()
|
||||
|
||||
if line.startswith('+'):
|
||||
current_hunk.lines.append(HunkLine('+', line[1:]))
|
||||
elif line.startswith('-'):
|
||||
current_hunk.lines.append(HunkLine('-', line[1:]))
|
||||
elif line.startswith(' '):
|
||||
current_hunk.lines.append(HunkLine(' ', line[1:]))
|
||||
elif line.startswith('\\'):
|
||||
# "\ No newline at end of file" marker - skip
|
||||
pass
|
||||
else:
|
||||
# Treat as context line (implicit space prefix)
|
||||
current_hunk.lines.append(HunkLine(' ', line))
|
||||
|
||||
i += 1
|
||||
|
||||
# Don't forget the last operation
|
||||
if current_op:
|
||||
if current_hunk and current_hunk.lines:
|
||||
current_op.hunks.append(current_hunk)
|
||||
operations.append(current_op)
|
||||
|
||||
return operations, None
|
||||
|
||||
|
||||
def apply_v4a_operations(operations: List[PatchOperation],
|
||||
file_ops: Any) -> 'PatchResult':
|
||||
"""
|
||||
Apply V4A patch operations using a file operations interface.
|
||||
|
||||
Args:
|
||||
operations: List of PatchOperation from parse_v4a_patch
|
||||
file_ops: Object with read_file, write_file methods
|
||||
|
||||
Returns:
|
||||
PatchResult with results of all operations
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from tools.file_operations import PatchResult
|
||||
|
||||
files_modified = []
|
||||
files_created = []
|
||||
files_deleted = []
|
||||
all_diffs = []
|
||||
errors = []
|
||||
|
||||
for op in operations:
|
||||
try:
|
||||
if op.operation == OperationType.ADD:
|
||||
result = _apply_add(op, file_ops)
|
||||
if result[0]:
|
||||
files_created.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to add {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.DELETE:
|
||||
result = _apply_delete(op, file_ops)
|
||||
if result[0]:
|
||||
files_deleted.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to delete {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.MOVE:
|
||||
result = _apply_move(op, file_ops)
|
||||
if result[0]:
|
||||
files_modified.append(f"{op.file_path} -> {op.new_path}")
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to move {op.file_path}: {result[1]}")
|
||||
|
||||
elif op.operation == OperationType.UPDATE:
|
||||
result = _apply_update(op, file_ops)
|
||||
if result[0]:
|
||||
files_modified.append(op.file_path)
|
||||
all_diffs.append(result[1])
|
||||
else:
|
||||
errors.append(f"Failed to update {op.file_path}: {result[1]}")
|
||||
|
||||
except Exception as e:
|
||||
errors.append(f"Error processing {op.file_path}: {str(e)}")
|
||||
|
||||
# Run lint on all modified/created files
|
||||
lint_results = {}
|
||||
for f in files_modified + files_created:
|
||||
if hasattr(file_ops, '_check_lint'):
|
||||
lint_result = file_ops._check_lint(f)
|
||||
lint_results[f] = lint_result.to_dict()
|
||||
|
||||
combined_diff = '\n'.join(all_diffs)
|
||||
|
||||
if errors:
|
||||
return PatchResult(
|
||||
success=False,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None,
|
||||
error='; '.join(errors)
|
||||
)
|
||||
|
||||
return PatchResult(
|
||||
success=True,
|
||||
diff=combined_diff,
|
||||
files_modified=files_modified,
|
||||
files_created=files_created,
|
||||
files_deleted=files_deleted,
|
||||
lint=lint_results if lint_results else None
|
||||
)
|
||||
|
||||
|
||||
def _apply_add(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an add file operation."""
|
||||
# Extract content from hunks (all + lines)
|
||||
content_lines = []
|
||||
for hunk in op.hunks:
|
||||
for line in hunk.lines:
|
||||
if line.prefix == '+':
|
||||
content_lines.append(line.content)
|
||||
|
||||
content = '\n'.join(content_lines)
|
||||
|
||||
result = file_ops.write_file(op.file_path, content)
|
||||
if result.error:
|
||||
return False, result.error
|
||||
|
||||
diff = f"--- /dev/null\n+++ b/{op.file_path}\n"
|
||||
diff += '\n'.join(f"+{line}" for line in content_lines)
|
||||
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_delete(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a delete file operation."""
|
||||
# Read file first for diff
|
||||
read_result = file_ops.read_file(op.file_path)
|
||||
|
||||
if read_result.error and "not found" in read_result.error.lower():
|
||||
# File doesn't exist, nothing to delete
|
||||
return True, f"# {op.file_path} already deleted or doesn't exist"
|
||||
|
||||
# Delete by writing empty and then removing
|
||||
# Use shell command via the underlying environment
|
||||
rm_result = file_ops._exec(f"rm -f {file_ops._escape_shell_arg(op.file_path)}")
|
||||
|
||||
if rm_result.exit_code != 0:
|
||||
return False, rm_result.stdout
|
||||
|
||||
diff = f"--- a/{op.file_path}\n+++ /dev/null\n# File deleted"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_move(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply a move file operation."""
|
||||
# Use shell mv command
|
||||
mv_result = file_ops._exec(
|
||||
f"mv {file_ops._escape_shell_arg(op.file_path)} {file_ops._escape_shell_arg(op.new_path)}"
|
||||
)
|
||||
|
||||
if mv_result.exit_code != 0:
|
||||
return False, mv_result.stdout
|
||||
|
||||
diff = f"# Moved: {op.file_path} -> {op.new_path}"
|
||||
return True, diff
|
||||
|
||||
|
||||
def _apply_update(op: PatchOperation, file_ops: Any) -> Tuple[bool, str]:
|
||||
"""Apply an update file operation."""
|
||||
# Read current content
|
||||
read_result = file_ops.read_file(op.file_path, limit=10000)
|
||||
|
||||
if read_result.error:
|
||||
return False, f"Cannot read file: {read_result.error}"
|
||||
|
||||
# Parse content (remove line numbers)
|
||||
current_lines = []
|
||||
for line in read_result.content.split('\n'):
|
||||
if '|' in line:
|
||||
# Line format: " 123|content"
|
||||
parts = line.split('|', 1)
|
||||
if len(parts) == 2:
|
||||
current_lines.append(parts[1])
|
||||
else:
|
||||
current_lines.append(line)
|
||||
else:
|
||||
current_lines.append(line)
|
||||
|
||||
current_content = '\n'.join(current_lines)
|
||||
|
||||
# Apply each hunk
|
||||
new_content = current_content
|
||||
|
||||
for hunk in op.hunks:
|
||||
# Build search pattern from context and removed lines
|
||||
search_lines = []
|
||||
replace_lines = []
|
||||
|
||||
for line in hunk.lines:
|
||||
if line.prefix == ' ':
|
||||
search_lines.append(line.content)
|
||||
replace_lines.append(line.content)
|
||||
elif line.prefix == '-':
|
||||
search_lines.append(line.content)
|
||||
elif line.prefix == '+':
|
||||
replace_lines.append(line.content)
|
||||
|
||||
if search_lines:
|
||||
search_pattern = '\n'.join(search_lines)
|
||||
replacement = '\n'.join(replace_lines)
|
||||
|
||||
# Use fuzzy matching
|
||||
from tools.fuzzy_match import fuzzy_find_and_replace
|
||||
new_content, count, error = fuzzy_find_and_replace(
|
||||
new_content, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
if error and count == 0:
|
||||
# Try with context hint if available
|
||||
if hunk.context_hint:
|
||||
# Find the context hint location and search nearby
|
||||
hint_pos = new_content.find(hunk.context_hint)
|
||||
if hint_pos != -1:
|
||||
# Search in a window around the hint
|
||||
window_start = max(0, hint_pos - 500)
|
||||
window_end = min(len(new_content), hint_pos + 2000)
|
||||
window = new_content[window_start:window_end]
|
||||
|
||||
window_new, count, error = fuzzy_find_and_replace(
|
||||
window, search_pattern, replacement, replace_all=False
|
||||
)
|
||||
|
||||
if count > 0:
|
||||
new_content = new_content[:window_start] + window_new + new_content[window_end:]
|
||||
error = None
|
||||
|
||||
if error:
|
||||
return False, f"Could not apply hunk: {error}"
|
||||
|
||||
# Write new content
|
||||
write_result = file_ops.write_file(op.file_path, new_content)
|
||||
if write_result.error:
|
||||
return False, write_result.error
|
||||
|
||||
# Generate diff
|
||||
import difflib
|
||||
diff_lines = difflib.unified_diff(
|
||||
current_content.splitlines(keepends=True),
|
||||
new_content.splitlines(keepends=True),
|
||||
fromfile=f"a/{op.file_path}",
|
||||
tofile=f"b/{op.file_path}"
|
||||
)
|
||||
diff = ''.join(diff_lines)
|
||||
|
||||
return True, diff
|
||||
1321
tools/rl_training_tool.py
Normal file
1321
tools/rl_training_tool.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -35,6 +35,7 @@ DISTRIBUTIONS = {
|
|||
"vision": 100,
|
||||
"image_gen": 100,
|
||||
"terminal": 100,
|
||||
"file": 100,
|
||||
"moa": 100,
|
||||
"browser": 100
|
||||
}
|
||||
|
|
@ -66,10 +67,11 @@ DISTRIBUTIONS = {
|
|||
|
||||
# Scientific problem solving focused distribution
|
||||
"science": {
|
||||
"description": "Scientific research with web, terminal, and browser capabilities",
|
||||
"description": "Scientific research with web, terminal, file, and browser capabilities",
|
||||
"toolsets": {
|
||||
"web": 94, # 94% chance of web tools
|
||||
"terminal": 94, # 94% chance of terminal tools
|
||||
"file": 94, # 94% chance of file tools
|
||||
"vision": 65, # 65% chance of vision tools
|
||||
"browser": 50, # 50% chance of browser for accessing papers/databases
|
||||
"image_gen": 15, # 15% chance of image generation tools
|
||||
|
|
@ -79,9 +81,10 @@ DISTRIBUTIONS = {
|
|||
|
||||
# Development-focused distribution
|
||||
"development": {
|
||||
"description": "Terminal and reasoning with occasional web lookup",
|
||||
"description": "Terminal, file tools, and reasoning with occasional web lookup",
|
||||
"toolsets": {
|
||||
"terminal": 80, # 80% chance of terminal tools
|
||||
"file": 80, # 80% chance of file tools (read, write, patch, search)
|
||||
"moa": 60, # 60% chance of reasoning tools
|
||||
"web": 30, # 30% chance of web tools
|
||||
"vision": 10 # 10% chance of vision tools
|
||||
|
|
@ -108,6 +111,7 @@ DISTRIBUTIONS = {
|
|||
"vision": 50,
|
||||
"image_gen": 50,
|
||||
"terminal": 50,
|
||||
"file": 50,
|
||||
"moa": 50,
|
||||
"browser": 50
|
||||
}
|
||||
|
|
@ -123,17 +127,19 @@ DISTRIBUTIONS = {
|
|||
|
||||
# Terminal only
|
||||
"terminal_only": {
|
||||
"description": "Only terminal tool for code execution tasks",
|
||||
"description": "Terminal and file tools for code execution tasks",
|
||||
"toolsets": {
|
||||
"terminal": 100
|
||||
"terminal": 100,
|
||||
"file": 100
|
||||
}
|
||||
},
|
||||
|
||||
# Terminal + web (common for coding tasks that need docs)
|
||||
"terminal_web": {
|
||||
"description": "Terminal with web search for documentation lookup",
|
||||
"description": "Terminal and file tools with web search for documentation lookup",
|
||||
"toolsets": {
|
||||
"terminal": 100,
|
||||
"file": 100,
|
||||
"web": 100
|
||||
}
|
||||
},
|
||||
|
|
@ -188,9 +194,10 @@ DISTRIBUTIONS = {
|
|||
|
||||
# Terminal-focused tasks distribution (for nous-terminal-tasks.jsonl)
|
||||
"terminal_tasks": {
|
||||
"description": "Terminal-focused distribution with high terminal availability, occasional other tools",
|
||||
"description": "Terminal-focused distribution with high terminal/file availability, occasional other tools",
|
||||
"toolsets": {
|
||||
"terminal": 97, # 97% - terminal almost always available
|
||||
"file": 97, # 97% - file tools almost always available
|
||||
"web": 15, # 15% - web search/scrape for documentation
|
||||
"browser": 10, # 10% - browser occasionally for web interaction
|
||||
"vision": 8, # 8% - vision analysis rarely
|
||||
|
|
@ -200,10 +207,11 @@ DISTRIBUTIONS = {
|
|||
|
||||
# Mixed browser+terminal tasks distribution (for mixed-browser-terminal-tasks.jsonl)
|
||||
"mixed_tasks": {
|
||||
"description": "Mixed distribution with high browser and terminal availability for complex tasks",
|
||||
"description": "Mixed distribution with high browser, terminal, and file availability for complex tasks",
|
||||
"toolsets": {
|
||||
"browser": 92, # 92% - browser tools highly available
|
||||
"terminal": 92, # 92% - terminal highly available
|
||||
"terminal": 92, # 92% - terminal highly available
|
||||
"file": 92, # 92% - file tools highly available
|
||||
"web": 35, # 35% - web search/scrape fairly common
|
||||
"vision": 15, # 15% - vision analysis occasionally
|
||||
"image_gen": 15 # 15% - image generation occasionally
|
||||
|
|
|
|||
26
toolsets.py
26
toolsets.py
|
|
@ -90,12 +90,30 @@ TOOLSETS = {
|
|||
"includes": []
|
||||
},
|
||||
|
||||
"rl": {
|
||||
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
|
||||
"tools": [
|
||||
"rl_list_environments", "rl_select_environment",
|
||||
"rl_get_current_config", "rl_edit_config",
|
||||
"rl_start_training", "rl_check_status",
|
||||
"rl_stop_training", "rl_get_results",
|
||||
"rl_list_runs", "rl_test_inference"
|
||||
],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
"file": {
|
||||
"description": "File manipulation tools: read, write, patch (with fuzzy matching), and search (content + files)",
|
||||
"tools": ["read_file", "write_file", "patch", "search"],
|
||||
"includes": []
|
||||
},
|
||||
|
||||
# Scenario-specific toolsets
|
||||
|
||||
"debugging": {
|
||||
"description": "Debugging and troubleshooting toolkit",
|
||||
"tools": ["terminal"],
|
||||
"includes": ["web"] # For searching error messages and solutions
|
||||
"includes": ["web", "file"] # For searching error messages and solutions, and file operations
|
||||
},
|
||||
|
||||
"safe": {
|
||||
|
|
@ -115,6 +133,8 @@ TOOLSETS = {
|
|||
"web_search", "web_extract",
|
||||
# Terminal
|
||||
"terminal",
|
||||
# File manipulation
|
||||
"read_file", "write_file", "patch", "search",
|
||||
# Vision
|
||||
"vision_analyze",
|
||||
# Image generation
|
||||
|
|
@ -143,6 +163,8 @@ TOOLSETS = {
|
|||
"tools": [
|
||||
# Terminal - enabled with dangerous command approval system
|
||||
"terminal",
|
||||
# File manipulation
|
||||
"read_file", "write_file", "patch", "search",
|
||||
# Web tools
|
||||
"web_search", "web_extract",
|
||||
# Vision - analyze images sent by users
|
||||
|
|
@ -177,6 +199,8 @@ TOOLSETS = {
|
|||
"web_search", "web_extract",
|
||||
# Terminal - only for trusted personal accounts
|
||||
"terminal",
|
||||
# File manipulation
|
||||
"read_file", "write_file", "patch", "search",
|
||||
# Vision
|
||||
"vision_analyze",
|
||||
# Skills
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue