mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-27 06:11:40 +00:00
chore: remove Atropos RL environments and tinker-atropos integration (#26106)
* chore: remove Atropos RL environments, tools, tests, skill, and tinker-atropos submodule Delete: - environments/ (43 files — base env, agent loop, tool call parsers, benchmarks) - rl_cli.py (standalone RL training CLI) - tools/rl_training_tool.py (all 10 rl_* tools) - tests: test_rl_training_tool, test_tool_call_parsers, test_managed_server_tool_support, test_agent_loop, test_agent_loop_vllm, test_agent_loop_tool_calling, test_terminalbench2_env_security - optional-skills/mlops/hermes-atropos-environments/ - tinker-atropos git submodule + .gitmodules * chore: remove RL/Atropos references from Python source - toolsets.py: remove rl toolset block + update comment - model_tools.py: remove rl_tools group + update async bridging comment - hermes_cli/tools_config.py: remove RL display entry, _DEFAULT_OFF_TOOLSETS, setup block, and rl_training post-setup handler - tools/budget_config.py: remove RL environment reference in docstring - tests/test_model_tools.py: remove rl_tools from expected groups - tests/run_agent/test_streaming_tool_call_repair.py: fix stale cross-reference * chore: remove rl/yc-bench extras and tinker-atropos refs from pyproject.toml - Remove rl extra (atroposlib, tinker, fastapi, uvicorn, wandb) - Remove yc-bench extra - Remove rl_cli from py-modules - Remove [tool.ty.src] exclude for tinker-atropos - Remove [tool.ruff] exclude for tinker-atropos - Regenerate uv.lock * chore: remove tinker-atropos from install/setup scripts - setup-hermes.sh: remove entire tinker-atropos submodule install block - scripts/install.sh: remove both tinker-atropos blocks (Termux + standard) - scripts/install.ps1: remove tinker-atropos block - nix/hermes-agent.nix: remove tinker-atropos pip install line * chore: remove RL references from cli-config.yaml.example * docs: remove Atropos/RL references from README, CONTRIBUTING, AGENTS.md * docs: remove RL/Atropos references from website - Delete: environments.md, rl-training.md, mlops-hermes-atropos-environments.md - sidebars.ts: remove rl-training and environments sidebar entries - optional-skills-catalog.md: remove hermes-atropos-environments row - tools-reference.md: remove entire rl toolset section - toolsets-reference.md: remove rl row + update example - integrations/index.md: remove RL Training bullet - architecture.md: remove environments/ from tree + RL section - contributing.md: remove tinker-atropos setup - updating.md: remove tinker-atropos install + stale submodule update * chore: remove remaining RL/Atropos stragglers - hermes_cli/config.py: remove TINKER_API_KEY + WANDB_API_KEY env var defs - hermes_cli/doctor.py: remove Submodules check section (tinker-atropos) - hermes_cli/setup.py: remove RL Training status check - hermes_cli/status.py: remove Tinker + WandB from API key status display - agent/display.py: remove both rl_* tool preview/activity blocks - website/docs: remove RL references from providers.md + env-variables.md - tests: remove TINKER_API_KEY from conftest, set_config_value, setup_script * chore: remove RL training section from .env.example
This commit is contained in:
parent
d364132114
commit
5af672c753
97 changed files with 18 additions and 15690 deletions
18
.env.example
18
.env.example
|
|
@ -394,24 +394,6 @@ IMAGE_TOOLS_DEBUG=false
|
||||||
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
# CONTEXT_COMPRESSION_THRESHOLD=0.85 # Compress at 85% of context limit
|
||||||
# Model is set via compression.summary_model in config.yaml (default: google/gemini-3-flash-preview)
|
# Model is set via compression.summary_model in config.yaml (default: google/gemini-3-flash-preview)
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# RL TRAINING (Tinker + Atropos)
|
|
||||||
# =============================================================================
|
|
||||||
# Run reinforcement learning training on language models using the Tinker API.
|
|
||||||
# Requires the rl-server to be running (from tinker-atropos package).
|
|
||||||
|
|
||||||
# Tinker API Key - RL training service
|
|
||||||
# Get at: https://tinker-console.thinkingmachines.ai/keys
|
|
||||||
# TINKER_API_KEY=
|
|
||||||
|
|
||||||
# Weights & Biases API Key - Experiment tracking and metrics
|
|
||||||
# Get at: https://wandb.ai/authorize
|
|
||||||
# WANDB_API_KEY=
|
|
||||||
|
|
||||||
# RL API Server URL (default: http://localhost:8080)
|
|
||||||
# Change if running the rl-server on a different host/port
|
|
||||||
# RL_API_URL=http://localhost:8080
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# SKILLS HUB (GitHub integration for skill search/install/publish)
|
# SKILLS HUB (GitHub integration for skill search/install/publish)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
|
||||||
3
.gitmodules
vendored
3
.gitmodules
vendored
|
|
@ -1,3 +0,0 @@
|
||||||
[submodule "tinker-atropos"]
|
|
||||||
path = tinker-atropos
|
|
||||||
url = https://github.com/nousresearch/tinker-atropos
|
|
||||||
|
|
@ -56,7 +56,6 @@ hermes-agent/
|
||||||
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
├── tui_gateway/ # Python JSON-RPC backend for the TUI
|
||||||
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
├── acp_adapter/ # ACP server (VS Code / Zed / JetBrains integration)
|
||||||
├── cron/ # Scheduler — jobs.py, scheduler.py
|
├── cron/ # Scheduler — jobs.py, scheduler.py
|
||||||
├── environments/ # RL training environments (Atropos)
|
|
||||||
├── scripts/ # run_tests.sh, release.py, auxiliary scripts
|
├── scripts/ # run_tests.sh, release.py, auxiliary scripts
|
||||||
├── website/ # Docusaurus docs site
|
├── website/ # Docusaurus docs site
|
||||||
└── tests/ # Pytest suite (~17k tests across ~900 files as of May 2026)
|
└── tests/ # Pytest suite (~17k tests across ~900 files as of May 2026)
|
||||||
|
|
|
||||||
|
|
@ -91,9 +91,6 @@ export VIRTUAL_ENV="$(pwd)/venv"
|
||||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||||
uv pip install -e ".[all,dev]"
|
uv pip install -e ".[all,dev]"
|
||||||
|
|
||||||
# Optional: RL training submodule
|
|
||||||
# git submodule update --init tinker-atropos && uv pip install -e "./tinker-atropos"
|
|
||||||
|
|
||||||
# Optional: browser tools
|
# Optional: browser tools
|
||||||
npm install
|
npm install
|
||||||
```
|
```
|
||||||
|
|
@ -196,7 +193,6 @@ hermes-agent/
|
||||||
│
|
│
|
||||||
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
├── skills/ # Bundled skills (copied to ~/.hermes/skills/ on install)
|
||||||
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
├── optional-skills/ # Official optional skills (discoverable via hub, not activated by default)
|
||||||
├── environments/ # RL training environments (Atropos integration)
|
|
||||||
├── tests/ # Test suite
|
├── tests/ # Test suite
|
||||||
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
├── website/ # Documentation site (hermes-agent.nousresearch.com)
|
||||||
│
|
│
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@ Use any model you want — [Nous Portal](https://portal.nousresearch.com), [Open
|
||||||
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
<tr><td><b>Scheduled automations</b></td><td>Built-in cron scheduler with delivery to any platform. Daily reports, nightly backups, weekly audits — all in natural language, running unattended.</td></tr>
|
||||||
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
<tr><td><b>Delegates and parallelizes</b></td><td>Spawn isolated subagents for parallel workstreams. Write Python scripts that call tools via RPC, collapsing multi-step pipelines into zero-context-cost turns.</td></tr>
|
||||||
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Seven terminal backends — local, Docker, SSH, Singularity, Modal, Daytona, and Vercel Sandbox. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
<tr><td><b>Runs anywhere, not just your laptop</b></td><td>Seven terminal backends — local, Docker, SSH, Singularity, Modal, Daytona, and Vercel Sandbox. Daytona and Modal offer serverless persistence — your agent's environment hibernates when idle and wakes on demand, costing nearly nothing between sessions. Run it on a $5 VPS or a GPU cluster.</td></tr>
|
||||||
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, Atropos RL environments, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
<tr><td><b>Research-ready</b></td><td>Batch trajectory generation, trajectory compression for training the next generation of tool-calling models.</td></tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
@ -175,8 +175,6 @@ uv pip install -e ".[all,dev]"
|
||||||
scripts/run_tests.sh
|
scripts/run_tests.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
> **RL Training (optional):** The RL/Atropos integration (`environments/`) — see [`CONTRIBUTING.md`](https://github.com/NousResearch/hermes-agent/blob/main/CONTRIBUTING.md#development-setup) for the full setup.
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Community
|
## Community
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
<tr><td><b>定时自动化</b></td><td>内置 cron 调度器,支持向任何平台投递。日报、夜间备份、周审计——全部用自然语言描述,无人值守运行。</td></tr>
|
<tr><td><b>定时自动化</b></td><td>内置 cron 调度器,支持向任何平台投递。日报、夜间备份、周审计——全部用自然语言描述,无人值守运行。</td></tr>
|
||||||
<tr><td><b>委派与并行</b></td><td>生成隔离子代理处理并行工作流。编写 Python 脚本通过 RPC 调用工具,将多步管道压缩为零上下文开销的轮次。</td></tr>
|
<tr><td><b>委派与并行</b></td><td>生成隔离子代理处理并行工作流。编写 Python 脚本通过 RPC 调用工具,将多步管道压缩为零上下文开销的轮次。</td></tr>
|
||||||
<tr><td><b>随处运行</b></td><td>六种终端后端——本地、Docker、SSH、Daytona、Singularity 和 Modal。Daytona 和 Modal 提供 Serverless 持久化——代理环境空闲时休眠、按需唤醒,空闲期间几乎零成本。$5 VPS 或 GPU 集群都能跑。</td></tr>
|
<tr><td><b>随处运行</b></td><td>六种终端后端——本地、Docker、SSH、Daytona、Singularity 和 Modal。Daytona 和 Modal 提供 Serverless 持久化——代理环境空闲时休眠、按需唤醒,空闲期间几乎零成本。$5 VPS 或 GPU 集群都能跑。</td></tr>
|
||||||
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、Atropos RL 环境、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
<tr><td><b>研究就绪</b></td><td>批量轨迹生成、轨迹压缩——用于训练下一代工具调用模型。</td></tr>
|
||||||
</table>
|
</table>
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
@ -161,12 +161,6 @@ uv pip install -e ".[all,dev]"
|
||||||
python -m pytest tests/ -q
|
python -m pytest tests/ -q
|
||||||
```
|
```
|
||||||
|
|
||||||
> **RL 训练(可选):** 如需参与 RL/Tinker-Atropos 集成开发:
|
|
||||||
> ```bash
|
|
||||||
> git submodule update --init tinker-atropos
|
|
||||||
> uv pip install -e "./tinker-atropos"
|
|
||||||
> ```
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 社区
|
## 社区
|
||||||
|
|
|
||||||
|
|
@ -240,21 +240,6 @@ def build_tool_preview(tool_name: str, args: dict, max_len: int | None = None) -
|
||||||
msg = msg[:17] + "..."
|
msg = msg[:17] + "..."
|
||||||
return f"to {target}: \"{msg}\""
|
return f"to {target}: \"{msg}\""
|
||||||
|
|
||||||
if tool_name.startswith("rl_"):
|
|
||||||
rl_previews = {
|
|
||||||
"rl_list_environments": "listing envs",
|
|
||||||
"rl_select_environment": args.get("name", ""),
|
|
||||||
"rl_get_current_config": "reading config",
|
|
||||||
"rl_edit_config": f"{args.get('field', '')}={args.get('value', '')}",
|
|
||||||
"rl_start_training": "starting",
|
|
||||||
"rl_check_status": args.get("run_id", "")[:16],
|
|
||||||
"rl_stop_training": f"stopping {args.get('run_id', '')[:16]}",
|
|
||||||
"rl_get_results": args.get("run_id", "")[:16],
|
|
||||||
"rl_list_runs": "listing runs",
|
|
||||||
"rl_test_inference": f"{args.get('num_steps', 3)} steps",
|
|
||||||
}
|
|
||||||
return rl_previews.get(tool_name)
|
|
||||||
|
|
||||||
key = primary_args.get(tool_name)
|
key = primary_args.get(tool_name)
|
||||||
if not key:
|
if not key:
|
||||||
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
|
for fallback_key in ("query", "text", "command", "path", "name", "prompt", "code", "goal"):
|
||||||
|
|
@ -981,15 +966,6 @@ def get_cute_tool_message(
|
||||||
if action == "list":
|
if action == "list":
|
||||||
return _wrap(f"┊ ⏰ cron listing {dur}")
|
return _wrap(f"┊ ⏰ cron listing {dur}")
|
||||||
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
|
return _wrap(f"┊ ⏰ cron {action} {args.get('job_id', '')} {dur}")
|
||||||
if tool_name.startswith("rl_"):
|
|
||||||
rl = {
|
|
||||||
"rl_list_environments": "list envs", "rl_select_environment": f"select {args.get('name', '')}",
|
|
||||||
"rl_get_current_config": "get config", "rl_edit_config": f"set {args.get('field', '?')}",
|
|
||||||
"rl_start_training": "start training", "rl_check_status": f"status {args.get('run_id', '?')[:12]}",
|
|
||||||
"rl_stop_training": f"stop {args.get('run_id', '?')[:12]}", "rl_get_results": f"results {args.get('run_id', '?')[:12]}",
|
|
||||||
"rl_list_runs": "list runs", "rl_test_inference": "test inference",
|
|
||||||
}
|
|
||||||
return _wrap(f"┊ 🧪 rl {rl.get(tool_name, tool_name.replace('rl_', ''))} {dur}")
|
|
||||||
if tool_name == "execute_code":
|
if tool_name == "execute_code":
|
||||||
code = args.get("code", "")
|
code = args.get("code", "")
|
||||||
first_line = code.strip().split("\n")[0] if code.strip() else ""
|
first_line = code.strip().split("\n")[0] if code.strip() else ""
|
||||||
|
|
|
||||||
|
|
@ -457,7 +457,7 @@ prompt_caching:
|
||||||
# Two stores: MEMORY.md (agent's notes) and USER.md (user profile).
|
# Two stores: MEMORY.md (agent's notes) and USER.md (user profile).
|
||||||
# Character limits keep the memory small and focused. The agent manages
|
# Character limits keep the memory small and focused. The agent manages
|
||||||
# pruning -- when at the limit, it must consolidate or replace entries.
|
# pruning -- when at the limit, it must consolidate or replace entries.
|
||||||
# Disabled by default in batch_runner and RL environments.
|
# Disabled by default in batch_runner.
|
||||||
#
|
#
|
||||||
memory:
|
memory:
|
||||||
# Agent's personal notes: environment facts, conventions, things learned
|
# Agent's personal notes: environment facts, conventions, things learned
|
||||||
|
|
@ -715,10 +715,9 @@ platform_toolsets:
|
||||||
# todo - todo (in-memory task planning, no deps)
|
# todo - todo (in-memory task planning, no deps)
|
||||||
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key)
|
# tts - text_to_speech (Edge TTS free, or ELEVENLABS/OPENAI/MINIMAX/MISTRAL key)
|
||||||
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
# cronjob - cronjob (create/list/update/pause/resume/run/remove scheduled tasks)
|
||||||
# rl - rl_list_environments, rl_start_training, etc. (requires TINKER_API_KEY)
|
|
||||||
#
|
#
|
||||||
# PRESETS (curated bundles):
|
# PRESETS (curated bundles):
|
||||||
# hermes-cli - All of the above except rl + send_message
|
# hermes-cli - All of the above except send_message
|
||||||
# hermes-telegram - terminal, file, web, vision, image_gen, tts, browser,
|
# hermes-telegram - terminal, file, web, vision, image_gen, tts, browser,
|
||||||
# skills, todo, cronjob, send_message
|
# skills, todo, cronjob, send_message
|
||||||
# hermes-discord - Same as hermes-telegram
|
# hermes-discord - Same as hermes-telegram
|
||||||
|
|
@ -744,7 +743,6 @@ platform_toolsets:
|
||||||
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
# session_search - Search and recall past conversations (FTS5 + Gemini Flash summarization)
|
||||||
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral)
|
# tts - Text-to-speech (Edge TTS free, ElevenLabs, OpenAI, MiniMax, Mistral)
|
||||||
# cronjob - Schedule and manage automated tasks (CLI-only)
|
# cronjob - Schedule and manage automated tasks (CLI-only)
|
||||||
# rl - RL training tools (Tinker-Atropos)
|
|
||||||
#
|
#
|
||||||
# Composite toolsets:
|
# Composite toolsets:
|
||||||
# debugging - terminal + web + file (for troubleshooting)
|
# debugging - terminal + web + file (for troubleshooting)
|
||||||
|
|
|
||||||
|
|
@ -1,324 +0,0 @@
|
||||||
# Hermes-Agent Atropos Environments
|
|
||||||
|
|
||||||
This directory contains the integration layer between **hermes-agent's** tool-calling capabilities and the **Atropos** RL training framework. It provides everything needed to run agentic LLMs through multi-turn tool-calling loops, score their output with arbitrary reward functions, and feed results into Atropos for training or evaluation.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
```
|
|
||||||
Atropos Framework
|
|
||||||
┌───────────────────────┐
|
|
||||||
│ BaseEnv │ (atroposlib)
|
|
||||||
│ - Server management │
|
|
||||||
│ - Worker scheduling │
|
|
||||||
│ - Wandb logging │
|
|
||||||
│ - CLI (serve/process/ │
|
|
||||||
│ evaluate) │
|
|
||||||
└───────────┬───────────┘
|
|
||||||
│ inherits
|
|
||||||
┌───────────┴───────────┐
|
|
||||||
│ HermesAgentBaseEnv │ hermes_base_env.py
|
|
||||||
│ - Terminal backend │
|
|
||||||
│ - Tool resolution │
|
|
||||||
│ - Agent loop │
|
|
||||||
│ - ToolContext │
|
|
||||||
│ - Async patches │
|
|
||||||
└───────────┬───────────┘
|
|
||||||
│ inherits
|
|
||||||
┌─────────────────┼─────────────────┐
|
|
||||||
│ │ │
|
|
||||||
TerminalTestEnv HermesSweEnv TerminalBench2EvalEnv
|
|
||||||
(stack testing) (SWE training) (TB2 benchmark eval)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Inheritance Chain
|
|
||||||
|
|
||||||
**BaseEnv** (from `atroposlib`) is the Atropos base class. It provides:
|
|
||||||
- Server management (OpenAI-compatible API servers, VLLM, SGLang)
|
|
||||||
- Worker scheduling for parallel rollouts
|
|
||||||
- Wandb integration for metrics and rollout logging
|
|
||||||
- CLI interface with three subcommands: `serve`, `process`, `evaluate`
|
|
||||||
- `evaluate_log()` for saving eval results to JSON + samples.jsonl
|
|
||||||
|
|
||||||
**HermesAgentBaseEnv** (`hermes_base_env.py`) extends BaseEnv with hermes-agent specifics:
|
|
||||||
- Sets `os.environ["TERMINAL_ENV"]` to configure the terminal backend (local, docker, ssh, singularity, modal, daytona, vercel_sandbox)
|
|
||||||
- Resolves hermes-agent toolsets via `_resolve_tools_for_group()` (calls `get_tool_definitions()` which queries `tools/registry.py`)
|
|
||||||
- Implements `collect_trajectory()` which runs the full agent loop and computes rewards
|
|
||||||
- Supports two-phase operation (Phase 1: OpenAI server, Phase 2: VLLM ManagedServer)
|
|
||||||
- Applies monkey patches for async-safe tool operation at import time
|
|
||||||
|
|
||||||
Concrete environments inherit from `HermesAgentBaseEnv` and implement:
|
|
||||||
- `setup()` -- Load dataset, initialize state
|
|
||||||
- `get_next_item()` -- Return the next item for rollout
|
|
||||||
- `format_prompt()` -- Convert a dataset item into the user message
|
|
||||||
- `compute_reward()` -- Score the rollout using ToolContext
|
|
||||||
- `evaluate()` -- Periodic evaluation logic
|
|
||||||
|
|
||||||
## Core Components
|
|
||||||
|
|
||||||
### Agent Loop (`agent_loop.py`)
|
|
||||||
|
|
||||||
`HermesAgentLoop` is the reusable multi-turn agent engine. It runs the same pattern as hermes-agent's `run_agent.py`:
|
|
||||||
|
|
||||||
1. Send messages + tools to the API via `server.chat_completion()`
|
|
||||||
2. If the response contains `tool_calls`, execute each one via `handle_function_call()` (which delegates to `tools/registry.py`'s `dispatch()`)
|
|
||||||
3. Append tool results to the conversation and go back to step 1
|
|
||||||
4. If the response has no tool_calls, the agent is done
|
|
||||||
|
|
||||||
Tool calls are executed in a thread pool (`run_in_executor`) so backends that use `asyncio.run()` internally (Modal, Docker) don't deadlock inside Atropos's event loop.
|
|
||||||
|
|
||||||
Returns an `AgentResult` containing the full conversation history, turn count, reasoning content per turn, tool errors, and optional ManagedServer state (for Phase 2).
|
|
||||||
|
|
||||||
### Tool Context (`tool_context.py`)
|
|
||||||
|
|
||||||
`ToolContext` is a per-rollout handle that gives reward/verification functions direct access to **all** hermes-agent tools, scoped to the rollout's `task_id`. The same `task_id` means the terminal/browser session is the SAME one the model used during its rollout -- all state (files, processes, browser tabs) is preserved.
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
|
||||||
# Run tests in the model's terminal sandbox
|
|
||||||
test = ctx.terminal("pytest -v")
|
|
||||||
if test["exit_code"] == 0:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Check if a file was created
|
|
||||||
content = ctx.read_file("/workspace/solution.py")
|
|
||||||
if content.get("content"):
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
# Download files locally for verification (binary-safe)
|
|
||||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
|
||||||
|
|
||||||
return 0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
Available methods:
|
|
||||||
- **Terminal**: `terminal(command, timeout)` -- run shell commands
|
|
||||||
- **Files**: `read_file(path)`, `write_file(path, content)`, `search(query, path)`
|
|
||||||
- **Transfers**: `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` -- binary-safe file transfers between host and sandbox
|
|
||||||
- **Web**: `web_search(query)`, `web_extract(urls)`
|
|
||||||
- **Browser**: `browser_navigate(url)`, `browser_snapshot()`
|
|
||||||
- **Generic**: `call_tool(name, args)` -- call any hermes-agent tool by name
|
|
||||||
- **Cleanup**: `cleanup()` -- release all resources (called automatically after `compute_reward`)
|
|
||||||
|
|
||||||
### Patches (`patches.py`)
|
|
||||||
|
|
||||||
**Problem**: Some hermes-agent tools use `asyncio.run()` internally (e.g., the Modal backend). This crashes when called from inside Atropos's event loop because `asyncio.run()` cannot be nested.
|
|
||||||
|
|
||||||
**Solution**: `ModalEnvironment` uses a dedicated `_AsyncWorker` background thread with its own event loop. The calling code sees a sync interface, but internally all async Modal SDK calls happen on the worker thread so they don't conflict with Atropos's loop. This is built directly into `tools/environments/modal.py` — no monkey-patching required.
|
|
||||||
|
|
||||||
`patches.py` is now a no-op (kept for backward compatibility with imports).
|
|
||||||
|
|
||||||
### Tool Call Parsers (`tool_call_parsers/`)
|
|
||||||
|
|
||||||
Client-side parsers that extract structured `tool_calls` from raw model output text. Used in **Phase 2** (VLLM server type) where ManagedServer's `/generate` endpoint returns raw text without tool call parsing.
|
|
||||||
|
|
||||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's `extract_tool_calls()` logic. No VLLM dependency -- only standard library (`re`, `json`, `uuid`) and `openai` types.
|
|
||||||
|
|
||||||
Available parsers:
|
|
||||||
- `hermes` -- Hermes/ChatML `<tool_call>` XML format
|
|
||||||
- `mistral` -- Mistral `[TOOL_CALLS]` format
|
|
||||||
- `llama3_json` -- Llama 3 JSON tool calling
|
|
||||||
- `qwen` -- Qwen tool calling format
|
|
||||||
- `qwen3_coder` -- Qwen3 Coder format
|
|
||||||
- `deepseek_v3` -- DeepSeek V3 format
|
|
||||||
- `deepseek_v3_1` -- DeepSeek V3.1 format
|
|
||||||
- `kimi_k2` -- Kimi K2 format
|
|
||||||
- `longcat` -- Longcat format
|
|
||||||
- `glm45` / `glm47` -- GLM model formats
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
```python
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
content, tool_calls = parser.parse(raw_model_output)
|
|
||||||
```
|
|
||||||
|
|
||||||
In Phase 1 (OpenAI server type), these parsers are not needed -- the server handles tool call parsing natively.
|
|
||||||
|
|
||||||
## Two-Phase Operation
|
|
||||||
|
|
||||||
### Phase 1: OpenAI Server (Evaluation / SFT Data Generation)
|
|
||||||
|
|
||||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
|
||||||
|
|
||||||
- Good for: evaluation, SFT data generation, testing
|
|
||||||
- Run with: `serve` (with `run-api`), `process`, or `evaluate` subcommands
|
|
||||||
- Placeholder tokens are created for the Atropos pipeline
|
|
||||||
|
|
||||||
### Phase 2: VLLM ManagedServer (Full RL Training)
|
|
||||||
|
|
||||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. Client-side tool call parser (from `tool_call_parsers/`) reconstructs structured `tool_calls` from raw output.
|
|
||||||
|
|
||||||
- Good for: full RL training with GRPO/PPO
|
|
||||||
- Run with: `serve` subcommand
|
|
||||||
- Real tokens, masks, and logprobs flow through the pipeline
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
environments/
|
|
||||||
├── README.md # This file
|
|
||||||
├── __init__.py # Package exports
|
|
||||||
├── hermes_base_env.py # Abstract base (HermesAgentBaseEnv)
|
|
||||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
|
||||||
├── tool_context.py # Per-rollout tool access for reward functions
|
|
||||||
├── patches.py # Async-safety patches for Modal backend
|
|
||||||
│
|
|
||||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
|
||||||
│ ├── __init__.py # Registry + base class
|
|
||||||
│ ├── hermes_parser.py
|
|
||||||
│ ├── mistral_parser.py
|
|
||||||
│ ├── llama_parser.py
|
|
||||||
│ ├── qwen_parser.py
|
|
||||||
│ ├── qwen3_coder_parser.py
|
|
||||||
│ ├── deepseek_v3_parser.py
|
|
||||||
│ ├── deepseek_v3_1_parser.py
|
|
||||||
│ ├── kimi_k2_parser.py
|
|
||||||
│ ├── longcat_parser.py
|
|
||||||
│ ├── glm45_parser.py
|
|
||||||
│ └── glm47_parser.py
|
|
||||||
│
|
|
||||||
├── terminal_test_env/ # Stack validation environment
|
|
||||||
│ └── terminal_test_env.py
|
|
||||||
│
|
|
||||||
├── hermes_swe_env/ # SWE-bench style training environment
|
|
||||||
│ └── hermes_swe_env.py
|
|
||||||
│
|
|
||||||
└── benchmarks/ # Evaluation benchmarks
|
|
||||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
|
||||||
│ └── terminalbench2_env.py
|
|
||||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
|
||||||
│ └── tblite_env.py
|
|
||||||
└── yc_bench/ # Long-horizon strategic benchmark
|
|
||||||
└── yc_bench_env.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Concrete Environments
|
|
||||||
|
|
||||||
### TerminalTestEnv (`terminal_test_env/`)
|
|
||||||
|
|
||||||
A self-contained environment with inline tasks (no external dataset needed) for validating the full stack end-to-end. Each task asks the model to create a file at a known path, and the verifier checks the content matches.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Serve mode (needs run-api)
|
|
||||||
run-api
|
|
||||||
python environments/terminal_test_env/terminal_test_env.py serve
|
|
||||||
|
|
||||||
# Process mode (no run-api, saves to JSONL)
|
|
||||||
python environments/terminal_test_env/terminal_test_env.py process \
|
|
||||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
|
||||||
```
|
|
||||||
|
|
||||||
### HermesSweEnv (`hermes_swe_env/`)
|
|
||||||
|
|
||||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
|
||||||
--openai.model_name YourModel \
|
|
||||||
--env.dataset_name bigcode/humanevalpack \
|
|
||||||
--env.terminal_backend modal
|
|
||||||
```
|
|
||||||
|
|
||||||
### TerminalBench2EvalEnv (`benchmarks/terminalbench_2/`)
|
|
||||||
|
|
||||||
**Eval-only** environment for the Terminal-Bench 2.0 benchmark (89 tasks). Each task gets a pre-built Docker Hub image, a natural language instruction, and a test suite. The agent uses terminal + file tools to solve the task, then the test suite verifies correctness.
|
|
||||||
|
|
||||||
Follows the standard Atropos eval pattern (like GPQA, MMLU, etc.):
|
|
||||||
- Run via `evaluate` subcommand (no `run-api` needed)
|
|
||||||
- `setup()` loads the dataset, `evaluate()` runs all tasks
|
|
||||||
- `rollout_and_score_eval()` handles per-task agent loop + test verification
|
|
||||||
- Downloads verifier output locally for reliable reward checking (Harbor pattern)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run full benchmark
|
|
||||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
--openai.model_name anthropic/claude-opus-4.6
|
|
||||||
|
|
||||||
# Run subset of tasks
|
|
||||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
--openai.model_name anthropic/claude-opus-4.6 \
|
|
||||||
--env.task_filter fix-git,git-multibranch
|
|
||||||
|
|
||||||
# Skip specific tasks
|
|
||||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
--openai.model_name anthropic/claude-opus-4.6 \
|
|
||||||
--env.skip_tasks heavy-task,slow-task
|
|
||||||
```
|
|
||||||
|
|
||||||
## Creating a New Environment
|
|
||||||
|
|
||||||
### Training Environment
|
|
||||||
|
|
||||||
1. Create a new directory under `environments/`
|
|
||||||
2. Create your env file inheriting from `HermesAgentBaseEnv`
|
|
||||||
3. Implement the four abstract methods + `evaluate()`
|
|
||||||
|
|
||||||
```python
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
|
|
||||||
class MyEnvConfig(HermesAgentEnvConfig):
|
|
||||||
pass # Add custom fields as needed
|
|
||||||
|
|
||||||
class MyEnv(HermesAgentBaseEnv):
|
|
||||||
name = "my-env"
|
|
||||||
env_config_cls = MyEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls):
|
|
||||||
env_config = MyEnvConfig(
|
|
||||||
enabled_toolsets=["terminal", "file"],
|
|
||||||
terminal_backend="modal",
|
|
||||||
# ... other config
|
|
||||||
)
|
|
||||||
server_configs = [APIServerConfig(...)]
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
self.dataset = load_dataset(...)
|
|
||||||
self.iter = 0
|
|
||||||
|
|
||||||
async def get_next_item(self):
|
|
||||||
item = self.dataset[self.iter % len(self.dataset)]
|
|
||||||
self.iter += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def format_prompt(self, item):
|
|
||||||
return item["instruction"]
|
|
||||||
|
|
||||||
async def compute_reward(self, item, result, ctx):
|
|
||||||
# ctx gives you full tool access to the rollout's sandbox
|
|
||||||
test = ctx.terminal("pytest -v")
|
|
||||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs):
|
|
||||||
# Periodic evaluation logic
|
|
||||||
...
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyEnv.cli()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Eval-Only Environment (Benchmark)
|
|
||||||
|
|
||||||
For eval benchmarks, follow the pattern in `terminalbench2_env.py`:
|
|
||||||
1. Create under `environments/benchmarks/your-benchmark/`
|
|
||||||
2. Inherit from `HermesAgentBaseEnv`
|
|
||||||
3. Set eval-only config: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
|
||||||
4. Stub the training methods (`collect_trajectories`, `score`)
|
|
||||||
5. Implement `rollout_and_score_eval()` and `evaluate()`
|
|
||||||
6. Run with `evaluate` subcommand
|
|
||||||
|
|
||||||
## Key Config Fields
|
|
||||||
|
|
||||||
| Field | Description | Default |
|
|
||||||
|-------|-------------|---------|
|
|
||||||
| `enabled_toolsets` | Which hermes toolsets to enable | `None` (all) |
|
|
||||||
| `disabled_toolsets` | Toolsets to disable | `None` |
|
|
||||||
| `distribution` | Probabilistic toolset distribution name | `None` |
|
|
||||||
| `max_agent_turns` | Max LLM calls per rollout | `30` |
|
|
||||||
| `agent_temperature` | Sampling temperature | `1.0` |
|
|
||||||
| `terminal_backend` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` | `local` |
|
|
||||||
| `system_prompt` | System message for the agent | `None` |
|
|
||||||
| `tool_call_parser` | Parser name for Phase 2 | `hermes` |
|
|
||||||
| `eval_handling` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` | `STOP_TRAIN` |
|
|
||||||
|
|
@ -1,36 +0,0 @@
|
||||||
"""
|
|
||||||
Hermes-Agent Atropos Environments
|
|
||||||
|
|
||||||
Provides a layered integration between hermes-agent's tool-calling capabilities
|
|
||||||
and the Atropos RL training framework.
|
|
||||||
|
|
||||||
Core layers:
|
|
||||||
- agent_loop: Reusable multi-turn agent loop with standard OpenAI-spec tool calling
|
|
||||||
- tool_context: Per-rollout tool access handle for reward/verification functions
|
|
||||||
- hermes_base_env: Abstract base environment (BaseEnv subclass) for Atropos
|
|
||||||
- tool_call_parsers: Client-side tool call parser registry for Phase 2 (VLLM /generate)
|
|
||||||
|
|
||||||
Concrete environments:
|
|
||||||
- terminal_test_env/: Simple file-creation tasks for testing the stack
|
|
||||||
- hermes_swe_env/: SWE-bench style tasks with Modal sandboxes
|
|
||||||
|
|
||||||
Benchmarks (eval-only):
|
|
||||||
- benchmarks/terminalbench_2/: Terminal-Bench 2.0 evaluation
|
|
||||||
"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
except ImportError:
|
|
||||||
# atroposlib not installed — environments are unavailable but
|
|
||||||
# submodules like tool_call_parsers can still be imported directly.
|
|
||||||
pass
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AgentResult",
|
|
||||||
"HermesAgentLoop",
|
|
||||||
"ToolContext",
|
|
||||||
"HermesAgentBaseEnv",
|
|
||||||
"HermesAgentEnvConfig",
|
|
||||||
]
|
|
||||||
|
|
@ -1,534 +0,0 @@
|
||||||
"""
|
|
||||||
HermesAgentLoop -- Reusable Multi-Turn Agent Engine
|
|
||||||
|
|
||||||
Runs the hermes-agent tool-calling loop using standard OpenAI-spec tool calling.
|
|
||||||
Works with any server that returns ChatCompletion objects with tool_calls:
|
|
||||||
- Phase 1: OpenAI server type (VLLM, SGLang, OpenRouter, OpenAI API)
|
|
||||||
- Phase 2: ManagedServer with client-side tool call parser
|
|
||||||
|
|
||||||
The loop passes tools= and checks response.choices[0].message.tool_calls,
|
|
||||||
identical to hermes-agent's run_agent.py. Tool execution is dispatched via
|
|
||||||
handle_function_call() from model_tools.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import concurrent.futures
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Optional, Set
|
|
||||||
|
|
||||||
from model_tools import handle_function_call
|
|
||||||
from tools.terminal_tool import get_active_env
|
|
||||||
from tools.tool_result_storage import maybe_persist_tool_result, enforce_turn_budget
|
|
||||||
|
|
||||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
|
||||||
# (e.g., the Modal/Docker/Daytona terminal backends). Running them in a separate
|
|
||||||
# thread gives them a clean event loop so they don't deadlock inside Atropos's loop.
|
|
||||||
# Size must be large enough for concurrent eval tasks (e.g., 89 TB2 tasks all
|
|
||||||
# making tool calls). Too small = thread pool starvation, tasks queue for minutes.
|
|
||||||
# Resized at runtime by HermesAgentBaseEnv.__init__ via resize_tool_pool().
|
|
||||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=128)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_tool_pool(max_workers: int):
|
|
||||||
"""
|
|
||||||
Replace the global tool executor with a new one of the given size.
|
|
||||||
|
|
||||||
Called by HermesAgentBaseEnv.__init__ based on config.tool_pool_size.
|
|
||||||
Safe to call before any tasks are submitted.
|
|
||||||
"""
|
|
||||||
global _tool_executor
|
|
||||||
old_executor = _tool_executor
|
|
||||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
|
||||||
old_executor.shutdown(wait=False)
|
|
||||||
logger.info("Tool thread pool resized to %d workers", max_workers)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolError:
|
|
||||||
"""Record of a tool execution error during the agent loop."""
|
|
||||||
|
|
||||||
turn: int # Which turn the error occurred on
|
|
||||||
tool_name: str # Which tool was called
|
|
||||||
arguments: str # The arguments passed (truncated)
|
|
||||||
error: str # The error message
|
|
||||||
tool_result: str # The raw result returned to the model
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentResult:
|
|
||||||
"""Result of running the agent loop."""
|
|
||||||
|
|
||||||
# Full conversation history in OpenAI message format
|
|
||||||
messages: List[Dict[str, Any]]
|
|
||||||
# ManagedServer.get_state() if available (Phase 2), None otherwise
|
|
||||||
managed_state: Optional[Dict[str, Any]] = None
|
|
||||||
# How many LLM calls were made
|
|
||||||
turns_used: int = 0
|
|
||||||
# True if model stopped calling tools naturally (vs hitting max_turns)
|
|
||||||
finished_naturally: bool = False
|
|
||||||
# Extracted reasoning content per turn (from PR #297 helpers)
|
|
||||||
reasoning_per_turn: List[Optional[str]] = field(default_factory=list)
|
|
||||||
# Tool errors encountered during the loop
|
|
||||||
tool_errors: List[ToolError] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_reasoning_from_message(message) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Extract reasoning content from a ChatCompletion message.
|
|
||||||
|
|
||||||
Handles multiple provider formats:
|
|
||||||
1. message.reasoning_content field (some providers)
|
|
||||||
2. message.reasoning field (some providers)
|
|
||||||
3. message.reasoning_details[].text (OpenRouter style)
|
|
||||||
|
|
||||||
Note: <think> block extraction from content is NOT done here -- that's
|
|
||||||
handled by the response already in Phase 1 (server does it) or by
|
|
||||||
ManagedServer's patch in Phase 2.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message: The assistant message from ChatCompletion response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Extracted reasoning text, or None if not found
|
|
||||||
"""
|
|
||||||
# Check reasoning_content field (common across providers)
|
|
||||||
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
||||||
return message.reasoning_content
|
|
||||||
|
|
||||||
# Check reasoning field
|
|
||||||
if hasattr(message, "reasoning") and message.reasoning:
|
|
||||||
return message.reasoning
|
|
||||||
|
|
||||||
# Check reasoning_details (OpenRouter style)
|
|
||||||
if hasattr(message, "reasoning_details") and message.reasoning_details:
|
|
||||||
for detail in message.reasoning_details:
|
|
||||||
if hasattr(detail, "text") and detail.text:
|
|
||||||
return detail.text
|
|
||||||
if isinstance(detail, dict) and detail.get("text"):
|
|
||||||
return detail["text"]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class HermesAgentLoop:
|
|
||||||
"""
|
|
||||||
Runs hermes-agent's tool-calling loop using standard OpenAI-spec tool calling.
|
|
||||||
|
|
||||||
Same pattern as run_agent.py:
|
|
||||||
- Pass tools= to the API
|
|
||||||
- Check response.choices[0].message.tool_calls
|
|
||||||
- Dispatch via handle_function_call()
|
|
||||||
|
|
||||||
Works identically with any server type -- OpenAI, VLLM, SGLang, OpenRouter,
|
|
||||||
or ManagedServer with a parser. The server determines how tool_calls get
|
|
||||||
populated on the response.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
server,
|
|
||||||
tool_schemas: List[Dict[str, Any]],
|
|
||||||
valid_tool_names: Set[str],
|
|
||||||
max_turns: int = 30,
|
|
||||||
task_id: Optional[str] = None,
|
|
||||||
temperature: float = 1.0,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
extra_body: Optional[Dict[str, Any]] = None,
|
|
||||||
budget_config: Optional["BudgetConfig"] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the agent loop.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
server: Server object with chat_completion() method (OpenAIServer,
|
|
||||||
ManagedServer, ServerManager, etc.)
|
|
||||||
tool_schemas: OpenAI-format tool definitions from get_tool_definitions()
|
|
||||||
valid_tool_names: Set of tool names the model is allowed to call
|
|
||||||
max_turns: Maximum number of LLM calls before stopping
|
|
||||||
task_id: Unique ID for terminal/browser session isolation
|
|
||||||
temperature: Sampling temperature for generation
|
|
||||||
max_tokens: Max tokens per generation (None for server default)
|
|
||||||
extra_body: Extra parameters passed to the OpenAI client's create() call.
|
|
||||||
Used for OpenRouter provider preferences, transforms, etc.
|
|
||||||
e.g. {"provider": {"ignore": ["DeepInfra"]}}
|
|
||||||
budget_config: Tool result persistence budget. Controls per-tool
|
|
||||||
thresholds, per-turn aggregate budget, and preview size.
|
|
||||||
If None, uses DEFAULT_BUDGET (current hardcoded values).
|
|
||||||
"""
|
|
||||||
from tools.budget_config import DEFAULT_BUDGET
|
|
||||||
self.server = server
|
|
||||||
self.tool_schemas = tool_schemas
|
|
||||||
self.valid_tool_names = valid_tool_names
|
|
||||||
self.max_turns = max_turns
|
|
||||||
self.task_id = task_id or str(uuid.uuid4())
|
|
||||||
self.temperature = temperature
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
self.extra_body = extra_body
|
|
||||||
self.budget_config = budget_config or DEFAULT_BUDGET
|
|
||||||
|
|
||||||
async def run(self, messages: List[Dict[str, Any]]) -> AgentResult:
|
|
||||||
"""
|
|
||||||
Execute the full agent loop using standard OpenAI tool calling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Initial conversation messages (system + user).
|
|
||||||
Modified in-place as the conversation progresses.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
AgentResult with full conversation history, managed state, and metadata
|
|
||||||
"""
|
|
||||||
reasoning_per_turn = []
|
|
||||||
tool_errors: List[ToolError] = []
|
|
||||||
|
|
||||||
# Per-loop TodoStore for the todo tool (ephemeral, dies with the loop)
|
|
||||||
from tools.todo_tool import TodoStore, todo_tool as _todo_tool
|
|
||||||
_todo_store = TodoStore()
|
|
||||||
|
|
||||||
# Extract user task from first user message for browser_snapshot context
|
|
||||||
_user_task = None
|
|
||||||
for msg in messages:
|
|
||||||
if msg.get("role") == "user":
|
|
||||||
content = msg.get("content", "")
|
|
||||||
if isinstance(content, str) and content.strip():
|
|
||||||
_user_task = content.strip()[:500] # Cap to avoid huge strings
|
|
||||||
break
|
|
||||||
|
|
||||||
import time as _time
|
|
||||||
|
|
||||||
for turn in range(self.max_turns):
|
|
||||||
turn_start = _time.monotonic()
|
|
||||||
|
|
||||||
# Build the chat_completion kwargs
|
|
||||||
chat_kwargs = {
|
|
||||||
"messages": messages,
|
|
||||||
"n": 1,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Only pass tools if we have them
|
|
||||||
if self.tool_schemas:
|
|
||||||
chat_kwargs["tools"] = self.tool_schemas
|
|
||||||
|
|
||||||
# Only pass max_tokens if explicitly set
|
|
||||||
if self.max_tokens is not None:
|
|
||||||
chat_kwargs["max_tokens"] = self.max_tokens
|
|
||||||
|
|
||||||
# Inject extra_body for provider-specific params (e.g., OpenRouter
|
|
||||||
# provider preferences like banned/preferred providers, transforms)
|
|
||||||
if self.extra_body:
|
|
||||||
chat_kwargs["extra_body"] = self.extra_body
|
|
||||||
|
|
||||||
# Make the API call -- standard OpenAI spec
|
|
||||||
api_start = _time.monotonic()
|
|
||||||
try:
|
|
||||||
response = await self.server.chat_completion(**chat_kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
api_elapsed = _time.monotonic() - api_start
|
|
||||||
logger.error("API call failed on turn %d (%.1fs): %s", turn + 1, api_elapsed, e)
|
|
||||||
return AgentResult(
|
|
||||||
messages=messages,
|
|
||||||
managed_state=self._get_managed_state(),
|
|
||||||
turns_used=turn + 1,
|
|
||||||
finished_naturally=False,
|
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
|
||||||
tool_errors=tool_errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
api_elapsed = _time.monotonic() - api_start
|
|
||||||
|
|
||||||
if not response or not response.choices:
|
|
||||||
logger.warning("Empty response on turn %d (api=%.1fs)", turn + 1, api_elapsed)
|
|
||||||
return AgentResult(
|
|
||||||
messages=messages,
|
|
||||||
managed_state=self._get_managed_state(),
|
|
||||||
turns_used=turn + 1,
|
|
||||||
finished_naturally=False,
|
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
|
||||||
tool_errors=tool_errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
assistant_msg = response.choices[0].message
|
|
||||||
|
|
||||||
# Extract reasoning content from the response (all provider formats)
|
|
||||||
reasoning = _extract_reasoning_from_message(assistant_msg)
|
|
||||||
reasoning_per_turn.append(reasoning)
|
|
||||||
|
|
||||||
# Check for tool calls -- standard OpenAI spec.
|
|
||||||
# Fallback: if response has no structured tool_calls but content
|
|
||||||
# contains raw tool call tags (e.g. <tool_call>), parse them using
|
|
||||||
# hermes-agent's standalone parsers. This handles the case where
|
|
||||||
# ManagedServer's ToolCallTranslator couldn't parse because vLLM
|
|
||||||
# isn't installed.
|
|
||||||
if (
|
|
||||||
not assistant_msg.tool_calls
|
|
||||||
and assistant_msg.content
|
|
||||||
and self.tool_schemas
|
|
||||||
and "<tool_call>" in (assistant_msg.content or "")
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
fallback_parser = get_parser("hermes")
|
|
||||||
parsed_content, parsed_calls = fallback_parser.parse(
|
|
||||||
assistant_msg.content
|
|
||||||
)
|
|
||||||
if parsed_calls:
|
|
||||||
assistant_msg.tool_calls = parsed_calls
|
|
||||||
if parsed_content is not None:
|
|
||||||
assistant_msg.content = parsed_content
|
|
||||||
logger.debug(
|
|
||||||
"Fallback parser extracted %d tool calls from raw content",
|
|
||||||
len(parsed_calls),
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Fall through to no tool calls
|
|
||||||
|
|
||||||
if assistant_msg.tool_calls:
|
|
||||||
# Normalize tool calls to dicts — they may come as objects
|
|
||||||
# (OpenAI API) or dicts (vLLM ToolCallTranslator).
|
|
||||||
def _tc_to_dict(tc):
|
|
||||||
if isinstance(tc, dict):
|
|
||||||
return {
|
|
||||||
"id": tc.get("id", f"call_{uuid.uuid4().hex[:8]}"),
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.get("function", {}).get("name", tc.get("name", "")),
|
|
||||||
"arguments": tc.get("function", {}).get("arguments", tc.get("arguments", "{}")),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
"id": tc.id,
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Build the assistant message dict for conversation history
|
|
||||||
msg_dict: Dict[str, Any] = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": assistant_msg.content or "",
|
|
||||||
"tool_calls": [_tc_to_dict(tc) for tc in assistant_msg.tool_calls],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Preserve reasoning_content for multi-turn chat template handling
|
|
||||||
# (e.g., Kimi-K2's template renders <think> blocks differently
|
|
||||||
# for history vs. the latest turn based on this field)
|
|
||||||
if reasoning:
|
|
||||||
msg_dict["reasoning_content"] = reasoning
|
|
||||||
|
|
||||||
messages.append(msg_dict)
|
|
||||||
|
|
||||||
# Execute each tool call via hermes-agent's dispatch
|
|
||||||
for tc in assistant_msg.tool_calls:
|
|
||||||
# Handle both object (OpenAI) and dict (vLLM) formats
|
|
||||||
if isinstance(tc, dict):
|
|
||||||
tool_name = tc.get("function", {}).get("name", tc.get("name", ""))
|
|
||||||
tool_args_raw = tc.get("function", {}).get("arguments", tc.get("arguments", "{}"))
|
|
||||||
else:
|
|
||||||
tool_name = tc.function.name
|
|
||||||
tool_args_raw = tc.function.arguments
|
|
||||||
|
|
||||||
# Validate tool name
|
|
||||||
if tool_name not in self.valid_tool_names:
|
|
||||||
tool_result = json.dumps(
|
|
||||||
{
|
|
||||||
"error": f"Unknown tool '{tool_name}'. "
|
|
||||||
f"Available tools: {sorted(self.valid_tool_names)}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tool_errors.append(ToolError(
|
|
||||||
turn=turn + 1, tool_name=tool_name,
|
|
||||||
arguments=tool_args_raw[:200],
|
|
||||||
error=f"Unknown tool '{tool_name}'",
|
|
||||||
tool_result=tool_result,
|
|
||||||
))
|
|
||||||
logger.warning(
|
|
||||||
"Model called unknown tool '%s' on turn %d",
|
|
||||||
tool_name, turn + 1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Parse arguments
|
|
||||||
try:
|
|
||||||
args = json.loads(tool_args_raw)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
args = None
|
|
||||||
tool_result = json.dumps(
|
|
||||||
{"error": f"Invalid JSON in tool arguments: {e}. Please retry with valid JSON."}
|
|
||||||
)
|
|
||||||
tool_errors.append(ToolError(
|
|
||||||
turn=turn + 1, tool_name=tool_name,
|
|
||||||
arguments=tool_args_raw[:200],
|
|
||||||
error=f"Invalid JSON: {e}",
|
|
||||||
tool_result=tool_result,
|
|
||||||
))
|
|
||||||
logger.warning(
|
|
||||||
"Invalid JSON in tool call arguments for '%s': %s",
|
|
||||||
tool_name, tool_args_raw[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dispatch tool only if arguments parsed successfully
|
|
||||||
if args is not None:
|
|
||||||
try:
|
|
||||||
if tool_name == "terminal":
|
|
||||||
backend = os.getenv("TERMINAL_ENV", "local")
|
|
||||||
cmd_preview = args.get("command", "")[:80]
|
|
||||||
logger.info(
|
|
||||||
"[%s] $ %s", self.task_id[:8], cmd_preview,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_submit_time = _time.monotonic()
|
|
||||||
|
|
||||||
# Todo tool -- handle locally (needs per-loop TodoStore)
|
|
||||||
if tool_name == "todo":
|
|
||||||
tool_result = _todo_tool(
|
|
||||||
todos=args.get("todos"),
|
|
||||||
merge=args.get("merge", False),
|
|
||||||
store=_todo_store,
|
|
||||||
)
|
|
||||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
|
||||||
elif tool_name == "memory":
|
|
||||||
tool_result = json.dumps({"error": "Memory is not available in RL environments."})
|
|
||||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
|
||||||
elif tool_name == "session_search":
|
|
||||||
tool_result = json.dumps({"error": "Session search is not available in RL environments."})
|
|
||||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
|
||||||
else:
|
|
||||||
# Run tool calls in a thread pool so backends that
|
|
||||||
# use asyncio.run() internally (modal, docker, daytona) get
|
|
||||||
# a clean event loop instead of deadlocking.
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
# Capture current tool_name/args for the lambda
|
|
||||||
_tn, _ta, _tid = tool_name, args, self.task_id
|
|
||||||
tool_result = await loop.run_in_executor(
|
|
||||||
_tool_executor,
|
|
||||||
lambda: handle_function_call(
|
|
||||||
_tn, _ta, task_id=_tid,
|
|
||||||
user_task=_user_task,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
tool_elapsed = _time.monotonic() - tool_submit_time
|
|
||||||
|
|
||||||
# Log slow tools and thread pool stats for debugging
|
|
||||||
pool_active = _tool_executor._work_queue.qsize()
|
|
||||||
if tool_elapsed > 30:
|
|
||||||
logger.warning(
|
|
||||||
"[%s] turn %d: %s took %.1fs (pool queue=%d)",
|
|
||||||
self.task_id[:8], turn + 1, tool_name,
|
|
||||||
tool_elapsed, pool_active,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
tool_result = json.dumps(
|
|
||||||
{"error": f"Tool execution failed: {type(e).__name__}: {str(e)}"}
|
|
||||||
)
|
|
||||||
tool_errors.append(ToolError(
|
|
||||||
turn=turn + 1, tool_name=tool_name,
|
|
||||||
arguments=tool_args_raw[:200],
|
|
||||||
error=f"{type(e).__name__}: {str(e)}",
|
|
||||||
tool_result=tool_result,
|
|
||||||
))
|
|
||||||
logger.error(
|
|
||||||
"Tool '%s' execution failed on turn %d: %s",
|
|
||||||
tool_name, turn + 1, e,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Also check if the tool returned an error in its JSON result
|
|
||||||
try:
|
|
||||||
result_data = json.loads(tool_result)
|
|
||||||
if isinstance(result_data, dict):
|
|
||||||
err = result_data.get("error")
|
|
||||||
exit_code = result_data.get("exit_code")
|
|
||||||
if err and exit_code and exit_code < 0:
|
|
||||||
tool_errors.append(ToolError(
|
|
||||||
turn=turn + 1, tool_name=tool_name,
|
|
||||||
arguments=tool_args_raw[:200],
|
|
||||||
error=str(err),
|
|
||||||
tool_result=tool_result[:500],
|
|
||||||
))
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
tc_id = tc.get("id", "") if isinstance(tc, dict) else tc.id
|
|
||||||
tool_result = maybe_persist_tool_result(
|
|
||||||
content=tool_result,
|
|
||||||
tool_name=tool_name,
|
|
||||||
tool_use_id=tc_id,
|
|
||||||
env=get_active_env(self.task_id),
|
|
||||||
config=self.budget_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"role": "tool",
|
|
||||||
"tool_call_id": tc_id,
|
|
||||||
"content": tool_result,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
num_tcs = len(assistant_msg.tool_calls)
|
|
||||||
if num_tcs > 0:
|
|
||||||
enforce_turn_budget(
|
|
||||||
messages[-num_tcs:],
|
|
||||||
env=get_active_env(self.task_id),
|
|
||||||
config=self.budget_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
turn_elapsed = _time.monotonic() - turn_start
|
|
||||||
logger.info(
|
|
||||||
"[%s] turn %d: api=%.1fs, %d tools, turn_total=%.1fs",
|
|
||||||
self.task_id[:8], turn + 1, api_elapsed,
|
|
||||||
len(assistant_msg.tool_calls), turn_elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# No tool calls -- model is done
|
|
||||||
msg_dict = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": assistant_msg.content or "",
|
|
||||||
}
|
|
||||||
if reasoning:
|
|
||||||
msg_dict["reasoning_content"] = reasoning
|
|
||||||
messages.append(msg_dict)
|
|
||||||
|
|
||||||
turn_elapsed = _time.monotonic() - turn_start
|
|
||||||
logger.info(
|
|
||||||
"[%s] turn %d: api=%.1fs, no tools (finished), turn_total=%.1fs",
|
|
||||||
self.task_id[:8], turn + 1, api_elapsed, turn_elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
return AgentResult(
|
|
||||||
messages=messages,
|
|
||||||
managed_state=self._get_managed_state(),
|
|
||||||
turns_used=turn + 1,
|
|
||||||
finished_naturally=True,
|
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
|
||||||
tool_errors=tool_errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hit max turns without the model stopping
|
|
||||||
logger.info("Agent hit max_turns (%d) without finishing", self.max_turns)
|
|
||||||
return AgentResult(
|
|
||||||
messages=messages,
|
|
||||||
managed_state=self._get_managed_state(),
|
|
||||||
turns_used=self.max_turns,
|
|
||||||
finished_naturally=False,
|
|
||||||
reasoning_per_turn=reasoning_per_turn,
|
|
||||||
tool_errors=tool_errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_managed_state(self) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get ManagedServer state if the server supports it.
|
|
||||||
|
|
||||||
Returns state dict with SequenceNodes containing tokens/logprobs/masks,
|
|
||||||
or None if the server doesn't support get_state() (e.g., regular OpenAI server).
|
|
||||||
"""
|
|
||||||
if hasattr(self.server, "get_state"):
|
|
||||||
return self.server.get_state()
|
|
||||||
return None
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,73 +0,0 @@
|
||||||
# OpenThoughts-TBLite Evaluation Environment
|
|
||||||
|
|
||||||
This environment evaluates terminal agents on the [OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite) benchmark, a difficulty-calibrated subset of [Terminal-Bench 2.0](https://www.tbench.ai/leaderboard/terminal-bench/2.0).
|
|
||||||
|
|
||||||
## Source
|
|
||||||
|
|
||||||
OpenThoughts-TBLite was created by the [OpenThoughts](https://www.openthoughts.ai/) Agent team in collaboration with [Snorkel AI](https://snorkel.ai/) and [Bespoke Labs](https://bespokelabs.ai/). The original dataset and documentation live at:
|
|
||||||
|
|
||||||
- **Dataset (source):** [open-thoughts/OpenThoughts-TBLite](https://huggingface.co/datasets/open-thoughts/OpenThoughts-TBLite)
|
|
||||||
- **GitHub:** [open-thoughts/OpenThoughts-TBLite](https://github.com/open-thoughts/OpenThoughts-TBLite)
|
|
||||||
- **Blog post:** [openthoughts.ai/blog/openthoughts-tblite](https://www.openthoughts.ai/blog/openthoughts-tblite)
|
|
||||||
|
|
||||||
## Our Dataset
|
|
||||||
|
|
||||||
We converted the source into the same schema used by our Terminal-Bench 2.0 environment (pre-built Docker Hub images, base64-encoded test tarballs, etc.) and published it as:
|
|
||||||
|
|
||||||
- **Dataset (ours):** [NousResearch/openthoughts-tblite](https://huggingface.co/datasets/NousResearch/openthoughts-tblite)
|
|
||||||
- **Docker images:** `nousresearch/tblite-<task-name>:latest` on Docker Hub (100 images)
|
|
||||||
|
|
||||||
The conversion script is at `scripts/prepare_tblite_dataset.py`.
|
|
||||||
|
|
||||||
## Why TBLite?
|
|
||||||
|
|
||||||
Terminal-Bench 2.0 is one of the strongest frontier evaluations for terminal agents, but when a model scores near the floor (e.g., Qwen 3 8B at <1%), many changes look identical in aggregate score. TBLite addresses this by calibrating task difficulty using Claude Haiku 4.5 as a reference:
|
|
||||||
|
|
||||||
| Difficulty | Pass Rate Range | Tasks |
|
|
||||||
|------------|----------------|-------|
|
|
||||||
| Easy | >= 70% | 40 |
|
|
||||||
| Medium | 40-69% | 26 |
|
|
||||||
| Hard | 10-39% | 26 |
|
|
||||||
| Extreme | < 10% | 8 |
|
|
||||||
|
|
||||||
This gives enough solvable tasks to detect small improvements quickly, while preserving enough hard tasks to avoid saturation. The correlation between TBLite and TB2 scores is **r = 0.911**.
|
|
||||||
|
|
||||||
TBLite also runs 2.6-8x faster than the full TB2, making it practical for iteration loops.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run the full benchmark
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
|
||||||
|
|
||||||
# Filter to specific tasks
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
--env.task_filter "broken-python,pandas-etl"
|
|
||||||
|
|
||||||
# Use a different model
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
--server.model_name "qwen/qwen3-30b"
|
|
||||||
```
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
`TBLiteEvalEnv` is a thin subclass of `TerminalBench2EvalEnv`. All evaluation logic (agent loop, Docker sandbox management, test verification, metrics) is inherited. Only the defaults differ:
|
|
||||||
|
|
||||||
| Setting | TB2 | TBLite |
|
|
||||||
|----------------|----------------------------------|-----------------------------------------|
|
|
||||||
| Dataset | `NousResearch/terminal-bench-2` | `NousResearch/openthoughts-tblite` |
|
|
||||||
| Tasks | 89 | 100 |
|
|
||||||
| Task timeout | 1800s (30 min) | 1200s (20 min) |
|
|
||||||
| Wandb name | `terminal-bench-2` | `openthoughts-tblite` |
|
|
||||||
|
|
||||||
## Citation
|
|
||||||
|
|
||||||
```bibtex
|
|
||||||
@software{OpenThoughts-TBLite,
|
|
||||||
author = {OpenThoughts-Agent team, Snorkel AI, Bespoke Labs},
|
|
||||||
month = Feb,
|
|
||||||
title = {{OpenThoughts-TBLite: A High-Signal Benchmark for Iterating on Terminal Agents}},
|
|
||||||
howpublished = {https://www.openthoughts.ai/blog/openthoughts-tblite},
|
|
||||||
year = {2026}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
# OpenThoughts-TBLite Evaluation -- Default Configuration
|
|
||||||
#
|
|
||||||
# Eval-only environment for the TBLite benchmark (100 difficulty-calibrated
|
|
||||||
# terminal tasks, a faster proxy for Terminal-Bench 2.0).
|
|
||||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
|
||||||
# and OpenRouter for inference.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/tblite/default.yaml
|
|
||||||
#
|
|
||||||
# # Override model:
|
|
||||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/tblite/default.yaml \
|
|
||||||
# --openai.model_name anthropic/claude-sonnet-4
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 60
|
|
||||||
max_token_length: 32000
|
|
||||||
agent_temperature: 0.8
|
|
||||||
terminal_backend: "modal"
|
|
||||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
|
||||||
tool_pool_size: 128 # thread pool for 100 parallel tasks
|
|
||||||
dataset_name: "NousResearch/openthoughts-tblite"
|
|
||||||
test_timeout: 600
|
|
||||||
task_timeout: 1200 # 20 min wall-clock per task (TBLite tasks are faster)
|
|
||||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_name: "openthoughts-tblite"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite"
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-opus-4.6"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
|
||||||
|
|
@ -1,38 +0,0 @@
|
||||||
# OpenThoughts-TBLite Evaluation -- Docker Backend (Local Compute)
|
|
||||||
#
|
|
||||||
# Runs tasks in Docker containers on the local machine.
|
|
||||||
# Sandboxed like Modal but no cloud costs. Good for dev/testing.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/tblite/local.yaml
|
|
||||||
#
|
|
||||||
# # Override concurrency:
|
|
||||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/tblite/local.yaml \
|
|
||||||
# --env.eval_concurrency 4
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 60
|
|
||||||
max_token_length: 32000
|
|
||||||
agent_temperature: 0.8
|
|
||||||
terminal_backend: "docker"
|
|
||||||
terminal_timeout: 300
|
|
||||||
tool_pool_size: 16
|
|
||||||
dataset_name: "NousResearch/openthoughts-tblite"
|
|
||||||
test_timeout: 600
|
|
||||||
task_timeout: 1200
|
|
||||||
eval_concurrency: 8 # max 8 tasks at once
|
|
||||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
use_wandb: false
|
|
||||||
wandb_name: "openthoughts-tblite-local"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
data_dir_to_save_evals: "environments/benchmarks/evals/openthoughts-tblite-local"
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-sonnet-4"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
||||||
# OpenThoughts-TBLite Evaluation -- Local vLLM Backend
|
|
||||||
#
|
|
||||||
# Runs against a local vLLM server with Docker sandboxes.
|
|
||||||
#
|
|
||||||
# Start the vLLM server from the atropos directory:
|
|
||||||
# python -m example_trainer.vllm_api_server \
|
|
||||||
# --model Qwen/Qwen3-4B-Instruct-2507 \
|
|
||||||
# --port 9001 \
|
|
||||||
# --gpu-memory-utilization 0.8 \
|
|
||||||
# --max-model-len=32000
|
|
||||||
#
|
|
||||||
# Then run:
|
|
||||||
# python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/tblite/local_vllm.yaml
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 60
|
|
||||||
max_token_length: 16000
|
|
||||||
agent_temperature: 0.6
|
|
||||||
terminal_backend: "docker"
|
|
||||||
terminal_timeout: 300
|
|
||||||
tool_pool_size: 16
|
|
||||||
dataset_name: "NousResearch/openthoughts-tblite"
|
|
||||||
test_timeout: 600
|
|
||||||
task_timeout: 1200
|
|
||||||
eval_concurrency: 8
|
|
||||||
tool_call_parser: "hermes"
|
|
||||||
system_prompt: "You are an expert terminal agent. You MUST use the provided tools to complete tasks. Use the terminal tool to run shell commands, read_file to read files, write_file to write files, search_files to search, and patch to edit files. Do NOT write out solutions as text - execute them using the tools. Always start by exploring the environment with terminal commands."
|
|
||||||
tokenizer_name: "Qwen/Qwen3-4B-Instruct-2507"
|
|
||||||
use_wandb: false
|
|
||||||
wandb_name: "tblite-qwen3-4b-instruct"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
data_dir_to_save_evals: "environments/benchmarks/evals/tblite-qwen3-4b-local"
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "http://localhost:9001"
|
|
||||||
model_name: "Qwen/Qwen3-4B-Instruct-2507"
|
|
||||||
server_type: "vllm"
|
|
||||||
health_check: false
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# OpenThoughts-TBLite Evaluation
|
|
||||||
#
|
|
||||||
# Run from repo root:
|
|
||||||
# bash environments/benchmarks/tblite/run_eval.sh
|
|
||||||
#
|
|
||||||
# Override model:
|
|
||||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
|
||||||
# --openai.model_name anthropic/claude-sonnet-4
|
|
||||||
#
|
|
||||||
# Run a subset:
|
|
||||||
# bash environments/benchmarks/tblite/run_eval.sh \
|
|
||||||
# --env.task_filter broken-python,pandas-etl
|
|
||||||
#
|
|
||||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
|
||||||
# configured via env config fields -- no env vars needed.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
mkdir -p logs evals/openthoughts-tblite
|
|
||||||
LOG_FILE="logs/tblite_$(date +%Y%m%d_%H%M%S).log"
|
|
||||||
|
|
||||||
echo "OpenThoughts-TBLite Evaluation"
|
|
||||||
echo "Log file: $LOG_FILE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Unbuffered python output so logs are written in real-time
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
|
||||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
|
||||||
export LOGLEVEL=INFO
|
|
||||||
|
|
||||||
python tblite_env.py evaluate \
|
|
||||||
--config default.yaml \
|
|
||||||
"$@" \
|
|
||||||
2>&1 | tee "$LOG_FILE"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Log saved to: $LOG_FILE"
|
|
||||||
echo "Eval results: evals/openthoughts-tblite/"
|
|
||||||
|
|
@ -1,119 +0,0 @@
|
||||||
"""
|
|
||||||
OpenThoughts-TBLite Evaluation Environment
|
|
||||||
|
|
||||||
A lighter, faster alternative to Terminal-Bench 2.0 for iterating on terminal
|
|
||||||
agents. Uses the same evaluation logic as TerminalBench2EvalEnv but defaults
|
|
||||||
to the NousResearch/openthoughts-tblite dataset (100 difficulty-calibrated
|
|
||||||
tasks vs TB2's 89 harder tasks).
|
|
||||||
|
|
||||||
TBLite tasks are a curated subset of TB2 with a difficulty distribution
|
|
||||||
designed to give meaningful signal even for smaller models:
|
|
||||||
- Easy (40 tasks): >= 70% pass rate with Claude Haiku 4.5
|
|
||||||
- Medium (26 tasks): 40-69% pass rate
|
|
||||||
- Hard (26 tasks): 10-39% pass rate
|
|
||||||
- Extreme (8 tasks): < 10% pass rate
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate
|
|
||||||
|
|
||||||
# Filter to specific tasks:
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate \\
|
|
||||||
--env.task_filter "broken-python,pandas-etl"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from atroposlib.envs.base import EvalHandlingEnum
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
|
|
||||||
from environments.benchmarks.terminalbench_2.terminalbench2_env import (
|
|
||||||
TerminalBench2EvalConfig,
|
|
||||||
TerminalBench2EvalEnv,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TBLiteEvalConfig(TerminalBench2EvalConfig):
|
|
||||||
"""Configuration for the OpenThoughts-TBLite evaluation environment.
|
|
||||||
|
|
||||||
Inherits all TB2 config fields. Only the dataset default and task timeout
|
|
||||||
differ -- TBLite tasks are calibrated to be faster.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_name: str = Field(
|
|
||||||
default="NousResearch/openthoughts-tblite",
|
|
||||||
description="HuggingFace dataset containing TBLite tasks.",
|
|
||||||
)
|
|
||||||
|
|
||||||
task_timeout: int = Field(
|
|
||||||
default=1200,
|
|
||||||
description="Maximum wall-clock seconds per task. TBLite tasks are "
|
|
||||||
"generally faster than TB2, so 20 minutes is usually sufficient.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TBLiteEvalEnv(TerminalBench2EvalEnv):
|
|
||||||
"""OpenThoughts-TBLite evaluation environment.
|
|
||||||
|
|
||||||
Inherits all evaluation logic from TerminalBench2EvalEnv (agent loop,
|
|
||||||
test verification, Docker image resolution, metrics, wandb logging).
|
|
||||||
Only the default configuration differs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "openthoughts-tblite"
|
|
||||||
env_config_cls = TBLiteEvalConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls) -> Tuple[TBLiteEvalConfig, List[APIServerConfig]]:
|
|
||||||
env_config = TBLiteEvalConfig(
|
|
||||||
enabled_toolsets=["terminal", "file"],
|
|
||||||
disabled_toolsets=None,
|
|
||||||
distribution=None,
|
|
||||||
|
|
||||||
max_agent_turns=60,
|
|
||||||
max_token_length=16000,
|
|
||||||
agent_temperature=0.6,
|
|
||||||
system_prompt=None,
|
|
||||||
|
|
||||||
terminal_backend="modal",
|
|
||||||
terminal_timeout=300,
|
|
||||||
|
|
||||||
test_timeout=180,
|
|
||||||
|
|
||||||
# 100 tasks in parallel
|
|
||||||
tool_pool_size=128,
|
|
||||||
|
|
||||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
|
||||||
group_size=1,
|
|
||||||
steps_per_eval=1,
|
|
||||||
total_steps=1,
|
|
||||||
|
|
||||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="openthoughts-tblite",
|
|
||||||
ensure_scores_are_not_same=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
server_configs = [
|
|
||||||
APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name="anthropic/claude-sonnet-4",
|
|
||||||
server_type="openai",
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
|
||||||
health_check=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
TBLiteEvalEnv.cli()
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
# Terminal-Bench 2.0 Evaluation -- Default Configuration
|
|
||||||
#
|
|
||||||
# Eval-only environment for the TB2 benchmark (89 terminal tasks).
|
|
||||||
# Uses Modal terminal backend for per-task cloud-isolated sandboxes
|
|
||||||
# and OpenRouter for inference.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/terminalbench_2/default.yaml
|
|
||||||
#
|
|
||||||
# # Override model:
|
|
||||||
# python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/terminalbench_2/default.yaml \
|
|
||||||
# --openai.model_name anthropic/claude-sonnet-4
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 60
|
|
||||||
max_token_length: 32000
|
|
||||||
agent_temperature: 0.8
|
|
||||||
terminal_backend: "modal"
|
|
||||||
terminal_timeout: 300 # 5 min per command (builds, pip install)
|
|
||||||
tool_pool_size: 128 # thread pool for 89 parallel tasks
|
|
||||||
dataset_name: "NousResearch/terminal-bench-2"
|
|
||||||
test_timeout: 600
|
|
||||||
task_timeout: 1800 # 30 min wall-clock per task, auto-FAIL if exceeded
|
|
||||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_name: "terminal-bench-2"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
data_dir_to_save_evals: "environments/benchmarks/evals/terminal-bench-2"
|
|
||||||
# CRITICAL: Limit concurrent Modal sandbox creations to avoid deadlocks.
|
|
||||||
# Modal's blocking calls (App.lookup, etc.) deadlock when too many sandboxes
|
|
||||||
# are created simultaneously inside thread pool workers via asyncio.run().
|
|
||||||
max_concurrent_tasks: 8
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-opus-4.6"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
|
||||||
|
|
@ -1,42 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# Terminal-Bench 2.0 Evaluation
|
|
||||||
#
|
|
||||||
# Run from repo root:
|
|
||||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh
|
|
||||||
#
|
|
||||||
# Override model:
|
|
||||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
|
||||||
# --openai.model_name anthropic/claude-sonnet-4
|
|
||||||
#
|
|
||||||
# Run a subset:
|
|
||||||
# bash environments/benchmarks/terminalbench_2/run_eval.sh \
|
|
||||||
# --env.task_filter fix-git,git-multibranch
|
|
||||||
#
|
|
||||||
# All terminal settings (backend, timeout, lifetime, pool size) are
|
|
||||||
# configured via env config fields -- no env vars needed.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
mkdir -p logs evals/terminal-bench-2
|
|
||||||
LOG_FILE="logs/terminalbench2_$(date +%Y%m%d_%H%M%S).log"
|
|
||||||
|
|
||||||
echo "Terminal-Bench 2.0 Evaluation"
|
|
||||||
echo "Log file: $LOG_FILE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Unbuffered python output so logs are written in real-time
|
|
||||||
export PYTHONUNBUFFERED=1
|
|
||||||
|
|
||||||
# Show INFO-level agent loop timing (api/tool durations per turn)
|
|
||||||
# These go to the log file; tqdm + [START]/[PASS]/[FAIL] go to terminal
|
|
||||||
export LOGLEVEL=INFO
|
|
||||||
|
|
||||||
python terminalbench2_env.py evaluate \
|
|
||||||
--config default.yaml \
|
|
||||||
"$@" \
|
|
||||||
2>&1 | tee "$LOG_FILE"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Log saved to: $LOG_FILE"
|
|
||||||
echo "Eval results: evals/terminal-bench-2/"
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,115 +0,0 @@
|
||||||
# YC-Bench: Long-Horizon Agent Benchmark
|
|
||||||
|
|
||||||
[YC-Bench](https://github.com/collinear-ai/yc-bench) by [Collinear AI](https://collinear.ai/) is a deterministic, long-horizon benchmark that tests LLM agents' ability to act as a tech startup CEO. The agent manages a simulated company over 1-3 years, making compounding decisions about resource allocation, cash flow, task management, and prestige specialisation across 4 skill domains.
|
|
||||||
|
|
||||||
Unlike TerminalBench2 (which evaluates per-task coding ability with binary pass/fail), YC-Bench measures **long-term strategic coherence** — whether an agent can maintain consistent strategy, manage compounding consequences, and adapt plans over hundreds of turns.
|
|
||||||
|
|
||||||
## Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install yc-bench (optional dependency)
|
|
||||||
pip install "hermes-agent[yc-bench]"
|
|
||||||
|
|
||||||
# Or install from source
|
|
||||||
git clone https://github.com/collinear-ai/yc-bench
|
|
||||||
cd yc-bench && pip install -e .
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
yc-bench --help
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# From the repo root:
|
|
||||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
|
||||||
|
|
||||||
# Or directly:
|
|
||||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
--config environments/benchmarks/yc_bench/default.yaml
|
|
||||||
|
|
||||||
# Override model:
|
|
||||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
|
||||||
--openai.model_name anthropic/claude-opus-4-20250514
|
|
||||||
|
|
||||||
# Quick single-preset test:
|
|
||||||
bash environments/benchmarks/yc_bench/run_eval.sh \
|
|
||||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
|
||||||
```
|
|
||||||
|
|
||||||
## How It Works
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
HermesAgentLoop (our agent)
|
|
||||||
-> terminal tool -> subprocess("yc-bench company status") -> JSON output
|
|
||||||
-> terminal tool -> subprocess("yc-bench task accept --task-id X") -> JSON
|
|
||||||
-> terminal tool -> subprocess("yc-bench sim resume") -> JSON (advance time)
|
|
||||||
-> ... (100-500 turns per run)
|
|
||||||
```
|
|
||||||
|
|
||||||
The environment initialises the simulation via `yc-bench sim init` (NOT `yc-bench run`, which would start yc-bench's own built-in agent loop). Our `HermesAgentLoop` then drives all interaction through CLI commands.
|
|
||||||
|
|
||||||
### Simulation Mechanics
|
|
||||||
|
|
||||||
- **4 skill domains**: research, inference, data_environment, training
|
|
||||||
- **Prestige system** (1.0-10.0): Gates access to higher-paying tasks
|
|
||||||
- **Employee management**: Junior/Mid/Senior with domain-specific skill rates
|
|
||||||
- **Throughput splitting**: `effective_rate = base_rate / N` active tasks per employee
|
|
||||||
- **Financial pressure**: Monthly payroll, bankruptcy = game over
|
|
||||||
- **Deterministic**: SHA256-based RNG — same seed + preset = same world
|
|
||||||
|
|
||||||
### Difficulty Presets
|
|
||||||
|
|
||||||
| Preset | Employees | Tasks | Focus |
|
|
||||||
|-----------|-----------|-------|-------|
|
|
||||||
| tutorial | 3 | 50 | Basic loop mechanics |
|
|
||||||
| easy | 5 | 100 | Throughput awareness |
|
|
||||||
| **medium**| 5 | 150 | Prestige climbing + domain specialisation |
|
|
||||||
| **hard** | 7 | 200 | Precise ETA reasoning |
|
|
||||||
| nightmare | 8 | 300 | Sustained perfection under payroll pressure |
|
|
||||||
| fast_test | (varies) | (varies) | Quick validation (~50 turns) |
|
|
||||||
|
|
||||||
Default eval runs **fast_test + medium + hard** × 3 seeds = 9 runs.
|
|
||||||
|
|
||||||
### Scoring
|
|
||||||
|
|
||||||
```
|
|
||||||
composite = 0.5 × survival + 0.5 × normalised_funds
|
|
||||||
```
|
|
||||||
|
|
||||||
- **Survival** (binary): Did the company avoid bankruptcy?
|
|
||||||
- **Normalised funds** (0.0-1.0): Log-scale relative to initial $250K capital
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
Key fields in `default.yaml`:
|
|
||||||
|
|
||||||
| Field | Default | Description |
|
|
||||||
|-------|---------|-------------|
|
|
||||||
| `presets` | `["fast_test", "medium", "hard"]` | Which presets to evaluate |
|
|
||||||
| `seeds` | `[1, 2, 3]` | RNG seeds per preset |
|
|
||||||
| `max_agent_turns` | 200 | Max LLM calls per run |
|
|
||||||
| `run_timeout` | 3600 | Wall-clock timeout per run (seconds) |
|
|
||||||
| `survival_weight` | 0.5 | Weight of survival in composite score |
|
|
||||||
| `funds_weight` | 0.5 | Weight of normalised funds in composite |
|
|
||||||
| `horizon_years` | null | Override horizon (null = auto from preset) |
|
|
||||||
|
|
||||||
## Cost & Time Estimates
|
|
||||||
|
|
||||||
Each run is 100-500 LLM turns. Approximate costs per run at typical API rates:
|
|
||||||
|
|
||||||
| Preset | Turns | Time | Est. Cost |
|
|
||||||
|--------|-------|------|-----------|
|
|
||||||
| fast_test | ~50 | 5-10 min | $1-5 |
|
|
||||||
| medium | ~200 | 20-40 min | $5-15 |
|
|
||||||
| hard | ~300 | 30-60 min | $10-25 |
|
|
||||||
|
|
||||||
Full default eval (9 runs): ~3-6 hours, $50-200 depending on model.
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — Official repository
|
|
||||||
- [Collinear AI](https://collinear.ai/) — Company behind yc-bench
|
|
||||||
- [TerminalBench2](../terminalbench_2/) — Per-task coding benchmark (complementary)
|
|
||||||
|
|
@ -1,43 +0,0 @@
|
||||||
# YC-Bench Evaluation -- Default Configuration
|
|
||||||
#
|
|
||||||
# Long-horizon agent benchmark: agent plays CEO of an AI startup over
|
|
||||||
# a simulated 1-3 year run, interacting via yc-bench CLI subcommands.
|
|
||||||
#
|
|
||||||
# Requires: pip install "hermes-agent[yc-bench]"
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/yc_bench/default.yaml
|
|
||||||
#
|
|
||||||
# # Override model:
|
|
||||||
# python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
# --config environments/benchmarks/yc_bench/default.yaml \
|
|
||||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal"]
|
|
||||||
max_agent_turns: 200
|
|
||||||
max_token_length: 32000
|
|
||||||
agent_temperature: 0.0
|
|
||||||
terminal_backend: "local"
|
|
||||||
terminal_timeout: 60
|
|
||||||
presets: ["fast_test", "medium", "hard"]
|
|
||||||
seeds: [1, 2, 3]
|
|
||||||
run_timeout: 3600 # 60 min wall-clock per run, auto-FAIL if exceeded
|
|
||||||
survival_weight: 0.5 # weight of binary survival in composite score
|
|
||||||
funds_weight: 0.5 # weight of normalised final funds in composite score
|
|
||||||
db_dir: "/tmp/yc_bench_dbs"
|
|
||||||
company_name: "BenchCo"
|
|
||||||
start_date: "01/01/2025" # MM/DD/YYYY (yc-bench convention)
|
|
||||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_name: "yc-bench"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
data_dir_to_save_evals: "environments/benchmarks/evals/yc-bench"
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-sonnet-4.6"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# YC-Bench Evaluation
|
|
||||||
#
|
|
||||||
# Requires: pip install "hermes-agent[yc-bench]"
|
|
||||||
#
|
|
||||||
# Run from repo root:
|
|
||||||
# bash environments/benchmarks/yc_bench/run_eval.sh
|
|
||||||
#
|
|
||||||
# Override model:
|
|
||||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
|
||||||
# --openai.model_name anthropic/claude-opus-4-20250514
|
|
||||||
#
|
|
||||||
# Run a single preset:
|
|
||||||
# bash environments/benchmarks/yc_bench/run_eval.sh \
|
|
||||||
# --env.presets '["fast_test"]' --env.seeds '[1]'
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
mkdir -p logs evals/yc-bench
|
|
||||||
LOG_FILE="logs/yc_bench_$(date +%Y%m%d_%H%M%S).log"
|
|
||||||
|
|
||||||
echo "YC-Bench Evaluation"
|
|
||||||
echo "Log: $LOG_FILE"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
PYTHONUNBUFFERED=1 LOGLEVEL="${LOGLEVEL:-INFO}" \
|
|
||||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
--config environments/benchmarks/yc_bench/default.yaml \
|
|
||||||
"$@" \
|
|
||||||
2>&1 | tee "$LOG_FILE"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Log saved to: $LOG_FILE"
|
|
||||||
|
|
@ -1,848 +0,0 @@
|
||||||
"""
|
|
||||||
YCBenchEvalEnv -- YC-Bench Long-Horizon Agent Benchmark Environment
|
|
||||||
|
|
||||||
Evaluates agentic LLMs on YC-Bench: a deterministic, long-horizon benchmark
|
|
||||||
where the agent acts as CEO of an AI startup over a simulated 1-3 year run.
|
|
||||||
The agent manages cash flow, employees, tasks, and prestige across 4 domains,
|
|
||||||
interacting exclusively via CLI subprocess calls against a SQLite-backed
|
|
||||||
discrete-event simulation.
|
|
||||||
|
|
||||||
Unlike TerminalBench2 (per-task binary pass/fail), YC-Bench measures sustained
|
|
||||||
multi-turn strategic coherence -- whether an agent can manage compounding
|
|
||||||
decisions over hundreds of turns without going bankrupt.
|
|
||||||
|
|
||||||
This is an eval-only environment. Run via:
|
|
||||||
|
|
||||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
--config environments/benchmarks/yc_bench/default.yaml
|
|
||||||
|
|
||||||
The evaluate flow:
|
|
||||||
1. setup() -- Verifies yc-bench installed, builds eval matrix (preset x seed)
|
|
||||||
2. evaluate() -- Iterates over all runs sequentially through:
|
|
||||||
a. rollout_and_score_eval() -- Per-run agent loop
|
|
||||||
- Initialises a fresh yc-bench simulation via `sim init` (NOT `run`)
|
|
||||||
- Runs HermesAgentLoop with terminal tool only
|
|
||||||
- Reads final SQLite DB to extract score
|
|
||||||
- Returns survival (0/1) + normalised funds score
|
|
||||||
b. Aggregates per-preset and overall metrics
|
|
||||||
c. Logs results via evaluate_log() and wandb
|
|
||||||
|
|
||||||
Key features:
|
|
||||||
- CLI-only interface: agent calls yc-bench subcommands via terminal tool
|
|
||||||
- Deterministic: same seed + preset = same world (SHA256-based RNG)
|
|
||||||
- Multi-dimensional scoring: survival + normalised final funds
|
|
||||||
- Per-preset difficulty breakdown in results
|
|
||||||
- Isolated SQLite DB per run (no cross-run state leakage)
|
|
||||||
|
|
||||||
Requires: pip install hermes-agent[yc-bench]
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import datetime
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
from atroposlib.envs.base import EvalHandlingEnum
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
|
|
||||||
from environments.agent_loop import HermesAgentLoop
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# System prompt
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
YC_BENCH_SYSTEM_PROMPT = """\
|
|
||||||
You are the autonomous CEO of an early-stage AI startup in a deterministic
|
|
||||||
business simulation. You manage the company exclusively through the `yc-bench`
|
|
||||||
CLI tool. Your primary goal is to **survive** until the simulation horizon ends
|
|
||||||
without going bankrupt, while **maximising final funds**.
|
|
||||||
|
|
||||||
## Simulation Mechanics
|
|
||||||
|
|
||||||
- **Funds**: You start with $250,000 seed capital. Revenue comes from completing
|
|
||||||
tasks. Rewards scale with your prestige: `base × (1 + scale × (prestige − 1))`.
|
|
||||||
- **Domains**: There are 4 skill domains: **research**, **inference**,
|
|
||||||
**data_environment**, and **training**. Each has its own prestige level
|
|
||||||
(1.0-10.0). Higher prestige unlocks better-paying tasks.
|
|
||||||
- **Employees**: You have employees (Junior/Mid/Senior) with domain-specific
|
|
||||||
skill rates. **Throughput splits**: `effective_rate = base_rate / N` where N
|
|
||||||
is the number of active tasks assigned to that employee. Focus beats breadth.
|
|
||||||
- **Payroll**: Deducted automatically on the first business day of each month.
|
|
||||||
Running out of funds = bankruptcy = game over.
|
|
||||||
- **Time**: The simulation runs on business days (Mon-Fri), 09:00-18:00.
|
|
||||||
Time only advances when you call `yc-bench sim resume`.
|
|
||||||
|
|
||||||
## Task Lifecycle
|
|
||||||
|
|
||||||
1. Browse market tasks with `market browse`
|
|
||||||
2. Accept a task with `task accept` (this sets its deadline)
|
|
||||||
3. Assign employees with `task assign`
|
|
||||||
4. Dispatch with `task dispatch` to start work
|
|
||||||
5. Call `sim resume` to advance time and let employees make progress
|
|
||||||
6. Tasks complete when all domain requirements are fulfilled
|
|
||||||
|
|
||||||
**Penalties for failure vary by difficulty preset.** Completing a task on time
|
|
||||||
earns full reward + prestige gain. Missing a deadline or cancelling a task
|
|
||||||
incurs prestige penalties -- cancelling is always more costly than letting a
|
|
||||||
task fail, so cancel only as a last resort.
|
|
||||||
|
|
||||||
## CLI Commands
|
|
||||||
|
|
||||||
### Observe
|
|
||||||
- `yc-bench company status` -- funds, prestige, runway
|
|
||||||
- `yc-bench employee list` -- skills, salary, active tasks
|
|
||||||
- `yc-bench market browse [--domain D] [--required-prestige-lte N]` -- available tasks
|
|
||||||
- `yc-bench task list [--status active|planned]` -- your tasks
|
|
||||||
- `yc-bench task inspect --task-id UUID` -- progress, deadline, assignments
|
|
||||||
- `yc-bench finance ledger [--category monthly_payroll|task_reward]` -- transaction history
|
|
||||||
- `yc-bench report monthly` -- monthly P&L
|
|
||||||
|
|
||||||
### Act
|
|
||||||
- `yc-bench task accept --task-id UUID` -- accept from market
|
|
||||||
- `yc-bench task assign --task-id UUID --employee-id UUID` -- assign employee
|
|
||||||
- `yc-bench task dispatch --task-id UUID` -- start work (needs >=1 assignment)
|
|
||||||
- `yc-bench task cancel --task-id UUID --reason "text"` -- cancel (prestige penalty)
|
|
||||||
- `yc-bench sim resume` -- advance simulation clock
|
|
||||||
|
|
||||||
### Memory (persists across context truncation)
|
|
||||||
- `yc-bench scratchpad read` -- read your persistent notes
|
|
||||||
- `yc-bench scratchpad write --content "text"` -- overwrite notes
|
|
||||||
- `yc-bench scratchpad append --content "text"` -- append to notes
|
|
||||||
- `yc-bench scratchpad clear` -- clear notes
|
|
||||||
|
|
||||||
## Strategy Guidelines
|
|
||||||
|
|
||||||
1. **Specialise in 2-3 domains** to climb the prestige ladder faster and unlock
|
|
||||||
high-reward tasks. Don't spread thin across all 4 domains early on.
|
|
||||||
2. **Focus employees** -- assigning one employee to many tasks halves their
|
|
||||||
throughput per additional task. Keep assignments concentrated.
|
|
||||||
3. **Use the scratchpad** to track your strategy, upcoming deadlines, and
|
|
||||||
employee assignments. This persists even if conversation context is truncated.
|
|
||||||
4. **Monitor runway** -- always know how many months of payroll you can cover.
|
|
||||||
Accept high-reward tasks before payroll dates.
|
|
||||||
5. **Don't over-accept** -- taking too many tasks and missing deadlines cascades
|
|
||||||
into prestige loss, locking you out of profitable contracts.
|
|
||||||
6. Use `finance ledger` and `report monthly` to track revenue trends.
|
|
||||||
|
|
||||||
## Your Turn
|
|
||||||
|
|
||||||
Each turn:
|
|
||||||
1. Call `yc-bench company status` and `yc-bench task list` to orient yourself.
|
|
||||||
2. Check for completed tasks and pending deadlines.
|
|
||||||
3. Browse market for profitable tasks within your prestige level.
|
|
||||||
4. Accept, assign, and dispatch tasks strategically.
|
|
||||||
5. Call `yc-bench sim resume` to advance time.
|
|
||||||
6. Repeat until the simulation ends.
|
|
||||||
|
|
||||||
Think step by step before acting."""
|
|
||||||
|
|
||||||
# Starting funds in cents ($250,000)
|
|
||||||
INITIAL_FUNDS_CENTS = 25_000_000
|
|
||||||
|
|
||||||
# Default horizon per preset (years)
|
|
||||||
_PRESET_HORIZONS = {
|
|
||||||
"tutorial": 1,
|
|
||||||
"easy": 1,
|
|
||||||
"medium": 1,
|
|
||||||
"hard": 1,
|
|
||||||
"nightmare": 1,
|
|
||||||
"fast_test": 1,
|
|
||||||
"default": 3,
|
|
||||||
"high_reward": 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
class YCBenchEvalConfig(HermesAgentEnvConfig):
|
|
||||||
"""
|
|
||||||
Configuration for the YC-Bench evaluation environment.
|
|
||||||
|
|
||||||
Extends HermesAgentEnvConfig with YC-Bench-specific settings for
|
|
||||||
preset selection, seed control, scoring, and simulation parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
presets: List[str] = Field(
|
|
||||||
default=["fast_test", "medium", "hard"],
|
|
||||||
description="YC-Bench preset names to evaluate.",
|
|
||||||
)
|
|
||||||
seeds: List[int] = Field(
|
|
||||||
default=[1, 2, 3],
|
|
||||||
description="Random seeds -- each preset x seed = one run.",
|
|
||||||
)
|
|
||||||
run_timeout: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="Maximum wall-clock seconds per run. Default 60 minutes.",
|
|
||||||
)
|
|
||||||
survival_weight: float = Field(
|
|
||||||
default=0.5,
|
|
||||||
description="Weight of survival (0/1) in composite score.",
|
|
||||||
)
|
|
||||||
funds_weight: float = Field(
|
|
||||||
default=0.5,
|
|
||||||
description="Weight of normalised final funds in composite score.",
|
|
||||||
)
|
|
||||||
db_dir: str = Field(
|
|
||||||
default="/tmp/yc_bench_dbs",
|
|
||||||
description="Directory for per-run SQLite databases.",
|
|
||||||
)
|
|
||||||
horizon_years: Optional[int] = Field(
|
|
||||||
default=None,
|
|
||||||
description=(
|
|
||||||
"Simulation horizon in years. If None (default), inferred from "
|
|
||||||
"preset name (1 year for most, 3 for 'default')."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
company_name: str = Field(
|
|
||||||
default="BenchCo",
|
|
||||||
description="Name of the simulated company.",
|
|
||||||
)
|
|
||||||
start_date: str = Field(
|
|
||||||
default="01/01/2025",
|
|
||||||
description="Simulation start date in MM/DD/YYYY format (yc-bench convention).",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Scoring helpers
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
def _read_final_score(db_path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Read final game state from a YC-Bench SQLite database.
|
|
||||||
|
|
||||||
Returns dict with final_funds_cents (int), survived (bool),
|
|
||||||
terminal_reason (str).
|
|
||||||
|
|
||||||
Note: yc-bench table names are plural -- 'companies' not 'company',
|
|
||||||
'sim_events' not 'simulation_log'.
|
|
||||||
"""
|
|
||||||
if not os.path.exists(db_path):
|
|
||||||
logger.warning("DB not found at %s", db_path)
|
|
||||||
return {
|
|
||||||
"final_funds_cents": 0,
|
|
||||||
"survived": False,
|
|
||||||
"terminal_reason": "db_missing",
|
|
||||||
}
|
|
||||||
|
|
||||||
conn = None
|
|
||||||
try:
|
|
||||||
conn = sqlite3.connect(db_path)
|
|
||||||
cur = conn.cursor()
|
|
||||||
|
|
||||||
# Read final funds from the 'companies' table
|
|
||||||
cur.execute("SELECT funds_cents FROM companies LIMIT 1")
|
|
||||||
row = cur.fetchone()
|
|
||||||
funds = row[0] if row else 0
|
|
||||||
|
|
||||||
# Determine terminal reason from 'sim_events' table
|
|
||||||
terminal_reason = "unknown"
|
|
||||||
try:
|
|
||||||
cur.execute(
|
|
||||||
"SELECT event_type FROM sim_events "
|
|
||||||
"WHERE event_type IN ('bankruptcy', 'horizon_end') "
|
|
||||||
"ORDER BY scheduled_at DESC LIMIT 1"
|
|
||||||
)
|
|
||||||
event_row = cur.fetchone()
|
|
||||||
if event_row:
|
|
||||||
terminal_reason = event_row[0]
|
|
||||||
except sqlite3.OperationalError:
|
|
||||||
# Table may not exist if simulation didn't progress
|
|
||||||
pass
|
|
||||||
|
|
||||||
survived = funds >= 0 and terminal_reason != "bankruptcy"
|
|
||||||
return {
|
|
||||||
"final_funds_cents": funds,
|
|
||||||
"survived": survived,
|
|
||||||
"terminal_reason": terminal_reason,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Failed to read DB %s: %s", db_path, e)
|
|
||||||
return {
|
|
||||||
"final_funds_cents": 0,
|
|
||||||
"survived": False,
|
|
||||||
"terminal_reason": f"db_error: {e}",
|
|
||||||
}
|
|
||||||
finally:
|
|
||||||
if conn:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_composite_score(
|
|
||||||
final_funds_cents: int,
|
|
||||||
survived: bool,
|
|
||||||
survival_weight: float = 0.5,
|
|
||||||
funds_weight: float = 0.5,
|
|
||||||
initial_funds_cents: int = INITIAL_FUNDS_CENTS,
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Compute composite score from survival and final funds.
|
|
||||||
|
|
||||||
Score = survival_weight * survival_score
|
|
||||||
+ funds_weight * normalised_funds_score
|
|
||||||
|
|
||||||
Normalised funds uses log-scale relative to initial capital:
|
|
||||||
- funds <= 0: 0.0
|
|
||||||
- funds == initial: ~0.15
|
|
||||||
- funds == 10x: ~0.52
|
|
||||||
- funds == 100x: 1.0
|
|
||||||
"""
|
|
||||||
survival_score = 1.0 if survived else 0.0
|
|
||||||
|
|
||||||
if final_funds_cents <= 0:
|
|
||||||
funds_score = 0.0
|
|
||||||
else:
|
|
||||||
max_ratio = 100.0
|
|
||||||
ratio = final_funds_cents / max(initial_funds_cents, 1)
|
|
||||||
funds_score = min(math.log1p(ratio) / math.log1p(max_ratio), 1.0)
|
|
||||||
|
|
||||||
return survival_weight * survival_score + funds_weight * funds_score
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Main Environment
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
class YCBenchEvalEnv(HermesAgentBaseEnv):
|
|
||||||
"""
|
|
||||||
YC-Bench long-horizon agent benchmark environment (eval-only).
|
|
||||||
|
|
||||||
Each eval item is a (preset, seed) pair. The environment initialises the
|
|
||||||
simulation via ``yc-bench sim init`` (NOT ``yc-bench run`` which would start
|
|
||||||
a competing built-in agent loop). The HermesAgentLoop then drives the
|
|
||||||
interaction by calling individual yc-bench CLI commands via the terminal tool.
|
|
||||||
|
|
||||||
After the agent loop ends, the SQLite DB is read to extract the final score.
|
|
||||||
|
|
||||||
Scoring:
|
|
||||||
composite = 0.5 * survival + 0.5 * normalised_funds
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "yc-bench"
|
|
||||||
env_config_cls = YCBenchEvalConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls) -> Tuple[YCBenchEvalConfig, List[APIServerConfig]]:
|
|
||||||
env_config = YCBenchEvalConfig(
|
|
||||||
enabled_toolsets=["terminal"],
|
|
||||||
disabled_toolsets=None,
|
|
||||||
distribution=None,
|
|
||||||
max_agent_turns=200,
|
|
||||||
max_token_length=32000,
|
|
||||||
agent_temperature=0.0,
|
|
||||||
system_prompt=YC_BENCH_SYSTEM_PROMPT,
|
|
||||||
terminal_backend="local",
|
|
||||||
terminal_timeout=60,
|
|
||||||
presets=["fast_test", "medium", "hard"],
|
|
||||||
seeds=[1, 2, 3],
|
|
||||||
run_timeout=3600,
|
|
||||||
survival_weight=0.5,
|
|
||||||
funds_weight=0.5,
|
|
||||||
db_dir="/tmp/yc_bench_dbs",
|
|
||||||
eval_handling=EvalHandlingEnum.STOP_TRAIN,
|
|
||||||
group_size=1,
|
|
||||||
steps_per_eval=1,
|
|
||||||
total_steps=1,
|
|
||||||
tokenizer_name="NousResearch/Hermes-3-Llama-3.1-8B",
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="yc-bench",
|
|
||||||
ensure_scores_are_not_same=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
server_configs = [
|
|
||||||
APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name="anthropic/claude-sonnet-4.6",
|
|
||||||
server_type="openai",
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
|
||||||
health_check=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Setup
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
"""Verify yc-bench is installed and build the eval matrix."""
|
|
||||||
# Verify yc-bench CLI is available
|
|
||||||
try:
|
|
||||||
result = subprocess.run(
|
|
||||||
["yc-bench", "--help"], capture_output=True, text=True, timeout=10
|
|
||||||
)
|
|
||||||
if result.returncode != 0:
|
|
||||||
raise FileNotFoundError
|
|
||||||
except (FileNotFoundError, subprocess.TimeoutExpired):
|
|
||||||
raise RuntimeError(
|
|
||||||
"yc-bench CLI not found. Install with:\n"
|
|
||||||
' pip install "hermes-agent[yc-bench]"\n'
|
|
||||||
"Or: git clone https://github.com/collinear-ai/yc-bench "
|
|
||||||
"&& cd yc-bench && pip install -e ."
|
|
||||||
)
|
|
||||||
print("yc-bench CLI verified.")
|
|
||||||
|
|
||||||
# Build eval matrix: preset x seed
|
|
||||||
self.all_eval_items = [
|
|
||||||
{"preset": preset, "seed": seed}
|
|
||||||
for preset in self.config.presets
|
|
||||||
for seed in self.config.seeds
|
|
||||||
]
|
|
||||||
self.iter = 0
|
|
||||||
|
|
||||||
os.makedirs(self.config.db_dir, exist_ok=True)
|
|
||||||
self.eval_metrics: List[Tuple[str, float]] = []
|
|
||||||
|
|
||||||
# Streaming JSONL log for crash-safe result persistence
|
|
||||||
log_dir = os.path.join(os.path.dirname(__file__), "logs")
|
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
|
||||||
run_ts = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
self._streaming_path = os.path.join(log_dir, f"samples_{run_ts}.jsonl")
|
|
||||||
self._streaming_file = open(self._streaming_path, "w", encoding="utf-8")
|
|
||||||
self._streaming_lock = threading.Lock()
|
|
||||||
|
|
||||||
print(f"\nYC-Bench eval matrix: {len(self.all_eval_items)} runs")
|
|
||||||
for item in self.all_eval_items:
|
|
||||||
print(f" preset={item['preset']!r} seed={item['seed']}")
|
|
||||||
print(f"Streaming results to: {self._streaming_path}\n")
|
|
||||||
|
|
||||||
def _save_result(self, result: Dict[str, Any]):
|
|
||||||
"""Write a single run result to the streaming JSONL file immediately."""
|
|
||||||
if not hasattr(self, "_streaming_file") or self._streaming_file.closed:
|
|
||||||
return
|
|
||||||
with self._streaming_lock:
|
|
||||||
self._streaming_file.write(
|
|
||||||
json.dumps(result, ensure_ascii=False, default=str) + "\n"
|
|
||||||
)
|
|
||||||
self._streaming_file.flush()
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Training pipeline stubs (eval-only -- not used)
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def get_next_item(self):
|
|
||||||
item = self.all_eval_items[self.iter % len(self.all_eval_items)]
|
|
||||||
self.iter += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
|
||||||
preset = item["preset"]
|
|
||||||
seed = item["seed"]
|
|
||||||
return (
|
|
||||||
f"A new YC-Bench simulation has been initialized "
|
|
||||||
f"(preset='{preset}', seed={seed}).\n"
|
|
||||||
f"Your company '{self.config.company_name}' is ready.\n\n"
|
|
||||||
"Begin by calling:\n"
|
|
||||||
"1. `yc-bench company status` -- see your starting funds and prestige\n"
|
|
||||||
"2. `yc-bench employee list` -- see your team and their skills\n"
|
|
||||||
"3. `yc-bench market browse --required-prestige-lte 1` -- find tasks "
|
|
||||||
"you can take\n\n"
|
|
||||||
"Then accept 2-3 tasks, assign employees, dispatch them, and call "
|
|
||||||
"`yc-bench sim resume` to advance time. Repeat this loop until the "
|
|
||||||
"simulation ends (horizon reached or bankruptcy)."
|
|
||||||
)
|
|
||||||
|
|
||||||
async def compute_reward(self, item, result, ctx) -> float:
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
async def collect_trajectories(self, item):
|
|
||||||
return None, []
|
|
||||||
|
|
||||||
async def score(self, rollout_group_data):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Per-run evaluation
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def rollout_and_score_eval(self, eval_item: Dict[str, Any]) -> Dict:
|
|
||||||
"""
|
|
||||||
Evaluate a single (preset, seed) run.
|
|
||||||
|
|
||||||
1. Sets DATABASE_URL and YC_BENCH_EXPERIMENT env vars
|
|
||||||
2. Initialises the simulation via ``yc-bench sim init`` (NOT ``run``)
|
|
||||||
3. Runs HermesAgentLoop with terminal tool
|
|
||||||
4. Reads SQLite DB to compute final score
|
|
||||||
5. Returns result dict with survival, funds, and composite score
|
|
||||||
"""
|
|
||||||
preset = eval_item["preset"]
|
|
||||||
seed = eval_item["seed"]
|
|
||||||
run_id = str(uuid.uuid4())[:8]
|
|
||||||
run_key = f"{preset}_seed{seed}_{run_id}"
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
|
||||||
tqdm.write(f" [START] preset={preset!r} seed={seed} (run_id={run_id})")
|
|
||||||
run_start = time.time()
|
|
||||||
|
|
||||||
# Isolated DB per run -- prevents cross-run state leakage
|
|
||||||
db_path = os.path.join(self.config.db_dir, f"yc_bench_{run_key}.db")
|
|
||||||
os.environ["DATABASE_URL"] = f"sqlite:///{db_path}"
|
|
||||||
os.environ["YC_BENCH_EXPERIMENT"] = preset
|
|
||||||
|
|
||||||
# Determine horizon: explicit config override > preset lookup > default 1
|
|
||||||
horizon = self.config.horizon_years or _PRESET_HORIZONS.get(preset, 1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
# Step 1: Initialise the simulation via CLI
|
|
||||||
# IMPORTANT: We use `sim init`, NOT `yc-bench run`.
|
|
||||||
# `yc-bench run` starts yc-bench's own LLM agent loop (via
|
|
||||||
# LiteLLM), which would compete with our HermesAgentLoop.
|
|
||||||
# `sim init` just sets up the world and returns.
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
init_cmd = [
|
|
||||||
"yc-bench", "sim", "init",
|
|
||||||
"--seed", str(seed),
|
|
||||||
"--start-date", self.config.start_date,
|
|
||||||
"--company-name", self.config.company_name,
|
|
||||||
"--horizon-years", str(horizon),
|
|
||||||
]
|
|
||||||
init_result = subprocess.run(
|
|
||||||
init_cmd, capture_output=True, text=True, timeout=30,
|
|
||||||
)
|
|
||||||
if init_result.returncode != 0:
|
|
||||||
error_msg = (init_result.stderr or init_result.stdout).strip()
|
|
||||||
raise RuntimeError(f"yc-bench sim init failed: {error_msg}")
|
|
||||||
|
|
||||||
tqdm.write(f" Simulation initialized (horizon={horizon}yr)")
|
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
# Step 2: Run the HermesAgentLoop
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
tools, valid_names = self._resolve_tools_for_group()
|
|
||||||
|
|
||||||
messages: List[Dict[str, Any]] = [
|
|
||||||
{"role": "system", "content": YC_BENCH_SYSTEM_PROMPT},
|
|
||||||
{"role": "user", "content": self.format_prompt(eval_item)},
|
|
||||||
]
|
|
||||||
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=run_id,
|
|
||||||
temperature=self.config.agent_temperature,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
budget_config=self.config.build_budget_config(),
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
# Step 3: Read final score from the simulation DB
|
|
||||||
# ----------------------------------------------------------
|
|
||||||
score_data = _read_final_score(db_path)
|
|
||||||
final_funds = score_data["final_funds_cents"]
|
|
||||||
survived = score_data["survived"]
|
|
||||||
terminal_reason = score_data["terminal_reason"]
|
|
||||||
|
|
||||||
composite = _compute_composite_score(
|
|
||||||
final_funds_cents=final_funds,
|
|
||||||
survived=survived,
|
|
||||||
survival_weight=self.config.survival_weight,
|
|
||||||
funds_weight=self.config.funds_weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
elapsed = time.time() - run_start
|
|
||||||
status = "SURVIVED" if survived else "BANKRUPT"
|
|
||||||
if final_funds >= 0:
|
|
||||||
funds_str = f"${final_funds / 100:,.0f}"
|
|
||||||
else:
|
|
||||||
funds_str = f"-${abs(final_funds) / 100:,.0f}"
|
|
||||||
|
|
||||||
tqdm.write(
|
|
||||||
f" [{status}] preset={preset!r} seed={seed} "
|
|
||||||
f"funds={funds_str} score={composite:.3f} "
|
|
||||||
f"turns={result.turns_used} ({elapsed:.0f}s)"
|
|
||||||
)
|
|
||||||
|
|
||||||
out = {
|
|
||||||
"preset": preset,
|
|
||||||
"seed": seed,
|
|
||||||
"survived": survived,
|
|
||||||
"final_funds_cents": final_funds,
|
|
||||||
"final_funds_usd": final_funds / 100,
|
|
||||||
"terminal_reason": terminal_reason,
|
|
||||||
"composite_score": composite,
|
|
||||||
"turns_used": result.turns_used,
|
|
||||||
"finished_naturally": result.finished_naturally,
|
|
||||||
"elapsed_seconds": elapsed,
|
|
||||||
"db_path": db_path,
|
|
||||||
"messages": result.messages,
|
|
||||||
}
|
|
||||||
self._save_result(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
elapsed = time.time() - run_start
|
|
||||||
logger.error("Run %s failed: %s", run_key, e, exc_info=True)
|
|
||||||
tqdm.write(
|
|
||||||
f" [ERROR] preset={preset!r} seed={seed}: {e} ({elapsed:.0f}s)"
|
|
||||||
)
|
|
||||||
out = {
|
|
||||||
"preset": preset,
|
|
||||||
"seed": seed,
|
|
||||||
"survived": False,
|
|
||||||
"final_funds_cents": 0,
|
|
||||||
"final_funds_usd": 0.0,
|
|
||||||
"terminal_reason": f"error: {e}",
|
|
||||||
"composite_score": 0.0,
|
|
||||||
"turns_used": 0,
|
|
||||||
"error": str(e),
|
|
||||||
"elapsed_seconds": elapsed,
|
|
||||||
}
|
|
||||||
self._save_result(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Evaluate
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def _run_with_timeout(self, item: Dict[str, Any]) -> Dict:
|
|
||||||
"""Wrap a single rollout with a wall-clock timeout."""
|
|
||||||
preset = item["preset"]
|
|
||||||
seed = item["seed"]
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(
|
|
||||||
self.rollout_and_score_eval(item),
|
|
||||||
timeout=self.config.run_timeout,
|
|
||||||
)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
from tqdm import tqdm
|
|
||||||
tqdm.write(
|
|
||||||
f" [TIMEOUT] preset={preset!r} seed={seed} "
|
|
||||||
f"(exceeded {self.config.run_timeout}s)"
|
|
||||||
)
|
|
||||||
out = {
|
|
||||||
"preset": preset,
|
|
||||||
"seed": seed,
|
|
||||||
"survived": False,
|
|
||||||
"final_funds_cents": 0,
|
|
||||||
"final_funds_usd": 0.0,
|
|
||||||
"terminal_reason": f"timeout ({self.config.run_timeout}s)",
|
|
||||||
"composite_score": 0.0,
|
|
||||||
"turns_used": 0,
|
|
||||||
"error": "timeout",
|
|
||||||
}
|
|
||||||
self._save_result(out)
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs) -> None:
|
|
||||||
"""
|
|
||||||
Run YC-Bench evaluation over all (preset, seed) combinations.
|
|
||||||
|
|
||||||
Runs sequentially -- each run is 100-500 turns, parallelising would
|
|
||||||
be prohibitively expensive and cause env var conflicts.
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# --- tqdm-compatible logging handler (TB2 pattern) ---
|
|
||||||
class _TqdmHandler(logging.Handler):
|
|
||||||
def emit(self, record):
|
|
||||||
try:
|
|
||||||
tqdm.write(self.format(record))
|
|
||||||
except Exception:
|
|
||||||
self.handleError(record)
|
|
||||||
|
|
||||||
root = logging.getLogger()
|
|
||||||
handler = _TqdmHandler()
|
|
||||||
handler.setFormatter(
|
|
||||||
logging.Formatter("%(levelname)s %(name)s: %(message)s")
|
|
||||||
)
|
|
||||||
root.handlers = [handler]
|
|
||||||
for noisy in ("httpx", "openai"):
|
|
||||||
logging.getLogger(noisy).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# --- Print config summary ---
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("Starting YC-Bench Evaluation")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f" Presets: {self.config.presets}")
|
|
||||||
print(f" Seeds: {self.config.seeds}")
|
|
||||||
print(f" Total runs: {len(self.all_eval_items)}")
|
|
||||||
print(f" Max turns/run: {self.config.max_agent_turns}")
|
|
||||||
print(f" Run timeout: {self.config.run_timeout}s")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
results = []
|
|
||||||
pbar = tqdm(
|
|
||||||
total=len(self.all_eval_items), desc="YC-Bench", dynamic_ncols=True
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for item in self.all_eval_items:
|
|
||||||
result = await self._run_with_timeout(item)
|
|
||||||
results.append(result)
|
|
||||||
survived_count = sum(1 for r in results if r.get("survived"))
|
|
||||||
pbar.set_postfix_str(
|
|
||||||
f"survived={survived_count}/{len(results)}"
|
|
||||||
)
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
except (KeyboardInterrupt, asyncio.CancelledError):
|
|
||||||
tqdm.write("\n[INTERRUPTED] Stopping evaluation...")
|
|
||||||
pbar.close()
|
|
||||||
try:
|
|
||||||
from tools.terminal_tool import cleanup_all_environments
|
|
||||||
cleanup_all_environments()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
|
||||||
self._streaming_file.close()
|
|
||||||
return
|
|
||||||
|
|
||||||
pbar.close()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
# --- Compute metrics ---
|
|
||||||
valid = [r for r in results if r is not None]
|
|
||||||
if not valid:
|
|
||||||
print("Warning: No valid results.")
|
|
||||||
return
|
|
||||||
|
|
||||||
total = len(valid)
|
|
||||||
survived_total = sum(1 for r in valid if r.get("survived"))
|
|
||||||
survival_rate = survived_total / total if total else 0.0
|
|
||||||
avg_score = (
|
|
||||||
sum(r.get("composite_score", 0) for r in valid) / total
|
|
||||||
if total
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
preset_results: Dict[str, List[Dict]] = defaultdict(list)
|
|
||||||
for r in valid:
|
|
||||||
preset_results[r["preset"]].append(r)
|
|
||||||
|
|
||||||
eval_metrics = {
|
|
||||||
"eval/survival_rate": survival_rate,
|
|
||||||
"eval/avg_composite_score": avg_score,
|
|
||||||
"eval/total_runs": total,
|
|
||||||
"eval/survived_runs": survived_total,
|
|
||||||
"eval/evaluation_time_seconds": end_time - start_time,
|
|
||||||
}
|
|
||||||
|
|
||||||
for preset, items in sorted(preset_results.items()):
|
|
||||||
ps = sum(1 for r in items if r.get("survived"))
|
|
||||||
pt = len(items)
|
|
||||||
pa = (
|
|
||||||
sum(r.get("composite_score", 0) for r in items) / pt
|
|
||||||
if pt
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
key = preset.replace("-", "_")
|
|
||||||
eval_metrics[f"eval/survival_rate_{key}"] = ps / pt if pt else 0
|
|
||||||
eval_metrics[f"eval/avg_score_{key}"] = pa
|
|
||||||
|
|
||||||
self.eval_metrics = list(eval_metrics.items())
|
|
||||||
|
|
||||||
# --- Print summary ---
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print("YC-Bench Evaluation Results")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(
|
|
||||||
f"Overall survival rate: {survival_rate:.1%} "
|
|
||||||
f"({survived_total}/{total})"
|
|
||||||
)
|
|
||||||
print(f"Average composite score: {avg_score:.4f}")
|
|
||||||
print(f"Evaluation time: {end_time - start_time:.1f}s")
|
|
||||||
|
|
||||||
print("\nPer-preset breakdown:")
|
|
||||||
for preset, items in sorted(preset_results.items()):
|
|
||||||
ps = sum(1 for r in items if r.get("survived"))
|
|
||||||
pt = len(items)
|
|
||||||
pa = (
|
|
||||||
sum(r.get("composite_score", 0) for r in items) / pt
|
|
||||||
if pt
|
|
||||||
else 0
|
|
||||||
)
|
|
||||||
print(f" {preset}: {ps}/{pt} survived avg_score={pa:.4f}")
|
|
||||||
for r in items:
|
|
||||||
status = "SURVIVED" if r.get("survived") else "BANKRUPT"
|
|
||||||
funds = r.get("final_funds_usd", 0)
|
|
||||||
print(
|
|
||||||
f" seed={r['seed']} [{status}] "
|
|
||||||
f"${funds:,.0f} "
|
|
||||||
f"score={r.get('composite_score', 0):.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# --- Log results ---
|
|
||||||
samples = [
|
|
||||||
{k: v for k, v in r.items() if k != "messages"} for r in valid
|
|
||||||
]
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self.evaluate_log(
|
|
||||||
metrics=eval_metrics,
|
|
||||||
samples=samples,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
generation_parameters={
|
|
||||||
"temperature": self.config.agent_temperature,
|
|
||||||
"max_tokens": self.config.max_token_length,
|
|
||||||
"max_agent_turns": self.config.max_agent_turns,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error logging results: {e}")
|
|
||||||
|
|
||||||
# --- Cleanup (TB2 pattern) ---
|
|
||||||
if hasattr(self, "_streaming_file") and not self._streaming_file.closed:
|
|
||||||
self._streaming_file.close()
|
|
||||||
print(f"Results saved to: {self._streaming_path}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from tools.terminal_tool import cleanup_all_environments
|
|
||||||
cleanup_all_environments()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.agent_loop import _tool_executor
|
|
||||||
_tool_executor.shutdown(wait=False, cancel_futures=True)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Wandb logging
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
||||||
"""Log YC-Bench-specific metrics to wandb."""
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
for k, v in self.eval_metrics:
|
|
||||||
wandb_metrics[k] = v
|
|
||||||
self.eval_metrics = []
|
|
||||||
await super().wandb_log(wandb_metrics)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
YCBenchEvalEnv.cli()
|
|
||||||
|
|
@ -1,714 +0,0 @@
|
||||||
"""
|
|
||||||
HermesAgentBaseEnv -- Abstract Base Environment for Hermes-Agent + Atropos
|
|
||||||
|
|
||||||
Provides the Atropos integration plumbing that all hermes-agent environments share:
|
|
||||||
- Two-mode operation (OpenAI server for Phase 1, VLLM ManagedServer for Phase 2)
|
|
||||||
- Per-group toolset/distribution resolution
|
|
||||||
- Agent loop orchestration via HermesAgentLoop
|
|
||||||
- ToolContext creation for reward functions
|
|
||||||
- ScoredDataGroup construction from ManagedServer state
|
|
||||||
|
|
||||||
Subclasses only need to implement:
|
|
||||||
setup() -- Load dataset, initialize state
|
|
||||||
get_next_item() -- Return the next item from the dataset
|
|
||||||
format_prompt() -- Convert a dataset item into the user message
|
|
||||||
compute_reward() -- Score the rollout (has full ToolContext access)
|
|
||||||
evaluate() -- Periodic evaluation
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
from abc import abstractmethod
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
||||||
|
|
||||||
# Ensure the hermes-agent repo root is on sys.path so that imports like
|
|
||||||
# `from model_tools import ...` and `from environments.X import ...` work
|
|
||||||
# regardless of where the script is invoked from.
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
# Load API keys from hermes-agent/.env so all environments can access them
|
|
||||||
_env_path = _repo_root / ".env"
|
|
||||||
if _env_path.exists():
|
|
||||||
load_dotenv(dotenv_path=_env_path)
|
|
||||||
|
|
||||||
# Apply monkey patches for async-safe tool operation inside Atropos's event loop.
|
|
||||||
# This patches SwerexModalEnvironment to use a background thread instead of
|
|
||||||
# asyncio.run(), which would deadlock inside Atropos. Safe for normal CLI too.
|
|
||||||
from environments.patches import apply_patches
|
|
||||||
apply_patches()
|
|
||||||
|
|
||||||
from atroposlib.envs.base import (
|
|
||||||
BaseEnv,
|
|
||||||
BaseEnvConfig,
|
|
||||||
ScoredDataGroup,
|
|
||||||
ScoredDataItem,
|
|
||||||
)
|
|
||||||
from atroposlib.envs.server_handling.server_manager import (
|
|
||||||
APIServerConfig,
|
|
||||||
ServerBaseline,
|
|
||||||
ServerManager,
|
|
||||||
)
|
|
||||||
from atroposlib.type_definitions import Item
|
|
||||||
|
|
||||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
from tools.budget_config import (
|
|
||||||
DEFAULT_RESULT_SIZE_CHARS,
|
|
||||||
DEFAULT_TURN_BUDGET_CHARS,
|
|
||||||
DEFAULT_PREVIEW_SIZE_CHARS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Import hermes-agent toolset infrastructure
|
|
||||||
from model_tools import get_tool_definitions
|
|
||||||
from toolset_distributions import sample_toolsets_from_distribution
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HermesAgentEnvConfig(BaseEnvConfig):
|
|
||||||
"""
|
|
||||||
Configuration for hermes-agent Atropos environments.
|
|
||||||
|
|
||||||
Extends BaseEnvConfig with agent-specific settings for toolsets,
|
|
||||||
terminal backend, dataset loading, and tool call parsing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# --- Toolset configuration ---
|
|
||||||
# Mutually exclusive: use either enabled_toolsets OR distribution
|
|
||||||
enabled_toolsets: Optional[List[str]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Explicit list of hermes toolsets to enable (e.g., ['terminal', 'file', 'web']). "
|
|
||||||
"If None and distribution is also None, all available toolsets are enabled.",
|
|
||||||
)
|
|
||||||
disabled_toolsets: Optional[List[str]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Toolsets to disable. Applied as a filter on top of enabled_toolsets or distribution.",
|
|
||||||
)
|
|
||||||
distribution: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Name of a toolset distribution from toolset_distributions.py "
|
|
||||||
"(e.g., 'development', 'terminal_tasks'). Sampled once per group. "
|
|
||||||
"Mutually exclusive with enabled_toolsets.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Agent loop configuration ---
|
|
||||||
max_agent_turns: int = Field(
|
|
||||||
default=30,
|
|
||||||
description="Maximum number of LLM calls (tool-calling iterations) per rollout.",
|
|
||||||
)
|
|
||||||
system_prompt: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="System prompt for the agent. Tools are handled via the tools= parameter, "
|
|
||||||
"not embedded in the prompt text.",
|
|
||||||
)
|
|
||||||
agent_temperature: float = Field(
|
|
||||||
default=1.0,
|
|
||||||
description="Sampling temperature for agent generation during rollouts.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Terminal backend ---
|
|
||||||
terminal_backend: str = Field(
|
|
||||||
default="local",
|
|
||||||
description="Terminal backend: 'local', 'docker', 'modal', 'daytona', 'ssh', 'singularity'. "
|
|
||||||
"Modal or Daytona recommended for production RL (cloud isolation per rollout).",
|
|
||||||
)
|
|
||||||
terminal_timeout: int = Field(
|
|
||||||
default=120,
|
|
||||||
description="Per-command timeout in seconds for terminal tool calls. "
|
|
||||||
"Commands exceeding this are killed. Increase for tasks with long-running "
|
|
||||||
"commands (compilation, pip install, etc.).",
|
|
||||||
)
|
|
||||||
terminal_lifetime: int = Field(
|
|
||||||
default=3600,
|
|
||||||
description="Sandbox inactivity lifetime in seconds. The cleanup thread kills "
|
|
||||||
"sandboxes that have been idle longer than this. Must be longer than "
|
|
||||||
"the longest gap between tool calls (e.g., waiting for LLM response).",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Dataset ---
|
|
||||||
dataset_name: Optional[str] = Field(
|
|
||||||
default=None,
|
|
||||||
description="HuggingFace dataset name. Optional if tasks are defined inline.",
|
|
||||||
)
|
|
||||||
dataset_split: str = Field(
|
|
||||||
default="train",
|
|
||||||
description="Dataset split to use.",
|
|
||||||
)
|
|
||||||
prompt_field: str = Field(
|
|
||||||
default="prompt",
|
|
||||||
description="Which field in the dataset contains the prompt.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Thread pool ---
|
|
||||||
tool_pool_size: int = Field(
|
|
||||||
default=128,
|
|
||||||
description="Thread pool size for tool execution. Each concurrent task needs a "
|
|
||||||
"thread for tool calls. Must be large enough for parallel evaluation. "
|
|
||||||
"Too small = thread pool starvation.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Phase 2: Tool call parsing ---
|
|
||||||
tool_call_parser: str = Field(
|
|
||||||
default="hermes",
|
|
||||||
description="Tool call parser name for Phase 2 (VLLM server type). "
|
|
||||||
"Ignored in Phase 1 (OpenAI server type where VLLM parses natively). "
|
|
||||||
"Options: hermes, mistral, llama3_json, qwen, deepseek_v3, etc.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Tool result budget ---
|
|
||||||
# Defaults imported from tools.budget_config (single source of truth).
|
|
||||||
default_result_size_chars: int = Field(
|
|
||||||
default=DEFAULT_RESULT_SIZE_CHARS,
|
|
||||||
description="Default per-tool threshold (chars) for persisting large results "
|
|
||||||
"to sandbox. Results exceeding this are written to /tmp/hermes-results/ "
|
|
||||||
"and replaced with a preview. Per-tool registry values take precedence "
|
|
||||||
"unless overridden via tool_result_overrides.",
|
|
||||||
)
|
|
||||||
turn_budget_chars: int = Field(
|
|
||||||
default=DEFAULT_TURN_BUDGET_CHARS,
|
|
||||||
description="Aggregate char budget per assistant turn. If all tool results "
|
|
||||||
"in a single turn exceed this, the largest are persisted to disk first.",
|
|
||||||
)
|
|
||||||
preview_size_chars: int = Field(
|
|
||||||
default=DEFAULT_PREVIEW_SIZE_CHARS,
|
|
||||||
description="Size of the inline preview shown after a tool result is persisted.",
|
|
||||||
)
|
|
||||||
tool_result_overrides: Optional[Dict[str, int]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Per-tool threshold overrides (chars). Keys are tool names, "
|
|
||||||
"values are char thresholds. Overrides both the default and registry "
|
|
||||||
"per-tool values. Example: {'terminal': 10000, 'search_files': 5000}. "
|
|
||||||
"Note: read_file is pinned to infinity and cannot be overridden.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Provider-specific parameters ---
|
|
||||||
# Passed as extra_body to the OpenAI client's chat.completions.create() call.
|
|
||||||
# Useful for OpenRouter provider preferences, transforms, route settings, etc.
|
|
||||||
# Example YAML:
|
|
||||||
# extra_body:
|
|
||||||
# provider:
|
|
||||||
# ignore: ["DeepInfra", "Fireworks"]
|
|
||||||
# order: ["Together"]
|
|
||||||
# transforms: ["middle-out"]
|
|
||||||
extra_body: Optional[Dict[str, Any]] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Extra body parameters passed to the OpenAI client's "
|
|
||||||
"chat.completions.create(). Used for OpenRouter provider preferences, "
|
|
||||||
"transforms, and other provider-specific settings.",
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_budget_config(self):
|
|
||||||
"""Build a BudgetConfig from env config fields."""
|
|
||||||
from tools.budget_config import BudgetConfig
|
|
||||||
return BudgetConfig(
|
|
||||||
default_result_size=self.default_result_size_chars,
|
|
||||||
turn_budget=self.turn_budget_chars,
|
|
||||||
preview_size=self.preview_size_chars,
|
|
||||||
tool_overrides=dict(self.tool_result_overrides) if self.tool_result_overrides else {},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HermesAgentBaseEnv(BaseEnv):
|
|
||||||
"""
|
|
||||||
Abstract base environment for hermes-agent Atropos integration.
|
|
||||||
|
|
||||||
Handles two modes of operation:
|
|
||||||
- Phase 1 (OpenAI server type): Uses server.chat_completion() directly.
|
|
||||||
The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing
|
|
||||||
and reasoning extraction natively. DummyManagedServer provides placeholder
|
|
||||||
tokens. Good for SFT data gen, verifier testing, evaluation.
|
|
||||||
|
|
||||||
- Phase 2 (VLLM server type): Uses ManagedServer for exact token IDs + logprobs
|
|
||||||
via /generate. Client-side tool call parser reconstructs structured tool_calls
|
|
||||||
from raw output. Full RL training capability.
|
|
||||||
|
|
||||||
Subclasses must implement:
|
|
||||||
setup() -- Load dataset, initialize state
|
|
||||||
get_next_item() -- Return the next item to roll out
|
|
||||||
format_prompt() -- Convert a dataset item into the user message string
|
|
||||||
compute_reward() -- Score the rollout using ToolContext
|
|
||||||
evaluate() -- Periodic evaluation
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Optional[str] = "hermes-agent"
|
|
||||||
env_config_cls = HermesAgentEnvConfig
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: HermesAgentEnvConfig,
|
|
||||||
server_configs: Union[ServerBaseline, List[APIServerConfig]],
|
|
||||||
slurm=False,
|
|
||||||
testing=False,
|
|
||||||
):
|
|
||||||
super().__init__(config, server_configs, slurm, testing)
|
|
||||||
|
|
||||||
# Set terminal environment variables so hermes tools pick them up.
|
|
||||||
# These can all be overridden per-environment via config fields instead
|
|
||||||
# of requiring users to set shell env vars.
|
|
||||||
if config.terminal_backend:
|
|
||||||
os.environ["TERMINAL_ENV"] = config.terminal_backend
|
|
||||||
os.environ["TERMINAL_TIMEOUT"] = str(config.terminal_timeout)
|
|
||||||
os.environ["TERMINAL_LIFETIME_SECONDS"] = str(config.terminal_lifetime)
|
|
||||||
print(
|
|
||||||
f"🖥️ Terminal: backend={config.terminal_backend}, "
|
|
||||||
f"timeout={config.terminal_timeout}s, lifetime={config.terminal_lifetime}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resize the agent loop's thread pool for tool execution.
|
|
||||||
# This must be large enough for the number of concurrent tasks
|
|
||||||
# (e.g., 89 parallel TB2 eval tasks each need a thread for tool calls).
|
|
||||||
from environments.agent_loop import resize_tool_pool
|
|
||||||
resize_tool_pool(config.tool_pool_size)
|
|
||||||
|
|
||||||
# Set tool_parser on the ServerManager so ManagedServer uses it
|
|
||||||
# for bidirectional tool call translation (raw text ↔ OpenAI tool_calls).
|
|
||||||
if hasattr(self.server, 'tool_parser'):
|
|
||||||
self.server.tool_parser = config.tool_call_parser
|
|
||||||
print(f"🔧 Tool parser: {config.tool_call_parser}")
|
|
||||||
|
|
||||||
# Current group's resolved tools (set in collect_trajectories)
|
|
||||||
self._current_group_tools: Optional[Tuple[List[Dict], Set[str]]] = None
|
|
||||||
|
|
||||||
# Tool error tracking for wandb logging
|
|
||||||
self._tool_error_buffer: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Toolset resolution (per-group)
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def _resolve_tools_for_group(self) -> Tuple[List[Dict[str, Any]], Set[str]]:
|
|
||||||
"""
|
|
||||||
Resolve toolsets for a group. Called once in collect_trajectories(),
|
|
||||||
then shared by all collect_trajectory() calls in the group.
|
|
||||||
|
|
||||||
If distribution is set, samples probabilistically.
|
|
||||||
If enabled_toolsets is set, uses that explicit list.
|
|
||||||
disabled_toolsets is applied as a filter on top.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(tool_schemas, valid_tool_names) tuple
|
|
||||||
"""
|
|
||||||
config = self.config
|
|
||||||
|
|
||||||
if config.distribution:
|
|
||||||
group_toolsets = sample_toolsets_from_distribution(config.distribution)
|
|
||||||
logger.info("Sampled toolsets from '%s': %s", config.distribution, group_toolsets)
|
|
||||||
else:
|
|
||||||
group_toolsets = config.enabled_toolsets # None means "all available"
|
|
||||||
if group_toolsets is None:
|
|
||||||
logger.warning(
|
|
||||||
"enabled_toolsets is None -- loading ALL tools including messaging. "
|
|
||||||
"Set explicit enabled_toolsets for RL training."
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = get_tool_definitions(
|
|
||||||
enabled_toolsets=group_toolsets,
|
|
||||||
disabled_toolsets=config.disabled_toolsets,
|
|
||||||
quiet_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_names = {t["function"]["name"] for t in tools} if tools else set()
|
|
||||||
logger.info("Resolved %d tools for group: %s", len(valid_names), sorted(valid_names))
|
|
||||||
return tools, valid_names
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Server mode detection
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def _use_managed_server(self) -> bool:
|
|
||||||
"""
|
|
||||||
Determine if we should use ManagedServer (Phase 2) or direct server (Phase 1).
|
|
||||||
|
|
||||||
Phase 2 (ManagedServer) is used when the server type is 'vllm' or 'sglang',
|
|
||||||
which go through the /generate endpoint for exact token tracking.
|
|
||||||
|
|
||||||
Phase 1 (direct server) is used for 'openai' server type, which uses
|
|
||||||
/v1/chat/completions with native tool call parsing.
|
|
||||||
"""
|
|
||||||
if not self.server.servers:
|
|
||||||
return False
|
|
||||||
|
|
||||||
server = self.server.servers[0]
|
|
||||||
# If the server is an OpenAI server (not VLLM/SGLang), use direct mode
|
|
||||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|
||||||
return not isinstance(server, OpenAIServer)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Core Atropos integration
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
async def collect_trajectories(
|
|
||||||
self, item: Item
|
|
||||||
) -> Tuple[
|
|
||||||
Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]],
|
|
||||||
List[Item],
|
|
||||||
]:
|
|
||||||
"""
|
|
||||||
Override collect_trajectories to resolve toolsets once per group,
|
|
||||||
then delegate to the standard group-level collection.
|
|
||||||
|
|
||||||
The default BaseEnv.collect_trajectories() calls collect_trajectory()
|
|
||||||
group_size times in parallel. We resolve tools once here and store
|
|
||||||
them for all those calls to use.
|
|
||||||
"""
|
|
||||||
# Resolve toolsets for this group (shared by all rollouts in the group)
|
|
||||||
self._current_group_tools = self._resolve_tools_for_group()
|
|
||||||
|
|
||||||
# Delegate to the default implementation which calls collect_trajectory()
|
|
||||||
# group_size times via asyncio.gather
|
|
||||||
return await super().collect_trajectories(item)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Wandb rollout display -- format trajectories nicely
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _format_trajectory_for_display(messages: List[Dict[str, Any]]) -> str:
|
|
||||||
"""
|
|
||||||
Format a conversation's messages into a readable trajectory string
|
|
||||||
for wandb rollout tables. Shows tool calls, tool results, and reasoning
|
|
||||||
in a structured way instead of raw token decoding.
|
|
||||||
"""
|
|
||||||
parts = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "unknown")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
|
|
||||||
if role == "system":
|
|
||||||
parts.append(f"[SYSTEM]\n{content}")
|
|
||||||
|
|
||||||
elif role == "user":
|
|
||||||
parts.append(f"[USER]\n{content}")
|
|
||||||
|
|
||||||
elif role == "assistant":
|
|
||||||
# Show reasoning if present
|
|
||||||
reasoning = msg.get("reasoning_content", "")
|
|
||||||
if reasoning:
|
|
||||||
# Truncate long reasoning for display
|
|
||||||
if len(reasoning) > 300:
|
|
||||||
reasoning = reasoning[:300] + "..."
|
|
||||||
parts.append(f"[ASSISTANT thinking]\n{reasoning}")
|
|
||||||
|
|
||||||
# Show content
|
|
||||||
if content:
|
|
||||||
parts.append(f"[ASSISTANT]\n{content}")
|
|
||||||
|
|
||||||
# Show tool calls
|
|
||||||
tool_calls = msg.get("tool_calls", [])
|
|
||||||
for tc in tool_calls:
|
|
||||||
func = tc.get("function", {})
|
|
||||||
name = func.get("name", "?")
|
|
||||||
args = func.get("arguments", "{}")
|
|
||||||
# Truncate long arguments for display
|
|
||||||
if len(args) > 200:
|
|
||||||
args = args[:200] + "..."
|
|
||||||
parts.append(f"[TOOL CALL] {name}({args})")
|
|
||||||
|
|
||||||
elif role == "tool":
|
|
||||||
tool_id = msg.get("tool_call_id", "")
|
|
||||||
result = content
|
|
||||||
# Truncate long tool results for display
|
|
||||||
if len(result) > 500:
|
|
||||||
result = result[:500] + "..."
|
|
||||||
parts.append(f"[TOOL RESULT] {result}")
|
|
||||||
|
|
||||||
return "\n\n".join(parts)
|
|
||||||
|
|
||||||
async def add_rollouts_for_wandb(
|
|
||||||
self,
|
|
||||||
scored_data,
|
|
||||||
item=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Override to show formatted trajectories with tool calls visible,
|
|
||||||
instead of raw token decoding which loses all structure.
|
|
||||||
"""
|
|
||||||
num_keep = self.config.num_rollouts_per_group_for_logging
|
|
||||||
if num_keep == -1:
|
|
||||||
num_keep = self.config.group_size
|
|
||||||
|
|
||||||
group = []
|
|
||||||
for i in range(min(num_keep, len(scored_data.get("scores", [])))):
|
|
||||||
score = scored_data["scores"][i]
|
|
||||||
|
|
||||||
# Use messages if available for rich display
|
|
||||||
messages = None
|
|
||||||
if scored_data.get("messages") and i < len(scored_data["messages"]):
|
|
||||||
messages = scored_data["messages"][i]
|
|
||||||
|
|
||||||
if messages:
|
|
||||||
text = self._format_trajectory_for_display(messages)
|
|
||||||
elif scored_data.get("tokens") and i < len(scored_data["tokens"]):
|
|
||||||
text = self.tokenizer.decode(scored_data["tokens"][i])
|
|
||||||
else:
|
|
||||||
text = "(no data)"
|
|
||||||
|
|
||||||
group.append((text, score))
|
|
||||||
|
|
||||||
self.rollouts_for_wandb.append(group)
|
|
||||||
if len(self.rollouts_for_wandb) > self.config.num_rollouts_to_keep:
|
|
||||||
self.rollouts_for_wandb.pop(0)
|
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
||||||
"""Log base metrics including tool errors to wandb."""
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
|
|
||||||
# Log tool error stats
|
|
||||||
if self._tool_error_buffer:
|
|
||||||
wandb_metrics["train/tool_errors_count"] = len(self._tool_error_buffer)
|
|
||||||
|
|
||||||
# Log error details as a summary string (tables can crash wandb on tmp cleanup)
|
|
||||||
error_summaries = []
|
|
||||||
for err in self._tool_error_buffer:
|
|
||||||
error_summaries.append(
|
|
||||||
f"[turn {err['turn']}] {err['tool']}({err['args'][:80]}) -> {err['error'][:150]}"
|
|
||||||
)
|
|
||||||
wandb_metrics["train/tool_error_details"] = "\n".join(error_summaries)
|
|
||||||
|
|
||||||
# Also print to stdout for immediate visibility
|
|
||||||
for summary in error_summaries:
|
|
||||||
print(f" Tool Error: {summary}")
|
|
||||||
|
|
||||||
self._tool_error_buffer = []
|
|
||||||
else:
|
|
||||||
wandb_metrics["train/tool_errors_count"] = 0
|
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
|
||||||
|
|
||||||
async def collect_trajectory(
|
|
||||||
self, item: Item
|
|
||||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
|
||||||
"""
|
|
||||||
Run a single rollout: agent loop + reward computation.
|
|
||||||
|
|
||||||
This is called group_size times in parallel by collect_trajectories().
|
|
||||||
Each call gets its own task_id for terminal/browser session isolation.
|
|
||||||
"""
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# Get group-level tools (resolved once in collect_trajectories)
|
|
||||||
if self._current_group_tools is None:
|
|
||||||
# Fallback: resolve per-trajectory if called outside collect_trajectories
|
|
||||||
tools, valid_names = self._resolve_tools_for_group()
|
|
||||||
else:
|
|
||||||
tools, valid_names = self._current_group_tools
|
|
||||||
|
|
||||||
# Build initial messages
|
|
||||||
messages: List[Dict[str, Any]] = []
|
|
||||||
if self.config.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
|
||||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
|
||||||
|
|
||||||
# Run the agent loop
|
|
||||||
result: AgentResult
|
|
||||||
if self._use_managed_server():
|
|
||||||
# Phase 2: ManagedServer with ToolCallTranslator -- exact tokens + logprobs
|
|
||||||
# tool_parser is set on ServerManager in __init__ and passed through
|
|
||||||
# to ManagedServer, which uses ToolCallTranslator for bidirectional
|
|
||||||
# translation between raw text and OpenAI tool_calls.
|
|
||||||
try:
|
|
||||||
async with self.server.managed_server(
|
|
||||||
tokenizer=self.tokenizer,
|
|
||||||
preserve_think_blocks=bool(self.config.thinking_mode),
|
|
||||||
) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=self.config.agent_temperature,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
budget_config=self.config.build_budget_config(),
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
except NotImplementedError:
|
|
||||||
# DummyManagedServer not allowed -- fall back to Phase 1
|
|
||||||
logger.warning(
|
|
||||||
"ManagedServer not available (OpenAI server?). "
|
|
||||||
"Falling back to direct server mode."
|
|
||||||
)
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=self.config.agent_temperature,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
budget_config=self.config.build_budget_config(),
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
else:
|
|
||||||
# Phase 1: OpenAI server -- native tool_calls, placeholder tokens
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=self.config.agent_temperature,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
budget_config=self.config.build_budget_config(),
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Skip reward computation if the agent loop produced no meaningful work
|
|
||||||
# (e.g., API call failed on turn 1). No point spinning up a Modal sandbox
|
|
||||||
# just to verify files that were never created.
|
|
||||||
only_system_and_user = all(
|
|
||||||
msg.get("role") in {"system", "user"} for msg in result.messages
|
|
||||||
)
|
|
||||||
if result.turns_used == 0 or only_system_and_user:
|
|
||||||
logger.warning(
|
|
||||||
"Agent loop produced no output (turns=%d, msgs=%d). Skipping reward.",
|
|
||||||
result.turns_used, len(result.messages),
|
|
||||||
)
|
|
||||||
reward = 0.0
|
|
||||||
else:
|
|
||||||
# Compute reward using ToolContext (gives verifier full tool access)
|
|
||||||
ctx = ToolContext(task_id)
|
|
||||||
try:
|
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("compute_reward failed: %s", e)
|
|
||||||
reward = 0.0
|
|
||||||
finally:
|
|
||||||
ctx.cleanup()
|
|
||||||
|
|
||||||
# Track tool errors for wandb logging
|
|
||||||
if result.tool_errors:
|
|
||||||
for err in result.tool_errors:
|
|
||||||
self._tool_error_buffer.append({
|
|
||||||
"turn": err.turn,
|
|
||||||
"tool": err.tool_name,
|
|
||||||
"args": err.arguments[:150],
|
|
||||||
"error": err.error[:300],
|
|
||||||
"result": err.tool_result[:300],
|
|
||||||
})
|
|
||||||
|
|
||||||
# Build ScoredDataItem from ManagedServer state
|
|
||||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
|
||||||
# Phase 1: placeholder tokens (still need a valid ScoredDataItem for the pipeline)
|
|
||||||
nodes = (result.managed_state or {}).get("nodes", [])
|
|
||||||
|
|
||||||
if nodes:
|
|
||||||
# Phase 2 (or DummyManagedServer): use actual node data
|
|
||||||
node = nodes[-1] # Final sequence node = full trajectory
|
|
||||||
scored_item: Dict[str, Any] = {
|
|
||||||
"tokens": node.tokens,
|
|
||||||
"masks": node.masked_tokens,
|
|
||||||
"scores": reward,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Include logprobs if available (Phase 2)
|
|
||||||
if hasattr(node, "logprobs") and node.logprobs:
|
|
||||||
scored_item["advantages"] = None # Computed by trainer
|
|
||||||
scored_item["ref_logprobs"] = None
|
|
||||||
else:
|
|
||||||
# Phase 1 with no managed state: create placeholder tokens
|
|
||||||
# so the data pipeline doesn't break. These are NOT suitable
|
|
||||||
# for training but allow process mode (SFT data gen) to work.
|
|
||||||
# Tokenize the full conversation to get approximate tokens.
|
|
||||||
full_text = "\n".join(
|
|
||||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
|
||||||
)
|
|
||||||
if self.tokenizer:
|
|
||||||
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
|
|
||||||
else:
|
|
||||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
|
||||||
|
|
||||||
scored_item = {
|
|
||||||
"tokens": tokens,
|
|
||||||
"masks": [-100] + tokens[1:], # Mask first token as prompt
|
|
||||||
"scores": reward,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Always include messages for wandb rollout display and data logging
|
|
||||||
scored_item["messages"] = result.messages
|
|
||||||
|
|
||||||
return scored_item, []
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Abstract methods -- subclasses must implement
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def setup(self):
|
|
||||||
"""
|
|
||||||
Load dataset, initialize state.
|
|
||||||
|
|
||||||
Called once when the environment starts. Typical implementation:
|
|
||||||
self.dataset = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
|
||||||
self.iter = 0
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_next_item(self) -> Item:
|
|
||||||
"""
|
|
||||||
Return the next item from the dataset for rollout.
|
|
||||||
|
|
||||||
Called by the base env's main loop to get items for workers.
|
|
||||||
Should cycle through the dataset.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def format_prompt(self, item: Item) -> str:
|
|
||||||
"""
|
|
||||||
Convert a dataset item into the user message for the agent.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item: Dataset item (dict, tuple, etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The prompt string to send to the agent
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def compute_reward(
|
|
||||||
self, item: Item, result: AgentResult, ctx: ToolContext
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Score the rollout. Has full access to:
|
|
||||||
- item: the original dataset item (ground truth, test commands, etc.)
|
|
||||||
- result: AgentResult with full messages, turn count, reasoning, etc.
|
|
||||||
- ctx: ToolContext -- call ANY hermes-agent tool (terminal, file, web,
|
|
||||||
browser, vision...) scoped to this rollout's sandbox. Nothing
|
|
||||||
is off-limits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item: The dataset item that was rolled out
|
|
||||||
result: The agent's rollout result
|
|
||||||
ctx: ToolContext with full tool access for verification
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Reward float (typically 0.0 to 1.0, but any float is valid)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def evaluate(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Periodic evaluation. Called every steps_per_eval steps.
|
|
||||||
|
|
||||||
Typical implementation runs the agent on a held-out eval set
|
|
||||||
and logs metrics via wandb/evaluate_log.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
# SWE Environment -- Default Configuration
|
|
||||||
#
|
|
||||||
# SWE-bench style tasks with Modal sandboxes for cloud isolation.
|
|
||||||
# Uses terminal + file + web toolsets.
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# python environments/hermes_swe_env/hermes_swe_env.py serve \
|
|
||||||
# --config environments/hermes_swe_env/default.yaml
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file", "web"]
|
|
||||||
max_agent_turns: 30
|
|
||||||
max_token_length: 4096
|
|
||||||
group_size: 4
|
|
||||||
terminal_backend: "modal"
|
|
||||||
tool_call_parser: "hermes"
|
|
||||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
|
||||||
dataset_name: "bigcode/humanevalpack"
|
|
||||||
dataset_split: "test"
|
|
||||||
prompt_field: "prompt"
|
|
||||||
steps_per_eval: 50
|
|
||||||
total_steps: 500
|
|
||||||
use_wandb: true
|
|
||||||
wandb_name: "hermes-swe"
|
|
||||||
system_prompt: >
|
|
||||||
You are a skilled software engineer. You have access to a terminal,
|
|
||||||
file tools, and web search. Use these tools to complete the coding task.
|
|
||||||
Write clean, working code and verify it runs correctly before finishing.
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "http://localhost:8000/v1"
|
|
||||||
model_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
|
||||||
server_type: "openai"
|
|
||||||
api_key: ""
|
|
||||||
|
|
@ -1,229 +0,0 @@
|
||||||
"""
|
|
||||||
HermesSweEnv -- SWE-Bench Style Environment with Modal Sandboxes
|
|
||||||
|
|
||||||
A concrete environment for software engineering tasks where the model writes code
|
|
||||||
and the reward function runs tests to verify correctness. Uses Modal terminal
|
|
||||||
backend for cloud-isolated sandboxes per rollout.
|
|
||||||
|
|
||||||
The reward function uses ToolContext.terminal() to run test commands in the same
|
|
||||||
Modal sandbox the model used during its agentic loop. All filesystem state from
|
|
||||||
the model's tool calls is preserved for verification.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Phase 1: OpenAI server type
|
|
||||||
vllm serve YourModel --tool-parser hermes
|
|
||||||
run-api
|
|
||||||
python environments/hermes_swe_env.py serve \\
|
|
||||||
--openai.base_url http://localhost:8000/v1 \\
|
|
||||||
--openai.model_name YourModel \\
|
|
||||||
--openai.server_type openai \\
|
|
||||||
--env.dataset_name bigcode/humanevalpack \\
|
|
||||||
--env.terminal_backend modal
|
|
||||||
|
|
||||||
# Phase 2: VLLM server type (full RL training)
|
|
||||||
python environments/hermes_swe_env.py serve \\
|
|
||||||
--openai.base_url http://localhost:8000/v1 \\
|
|
||||||
--openai.model_name YourModel \\
|
|
||||||
--openai.server_type vllm \\
|
|
||||||
--env.tool_call_parser hermes \\
|
|
||||||
--env.terminal_backend modal
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
# Ensure repo root is on sys.path for imports
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
from datasets import load_dataset
|
|
||||||
|
|
||||||
from atroposlib.envs.base import ScoredDataGroup
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
from atroposlib.type_definitions import Item
|
|
||||||
|
|
||||||
from environments.agent_loop import AgentResult
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HermesSweEnvConfig(HermesAgentEnvConfig):
|
|
||||||
"""Config with defaults for SWE-bench style tasks."""
|
|
||||||
|
|
||||||
pass # Inherits all fields, overrides defaults in config_init
|
|
||||||
|
|
||||||
|
|
||||||
class HermesSweEnv(HermesAgentBaseEnv):
|
|
||||||
"""
|
|
||||||
SWE-bench style environment using Modal terminal backend.
|
|
||||||
|
|
||||||
The model gets a coding task, uses terminal + file + web tools to solve it,
|
|
||||||
and the reward function runs tests in the same Modal sandbox to verify.
|
|
||||||
|
|
||||||
Subclass this for specific SWE datasets (HumanEval, SWE-bench, etc.)
|
|
||||||
and customize format_prompt() and compute_reward() as needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "hermes-swe"
|
|
||||||
env_config_cls = HermesSweEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls) -> Tuple[HermesSweEnvConfig, List[APIServerConfig]]:
|
|
||||||
"""
|
|
||||||
Default configuration for the SWE environment.
|
|
||||||
|
|
||||||
Uses Modal terminal backend for cloud isolation and terminal + file + web toolsets.
|
|
||||||
"""
|
|
||||||
env_config = HermesSweEnvConfig(
|
|
||||||
# Toolsets: terminal for running code, file for reading/writing, web for docs
|
|
||||||
enabled_toolsets=["terminal", "file", "web"],
|
|
||||||
disabled_toolsets=None,
|
|
||||||
distribution=None,
|
|
||||||
# Agent settings -- SWE tasks need more turns
|
|
||||||
max_agent_turns=30,
|
|
||||||
max_token_length=4096,
|
|
||||||
agent_temperature=1.0,
|
|
||||||
system_prompt=(
|
|
||||||
"You are a skilled software engineer. You have access to a terminal, "
|
|
||||||
"file tools, and web search. Use these tools to complete the coding task. "
|
|
||||||
"Write clean, working code and verify it runs correctly before finishing."
|
|
||||||
),
|
|
||||||
# Modal backend for cloud-isolated sandboxes
|
|
||||||
terminal_backend="modal",
|
|
||||||
# Dataset -- override via CLI for your specific SWE dataset
|
|
||||||
dataset_name="bigcode/humanevalpack",
|
|
||||||
dataset_split="test",
|
|
||||||
prompt_field="prompt",
|
|
||||||
# Atropos settings
|
|
||||||
group_size=4,
|
|
||||||
tokenizer_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
||||||
tool_call_parser="hermes",
|
|
||||||
steps_per_eval=50,
|
|
||||||
total_steps=500,
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="hermes-swe",
|
|
||||||
)
|
|
||||||
|
|
||||||
server_configs = [
|
|
||||||
APIServerConfig(
|
|
||||||
base_url="http://localhost:8000/v1",
|
|
||||||
model_name="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
|
|
||||||
server_type="openai", # Phase 1; switch to "vllm" for Phase 2
|
|
||||||
api_key="",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
"""Load the SWE dataset."""
|
|
||||||
if self.config.dataset_name:
|
|
||||||
self.dataset = load_dataset(
|
|
||||||
self.config.dataset_name, split=self.config.dataset_split
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Placeholder if no dataset specified
|
|
||||||
self.dataset = []
|
|
||||||
self.iter = 0
|
|
||||||
self.reward_buffer: List[float] = []
|
|
||||||
|
|
||||||
async def get_next_item(self) -> Dict[str, Any]:
|
|
||||||
"""Cycle through the SWE dataset."""
|
|
||||||
if not self.dataset:
|
|
||||||
raise ValueError("No dataset loaded. Set dataset_name in config.")
|
|
||||||
item = self.dataset[self.iter % len(self.dataset)]
|
|
||||||
self.iter += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def format_prompt(self, item: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Format the SWE task prompt.
|
|
||||||
|
|
||||||
Override this in subclasses for different dataset formats.
|
|
||||||
Default assumes the dataset has a 'prompt' field and optionally a 'test' field.
|
|
||||||
"""
|
|
||||||
prompt = item.get(self.config.prompt_field, "")
|
|
||||||
|
|
||||||
# If the dataset has test information, include it in the prompt
|
|
||||||
test_info = item.get("test", item.get("test_code", item.get("tests", "")))
|
|
||||||
if test_info:
|
|
||||||
prompt += f"\n\nTests to pass:\n{test_info}"
|
|
||||||
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
async def compute_reward(
|
|
||||||
self, item: Dict[str, Any], result: AgentResult, ctx: ToolContext
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Score by running tests in the model's Modal sandbox.
|
|
||||||
|
|
||||||
Default implementation:
|
|
||||||
- If the dataset item has a 'test' or 'test_code' field, run it
|
|
||||||
- Check exit code: 0 = pass, non-zero = fail
|
|
||||||
- Partial credit for file creation
|
|
||||||
|
|
||||||
Override this in subclasses for more sophisticated reward logic.
|
|
||||||
"""
|
|
||||||
# Find the test command from the dataset item
|
|
||||||
test_code = item.get("test", item.get("test_code", item.get("tests", "")))
|
|
||||||
|
|
||||||
if test_code:
|
|
||||||
# Run the test in the model's sandbox
|
|
||||||
test_result = ctx.terminal(
|
|
||||||
f'cd /workspace && python3 -c "{test_code}"', timeout=60
|
|
||||||
)
|
|
||||||
|
|
||||||
if test_result["exit_code"] == 0:
|
|
||||||
self.reward_buffer.append(1.0)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Partial credit: check if the model created any Python files
|
|
||||||
file_check = ctx.terminal("find /workspace -name '*.py' -newer /tmp/.start_marker 2>/dev/null | head -5")
|
|
||||||
if file_check["exit_code"] == 0 and file_check.get("output", "").strip():
|
|
||||||
self.reward_buffer.append(0.1)
|
|
||||||
return 0.1
|
|
||||||
|
|
||||||
self.reward_buffer.append(0.0)
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Run evaluation on a held-out set.
|
|
||||||
|
|
||||||
Override for dataset-specific evaluation logic.
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
eval_metrics = {"eval/placeholder": 0.0}
|
|
||||||
await self.evaluate_log(
|
|
||||||
metrics=eval_metrics,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
||||||
"""Log SWE-specific metrics."""
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
|
|
||||||
if self.reward_buffer:
|
|
||||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / len(
|
|
||||||
self.reward_buffer
|
|
||||||
)
|
|
||||||
wandb_metrics["train/pass_rate"] = sum(
|
|
||||||
1 for r in self.reward_buffer if r == 1.0
|
|
||||||
) / len(self.reward_buffer)
|
|
||||||
self.reward_buffer = []
|
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
HermesSweEnv.cli()
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
"""
|
|
||||||
Monkey patches for making hermes-agent tools work inside async frameworks (Atropos).
|
|
||||||
|
|
||||||
Problem:
|
|
||||||
Some tools use asyncio.run() internally (e.g., Modal backend via SWE-ReX,
|
|
||||||
web_extract). This crashes when called from inside Atropos's event loop because
|
|
||||||
asyncio.run() can't be nested.
|
|
||||||
|
|
||||||
Solution:
|
|
||||||
The Modal environment (tools/environments/modal.py) now uses a dedicated
|
|
||||||
_AsyncWorker thread internally, making it safe for both CLI and Atropos use.
|
|
||||||
No monkey-patching is required.
|
|
||||||
|
|
||||||
This module is kept for backward compatibility. apply_patches() is a no-op.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
Call apply_patches() once at import time (done automatically by hermes_base_env.py).
|
|
||||||
This is idempotent and safe to call multiple times.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_patches_applied = False
|
|
||||||
|
|
||||||
|
|
||||||
def apply_patches():
|
|
||||||
"""Apply all monkey patches needed for Atropos compatibility."""
|
|
||||||
global _patches_applied
|
|
||||||
if _patches_applied:
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.debug("apply_patches() called; no patches needed (async safety is built-in)")
|
|
||||||
_patches_applied = True
|
|
||||||
|
|
@ -1,34 +0,0 @@
|
||||||
# Terminal Test Environment -- Default Configuration
|
|
||||||
#
|
|
||||||
# Simple file-creation tasks for validating the full Atropos + hermes-agent stack.
|
|
||||||
# Uses Modal terminal backend and OpenRouter (Claude) for inference.
|
|
||||||
# API keys loaded from ~/hermes-agent/.env
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# run-api
|
|
||||||
# python environments/terminal_test_env/terminal_test_env.py serve \
|
|
||||||
# --config environments/terminal_test_env/default.yaml
|
|
||||||
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 10
|
|
||||||
max_token_length: 2048
|
|
||||||
group_size: 3
|
|
||||||
total_steps: 3
|
|
||||||
steps_per_eval: 3
|
|
||||||
terminal_backend: "modal"
|
|
||||||
tool_call_parser: "hermes"
|
|
||||||
tokenizer_name: "NousResearch/DeepHermes-3-Llama-3-3B-Preview"
|
|
||||||
ensure_scores_are_not_same: false
|
|
||||||
use_wandb: false
|
|
||||||
system_prompt: >
|
|
||||||
You are a helpful assistant with access to a terminal and file tools.
|
|
||||||
Complete the user's request by using the available tools.
|
|
||||||
Be precise and follow instructions exactly.
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-opus-4.6"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
# api_key loaded from OPENROUTER_API_KEY in .env
|
|
||||||
|
|
@ -1,292 +0,0 @@
|
||||||
"""
|
|
||||||
TerminalTestEnv -- Simple Test Environment for Validating the Stack
|
|
||||||
|
|
||||||
A self-contained environment with inline tasks (no external dataset needed).
|
|
||||||
Each task asks the model to create a file at a known path with specific content.
|
|
||||||
The reward verifier cats the file and checks if the content matches.
|
|
||||||
|
|
||||||
Enables only terminal + file toolsets. Uses Modal terminal backend with
|
|
||||||
OpenRouter (Claude) by default.
|
|
||||||
|
|
||||||
Training tasks (3):
|
|
||||||
1. Create ~/greeting.txt with "Hello from Hermes Agent"
|
|
||||||
2. Create ~/count.txt with numbers 1-5, one per line
|
|
||||||
3. Create ~/answer.txt with the result of 123 + 456
|
|
||||||
|
|
||||||
Eval task (1):
|
|
||||||
1. Create ~/result.txt with the result of 6 * 7
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Start Atropos API server
|
|
||||||
run-api
|
|
||||||
|
|
||||||
# Run environment (uses OpenRouter + Modal by default)
|
|
||||||
python environments/terminal_test_env.py serve
|
|
||||||
|
|
||||||
# Process mode (no run-api needed, saves to JSONL)
|
|
||||||
python environments/terminal_test_env.py process \\
|
|
||||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
# Ensure repo root is on sys.path for imports
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
from atroposlib.envs.base import ScoredDataGroup
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
from atroposlib.type_definitions import Item
|
|
||||||
|
|
||||||
from environments.agent_loop import AgentResult
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Inline task definitions -- no external dataset needed
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
TRAIN_TASKS = [
|
|
||||||
{
|
|
||||||
"prompt": "Create a file at ~/greeting.txt containing exactly the text: Hello from Hermes Agent",
|
|
||||||
"verify_path": "~/greeting.txt",
|
|
||||||
"expected_content": "Hello from Hermes Agent",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prompt": "Create a file at ~/count.txt containing the numbers 1 through 5, one per line",
|
|
||||||
"verify_path": "~/count.txt",
|
|
||||||
"expected_content": "1\n2\n3\n4\n5",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"prompt": "Create a file at ~/answer.txt containing the result of 123 + 456",
|
|
||||||
"verify_path": "~/answer.txt",
|
|
||||||
"expected_content": "579",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
EVAL_TASKS = [
|
|
||||||
{
|
|
||||||
"prompt": "Create a file at ~/result.txt containing the result of 6 * 7",
|
|
||||||
"verify_path": "~/result.txt",
|
|
||||||
"expected_content": "42",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TerminalTestEnvConfig(HermesAgentEnvConfig):
|
|
||||||
"""Config with defaults suitable for terminal testing."""
|
|
||||||
|
|
||||||
pass # Inherits all fields, overrides defaults in config_init
|
|
||||||
|
|
||||||
|
|
||||||
class TerminalTestEnv(HermesAgentBaseEnv):
|
|
||||||
"""
|
|
||||||
Simple test environment with inline file-creation tasks.
|
|
||||||
|
|
||||||
All tasks follow the same pattern: "create a file at ~/X.txt with content Y".
|
|
||||||
The verifier runs `cat ~/X.txt` in the rollout's terminal and checks the output
|
|
||||||
against the expected string. Same verifier logic for all tasks.
|
|
||||||
|
|
||||||
This environment is designed to validate the full stack end-to-end:
|
|
||||||
- Agent loop executes tool calls (terminal/file)
|
|
||||||
- ToolContext provides terminal access to the reward function
|
|
||||||
- Reward function verifies file content via cat
|
|
||||||
- Scored data flows through the Atropos pipeline
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "terminal-test"
|
|
||||||
env_config_cls = TerminalTestEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls) -> Tuple[TerminalTestEnvConfig, List[APIServerConfig]]:
|
|
||||||
"""
|
|
||||||
Default configuration for the terminal test environment.
|
|
||||||
|
|
||||||
Uses Modal terminal backend for cloud isolation and OpenRouter with
|
|
||||||
Claude for inference. API keys loaded from ~/hermes-agent/.env.
|
|
||||||
"""
|
|
||||||
env_config = TerminalTestEnvConfig(
|
|
||||||
# Terminal + file tools only
|
|
||||||
enabled_toolsets=["terminal", "file"],
|
|
||||||
disabled_toolsets=None,
|
|
||||||
distribution=None,
|
|
||||||
# Agent settings
|
|
||||||
max_agent_turns=10, # Simple tasks, don't need many turns
|
|
||||||
max_token_length=16000,
|
|
||||||
agent_temperature=1.0,
|
|
||||||
system_prompt=(
|
|
||||||
"You are a helpful assistant with access to a terminal and file tools. "
|
|
||||||
"Complete the user's request by using the available tools. "
|
|
||||||
"Be precise and follow instructions exactly."
|
|
||||||
),
|
|
||||||
# Modal terminal backend for cloud-isolated sandboxes per rollout
|
|
||||||
terminal_backend="modal",
|
|
||||||
# Atropos settings
|
|
||||||
group_size=3, # 3 rollouts per group
|
|
||||||
tokenizer_name="NousResearch/q-30b-t-h45-e1",
|
|
||||||
tool_call_parser="hermes",
|
|
||||||
steps_per_eval=3, # Eval after all 3 steps
|
|
||||||
total_steps=3, # 3 groups total (1 group per step)
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="terminal-test",
|
|
||||||
ensure_scores_are_not_same=False, # Allow all-same scores for simple tasks
|
|
||||||
# No external dataset
|
|
||||||
dataset_name=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# OpenRouter with Claude -- API key loaded from .env (OPENROUTER_API_KEY)
|
|
||||||
server_configs = [
|
|
||||||
APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name="anthropic/claude-opus-4.6",
|
|
||||||
server_type="openai",
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
|
||||||
health_check=False, # OpenRouter doesn't have a /health endpoint
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
"""Initialize inline task lists."""
|
|
||||||
self.train_tasks = list(TRAIN_TASKS)
|
|
||||||
self.eval_tasks = list(EVAL_TASKS)
|
|
||||||
self.iter = 0
|
|
||||||
# Track reward stats for wandb logging
|
|
||||||
self.reward_buffer: List[float] = []
|
|
||||||
|
|
||||||
async def get_next_item(self) -> Dict[str, str]:
|
|
||||||
"""Cycle through training tasks."""
|
|
||||||
item = self.train_tasks[self.iter % len(self.train_tasks)]
|
|
||||||
self.iter += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def format_prompt(self, item: Dict[str, str]) -> str:
|
|
||||||
"""The prompt is directly in the task item."""
|
|
||||||
return item["prompt"]
|
|
||||||
|
|
||||||
async def compute_reward(
|
|
||||||
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Verify by cat-ing the expected file path and checking content matches.
|
|
||||||
Same verifier for all tasks -- they all write a file at a known path.
|
|
||||||
|
|
||||||
Scoring:
|
|
||||||
1.0 = exact match
|
|
||||||
0.5 = expected content is present but has extra stuff
|
|
||||||
0.0 = file doesn't exist or content doesn't match
|
|
||||||
"""
|
|
||||||
verify_result = ctx.terminal(f"cat {item['verify_path']}")
|
|
||||||
|
|
||||||
# File doesn't exist or can't be read
|
|
||||||
if verify_result["exit_code"] != 0:
|
|
||||||
self.reward_buffer.append(0.0)
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
actual = verify_result.get("output", "").strip()
|
|
||||||
expected = item["expected_content"].strip()
|
|
||||||
|
|
||||||
# Exact match
|
|
||||||
if actual == expected:
|
|
||||||
self.reward_buffer.append(1.0)
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Partial credit: expected content is present but has extra stuff
|
|
||||||
if expected in actual:
|
|
||||||
self.reward_buffer.append(0.5)
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
self.reward_buffer.append(0.0)
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Run eval tasks using the agent loop and verify results.
|
|
||||||
Logs accuracy metrics.
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
correct = 0
|
|
||||||
total = len(self.eval_tasks)
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
for eval_item in self.eval_tasks:
|
|
||||||
try:
|
|
||||||
# For eval, we do a simple single-turn completion (not full agent loop)
|
|
||||||
# to keep eval fast. The agent loop is tested via training.
|
|
||||||
completion = await self.server.chat_completion(
|
|
||||||
messages=[
|
|
||||||
{"role": "system", "content": self.config.system_prompt or ""},
|
|
||||||
{"role": "user", "content": eval_item["prompt"]},
|
|
||||||
],
|
|
||||||
n=1,
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
temperature=0.0,
|
|
||||||
split="eval",
|
|
||||||
)
|
|
||||||
|
|
||||||
response_content = (
|
|
||||||
completion.choices[0].message.content if completion.choices else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
samples.append(
|
|
||||||
{
|
|
||||||
"prompt": eval_item["prompt"],
|
|
||||||
"response": response_content,
|
|
||||||
"expected": eval_item["expected_content"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error("Eval failed for item: %s", e)
|
|
||||||
samples.append(
|
|
||||||
{
|
|
||||||
"prompt": eval_item["prompt"],
|
|
||||||
"response": f"ERROR: {e}",
|
|
||||||
"expected": eval_item["expected_content"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
eval_metrics = {
|
|
||||||
"eval/num_samples": total,
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.evaluate_log(
|
|
||||||
metrics=eval_metrics,
|
|
||||||
samples=samples,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
||||||
"""Log training metrics including reward stats and accuracy."""
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
|
|
||||||
if self.reward_buffer:
|
|
||||||
total = len(self.reward_buffer)
|
|
||||||
correct = sum(1 for r in self.reward_buffer if r == 1.0)
|
|
||||||
partial = sum(1 for r in self.reward_buffer if r == 0.5)
|
|
||||||
|
|
||||||
wandb_metrics["train/avg_reward"] = sum(self.reward_buffer) / total
|
|
||||||
wandb_metrics["train/accuracy"] = correct / total
|
|
||||||
wandb_metrics["train/partial_match_rate"] = partial / total
|
|
||||||
wandb_metrics["train/total_rollouts"] = total
|
|
||||||
self.reward_buffer = []
|
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
TerminalTestEnv.cli()
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
"""
|
|
||||||
Tool Call Parser Registry
|
|
||||||
|
|
||||||
Client-side parsers that extract structured tool_calls from raw model output text.
|
|
||||||
Used in Phase 2 (VLLM server type) where ManagedServer's /generate endpoint returns
|
|
||||||
raw text without tool call parsing.
|
|
||||||
|
|
||||||
Each parser is a standalone reimplementation of the corresponding VLLM parser's
|
|
||||||
non-streaming extract_tool_calls() logic. No VLLM dependency -- only standard library
|
|
||||||
(re, json, uuid) and openai types.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
content, tool_calls = parser.parse(raw_model_output)
|
|
||||||
# content = text with tool call markup stripped
|
|
||||||
# tool_calls = list of ChatCompletionMessageToolCall objects, or None
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Type alias for parser return value
|
|
||||||
ParseResult = Tuple[Optional[str], Optional[List[ChatCompletionMessageToolCall]]]
|
|
||||||
|
|
||||||
|
|
||||||
class ToolCallParser(ABC):
|
|
||||||
"""
|
|
||||||
Base class for tool call parsers.
|
|
||||||
|
|
||||||
Each parser knows how to extract structured tool_calls from a specific
|
|
||||||
model family's raw output text format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
"""
|
|
||||||
Parse raw model output text for tool calls.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: Raw decoded text from the model's completion
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (content, tool_calls) where:
|
|
||||||
- content: text with tool call markup stripped (the message 'content' field),
|
|
||||||
or None if the entire output was tool calls
|
|
||||||
- tool_calls: list of ChatCompletionMessageToolCall objects,
|
|
||||||
or None if no tool calls were found
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
# Global parser registry: name -> parser class
|
|
||||||
PARSER_REGISTRY: Dict[str, Type[ToolCallParser]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def register_parser(name: str):
|
|
||||||
"""
|
|
||||||
Decorator to register a parser class under a given name.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@register_parser("hermes")
|
|
||||||
class HermesToolCallParser(ToolCallParser):
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(cls: Type[ToolCallParser]) -> Type[ToolCallParser]:
|
|
||||||
PARSER_REGISTRY[name] = cls
|
|
||||||
return cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser(name: str) -> ToolCallParser:
|
|
||||||
"""
|
|
||||||
Get a parser instance by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Parser name (e.g., "hermes", "mistral", "llama3_json")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Instantiated parser
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If parser name is not found in registry
|
|
||||||
"""
|
|
||||||
if name not in PARSER_REGISTRY:
|
|
||||||
available = sorted(PARSER_REGISTRY.keys())
|
|
||||||
raise KeyError(
|
|
||||||
f"Tool call parser '{name}' not found. Available parsers: {available}"
|
|
||||||
)
|
|
||||||
return PARSER_REGISTRY[name]()
|
|
||||||
|
|
||||||
|
|
||||||
def list_parsers() -> List[str]:
|
|
||||||
"""Return sorted list of registered parser names."""
|
|
||||||
return sorted(PARSER_REGISTRY.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# Import all parser modules to trigger registration via @register_parser decorators
|
|
||||||
# Each module registers itself when imported
|
|
||||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.longcat_parser import LongcatToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.mistral_parser import MistralToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.llama_parser import LlamaToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.qwen_parser import QwenToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.deepseek_v3_parser import DeepSeekV3ToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.deepseek_v3_1_parser import DeepSeekV31ToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.kimi_k2_parser import KimiK2ToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.glm47_parser import Glm47ToolCallParser # noqa: E402, F401
|
|
||||||
from environments.tool_call_parsers.qwen3_coder_parser import Qwen3CoderToolCallParser # noqa: E402, F401
|
|
||||||
|
|
@ -1,72 +0,0 @@
|
||||||
"""
|
|
||||||
DeepSeek V3.1 tool call parser.
|
|
||||||
|
|
||||||
Similar to V3 but with a slightly different format:
|
|
||||||
<|tool▁call▁begin|>function_name<|tool▁sep|>arguments<|tool▁call▁end|>
|
|
||||||
|
|
||||||
Note: V3 has type+name before the separator, V3.1 has name before and args after.
|
|
||||||
|
|
||||||
Based on VLLM's DeepSeekV31ToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("deepseek_v3_1")
|
|
||||||
@register_parser("deepseek_v31")
|
|
||||||
class DeepSeekV31ToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for DeepSeek V3.1 tool calls.
|
|
||||||
|
|
||||||
Slightly different regex than V3: function_name comes before the separator,
|
|
||||||
arguments come after (no type field, no json code block wrapper).
|
|
||||||
"""
|
|
||||||
|
|
||||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
|
||||||
|
|
||||||
# Regex captures: function_name, function_arguments
|
|
||||||
PATTERN = re.compile(
|
|
||||||
r"<|tool▁call▁begin|>(?P<function_name>.*?)<|tool▁sep|>(?P<function_arguments>.*?)<|tool▁call▁end|>",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if self.START_TOKEN not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
matches = self.PATTERN.findall(text)
|
|
||||||
if not matches:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
for match in matches:
|
|
||||||
func_name, func_args = match
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=func_name.strip(),
|
|
||||||
arguments=func_args.strip(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
content = text[: text.find(self.START_TOKEN)].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,89 +0,0 @@
|
||||||
"""
|
|
||||||
DeepSeek V3 tool call parser.
|
|
||||||
|
|
||||||
Format uses special unicode tokens:
|
|
||||||
<|tool▁calls▁begin|>
|
|
||||||
<|tool▁call▁begin|>type<|tool▁sep|>function_name
|
|
||||||
```json
|
|
||||||
{"arg": "value"}
|
|
||||||
```
|
|
||||||
<|tool▁call▁end|>
|
|
||||||
<|tool▁calls▁end|>
|
|
||||||
|
|
||||||
Fixes Issue #989: Support for multiple simultaneous tool calls.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
import logging
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
@register_parser("deepseek_v3")
|
|
||||||
class DeepSeekV3ToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for DeepSeek V3 tool calls.
|
|
||||||
|
|
||||||
Uses special unicode tokens with fullwidth angle brackets and block elements.
|
|
||||||
Extracts type, function name, and JSON arguments from the structured format.
|
|
||||||
Ensures all tool calls are captured when the model executes multiple actions.
|
|
||||||
"""
|
|
||||||
|
|
||||||
START_TOKEN = "<|tool▁calls▁begin|>"
|
|
||||||
|
|
||||||
# Updated PATTERN: Using \s* instead of literal \n for increased robustness
|
|
||||||
# against variations in model formatting (Issue #989).
|
|
||||||
PATTERN = re.compile(
|
|
||||||
r"<|tool▁call▁begin|>(?P<type>.*?)<|tool▁sep|>(?P<function_name>.*?)\s*```json\s*(?P<function_arguments>.*?)\s*```\s*<|tool▁call▁end|>",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
"""
|
|
||||||
Parses the input text and extracts all available tool calls.
|
|
||||||
"""
|
|
||||||
if self.START_TOKEN not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Using finditer to capture ALL tool calls in the sequence
|
|
||||||
matches = list(self.PATTERN.finditer(text))
|
|
||||||
if not matches:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
func_name = match.group("function_name").strip()
|
|
||||||
func_args = match.group("function_arguments").strip()
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=func_name,
|
|
||||||
arguments=func_args,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if tool_calls:
|
|
||||||
# Content is text before the first tool call block
|
|
||||||
content_index = text.find(self.START_TOKEN)
|
|
||||||
content = text[:content_index].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error parsing DeepSeek V3 tool calls: {e}")
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,109 +0,0 @@
|
||||||
"""
|
|
||||||
GLM 4.5 (GLM-4-MoE) tool call parser.
|
|
||||||
|
|
||||||
Format uses custom arg_key/arg_value tags rather than standard JSON:
|
|
||||||
<tool_call>function_name
|
|
||||||
<arg_key>param1</arg_key><arg_value>value1</arg_value>
|
|
||||||
<arg_key>param2</arg_key><arg_value>value2</arg_value>
|
|
||||||
</tool_call>
|
|
||||||
|
|
||||||
Values are deserialized using json.loads -> ast.literal_eval -> raw string fallback.
|
|
||||||
|
|
||||||
Based on VLLM's Glm4MoeModelToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import ast
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _deserialize_value(value: str) -> Any:
|
|
||||||
"""
|
|
||||||
Try to deserialize a string value to its native Python type.
|
|
||||||
Attempts json.loads, then ast.literal_eval, then returns raw string.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return json.loads(value)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
return ast.literal_eval(value)
|
|
||||||
except (ValueError, SyntaxError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("glm45")
|
|
||||||
class Glm45ToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for GLM 4.5 (GLM-4-MoE) tool calls.
|
|
||||||
|
|
||||||
Uses <tool_call>...</tool_call> tags with <arg_key>/<arg_value> pairs
|
|
||||||
instead of standard JSON arguments.
|
|
||||||
"""
|
|
||||||
|
|
||||||
FUNC_CALL_REGEX = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
|
|
||||||
FUNC_DETAIL_REGEX = re.compile(r"<tool_call>([^\n]*)\n(.*)</tool_call>", re.DOTALL)
|
|
||||||
FUNC_ARG_REGEX = re.compile(
|
|
||||||
r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>", re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
START_TOKEN = "<tool_call>"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if self.START_TOKEN not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
matched_calls = self.FUNC_CALL_REGEX.findall(text)
|
|
||||||
if not matched_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
|
|
||||||
for match in matched_calls:
|
|
||||||
detail = self.FUNC_DETAIL_REGEX.search(match)
|
|
||||||
if not detail:
|
|
||||||
continue
|
|
||||||
|
|
||||||
func_name = detail.group(1).strip()
|
|
||||||
func_args_raw = detail.group(2)
|
|
||||||
|
|
||||||
# Parse arg_key/arg_value pairs
|
|
||||||
pairs = self.FUNC_ARG_REGEX.findall(func_args_raw) if func_args_raw else []
|
|
||||||
arg_dict: Dict[str, Any] = {}
|
|
||||||
for key, value in pairs:
|
|
||||||
arg_key = key.strip()
|
|
||||||
arg_val = _deserialize_value(value.strip())
|
|
||||||
arg_dict[arg_key] = arg_val
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=func_name,
|
|
||||||
arguments=json.dumps(arg_dict, ensure_ascii=False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
content = text[: text.find(self.START_TOKEN)].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,35 +0,0 @@
|
||||||
"""
|
|
||||||
GLM 4.7 tool call parser.
|
|
||||||
|
|
||||||
Same as GLM 4.5 but with slightly different regex patterns.
|
|
||||||
The tool_call tags may wrap differently and arg parsing handles
|
|
||||||
newlines between key/value pairs.
|
|
||||||
|
|
||||||
Based on VLLM's Glm47MoeModelToolParser (extends Glm4MoeModelToolParser).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, register_parser
|
|
||||||
from environments.tool_call_parsers.glm45_parser import Glm45ToolCallParser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("glm47")
|
|
||||||
class Glm47ToolCallParser(Glm45ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for GLM 4.7 tool calls.
|
|
||||||
Extends GLM 4.5 with updated regex patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
# GLM 4.7 uses a slightly different detail regex that includes
|
|
||||||
# the <tool_call> wrapper and optional arg_key content
|
|
||||||
self.FUNC_DETAIL_REGEX = re.compile(
|
|
||||||
r"<tool_call>(.*?)(<arg_key>.*?)?</tool_call>", re.DOTALL
|
|
||||||
)
|
|
||||||
# GLM 4.7 handles newlines between arg_key and arg_value tags
|
|
||||||
self.FUNC_ARG_REGEX = re.compile(
|
|
||||||
r"<arg_key>(.*?)</arg_key>(?:\\n|\s)*<arg_value>(.*?)</arg_value>",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
@ -1,75 +0,0 @@
|
||||||
"""
|
|
||||||
Hermes tool call parser.
|
|
||||||
|
|
||||||
Format: <tool_call>{"name": "func", "arguments": {...}}</tool_call>
|
|
||||||
Based on VLLM's Hermes2ProToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("hermes")
|
|
||||||
class HermesToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Hermes-format tool calls.
|
|
||||||
|
|
||||||
Matches <tool_call>...</tool_call> tags containing JSON with "name" and "arguments".
|
|
||||||
Also handles unclosed <tool_call> at end-of-string (truncated generation).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Matches both closed and unclosed tool_call tags
|
|
||||||
PATTERN = re.compile(
|
|
||||||
r"<tool_call>\s*(.*?)\s*</tool_call>|<tool_call>\s*(.*)", re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if "<tool_call>" not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
matches = self.PATTERN.findall(text)
|
|
||||||
if not matches:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
for match in matches:
|
|
||||||
# match is a tuple: (closed_content, unclosed_content)
|
|
||||||
raw_json = match[0] if match[0] else match[1]
|
|
||||||
if not raw_json.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
tc_data = json.loads(raw_json)
|
|
||||||
if "name" not in tc_data:
|
|
||||||
continue
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=tc_data["name"],
|
|
||||||
arguments=json.dumps(
|
|
||||||
tc_data.get("arguments", {}), ensure_ascii=False
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
# Content is everything before the first <tool_call> tag
|
|
||||||
content = text[: text.find("<tool_call>")].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,93 +0,0 @@
|
||||||
"""
|
|
||||||
Kimi K2 tool call parser.
|
|
||||||
|
|
||||||
Format:
|
|
||||||
<|tool_calls_section_begin|>
|
|
||||||
<|tool_call_begin|>function_id:0<|tool_call_argument_begin|>{"arg": "val"}<|tool_call_end|>
|
|
||||||
<|tool_calls_section_end|>
|
|
||||||
|
|
||||||
The function_id format is typically "functions.func_name:index" or "func_name:index".
|
|
||||||
|
|
||||||
Based on VLLM's KimiK2ToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("kimi_k2")
|
|
||||||
class KimiK2ToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Kimi K2 tool calls.
|
|
||||||
|
|
||||||
Uses section begin/end tokens wrapping individual tool call begin/end tokens.
|
|
||||||
The tool_call_id contains the function name (after last dot, before colon).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Support both singular and plural variants
|
|
||||||
START_TOKENS = [
|
|
||||||
"<|tool_calls_section_begin|>",
|
|
||||||
"<|tool_call_section_begin|>",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Regex captures: tool_call_id (e.g., "functions.get_weather:0"), function_arguments
|
|
||||||
PATTERN = re.compile(
|
|
||||||
r"<\|tool_call_begin\|>\s*(?P<tool_call_id>[^<]+:\d+)\s*"
|
|
||||||
r"<\|tool_call_argument_begin\|>\s*"
|
|
||||||
r"(?P<function_arguments>(?:(?!<\|tool_call_begin\|>).)*?)\s*"
|
|
||||||
r"<\|tool_call_end\|>",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
# Check for any variant of the start token
|
|
||||||
has_start = any(token in text for token in self.START_TOKENS)
|
|
||||||
if not has_start:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
matches = self.PATTERN.findall(text)
|
|
||||||
if not matches:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
for match in matches:
|
|
||||||
function_id, function_args = match
|
|
||||||
|
|
||||||
# Extract function name from ID format: "functions.get_weather:0" -> "get_weather"
|
|
||||||
function_name = function_id.split(":")[0].split(".")[-1]
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=function_id, # Preserve the original ID format
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=function_name,
|
|
||||||
arguments=function_args.strip(),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
# Content is everything before the tool calls section
|
|
||||||
earliest_start = len(text)
|
|
||||||
for token in self.START_TOKENS:
|
|
||||||
idx = text.find(token)
|
|
||||||
if idx >= 0 and idx < earliest_start:
|
|
||||||
earliest_start = idx
|
|
||||||
|
|
||||||
content = text[:earliest_start].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,96 +0,0 @@
|
||||||
"""
|
|
||||||
Llama 3.x / 4 tool call parser.
|
|
||||||
|
|
||||||
Format: The model outputs JSON objects with "name" and "arguments" (or "parameters") keys.
|
|
||||||
May be preceded by <|python_tag|> token. Supports multiple JSON objects separated
|
|
||||||
by content or semicolons.
|
|
||||||
|
|
||||||
Based on VLLM's Llama3JsonToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("llama3_json")
|
|
||||||
@register_parser("llama4_json")
|
|
||||||
class LlamaToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Llama 3.x and 4 JSON-format tool calls.
|
|
||||||
|
|
||||||
Finds JSON objects containing "name" + ("arguments" or "parameters") keys.
|
|
||||||
Uses Python's json.JSONDecoder.raw_decode for robust extraction of
|
|
||||||
JSON objects from mixed text.
|
|
||||||
"""
|
|
||||||
|
|
||||||
BOT_TOKEN = "<|python_tag|>"
|
|
||||||
|
|
||||||
# Regex to find the start of potential JSON objects
|
|
||||||
JSON_START = re.compile(r"\{")
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
# Quick check: need either the bot token or a JSON brace
|
|
||||||
if self.BOT_TOKEN not in text and "{" not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
decoder = json.JSONDecoder()
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
end_index = -1 # Track where the last parsed JSON ended
|
|
||||||
|
|
||||||
for match in self.JSON_START.finditer(text):
|
|
||||||
start = match.start()
|
|
||||||
# Skip if this brace is inside a previously parsed JSON object
|
|
||||||
if start <= end_index:
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
obj, json_end = decoder.raw_decode(text[start:])
|
|
||||||
end_index = start + json_end
|
|
||||||
|
|
||||||
# Must have "name" and either "arguments" or "parameters"
|
|
||||||
name = obj.get("name")
|
|
||||||
args = obj.get("arguments", obj.get("parameters"))
|
|
||||||
|
|
||||||
if not name or args is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Normalize arguments to JSON string
|
|
||||||
if isinstance(args, dict):
|
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
|
||||||
elif not isinstance(args, str):
|
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(name=name, arguments=args),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, KeyError, ValueError):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
# Content is everything before the first tool call JSON
|
|
||||||
# Find where the first tool call starts in the text
|
|
||||||
first_tc_start = text.find("{")
|
|
||||||
if self.BOT_TOKEN in text:
|
|
||||||
first_tc_start = text.find(self.BOT_TOKEN)
|
|
||||||
content = text[:first_tc_start].strip() if first_tc_start > 0 else None
|
|
||||||
|
|
||||||
return content, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
"""
|
|
||||||
Longcat Flash Chat tool call parser.
|
|
||||||
|
|
||||||
Same as Hermes but uses <longcat_tool_call> tags instead of <tool_call>.
|
|
||||||
Based on VLLM's LongcatFlashToolParser (extends Hermes2ProToolParser).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("longcat")
|
|
||||||
class LongcatToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Longcat Flash Chat tool calls.
|
|
||||||
Identical logic to Hermes, just different tag names.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PATTERN = re.compile(
|
|
||||||
r"<longcat_tool_call>\s*(.*?)\s*</longcat_tool_call>|<longcat_tool_call>\s*(.*)",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if "<longcat_tool_call>" not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
matches = self.PATTERN.findall(text)
|
|
||||||
if not matches:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
for match in matches:
|
|
||||||
raw_json = match[0] if match[0] else match[1]
|
|
||||||
if not raw_json.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
tc_data = json.loads(raw_json)
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:8]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=tc_data["name"],
|
|
||||||
arguments=json.dumps(
|
|
||||||
tc_data.get("arguments", {}), ensure_ascii=False
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
content = text[: text.find("<longcat_tool_call>")].strip()
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,137 +0,0 @@
|
||||||
"""
|
|
||||||
Mistral tool call parser.
|
|
||||||
|
|
||||||
Supports two formats depending on tokenizer version:
|
|
||||||
- Pre-v11: content[TOOL_CALLS] [{"name": ..., "arguments": {...}}, ...]
|
|
||||||
- v11+: content[TOOL_CALLS]tool_name1{"arg": "val"}[TOOL_CALLS]tool_name2{"arg": "val"}
|
|
||||||
|
|
||||||
Based on VLLM's MistralToolParser.extract_tool_calls()
|
|
||||||
The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_mistral_id() -> str:
|
|
||||||
"""Mistral tool call IDs are 9-char alphanumeric strings."""
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
|
|
||||||
return "".join(random.choices(string.ascii_letters + string.digits, k=9))
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("mistral")
|
|
||||||
class MistralToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Mistral-format tool calls.
|
|
||||||
|
|
||||||
Detects format by checking if the content after [TOOL_CALLS] starts with '['
|
|
||||||
(pre-v11 JSON array) or with a tool name (v11+ format).
|
|
||||||
"""
|
|
||||||
|
|
||||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
|
||||||
BOT_TOKEN = "[TOOL_CALLS]"
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if self.BOT_TOKEN not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
parts = text.split(self.BOT_TOKEN)
|
|
||||||
content = parts[0].strip()
|
|
||||||
raw_tool_calls = parts[1:]
|
|
||||||
|
|
||||||
# Detect format: if the first raw part starts with '[', it's pre-v11
|
|
||||||
first_raw = raw_tool_calls[0].strip() if raw_tool_calls else ""
|
|
||||||
is_pre_v11 = first_raw.startswith("[") or first_raw.startswith("{")
|
|
||||||
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
|
|
||||||
if not is_pre_v11:
|
|
||||||
# v11+ format: [TOOL_CALLS]tool_name{args}[TOOL_CALLS]tool_name2{args2}
|
|
||||||
for raw in raw_tool_calls:
|
|
||||||
raw = raw.strip()
|
|
||||||
if not raw or "{" not in raw:
|
|
||||||
continue
|
|
||||||
|
|
||||||
brace_idx = raw.find("{")
|
|
||||||
tool_name = raw[:brace_idx].strip()
|
|
||||||
args_str = raw[brace_idx:]
|
|
||||||
|
|
||||||
# Validate and clean the JSON arguments
|
|
||||||
try:
|
|
||||||
parsed_args = json.loads(args_str)
|
|
||||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # Keep raw if parsing fails
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=_generate_mistral_id(),
|
|
||||||
type="function",
|
|
||||||
function=Function(name=tool_name, arguments=args_str),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Pre-v11 format: [TOOL_CALLS] [{"name": ..., "arguments": {...}}]
|
|
||||||
try:
|
|
||||||
parsed = json.loads(first_raw)
|
|
||||||
if isinstance(parsed, dict):
|
|
||||||
parsed = [parsed]
|
|
||||||
|
|
||||||
for tc in parsed:
|
|
||||||
if "name" not in tc:
|
|
||||||
continue
|
|
||||||
args = tc.get("arguments", {})
|
|
||||||
if isinstance(args, dict):
|
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=_generate_mistral_id(),
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=tc["name"], arguments=args
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# Fallback: extract JSON objects using raw_decode
|
|
||||||
decoder = json.JSONDecoder()
|
|
||||||
idx = 0
|
|
||||||
while idx < len(first_raw):
|
|
||||||
try:
|
|
||||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
|
||||||
if isinstance(obj, dict) and "name" in obj:
|
|
||||||
args = obj.get("arguments", {})
|
|
||||||
if isinstance(args, dict):
|
|
||||||
args = json.dumps(args, ensure_ascii=False)
|
|
||||||
tool_calls.append(
|
|
||||||
ChatCompletionMessageToolCall(
|
|
||||||
id=_generate_mistral_id(),
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=obj["name"], arguments=args
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
idx = end_idx
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
return content if content else None, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,163 +0,0 @@
|
||||||
"""
|
|
||||||
Qwen3-Coder tool call parser.
|
|
||||||
|
|
||||||
Format uses XML-style nested tags:
|
|
||||||
<tool_call>
|
|
||||||
<function=function_name>
|
|
||||||
<parameter=param_name>value</parameter>
|
|
||||||
<parameter=param_name2>value2</parameter>
|
|
||||||
</function>
|
|
||||||
</tool_call>
|
|
||||||
|
|
||||||
Parameters are extracted from <parameter=name>value</parameter> tags and
|
|
||||||
type-converted using the schema if available, otherwise treated as strings.
|
|
||||||
|
|
||||||
Based on VLLM's Qwen3CoderToolParser.extract_tool_calls()
|
|
||||||
"""
|
|
||||||
|
|
||||||
import ast
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
from openai.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
Function,
|
|
||||||
)
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import ParseResult, ToolCallParser, register_parser
|
|
||||||
|
|
||||||
|
|
||||||
def _try_convert_value(value: str) -> Any:
|
|
||||||
"""
|
|
||||||
Try to convert a parameter value string to a native Python type.
|
|
||||||
Handles null, numbers, booleans, JSON objects/arrays, and falls back to string.
|
|
||||||
"""
|
|
||||||
stripped = value.strip()
|
|
||||||
|
|
||||||
# Handle null
|
|
||||||
if stripped.lower() == "null":
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Try JSON first (handles objects, arrays, strings, numbers, booleans)
|
|
||||||
try:
|
|
||||||
return json.loads(stripped)
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Try Python literal eval (handles tuples, etc.)
|
|
||||||
try:
|
|
||||||
return ast.literal_eval(stripped)
|
|
||||||
except (ValueError, SyntaxError, TypeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Return as string
|
|
||||||
return stripped
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("qwen3_coder")
|
|
||||||
class Qwen3CoderToolCallParser(ToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Qwen3-Coder XML-format tool calls.
|
|
||||||
|
|
||||||
Uses nested XML tags: <tool_call><function=name><parameter=key>val</parameter></function></tool_call>
|
|
||||||
"""
|
|
||||||
|
|
||||||
START_TOKEN = "<tool_call>"
|
|
||||||
FUNCTION_PREFIX = "<function="
|
|
||||||
|
|
||||||
# Find complete tool_call blocks (or unclosed at end)
|
|
||||||
TOOL_CALL_REGEX = re.compile(
|
|
||||||
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find function blocks within a tool_call
|
|
||||||
FUNCTION_REGEX = re.compile(
|
|
||||||
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Find parameter blocks within a function
|
|
||||||
PARAMETER_REGEX = re.compile(
|
|
||||||
r"<parameter=(.*?)(?:</parameter>|(?=<parameter=)|(?=</function>)|$)",
|
|
||||||
re.DOTALL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_function_call(self, function_str: str) -> Optional[ChatCompletionMessageToolCall]:
|
|
||||||
"""Parse a single <function=name>...</function> block into a ToolCall."""
|
|
||||||
try:
|
|
||||||
# Extract function name: everything before the first '>'
|
|
||||||
gt_idx = function_str.index(">")
|
|
||||||
func_name = function_str[:gt_idx].strip()
|
|
||||||
params_str = function_str[gt_idx + 1:]
|
|
||||||
|
|
||||||
# Extract parameters
|
|
||||||
param_dict: Dict[str, Any] = {}
|
|
||||||
for match_text in self.PARAMETER_REGEX.findall(params_str):
|
|
||||||
if ">" not in match_text:
|
|
||||||
continue
|
|
||||||
eq_idx = match_text.index(">")
|
|
||||||
param_name = match_text[:eq_idx].strip()
|
|
||||||
param_value = match_text[eq_idx + 1:]
|
|
||||||
|
|
||||||
# Clean up whitespace
|
|
||||||
if param_value.startswith("\n"):
|
|
||||||
param_value = param_value[1:]
|
|
||||||
if param_value.endswith("\n"):
|
|
||||||
param_value = param_value[:-1]
|
|
||||||
|
|
||||||
param_dict[param_name] = _try_convert_value(param_value)
|
|
||||||
|
|
||||||
return ChatCompletionMessageToolCall(
|
|
||||||
id=f"call_{uuid.uuid4().hex[:24]}",
|
|
||||||
type="function",
|
|
||||||
function=Function(
|
|
||||||
name=func_name,
|
|
||||||
arguments=json.dumps(param_dict, ensure_ascii=False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def parse(self, text: str) -> ParseResult:
|
|
||||||
if self.FUNCTION_PREFIX not in text:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Find all tool_call blocks
|
|
||||||
tc_matches = self.TOOL_CALL_REGEX.findall(text)
|
|
||||||
raw_blocks = [m[0] if m[0] else m[1] for m in tc_matches]
|
|
||||||
|
|
||||||
# Fallback: if no tool_call tags, try the whole text
|
|
||||||
if not raw_blocks:
|
|
||||||
raw_blocks = [text]
|
|
||||||
|
|
||||||
# Find function blocks within each tool_call
|
|
||||||
function_strs: List[str] = []
|
|
||||||
for block in raw_blocks:
|
|
||||||
func_matches = self.FUNCTION_REGEX.findall(block)
|
|
||||||
function_strs.extend(m[0] if m[0] else m[1] for m in func_matches)
|
|
||||||
|
|
||||||
if not function_strs:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
# Parse each function call
|
|
||||||
tool_calls: List[ChatCompletionMessageToolCall] = []
|
|
||||||
for func_str in function_strs:
|
|
||||||
tc = self._parse_function_call(func_str)
|
|
||||||
if tc is not None:
|
|
||||||
tool_calls.append(tc)
|
|
||||||
|
|
||||||
if not tool_calls:
|
|
||||||
return text, None
|
|
||||||
|
|
||||||
# Content before tool calls
|
|
||||||
first_tc = text.find(self.START_TOKEN)
|
|
||||||
if first_tc < 0:
|
|
||||||
first_tc = text.find(self.FUNCTION_PREFIX)
|
|
||||||
content = text[:first_tc].strip() if first_tc > 0 else None
|
|
||||||
|
|
||||||
return content, tool_calls
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
return text, None
|
|
||||||
|
|
@ -1,19 +0,0 @@
|
||||||
"""
|
|
||||||
Qwen 2.5 tool call parser.
|
|
||||||
|
|
||||||
Uses the same <tool_call> format as Hermes.
|
|
||||||
Registered as a separate parser name for clarity when using --tool-parser=qwen.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from environments.tool_call_parsers import register_parser
|
|
||||||
from environments.tool_call_parsers.hermes_parser import HermesToolCallParser
|
|
||||||
|
|
||||||
|
|
||||||
@register_parser("qwen")
|
|
||||||
class QwenToolCallParser(HermesToolCallParser):
|
|
||||||
"""
|
|
||||||
Parser for Qwen 2.5 tool calls.
|
|
||||||
Same <tool_call>{"name": ..., "arguments": ...}</tool_call> format as Hermes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pass # Identical format -- inherits everything from Hermes
|
|
||||||
|
|
@ -1,473 +0,0 @@
|
||||||
"""
|
|
||||||
ToolContext -- Unrestricted Tool Access for Reward Functions
|
|
||||||
|
|
||||||
A per-rollout handle that gives reward/verification functions direct access to
|
|
||||||
ALL hermes-agent tools, scoped to the rollout's task_id. The same task_id means
|
|
||||||
the terminal/browser session is the SAME one the model used during its rollout --
|
|
||||||
all state (files, processes, browser tabs) is preserved.
|
|
||||||
|
|
||||||
The verifier author decides which tools to use. Nothing is hardcoded or gated.
|
|
||||||
|
|
||||||
Example usage in a compute_reward():
|
|
||||||
async def compute_reward(self, item, result, ctx):
|
|
||||||
# Run tests in the model's terminal sandbox
|
|
||||||
test = ctx.terminal("pytest -v")
|
|
||||||
if test["exit_code"] == 0:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Check if a file was created
|
|
||||||
content = ctx.read_file("/workspace/solution.py")
|
|
||||||
if content.get("content"):
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
return 0.0
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import concurrent.futures
|
|
||||||
|
|
||||||
from model_tools import handle_function_call
|
|
||||||
from tools.terminal_tool import cleanup_vm
|
|
||||||
from tools.browser_tool import cleanup_browser
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Thread pool for running sync tool calls that internally use asyncio.run()
|
|
||||||
_tool_executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
|
|
||||||
|
|
||||||
|
|
||||||
def _run_tool_in_thread(tool_name: str, arguments: Dict[str, Any], task_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Run a tool call in a thread pool executor so backends that use asyncio.run()
|
|
||||||
internally (modal, docker, daytona) get a clean event loop.
|
|
||||||
|
|
||||||
If we're already in an async context, executes handle_function_call() in a
|
|
||||||
disposable worker thread and blocks for the result.
|
|
||||||
If not (e.g., called from sync code), runs directly.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
# We're in an async context -- need to run in thread
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as pool:
|
|
||||||
future = pool.submit(
|
|
||||||
handle_function_call, tool_name, arguments, task_id
|
|
||||||
)
|
|
||||||
return future.result(timeout=300)
|
|
||||||
except RuntimeError:
|
|
||||||
# No running event loop -- safe to call directly
|
|
||||||
return handle_function_call(tool_name, arguments, task_id)
|
|
||||||
|
|
||||||
|
|
||||||
class ToolContext:
|
|
||||||
"""
|
|
||||||
Open-ended access to all hermes-agent tools for a specific rollout.
|
|
||||||
|
|
||||||
Passed to compute_reward() so verifiers can use any tool they need:
|
|
||||||
terminal commands, file reads/writes, web searches, browser automation, etc.
|
|
||||||
All calls share the rollout's task_id for session isolation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, task_id: str):
|
|
||||||
self.task_id = task_id
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Terminal tools
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def terminal(self, command: str, timeout: int = 180) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Run a command in the rollout's terminal session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Shell command to execute
|
|
||||||
timeout: Command timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'exit_code' (int) and 'output' (str)
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
backend = os.getenv("TERMINAL_ENV", "local")
|
|
||||||
logger.debug("ToolContext.terminal [%s backend] task=%s: %s", backend, self.task_id[:8], command[:100])
|
|
||||||
|
|
||||||
# Run via thread helper so modal/docker/daytona backends' asyncio.run() doesn't deadlock
|
|
||||||
result = _run_tool_in_thread(
|
|
||||||
"terminal",
|
|
||||||
{"command": command, "timeout": timeout},
|
|
||||||
self.task_id,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"exit_code": -1, "output": result}
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# File tools
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def read_file(self, path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Read a file from the rollout's filesystem.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: File path to read
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with file content or error
|
|
||||||
"""
|
|
||||||
result = handle_function_call(
|
|
||||||
"read_file", {"path": path}, task_id=self.task_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
def write_file(self, path: str, content: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Write a TEXT file in the rollout's filesystem.
|
|
||||||
|
|
||||||
Uses a shell heredoc under the hood, so this is only safe for text content.
|
|
||||||
For binary files (images, compiled artifacts, etc.), use upload_file() instead.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: File path to write
|
|
||||||
content: Text content to write
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with success status or error
|
|
||||||
"""
|
|
||||||
result = handle_function_call(
|
|
||||||
"write_file", {"path": path, "content": content}, task_id=self.task_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
def upload_file(self, local_path: str, remote_path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Upload a local file to the rollout's sandbox (binary-safe).
|
|
||||||
|
|
||||||
Unlike write_file() which passes content through a shell heredoc (text-only),
|
|
||||||
this method base64-encodes the file and decodes it inside the sandbox.
|
|
||||||
Safe for any file type: binaries, images, archives, etc.
|
|
||||||
|
|
||||||
For large files (>1MB), the content is split into chunks to avoid
|
|
||||||
hitting shell command-length limits.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_path: Path to a local file on the host
|
|
||||||
remote_path: Destination path inside the sandbox
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'exit_code' and 'output'
|
|
||||||
"""
|
|
||||||
import base64
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
local = _Path(local_path)
|
|
||||||
if not local.exists():
|
|
||||||
return {"exit_code": -1, "output": f"Local file not found: {local_path}"}
|
|
||||||
|
|
||||||
raw = local.read_bytes()
|
|
||||||
b64 = base64.b64encode(raw).decode("ascii")
|
|
||||||
|
|
||||||
# Ensure parent directory exists in the sandbox
|
|
||||||
parent = str(_Path(remote_path).parent)
|
|
||||||
if parent not in {".", "/"}:
|
|
||||||
self.terminal(f"mkdir -p {parent}", timeout=10)
|
|
||||||
|
|
||||||
# For small files, single command is fine
|
|
||||||
chunk_size = 60_000 # ~60KB per chunk (well within shell limits)
|
|
||||||
if len(b64) <= chunk_size:
|
|
||||||
result = self.terminal(
|
|
||||||
f"printf '%s' '{b64}' | base64 -d > {remote_path}",
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# For larger files, write base64 in chunks then decode
|
|
||||||
tmp_b64 = "/tmp/_hermes_upload.b64"
|
|
||||||
self.terminal(f": > {tmp_b64}", timeout=5) # truncate
|
|
||||||
for i in range(0, len(b64), chunk_size):
|
|
||||||
chunk = b64[i : i + chunk_size]
|
|
||||||
self.terminal(f"printf '%s' '{chunk}' >> {tmp_b64}", timeout=15)
|
|
||||||
result = self.terminal(
|
|
||||||
f"base64 -d {tmp_b64} > {remote_path} && rm -f {tmp_b64}",
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def upload_dir(self, local_dir: str, remote_dir: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Upload an entire local directory to the rollout's sandbox (binary-safe).
|
|
||||||
|
|
||||||
Recursively uploads all files, preserving directory structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
local_dir: Path to a local directory on the host
|
|
||||||
remote_dir: Destination directory inside the sandbox
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results, one per file uploaded
|
|
||||||
"""
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
local = _Path(local_dir)
|
|
||||||
if not local.exists() or not local.is_dir():
|
|
||||||
return [{"exit_code": -1, "output": f"Local directory not found: {local_dir}"}]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for file_path in sorted(local.rglob("*")):
|
|
||||||
if file_path.is_file():
|
|
||||||
relative = file_path.relative_to(local)
|
|
||||||
target = f"{remote_dir}/{relative}"
|
|
||||||
results.append(self.upload_file(str(file_path), target))
|
|
||||||
return results
|
|
||||||
|
|
||||||
def download_file(self, remote_path: str, local_path: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Download a file from the rollout's sandbox to the host (binary-safe).
|
|
||||||
|
|
||||||
The inverse of upload_file(). Base64-encodes the file inside the sandbox,
|
|
||||||
reads the encoded data through the terminal, and decodes it locally.
|
|
||||||
Safe for any file type.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
remote_path: Path to the file inside the sandbox
|
|
||||||
local_path: Destination path on the host
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with 'success' (bool) and 'bytes' (int) or 'error' (str)
|
|
||||||
"""
|
|
||||||
import base64
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
# Base64-encode the file inside the sandbox and capture output
|
|
||||||
result = self.terminal(
|
|
||||||
f"base64 {remote_path} 2>/dev/null",
|
|
||||||
timeout=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.get("exit_code", -1) != 0:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to read remote file: {result.get('output', '')}",
|
|
||||||
}
|
|
||||||
|
|
||||||
b64_data = result.get("output", "").strip()
|
|
||||||
if not b64_data:
|
|
||||||
return {"success": False, "error": f"Remote file is empty or missing: {remote_path}"}
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = base64.b64decode(b64_data)
|
|
||||||
except Exception as e:
|
|
||||||
return {"success": False, "error": f"Base64 decode failed: {e}"}
|
|
||||||
|
|
||||||
# Write to local host filesystem
|
|
||||||
local = _Path(local_path)
|
|
||||||
local.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
local.write_bytes(raw)
|
|
||||||
|
|
||||||
return {"success": True, "bytes": len(raw)}
|
|
||||||
|
|
||||||
def download_dir(self, remote_dir: str, local_dir: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Download a directory from the rollout's sandbox to the host (binary-safe).
|
|
||||||
|
|
||||||
Lists all files in the remote directory, then downloads each one.
|
|
||||||
Preserves directory structure.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
remote_dir: Path to the directory inside the sandbox
|
|
||||||
local_dir: Destination directory on the host
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results, one per file downloaded
|
|
||||||
"""
|
|
||||||
from pathlib import Path as _Path
|
|
||||||
|
|
||||||
# List files in the remote directory
|
|
||||||
ls_result = self.terminal(
|
|
||||||
f"find {remote_dir} -type f 2>/dev/null",
|
|
||||||
timeout=15,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ls_result.get("exit_code", -1) != 0:
|
|
||||||
return [{"success": False, "error": f"Failed to list remote dir: {remote_dir}"}]
|
|
||||||
|
|
||||||
file_list = ls_result.get("output", "").strip()
|
|
||||||
if not file_list:
|
|
||||||
return [{"success": False, "error": f"Remote directory is empty or missing: {remote_dir}"}]
|
|
||||||
|
|
||||||
results = []
|
|
||||||
for remote_file in file_list.splitlines():
|
|
||||||
remote_file = remote_file.strip()
|
|
||||||
if not remote_file:
|
|
||||||
continue
|
|
||||||
# Compute the relative path to preserve directory structure
|
|
||||||
if remote_file.startswith(remote_dir):
|
|
||||||
relative = remote_file[len(remote_dir):].lstrip("/")
|
|
||||||
else:
|
|
||||||
relative = _Path(remote_file).name
|
|
||||||
local_file = str(_Path(local_dir) / relative)
|
|
||||||
results.append(self.download_file(remote_file, local_file))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def search(self, query: str, path: str = ".") -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Search for text in the rollout's filesystem.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query
|
|
||||||
path: Directory to search in
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with search results
|
|
||||||
"""
|
|
||||||
result = handle_function_call(
|
|
||||||
"search_files", {"pattern": query, "path": path}, task_id=self.task_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Web tools
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def web_search(self, query: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Search the web.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: Search query
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with search results
|
|
||||||
"""
|
|
||||||
result = handle_function_call("web_search", {"query": query})
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
def web_extract(self, urls: List[str]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Extract content from URLs.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
urls: List of URLs to extract content from
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with extracted content
|
|
||||||
"""
|
|
||||||
result = handle_function_call("web_extract", {"urls": urls})
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Browser tools
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def browser_navigate(self, url: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Navigate the rollout's browser session to a URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: URL to navigate to
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with page snapshot or error
|
|
||||||
"""
|
|
||||||
result = handle_function_call(
|
|
||||||
"browser_navigate", {"url": url}, task_id=self.task_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
def browser_snapshot(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Take a snapshot of the current browser page.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with page content/accessibility snapshot
|
|
||||||
"""
|
|
||||||
result = handle_function_call(
|
|
||||||
"browser_snapshot", {}, task_id=self.task_id
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
return json.loads(result)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return {"error": result}
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Generic tool access
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
|
||||||
"""
|
|
||||||
Call any hermes-agent tool by name.
|
|
||||||
|
|
||||||
This is the generic escape hatch -- if a tool doesn't have a convenience
|
|
||||||
wrapper above, you can call it directly here.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: Name of the tool (e.g., "vision_analyze", "skills_list")
|
|
||||||
arguments: Dict of arguments for the tool
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Raw JSON string result from the tool
|
|
||||||
"""
|
|
||||||
return _run_tool_in_thread(tool_name, arguments, self.task_id)
|
|
||||||
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
# Cleanup
|
|
||||||
# -------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def cleanup(self):
|
|
||||||
"""
|
|
||||||
Release all resources (terminal VMs, browser sessions, background processes)
|
|
||||||
for this rollout.
|
|
||||||
|
|
||||||
Called automatically by the base environment via try/finally after
|
|
||||||
compute_reward() completes. You generally don't need to call this yourself.
|
|
||||||
"""
|
|
||||||
# Kill any background processes from this rollout (safety net)
|
|
||||||
try:
|
|
||||||
from tools.process_registry import process_registry
|
|
||||||
killed = process_registry.kill_all(task_id=self.task_id)
|
|
||||||
if killed:
|
|
||||||
logger.debug("Process cleanup for task %s: killed %d process(es)", self.task_id, killed)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Process cleanup for task %s: %s", self.task_id, e)
|
|
||||||
|
|
||||||
try:
|
|
||||||
cleanup_vm(self.task_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("VM cleanup for task %s: %s", self.task_id, e)
|
|
||||||
|
|
||||||
# Suppress browser_tool's noisy debug prints during cleanup.
|
|
||||||
# The cleanup still runs (safe), it just doesn't spam the console.
|
|
||||||
_prev_quiet = os.environ.get("HERMES_QUIET")
|
|
||||||
os.environ["HERMES_QUIET"] = "1"
|
|
||||||
try:
|
|
||||||
cleanup_browser(self.task_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug("Browser cleanup for task %s: %s", self.task_id, e)
|
|
||||||
finally:
|
|
||||||
if _prev_quiet is None:
|
|
||||||
os.environ.pop("HERMES_QUIET", None)
|
|
||||||
else:
|
|
||||||
os.environ["HERMES_QUIET"] = _prev_quiet
|
|
||||||
|
|
@ -1,719 +0,0 @@
|
||||||
"""
|
|
||||||
WebResearchEnv — RL Environment for Multi-Step Web Research
|
|
||||||
============================================================
|
|
||||||
|
|
||||||
Trains models to do accurate, efficient, multi-source web research.
|
|
||||||
|
|
||||||
Reward signals:
|
|
||||||
- Answer correctness (LLM judge, 0.0–1.0)
|
|
||||||
- Source diversity (used ≥2 distinct domains)
|
|
||||||
- Efficiency (penalizes excessive tool calls)
|
|
||||||
- Tool usage (bonus for actually using web tools)
|
|
||||||
|
|
||||||
Dataset: FRAMES benchmark (Google, 2024) — multi-hop factual questions
|
|
||||||
HuggingFace: google/frames-benchmark
|
|
||||||
Fallback: built-in sample questions (no HF token needed)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Phase 1 (OpenAI-compatible server)
|
|
||||||
python environments/web_research_env.py serve \\
|
|
||||||
--openai.base_url http://localhost:8000/v1 \\
|
|
||||||
--openai.model_name YourModel \\
|
|
||||||
--openai.server_type openai
|
|
||||||
|
|
||||||
# Process mode (offline data generation)
|
|
||||||
python environments/web_research_env.py process \\
|
|
||||||
--env.data_path_to_save_groups data/web_research.jsonl
|
|
||||||
|
|
||||||
# Standalone eval
|
|
||||||
python environments/web_research_env.py evaluate \\
|
|
||||||
--openai.base_url http://localhost:8000/v1 \\
|
|
||||||
--openai.model_name YourModel
|
|
||||||
|
|
||||||
Built by: github.com/jackx707
|
|
||||||
Inspired by: GroceryMind — production Hermes agent doing live web research
|
|
||||||
across German grocery stores (firecrawl + hermes-agent)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
|
|
||||||
# Ensure hermes-agent root is on path
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Optional HuggingFace datasets import
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
try:
|
|
||||||
from datasets import load_dataset
|
|
||||||
HF_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
HF_AVAILABLE = False
|
|
||||||
|
|
||||||
from atroposlib.envs.base import ScoredDataGroup
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
from atroposlib.type_definitions import Item
|
|
||||||
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
from environments.agent_loop import AgentResult
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Fallback sample dataset (used when HuggingFace is unavailable)
|
|
||||||
# Multi-hop questions requiring real web search to answer.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
SAMPLE_QUESTIONS = [
|
|
||||||
{
|
|
||||||
"question": "What is the current population of the capital city of the country that won the 2022 FIFA World Cup?",
|
|
||||||
"answer": "Buenos Aires has approximately 3 million people in the city proper, or around 15 million in the greater metro area.",
|
|
||||||
"difficulty": "medium",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "Who is the CEO of the company that makes the most widely used open-source container orchestration platform?",
|
|
||||||
"answer": "The Linux Foundation oversees Kubernetes. CNCF (Cloud Native Computing Foundation) is the specific body — it does not have a traditional CEO but has an executive director.",
|
|
||||||
"difficulty": "medium",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "What programming language was used to write the original version of the web framework used by Instagram?",
|
|
||||||
"answer": "Django, which Instagram was built on, is written in Python.",
|
|
||||||
"difficulty": "easy",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "In what year was the university founded where the inventor of the World Wide Web currently holds a professorship?",
|
|
||||||
"answer": "Tim Berners-Lee holds a professorship at MIT (founded 1861) and the University of Southampton (founded 1952).",
|
|
||||||
"difficulty": "hard",
|
|
||||||
"hops": 3,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "What is the latest stable version of the programming language that ranks #1 on the TIOBE index as of this year?",
|
|
||||||
"answer": "Python is currently #1 on TIOBE. The latest stable version should be verified via the official python.org site.",
|
|
||||||
"difficulty": "medium",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "How many employees does the parent company of Instagram have?",
|
|
||||||
"answer": "Meta Platforms (parent of Instagram) employs approximately 70,000+ people as of recent reports.",
|
|
||||||
"difficulty": "medium",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "What is the current interest rate set by the central bank of the country where the Eiffel Tower is located?",
|
|
||||||
"answer": "The European Central Bank sets rates for France/eurozone. The current rate should be verified — it has changed frequently in 2023-2025.",
|
|
||||||
"difficulty": "hard",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "Which company acquired the startup founded by the creator of Oculus VR?",
|
|
||||||
"answer": "Palmer Luckey founded Oculus VR, which was acquired by Facebook (now Meta). He later founded Anduril Industries.",
|
|
||||||
"difficulty": "medium",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "What is the market cap of the company that owns the most popular search engine in Russia?",
|
|
||||||
"answer": "Yandex (now split into separate entities after 2024 restructuring). Current market cap should be verified via financial sources.",
|
|
||||||
"difficulty": "hard",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"question": "What was the GDP growth rate of the country that hosted the most recent Summer Olympics?",
|
|
||||||
"answer": "Paris, France hosted the 2024 Summer Olympics. France's recent GDP growth should be verified via World Bank or IMF data.",
|
|
||||||
"difficulty": "hard",
|
|
||||||
"hops": 2,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Configuration
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class WebResearchEnvConfig(HermesAgentEnvConfig):
|
|
||||||
"""Configuration for the web research RL environment."""
|
|
||||||
|
|
||||||
# Reward weights
|
|
||||||
correctness_weight: float = Field(
|
|
||||||
default=0.6,
|
|
||||||
description="Weight for answer correctness in reward (LLM judge score).",
|
|
||||||
)
|
|
||||||
tool_usage_weight: float = Field(
|
|
||||||
default=0.2,
|
|
||||||
description="Weight for tool usage signal (did the model actually use web tools?).",
|
|
||||||
)
|
|
||||||
efficiency_weight: float = Field(
|
|
||||||
default=0.2,
|
|
||||||
description="Weight for efficiency signal (penalizes excessive tool calls).",
|
|
||||||
)
|
|
||||||
diversity_bonus: float = Field(
|
|
||||||
default=0.1,
|
|
||||||
description="Bonus reward for citing ≥2 distinct domains.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Efficiency thresholds
|
|
||||||
efficient_max_calls: int = Field(
|
|
||||||
default=5,
|
|
||||||
description="Maximum tool calls before efficiency penalty begins.",
|
|
||||||
)
|
|
||||||
heavy_penalty_calls: int = Field(
|
|
||||||
default=10,
|
|
||||||
description="Tool call count where efficiency penalty steepens.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Eval
|
|
||||||
eval_size: int = Field(
|
|
||||||
default=20,
|
|
||||||
description="Number of held-out items for evaluation.",
|
|
||||||
)
|
|
||||||
eval_split_ratio: float = Field(
|
|
||||||
default=0.1,
|
|
||||||
description="Fraction of dataset to hold out for evaluation (0.0–1.0).",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Dataset
|
|
||||||
dataset_name: str = Field(
|
|
||||||
default="google/frames-benchmark",
|
|
||||||
description="HuggingFace dataset name for research questions.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Environment
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
class WebResearchEnv(HermesAgentBaseEnv):
|
|
||||||
"""
|
|
||||||
RL environment for training multi-step web research skills.
|
|
||||||
|
|
||||||
The model is given a factual question requiring 2-3 hops of web research
|
|
||||||
and must use web_search / web_extract tools to find and synthesize the answer.
|
|
||||||
|
|
||||||
Reward is multi-signal:
|
|
||||||
60% — answer correctness (LLM judge)
|
|
||||||
20% — tool usage (did the model actually search the web?)
|
|
||||||
20% — efficiency (penalizes >5 tool calls)
|
|
||||||
|
|
||||||
Bonus +0.1 for source diversity (≥2 distinct domains cited).
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "web-research"
|
|
||||||
env_config_cls = WebResearchEnvConfig
|
|
||||||
|
|
||||||
# Default toolsets for this environment — web + file for saving notes
|
|
||||||
default_toolsets = ["web", "file"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls) -> Tuple[WebResearchEnvConfig, List[APIServerConfig]]:
|
|
||||||
"""Default configuration for the web research environment."""
|
|
||||||
env_config = WebResearchEnvConfig(
|
|
||||||
enabled_toolsets=["web", "file"],
|
|
||||||
max_agent_turns=15,
|
|
||||||
agent_temperature=1.0,
|
|
||||||
system_prompt=(
|
|
||||||
"You are a highly capable research agent. When asked a factual question, "
|
|
||||||
"always use web_search to find current, accurate information before answering. "
|
|
||||||
"Cite at least 2 sources. Be concise and accurate."
|
|
||||||
),
|
|
||||||
group_size=4,
|
|
||||||
total_steps=1000,
|
|
||||||
steps_per_eval=100,
|
|
||||||
use_wandb=True,
|
|
||||||
wandb_name="web-research",
|
|
||||||
)
|
|
||||||
|
|
||||||
server_configs = [
|
|
||||||
APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name="anthropic/claude-sonnet-4.5",
|
|
||||||
server_type="openai",
|
|
||||||
api_key=os.getenv("OPENROUTER_API_KEY", ""),
|
|
||||||
health_check=False,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self._items: list[dict] = []
|
|
||||||
self._eval_items: list[dict] = []
|
|
||||||
self._index: int = 0
|
|
||||||
|
|
||||||
# Metrics tracking for wandb
|
|
||||||
self._reward_buffer: list[float] = []
|
|
||||||
self._correctness_buffer: list[float] = []
|
|
||||||
self._tool_usage_buffer: list[float] = []
|
|
||||||
self._efficiency_buffer: list[float] = []
|
|
||||||
self._diversity_buffer: list[float] = []
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 1. Setup — load dataset
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def setup(self) -> None:
|
|
||||||
"""Load the FRAMES benchmark or fall back to built-in samples."""
|
|
||||||
if HF_AVAILABLE:
|
|
||||||
try:
|
|
||||||
logger.info("Loading FRAMES benchmark from HuggingFace...")
|
|
||||||
ds = load_dataset(self.config.dataset_name, split="test")
|
|
||||||
self._items = [
|
|
||||||
{
|
|
||||||
"question": row["Prompt"],
|
|
||||||
"answer": row["Answer"],
|
|
||||||
"difficulty": row.get("reasoning_types", "unknown"),
|
|
||||||
"hops": 2,
|
|
||||||
}
|
|
||||||
for row in ds
|
|
||||||
]
|
|
||||||
# Hold out for eval
|
|
||||||
eval_size = max(
|
|
||||||
self.config.eval_size,
|
|
||||||
int(len(self._items) * self.config.eval_split_ratio),
|
|
||||||
)
|
|
||||||
random.shuffle(self._items)
|
|
||||||
self._eval_items = self._items[:eval_size]
|
|
||||||
self._items = self._items[eval_size:]
|
|
||||||
logger.info(
|
|
||||||
f"Loaded {len(self._items)} train / {len(self._eval_items)} eval items "
|
|
||||||
f"from FRAMES benchmark."
|
|
||||||
)
|
|
||||||
return
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Could not load FRAMES from HuggingFace: {e}. Using built-in samples.")
|
|
||||||
|
|
||||||
# Fallback
|
|
||||||
random.shuffle(SAMPLE_QUESTIONS)
|
|
||||||
split = max(1, len(SAMPLE_QUESTIONS) * 8 // 10)
|
|
||||||
self._items = SAMPLE_QUESTIONS[:split]
|
|
||||||
self._eval_items = SAMPLE_QUESTIONS[split:]
|
|
||||||
logger.info(
|
|
||||||
f"Using built-in sample dataset: {len(self._items)} train / "
|
|
||||||
f"{len(self._eval_items)} eval items."
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 2. get_next_item — return the next question
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def get_next_item(self) -> dict:
|
|
||||||
"""Return the next item, cycling through the dataset."""
|
|
||||||
if not self._items:
|
|
||||||
raise RuntimeError("Dataset is empty. Did you call setup()?")
|
|
||||||
item = self._items[self._index % len(self._items)]
|
|
||||||
self._index += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 3. format_prompt — build the user-facing prompt
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
def format_prompt(self, item: dict) -> str:
|
|
||||||
"""Format the research question as a task prompt."""
|
|
||||||
return (
|
|
||||||
f"Research the following question thoroughly using web search. "
|
|
||||||
f"You MUST search the web to find current, accurate information — "
|
|
||||||
f"do not rely solely on your training data.\n\n"
|
|
||||||
f"Question: {item['question']}\n\n"
|
|
||||||
f"Requirements:\n"
|
|
||||||
f"- Use web_search and/or web_extract tools to find information\n"
|
|
||||||
f"- Search at least 2 different sources\n"
|
|
||||||
f"- Provide a concise, accurate answer (2-4 sentences)\n"
|
|
||||||
f"- Cite the sources you used"
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 4. compute_reward — multi-signal scoring
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def compute_reward(
|
|
||||||
self,
|
|
||||||
item: dict,
|
|
||||||
result: AgentResult,
|
|
||||||
ctx: ToolContext,
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Multi-signal reward function:
|
|
||||||
|
|
||||||
correctness_weight * correctness — LLM judge comparing answer to ground truth
|
|
||||||
tool_usage_weight * tool_used — binary: did the model use web tools?
|
|
||||||
efficiency_weight * efficiency — penalizes wasteful tool usage
|
|
||||||
+ diversity_bonus — source diversity (≥2 distinct domains)
|
|
||||||
"""
|
|
||||||
# Extract final response from messages (last assistant message with content)
|
|
||||||
final_response = ""
|
|
||||||
tools_used: list[str] = []
|
|
||||||
for msg in reversed(result.messages):
|
|
||||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
|
||||||
final_response = msg["content"]
|
|
||||||
# Collect tool names from tool call messages
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
||||||
name = fn.get("name", "")
|
|
||||||
if name:
|
|
||||||
tools_used.append(name)
|
|
||||||
tool_call_count: int = result.turns_used or len(tools_used)
|
|
||||||
|
|
||||||
cfg = self.config
|
|
||||||
|
|
||||||
# ---- Signal 1: Answer correctness (LLM judge) ----------------
|
|
||||||
correctness = await self._llm_judge(
|
|
||||||
question=item["question"],
|
|
||||||
expected=item["answer"],
|
|
||||||
model_answer=final_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---- Signal 2: Web tool usage --------------------------------
|
|
||||||
web_tools = {"web_search", "web_extract", "search", "firecrawl"}
|
|
||||||
tool_used = 1.0 if any(t in web_tools for t in tools_used) else 0.0
|
|
||||||
|
|
||||||
# ---- Signal 3: Efficiency ------------------------------------
|
|
||||||
if tool_call_count <= cfg.efficient_max_calls:
|
|
||||||
efficiency = 1.0
|
|
||||||
elif tool_call_count <= cfg.heavy_penalty_calls:
|
|
||||||
efficiency = 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.08
|
|
||||||
else:
|
|
||||||
efficiency = max(0.0, 1.0 - (tool_call_count - cfg.efficient_max_calls) * 0.12)
|
|
||||||
|
|
||||||
# ---- Bonus: Source diversity ---------------------------------
|
|
||||||
domains = self._extract_domains(final_response)
|
|
||||||
diversity = cfg.diversity_bonus if len(domains) >= 2 else 0.0
|
|
||||||
|
|
||||||
# ---- Combine ------------------------------------------------
|
|
||||||
reward = (
|
|
||||||
cfg.correctness_weight * correctness
|
|
||||||
+ cfg.tool_usage_weight * tool_used
|
|
||||||
+ cfg.efficiency_weight * efficiency
|
|
||||||
+ diversity
|
|
||||||
)
|
|
||||||
reward = min(1.0, max(0.0, reward)) # clamp to [0, 1]
|
|
||||||
|
|
||||||
# Track for wandb
|
|
||||||
self._reward_buffer.append(reward)
|
|
||||||
self._correctness_buffer.append(correctness)
|
|
||||||
self._tool_usage_buffer.append(tool_used)
|
|
||||||
self._efficiency_buffer.append(efficiency)
|
|
||||||
self._diversity_buffer.append(diversity)
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Reward breakdown — correctness={correctness:.2f}, "
|
|
||||||
f"tool_used={tool_used:.1f}, efficiency={efficiency:.2f}, "
|
|
||||||
f"diversity={diversity:.1f} → total={reward:.3f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return reward
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 5. evaluate — run on held-out eval split
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs) -> None:
|
|
||||||
"""Run evaluation on the held-out split using the full agent loop with tools.
|
|
||||||
|
|
||||||
Each eval item runs through the same agent loop as training —
|
|
||||||
the model can use web_search, web_extract, etc. to research answers.
|
|
||||||
This measures actual agentic research capability, not just knowledge.
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from environments.agent_loop import HermesAgentLoop
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
items = self._eval_items
|
|
||||||
if not items:
|
|
||||||
logger.warning("No eval items available.")
|
|
||||||
return
|
|
||||||
|
|
||||||
eval_size = min(self.config.eval_size, len(items))
|
|
||||||
eval_items = items[:eval_size]
|
|
||||||
|
|
||||||
logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
|
|
||||||
start_time = time.time()
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
# Resolve tools once for all eval items
|
|
||||||
tools, valid_names = self._resolve_tools_for_group()
|
|
||||||
|
|
||||||
for i, item in enumerate(eval_items):
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Build messages
|
|
||||||
messages: List[Dict[str, Any]] = []
|
|
||||||
if self.config.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
|
||||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
|
||||||
|
|
||||||
# Run the full agent loop with tools
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=0.0, # Deterministic for eval
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
budget_config=self.config.build_budget_config(),
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Extract final response and tool usage from messages
|
|
||||||
final_response = ""
|
|
||||||
tool_call_count = 0
|
|
||||||
for msg in reversed(result.messages):
|
|
||||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
|
||||||
final_response = msg["content"]
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
tool_call_count += len(msg["tool_calls"])
|
|
||||||
|
|
||||||
# Compute reward (includes LLM judge for correctness)
|
|
||||||
# Temporarily save buffer lengths so we can extract the
|
|
||||||
# correctness score without calling judge twice, and avoid
|
|
||||||
# polluting training metric buffers with eval data.
|
|
||||||
buf_len = len(self._correctness_buffer)
|
|
||||||
ctx = ToolContext(task_id)
|
|
||||||
try:
|
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
|
||||||
finally:
|
|
||||||
ctx.cleanup()
|
|
||||||
|
|
||||||
# Extract correctness from the buffer (compute_reward appended it)
|
|
||||||
# then remove eval entries from training buffers
|
|
||||||
correctness = (
|
|
||||||
self._correctness_buffer[buf_len]
|
|
||||||
if len(self._correctness_buffer) > buf_len
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
# Roll back buffers to avoid polluting training metrics
|
|
||||||
for buf in (
|
|
||||||
self._reward_buffer, self._correctness_buffer,
|
|
||||||
self._tool_usage_buffer, self._efficiency_buffer,
|
|
||||||
self._diversity_buffer,
|
|
||||||
):
|
|
||||||
if len(buf) > buf_len:
|
|
||||||
buf.pop()
|
|
||||||
|
|
||||||
samples.append({
|
|
||||||
"prompt": item["question"],
|
|
||||||
"response": final_response[:500],
|
|
||||||
"expected": item["answer"],
|
|
||||||
"correctness": correctness,
|
|
||||||
"reward": reward,
|
|
||||||
"tool_calls": tool_call_count,
|
|
||||||
"turns": result.turns_used,
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
|
|
||||||
f"tools={tool_call_count}, turns={result.turns_used}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Eval error on item: {e}")
|
|
||||||
samples.append({
|
|
||||||
"prompt": item["question"],
|
|
||||||
"response": f"ERROR: {e}",
|
|
||||||
"expected": item["answer"],
|
|
||||||
"correctness": 0.0,
|
|
||||||
"reward": 0.0,
|
|
||||||
"tool_calls": 0,
|
|
||||||
"turns": 0,
|
|
||||||
})
|
|
||||||
|
|
||||||
end_time = time.time()
|
|
||||||
|
|
||||||
# Compute aggregate metrics
|
|
||||||
correctness_scores = [s["correctness"] for s in samples]
|
|
||||||
rewards = [s["reward"] for s in samples]
|
|
||||||
tool_counts = [s["tool_calls"] for s in samples]
|
|
||||||
n = len(samples)
|
|
||||||
|
|
||||||
eval_metrics = {
|
|
||||||
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
|
||||||
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
|
||||||
"eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
|
|
||||||
"eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
|
|
||||||
"eval/n_items": n,
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
|
|
||||||
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
|
|
||||||
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
|
|
||||||
)
|
|
||||||
|
|
||||||
await self.evaluate_log(
|
|
||||||
metrics=eval_metrics,
|
|
||||||
samples=samples,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# 6. wandb_log — custom metrics
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None) -> None:
|
|
||||||
"""Log reward breakdown metrics to wandb."""
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
|
|
||||||
if self._reward_buffer:
|
|
||||||
n = len(self._reward_buffer)
|
|
||||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
|
||||||
wandb_metrics["train/mean_correctness"] = sum(self._correctness_buffer) / n
|
|
||||||
wandb_metrics["train/mean_tool_usage"] = sum(self._tool_usage_buffer) / n
|
|
||||||
wandb_metrics["train/mean_efficiency"] = sum(self._efficiency_buffer) / n
|
|
||||||
wandb_metrics["train/mean_diversity"] = sum(self._diversity_buffer) / n
|
|
||||||
wandb_metrics["train/total_rollouts"] = n
|
|
||||||
|
|
||||||
# Accuracy buckets
|
|
||||||
wandb_metrics["train/correct_rate"] = (
|
|
||||||
sum(1 for c in self._correctness_buffer if c >= 0.7) / n
|
|
||||||
)
|
|
||||||
wandb_metrics["train/tool_usage_rate"] = (
|
|
||||||
sum(1 for t in self._tool_usage_buffer if t > 0) / n
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear buffers
|
|
||||||
self._reward_buffer.clear()
|
|
||||||
self._correctness_buffer.clear()
|
|
||||||
self._tool_usage_buffer.clear()
|
|
||||||
self._efficiency_buffer.clear()
|
|
||||||
self._diversity_buffer.clear()
|
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
# Private helpers
|
|
||||||
# ------------------------------------------------------------------
|
|
||||||
|
|
||||||
async def _llm_judge(
|
|
||||||
self,
|
|
||||||
question: str,
|
|
||||||
expected: str,
|
|
||||||
model_answer: str,
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Use the server's LLM to judge answer correctness.
|
|
||||||
Falls back to keyword heuristic if LLM call fails.
|
|
||||||
"""
|
|
||||||
if not model_answer or not model_answer.strip():
|
|
||||||
return 0.0
|
|
||||||
|
|
||||||
judge_prompt = (
|
|
||||||
"You are an impartial judge evaluating the quality of an AI research answer.\n\n"
|
|
||||||
f"Question: {question}\n\n"
|
|
||||||
f"Reference answer: {expected}\n\n"
|
|
||||||
f"Model answer: {model_answer}\n\n"
|
|
||||||
"Score the model answer on a scale from 0.0 to 1.0 where:\n"
|
|
||||||
" 1.0 = fully correct and complete\n"
|
|
||||||
" 0.7 = mostly correct with minor gaps\n"
|
|
||||||
" 0.4 = partially correct\n"
|
|
||||||
" 0.1 = mentions relevant topic but wrong or very incomplete\n"
|
|
||||||
" 0.0 = completely wrong or no answer\n\n"
|
|
||||||
"Consider: factual accuracy, completeness, and relevance.\n"
|
|
||||||
'Respond with ONLY a JSON object: {"score": <float>, "reason": "<one sentence>"}'
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await self.server.chat_completion(
|
|
||||||
messages=[{"role": "user", "content": judge_prompt}],
|
|
||||||
n=1,
|
|
||||||
max_tokens=150,
|
|
||||||
temperature=0.0,
|
|
||||||
split="eval",
|
|
||||||
)
|
|
||||||
text = response.choices[0].message.content if response.choices else ""
|
|
||||||
parsed = self._parse_judge_json(text)
|
|
||||||
if parsed is not None:
|
|
||||||
return float(parsed)
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"LLM judge failed: {e}. Using heuristic.")
|
|
||||||
|
|
||||||
return self._heuristic_score(expected, model_answer)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _parse_judge_json(text: str) -> Optional[float]:
|
|
||||||
"""Extract the score float from LLM judge JSON response."""
|
|
||||||
try:
|
|
||||||
clean = re.sub(r"```(?:json)?|```", "", text).strip()
|
|
||||||
data = json.loads(clean)
|
|
||||||
score = float(data.get("score", -1))
|
|
||||||
if 0.0 <= score <= 1.0:
|
|
||||||
return score
|
|
||||||
except Exception:
|
|
||||||
match = re.search(r'"score"\s*:\s*([0-9.]+)', text)
|
|
||||||
if match:
|
|
||||||
score = float(match.group(1))
|
|
||||||
if 0.0 <= score <= 1.0:
|
|
||||||
return score
|
|
||||||
return None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _heuristic_score(expected: str, model_answer: str) -> float:
|
|
||||||
"""Lightweight keyword overlap score as fallback."""
|
|
||||||
stopwords = {
|
|
||||||
"the", "a", "an", "is", "are", "was", "were", "of", "in", "on",
|
|
||||||
"at", "to", "for", "with", "and", "or", "but", "it", "its",
|
|
||||||
"this", "that", "as", "by", "from", "be", "has", "have", "had",
|
|
||||||
}
|
|
||||||
|
|
||||||
def tokenize(text: str) -> set:
|
|
||||||
tokens = re.findall(r'\b\w+\b', text.lower())
|
|
||||||
return {t for t in tokens if t not in stopwords and len(t) > 2}
|
|
||||||
|
|
||||||
expected_tokens = tokenize(expected)
|
|
||||||
answer_tokens = tokenize(model_answer)
|
|
||||||
|
|
||||||
if not expected_tokens:
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
overlap = len(expected_tokens & answer_tokens)
|
|
||||||
union = len(expected_tokens | answer_tokens)
|
|
||||||
|
|
||||||
jaccard = overlap / union if union > 0 else 0.0
|
|
||||||
recall = overlap / len(expected_tokens)
|
|
||||||
return min(1.0, 0.4 * jaccard + 0.6 * recall)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _extract_domains(text: str) -> set:
|
|
||||||
"""Extract unique domains from URLs cited in the response."""
|
|
||||||
urls = re.findall(r'https?://[^\s\)>\]"\']+', text)
|
|
||||||
domains = set()
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
parsed = urlparse(url)
|
|
||||||
domain = parsed.netloc.lower().lstrip("www.")
|
|
||||||
if domain:
|
|
||||||
domains.add(domain)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return domains
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Entry point
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
WebResearchEnv.cli()
|
|
||||||
|
|
@ -2138,22 +2138,6 @@ OPTIONAL_ENV_VARS = {
|
||||||
"password": True,
|
"password": True,
|
||||||
"category": "tool",
|
"category": "tool",
|
||||||
},
|
},
|
||||||
"TINKER_API_KEY": {
|
|
||||||
"description": "Tinker API key for RL training",
|
|
||||||
"prompt": "Tinker API key",
|
|
||||||
"url": "https://tinker-console.thinkingmachines.ai/keys",
|
|
||||||
"tools": ["rl_start_training", "rl_check_status", "rl_stop_training"],
|
|
||||||
"password": True,
|
|
||||||
"category": "tool",
|
|
||||||
},
|
|
||||||
"WANDB_API_KEY": {
|
|
||||||
"description": "Weights & Biases API key for experiment tracking",
|
|
||||||
"prompt": "WandB API key",
|
|
||||||
"url": "https://wandb.ai/authorize",
|
|
||||||
"tools": ["rl_get_results", "rl_check_status"],
|
|
||||||
"password": True,
|
|
||||||
"category": "tool",
|
|
||||||
},
|
|
||||||
"VOICE_TOOLS_OPENAI_KEY": {
|
"VOICE_TOOLS_OPENAI_KEY": {
|
||||||
"description": "OpenAI API key for voice transcription (Whisper) and OpenAI TTS",
|
"description": "OpenAI API key for voice transcription (Whisper) and OpenAI TTS",
|
||||||
"prompt": "OpenAI API Key (for Whisper STT + TTS)",
|
"prompt": "OpenAI API Key (for Whisper STT + TTS)",
|
||||||
|
|
@ -4990,8 +4974,7 @@ def set_config_value(key: str, value: str):
|
||||||
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
'FAL_KEY', 'TELEGRAM_BOT_TOKEN', 'DISCORD_BOT_TOKEN',
|
||||||
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
'TERMINAL_SSH_HOST', 'TERMINAL_SSH_USER', 'TERMINAL_SSH_KEY',
|
||||||
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
'SUDO_PASSWORD', 'SLACK_BOT_TOKEN', 'SLACK_APP_TOKEN',
|
||||||
'GITHUB_TOKEN', 'HONCHO_API_KEY', 'WANDB_API_KEY',
|
'GITHUB_TOKEN', 'HONCHO_API_KEY',
|
||||||
'TINKER_API_KEY',
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if key.upper() in api_keys or key.upper().endswith(('_API_KEY', '_TOKEN')) or key.upper().startswith('TERMINAL_SSH'):
|
if key.upper() in api_keys or key.upper().endswith(('_API_KEY', '_TOKEN')) or key.upper().startswith('TERMINAL_SSH'):
|
||||||
|
|
|
||||||
|
|
@ -1595,28 +1595,6 @@ def run_doctor(args):
|
||||||
for _issue in _r.issues:
|
for _issue in _r.issues:
|
||||||
issues.append(_issue)
|
issues.append(_issue)
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Check: Submodules
|
|
||||||
# =========================================================================
|
|
||||||
print()
|
|
||||||
print(color("◆ Submodules", Colors.CYAN, Colors.BOLD))
|
|
||||||
|
|
||||||
# tinker-atropos (RL training backend)
|
|
||||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
|
||||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
|
||||||
if py_version >= (3, 11):
|
|
||||||
try:
|
|
||||||
__import__("tinker_atropos")
|
|
||||||
check_ok("tinker-atropos", "(RL training backend)")
|
|
||||||
except ImportError:
|
|
||||||
install_cmd = f"{_python_install_cmd()} -e ./tinker-atropos"
|
|
||||||
check_warn("tinker-atropos found but not installed", f"(run: {install_cmd})")
|
|
||||||
issues.append(f"Install tinker-atropos: {install_cmd}")
|
|
||||||
else:
|
|
||||||
check_warn("tinker-atropos requires Python 3.11+", f"(current: {py_version.major}.{py_version.minor})")
|
|
||||||
else:
|
|
||||||
check_warn("tinker-atropos not found", "(run: git submodule update --init --recursive)")
|
|
||||||
|
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
# Check: Tool Availability
|
# Check: Tool Availability
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
|
||||||
|
|
@ -522,14 +522,6 @@ def _print_setup_summary(config: dict, hermes_home):
|
||||||
elif managed_nous_tools_enabled() and subscription_features.nous_auth_present:
|
elif managed_nous_tools_enabled() and subscription_features.nous_auth_present:
|
||||||
tool_status.append(("Modal Execution (optional via Nous subscription)", True, None))
|
tool_status.append(("Modal Execution (optional via Nous subscription)", True, None))
|
||||||
|
|
||||||
# Tinker + WandB (RL training)
|
|
||||||
if get_env_value("TINKER_API_KEY") and get_env_value("WANDB_API_KEY"):
|
|
||||||
tool_status.append(("RL Training (Tinker)", True, None))
|
|
||||||
elif get_env_value("TINKER_API_KEY"):
|
|
||||||
tool_status.append(("RL Training (Tinker)", False, "WANDB_API_KEY"))
|
|
||||||
else:
|
|
||||||
tool_status.append(("RL Training (Tinker)", False, "TINKER_API_KEY"))
|
|
||||||
|
|
||||||
# Home Assistant
|
# Home Assistant
|
||||||
if get_env_value("HASS_TOKEN"):
|
if get_env_value("HASS_TOKEN"):
|
||||||
tool_status.append(("Smart Home (Home Assistant)", True, None))
|
tool_status.append(("Smart Home (Home Assistant)", True, None))
|
||||||
|
|
|
||||||
|
|
@ -141,8 +141,6 @@ def show_status(args):
|
||||||
"Browser Use": "BROWSER_USE_API_KEY", # Optional — local browser works without this
|
"Browser Use": "BROWSER_USE_API_KEY", # Optional — local browser works without this
|
||||||
"Browserbase": "BROWSERBASE_API_KEY", # Optional — direct credentials only
|
"Browserbase": "BROWSERBASE_API_KEY", # Optional — direct credentials only
|
||||||
"FAL": "FAL_KEY",
|
"FAL": "FAL_KEY",
|
||||||
"Tinker": "TINKER_API_KEY",
|
|
||||||
"WandB": "WANDB_API_KEY",
|
|
||||||
"ElevenLabs": "ELEVENLABS_API_KEY",
|
"ElevenLabs": "ELEVENLABS_API_KEY",
|
||||||
"GitHub": "GITHUB_TOKEN",
|
"GitHub": "GITHUB_TOKEN",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,6 @@ CONFIGURABLE_TOOLSETS = [
|
||||||
("delegation", "👥 Task Delegation", "delegate_task"),
|
("delegation", "👥 Task Delegation", "delegate_task"),
|
||||||
("cronjob", "⏰ Cron Jobs", "create/list/update/pause/resume/run, with optional attached skills"),
|
("cronjob", "⏰ Cron Jobs", "create/list/update/pause/resume/run, with optional attached skills"),
|
||||||
("messaging", "📨 Cross-Platform Messaging", "send_message"),
|
("messaging", "📨 Cross-Platform Messaging", "send_message"),
|
||||||
("rl", "🧪 RL Training", "Tinker-Atropos training tools"),
|
|
||||||
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
("homeassistant", "🏠 Home Assistant", "smart home device control"),
|
||||||
("spotify", "🎵 Spotify", "playback, search, playlists, library"),
|
("spotify", "🎵 Spotify", "playback, search, playlists, library"),
|
||||||
("discord", "💬 Discord (read/participate)", "fetch messages, search members, create thread"),
|
("discord", "💬 Discord (read/participate)", "fetch messages, search members, create thread"),
|
||||||
|
|
@ -87,7 +86,7 @@ CONFIGURABLE_TOOLSETS = [
|
||||||
# Video gen is off by default — it's a niche, paid, slow feature. Users
|
# Video gen is off by default — it's a niche, paid, slow feature. Users
|
||||||
# who want it opt in via `hermes tools` → Video Generation, which walks
|
# who want it opt in via `hermes tools` → Video Generation, which walks
|
||||||
# them through provider + model selection.
|
# them through provider + model selection.
|
||||||
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "rl", "spotify", "discord", "discord_admin", "video", "video_gen"}
|
_DEFAULT_OFF_TOOLSETS = {"moa", "homeassistant", "spotify", "discord", "discord_admin", "video", "video_gen"}
|
||||||
|
|
||||||
# Platform-scoped toolsets: only appear in the `hermes tools` checklist for
|
# Platform-scoped toolsets: only appear in the `hermes tools` checklist for
|
||||||
# these platforms, and only resolve/save for these platforms. A toolset
|
# these platforms, and only resolve/save for these platforms. A toolset
|
||||||
|
|
@ -424,22 +423,6 @@ TOOL_CATEGORIES = {
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
"rl": {
|
|
||||||
"name": "RL Training",
|
|
||||||
"icon": "🧪",
|
|
||||||
"requires_python": (3, 11),
|
|
||||||
"providers": [
|
|
||||||
{
|
|
||||||
"name": "Tinker / Atropos",
|
|
||||||
"tag": "RL training platform",
|
|
||||||
"env_vars": [
|
|
||||||
{"key": "TINKER_API_KEY", "prompt": "Tinker API key", "url": "https://tinker-console.thinkingmachines.ai/keys"},
|
|
||||||
{"key": "WANDB_API_KEY", "prompt": "WandB API key", "url": "https://wandb.ai/authorize"},
|
|
||||||
],
|
|
||||||
"post_setup": "rl_training",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"langfuse": {
|
"langfuse": {
|
||||||
"name": "Langfuse Observability",
|
"name": "Langfuse Observability",
|
||||||
"icon": "📊",
|
"icon": "📊",
|
||||||
|
|
@ -912,24 +895,6 @@ def _run_post_setup(post_setup_key: str):
|
||||||
_print_warning(f" Spotify login failed: {exc}")
|
_print_warning(f" Spotify login failed: {exc}")
|
||||||
_print_info(" Run manually: hermes auth spotify")
|
_print_info(" Run manually: hermes auth spotify")
|
||||||
|
|
||||||
elif post_setup_key == "rl_training":
|
|
||||||
try:
|
|
||||||
__import__("tinker_atropos")
|
|
||||||
except ImportError:
|
|
||||||
tinker_dir = PROJECT_ROOT / "tinker-atropos"
|
|
||||||
if tinker_dir.exists() and (tinker_dir / "pyproject.toml").exists():
|
|
||||||
_print_info(" Installing tinker-atropos submodule...")
|
|
||||||
result = _pip_install(["-e", str(tinker_dir)])
|
|
||||||
if result.returncode == 0:
|
|
||||||
_print_success(" tinker-atropos installed")
|
|
||||||
else:
|
|
||||||
_print_warning(" tinker-atropos install failed - run manually:")
|
|
||||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
|
||||||
else:
|
|
||||||
_print_warning(" tinker-atropos submodule not found - run:")
|
|
||||||
_print_info(" git submodule update --init --recursive")
|
|
||||||
_print_info(' uv pip install -e "./tinker-atropos"')
|
|
||||||
|
|
||||||
elif post_setup_key == "langfuse":
|
elif post_setup_key == "langfuse":
|
||||||
# Install the langfuse SDK.
|
# Install the langfuse SDK.
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -97,9 +97,7 @@ def _run_async(coro):
|
||||||
asyncio.run()'s create-and-destroy lifecycle.
|
asyncio.run()'s create-and-destroy lifecycle.
|
||||||
|
|
||||||
This is the single source of truth for sync->async bridging in tool
|
This is the single source of truth for sync->async bridging in tool
|
||||||
handlers. The RL paths (agent_loop.py, tool_context.py) also provide
|
handlers. Each handler is self-protecting via this function.
|
||||||
outer thread-pool wrapping as defense-in-depth, but each handler is
|
|
||||||
self-protecting via this function.
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|
@ -231,13 +229,6 @@ _LEGACY_TOOLSET_MAP = {
|
||||||
"browser_vision", "browser_console"
|
"browser_vision", "browser_console"
|
||||||
],
|
],
|
||||||
"cronjob_tools": ["cronjob"],
|
"cronjob_tools": ["cronjob"],
|
||||||
"rl_tools": [
|
|
||||||
"rl_list_environments", "rl_select_environment",
|
|
||||||
"rl_get_current_config", "rl_edit_config",
|
|
||||||
"rl_start_training", "rl_check_status",
|
|
||||||
"rl_stop_training", "rl_get_results",
|
|
||||||
"rl_list_runs", "rl_test_inference"
|
|
||||||
],
|
|
||||||
"file_tools": ["read_file", "write_file", "patch", "search_files"],
|
"file_tools": ["read_file", "write_file", "patch", "search_files"],
|
||||||
"tts_tools": ["text_to_speech"],
|
"tts_tools": ["text_to_speech"],
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -192,7 +192,6 @@ stdenv.mkDerivation {
|
||||||
source .venv/bin/activate
|
source .venv/bin/activate
|
||||||
uv pip install -e ".[all]"
|
uv pip install -e ".[all]"
|
||||||
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
[ -d mini-swe-agent ] && uv pip install -e ./mini-swe-agent 2>/dev/null || true
|
||||||
[ -d tinker-atropos ] && uv pip install -e ./tinker-atropos 2>/dev/null || true
|
|
||||||
mkdir -p .nix-stamps
|
mkdir -p .nix-stamps
|
||||||
echo "$STAMP_VALUE" > "$STAMP"
|
echo "$STAMP_VALUE" > "$STAMP"
|
||||||
else
|
else
|
||||||
|
|
|
||||||
|
|
@ -1,303 +0,0 @@
|
||||||
---
|
|
||||||
name: hermes-atropos-environments
|
|
||||||
description: Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
|
|
||||||
version: 1.1.0
|
|
||||||
author: Hermes Agent
|
|
||||||
license: MIT
|
|
||||||
platforms: [linux, macos, windows]
|
|
||||||
metadata:
|
|
||||||
hermes:
|
|
||||||
tags: [atropos, rl, environments, training, reinforcement-learning, reward-functions]
|
|
||||||
related_skills: [axolotl, fine-tuning-with-trl, lm-evaluation-harness]
|
|
||||||
---
|
|
||||||
|
|
||||||
# Hermes Agent Atropos Environments
|
|
||||||
|
|
||||||
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
```
|
|
||||||
Atropos BaseEnv (atroposlib/envs/base.py)
|
|
||||||
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
|
|
||||||
├── Handles agent loop orchestration
|
|
||||||
├── Handles tool resolution per group
|
|
||||||
├── Handles ToolContext for reward verification
|
|
||||||
└── YOUR ENVIRONMENT (environments/your_env.py)
|
|
||||||
Only implements: setup, get_next_item, format_prompt,
|
|
||||||
compute_reward, evaluate, wandb_log
|
|
||||||
```
|
|
||||||
|
|
||||||
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
|
|
||||||
|
|
||||||
## File Locations
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|------|---------|
|
|
||||||
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
|
|
||||||
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
|
|
||||||
| `environments/tool_context.py` | `ToolContext` for reward verification |
|
|
||||||
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
|
|
||||||
| `environments/your_env.py` | Your environment implementation |
|
|
||||||
|
|
||||||
## Inference Setup — Ask the User First
|
|
||||||
|
|
||||||
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
|
|
||||||
|
|
||||||
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
|
|
||||||
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
|
|
||||||
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
|
|
||||||
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
|
|
||||||
|
|
||||||
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
|
|
||||||
|
|
||||||
> "Before I run this, how would you like to handle inference?
|
|
||||||
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
|
|
||||||
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
|
|
||||||
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
|
|
||||||
> 4. Local Atropos training server (serve mode)"
|
|
||||||
|
|
||||||
### Key flags by provider:
|
|
||||||
|
|
||||||
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|
|
||||||
|----------|----------------------|------------------------|-------------------|
|
|
||||||
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
|
|
||||||
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
|
|
||||||
| Other OpenAI-compatible | `openai` | `false` | As needed |
|
|
||||||
| Local Atropos | (default) | (default) | (not needed) |
|
|
||||||
|
|
||||||
## Required Methods
|
|
||||||
|
|
||||||
### 1. `setup()` — Load dataset and initialize state
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def setup(self) -> None:
|
|
||||||
"""Called once at startup. Load datasets, initialize state."""
|
|
||||||
# Try HuggingFace first, fallback to built-in samples
|
|
||||||
try:
|
|
||||||
from datasets import load_dataset
|
|
||||||
ds = load_dataset("your/dataset", split="test")
|
|
||||||
self._items = [...]
|
|
||||||
except Exception:
|
|
||||||
self._items = BUILTIN_SAMPLES
|
|
||||||
|
|
||||||
# Always split into train/eval
|
|
||||||
random.shuffle(self._items)
|
|
||||||
eval_size = max(20, int(len(self._items) * 0.1))
|
|
||||||
self._eval_items = self._items[:eval_size]
|
|
||||||
self._items = self._items[eval_size:]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. `get_next_item()` — Return next training item
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def get_next_item(self) -> dict:
|
|
||||||
"""Return next item, cycling through dataset."""
|
|
||||||
item = self._items[self._index % len(self._items)]
|
|
||||||
self._index += 1
|
|
||||||
return item
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. `format_prompt(item)` — Convert item to user message
|
|
||||||
|
|
||||||
```python
|
|
||||||
def format_prompt(self, item: dict) -> str:
|
|
||||||
"""Convert a dataset item into the user-facing prompt."""
|
|
||||||
return f"Research this question: {item['question']}"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. `compute_reward(item, result, ctx)` — Score the rollout
|
|
||||||
|
|
||||||
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
|
|
||||||
- `result.messages` — List of message dicts (OpenAI format)
|
|
||||||
- `result.turns_used` — Number of LLM calls made
|
|
||||||
- `result.finished_naturally` — True if model stopped voluntarily
|
|
||||||
- `result.tool_errors` — List of ToolError objects
|
|
||||||
|
|
||||||
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
|
|
||||||
You must extract these from `result.messages`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
|
|
||||||
# Extract final response (last assistant message with content)
|
|
||||||
final_response = ""
|
|
||||||
tools_used = []
|
|
||||||
for msg in reversed(result.messages):
|
|
||||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
|
||||||
final_response = msg["content"]
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
||||||
name = fn.get("name", "")
|
|
||||||
if name:
|
|
||||||
tools_used.append(name)
|
|
||||||
|
|
||||||
# Score using LLM judge, heuristic, or ToolContext verification
|
|
||||||
correctness = await self._llm_judge(item, final_response)
|
|
||||||
return correctness
|
|
||||||
```
|
|
||||||
|
|
||||||
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
|
|
||||||
```python
|
|
||||||
# Run tests in the agent's sandbox
|
|
||||||
result = ctx.terminal("pytest /workspace/test.py")
|
|
||||||
return 1.0 if result["exit_code"] == 0 else 0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. `evaluate()` — Periodic evaluation with full agent loop
|
|
||||||
|
|
||||||
**MUST use the full agent loop with tools**, not single-turn chat_completion.
|
|
||||||
The whole point of hermes-agent environments is agentic evaluation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def evaluate(self, *args, **kwargs) -> None:
|
|
||||||
import time, uuid
|
|
||||||
from environments.agent_loop import HermesAgentLoop
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
tools, valid_names = self._resolve_tools_for_group()
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
for item in self._eval_items[:self.config.eval_size]:
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
messages = []
|
|
||||||
if self.config.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
|
||||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
|
||||||
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=0.0, # Deterministic for eval
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
ctx = ToolContext(task_id)
|
|
||||||
try:
|
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
|
||||||
finally:
|
|
||||||
ctx.cleanup()
|
|
||||||
|
|
||||||
samples.append({"prompt": ..., "response": ..., "reward": reward})
|
|
||||||
|
|
||||||
eval_metrics = {"eval/mean_reward": ...}
|
|
||||||
await self.evaluate_log(metrics=eval_metrics, samples=samples,
|
|
||||||
start_time=start_time, end_time=time.time())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. `wandb_log()` — Custom metrics logging
|
|
||||||
|
|
||||||
Always call `super().wandb_log()` at the end:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def wandb_log(self, wandb_metrics=None):
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
if self._reward_buffer:
|
|
||||||
n = len(self._reward_buffer)
|
|
||||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
|
||||||
self._reward_buffer.clear()
|
|
||||||
await super().wandb_log(wandb_metrics) # MUST call super
|
|
||||||
```
|
|
||||||
|
|
||||||
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
|
|
||||||
|
|
||||||
## Config Class
|
|
||||||
|
|
||||||
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
|
|
||||||
|
|
||||||
## config_init() — Default Configuration
|
|
||||||
|
|
||||||
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
|
|
||||||
|
|
||||||
## Three CLI Modes
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# SERVE — Full training loop (connects to Atropos API server)
|
|
||||||
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
|
|
||||||
|
|
||||||
# PROCESS — Offline data generation (saves JSONL)
|
|
||||||
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
|
|
||||||
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
|
|
||||||
--openai.base_url "<USER_BASE_URL>" \
|
|
||||||
--openai.model_name "<USER_MODEL>" \
|
|
||||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
|
||||||
|
|
||||||
# EVALUATE — Standalone eval (runs setup + evaluate only)
|
|
||||||
python environments/my_env.py evaluate --env.eval_size 20 \
|
|
||||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
|
||||||
--openai.base_url "<USER_BASE_URL>" \
|
|
||||||
--openai.model_name "<USER_MODEL>" \
|
|
||||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
|
||||||
```
|
|
||||||
|
|
||||||
Config priority: CLI args > YAML file > config_init() defaults.
|
|
||||||
|
|
||||||
## Common Pitfalls
|
|
||||||
|
|
||||||
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
|
|
||||||
|
|
||||||
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
|
|
||||||
|
|
||||||
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
|
|
||||||
|
|
||||||
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
|
|
||||||
|
|
||||||
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
|
|
||||||
|
|
||||||
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
|
|
||||||
|
|
||||||
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
|
|
||||||
|
|
||||||
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
|
|
||||||
|
|
||||||
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
|
|
||||||
|
|
||||||
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
|
|
||||||
|
|
||||||
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
|
|
||||||
|
|
||||||
## Reward Function Patterns
|
|
||||||
|
|
||||||
### LLM Judge (for open-ended tasks)
|
|
||||||
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
|
|
||||||
|
|
||||||
### Binary Verification (for code/terminal tasks)
|
|
||||||
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
|
|
||||||
|
|
||||||
### Multi-Signal (combine multiple indicators)
|
|
||||||
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
|
|
||||||
|
|
||||||
## Testing Your Environment
|
|
||||||
|
|
||||||
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
|
|
||||||
2. **Ask the user for inference setup** (see "Inference Setup" section above)
|
|
||||||
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
|
|
||||||
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
|
|
||||||
5. **Check reward range**: Scores should be in [0, 1], not all identical
|
|
||||||
|
|
||||||
## Minimum Implementation Checklist
|
|
||||||
|
|
||||||
```python
|
|
||||||
class MyEnv(HermesAgentBaseEnv):
|
|
||||||
name = "my-env"
|
|
||||||
env_config_cls = MyEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls): ... # Default server + env config
|
|
||||||
async def setup(self): ... # Load dataset + train/eval split
|
|
||||||
async def get_next_item(self): ... # Cycle through training items
|
|
||||||
def format_prompt(self, item): ... # Item → user message string
|
|
||||||
async def compute_reward(self, item, result, ctx): ... # Score rollout
|
|
||||||
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
|
|
||||||
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyEnv.cli()
|
|
||||||
```
|
|
||||||
|
|
@ -1,59 +0,0 @@
|
||||||
# AgentResult Fields Reference
|
|
||||||
|
|
||||||
`AgentResult` is defined in `environments/agent_loop.py` as a dataclass.
|
|
||||||
|
|
||||||
## Fields
|
|
||||||
|
|
||||||
| Field | Type | Description |
|
|
||||||
|-------|------|-------------|
|
|
||||||
| `messages` | `List[Dict[str, Any]]` | Full conversation history in OpenAI message format |
|
|
||||||
| `managed_state` | `Optional[Dict]` | ManagedServer.get_state() if Phase 2, else None |
|
|
||||||
| `turns_used` | `int` | Number of LLM calls made during the loop |
|
|
||||||
| `finished_naturally` | `bool` | True if model stopped calling tools on its own |
|
|
||||||
| `reasoning_per_turn` | `List[Optional[str]]` | Extracted reasoning content per turn |
|
|
||||||
| `tool_errors` | `List[ToolError]` | Tool errors encountered during the loop |
|
|
||||||
|
|
||||||
## ToolError Fields
|
|
||||||
|
|
||||||
| Field | Type | Description |
|
|
||||||
|-------|------|-------------|
|
|
||||||
| `turn` | `int` | Which turn the error occurred |
|
|
||||||
| `tool_name` | `str` | Name of the tool that failed |
|
|
||||||
| `arguments` | `str` | Arguments passed to the tool |
|
|
||||||
| `error` | `str` | Error message |
|
|
||||||
| `tool_result` | `str` | The result returned to the model |
|
|
||||||
|
|
||||||
## Extracting Data from Messages
|
|
||||||
|
|
||||||
Messages follow OpenAI format. Common patterns:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Get final assistant response
|
|
||||||
for msg in reversed(result.messages):
|
|
||||||
if msg.get("role") == "assistant" and msg.get("content"):
|
|
||||||
final_response = msg["content"]
|
|
||||||
break
|
|
||||||
|
|
||||||
# Get all tool names used
|
|
||||||
tools = []
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
||||||
tools.append(fn.get("name", ""))
|
|
||||||
|
|
||||||
# Get tool results
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tool_output = msg.get("content", "")
|
|
||||||
call_id = msg.get("tool_call_id", "")
|
|
||||||
```
|
|
||||||
|
|
||||||
## Fields that DO NOT EXIST
|
|
||||||
|
|
||||||
These are common mistakes — AgentResult does NOT have:
|
|
||||||
- `final_response` — extract from messages
|
|
||||||
- `tool_calls` — extract from messages
|
|
||||||
- `tools_used` — extract from messages
|
|
||||||
- `output` — extract from messages
|
|
||||||
- `response` — extract from messages
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
||||||
# Atropos BaseEnv Reference
|
|
||||||
|
|
||||||
Source: `atroposlib/envs/base.py` (~2124 lines)
|
|
||||||
|
|
||||||
## Abstract Methods (MUST implement)
|
|
||||||
|
|
||||||
| Method | Signature | Description |
|
|
||||||
|--------|-----------|-------------|
|
|
||||||
| `get_next_item()` | `async def get_next_item(self) -> Item` | Return next item for trajectory. Return None to pause. |
|
|
||||||
| `evaluate()` | `async def evaluate(self, *args, **kwargs)` | Called every steps_per_eval steps. |
|
|
||||||
| `setup()` | `async def setup(self)` | Called once at start. Load datasets, init models. |
|
|
||||||
| `collect_trajectory()` | `async def collect_trajectory(self, item) -> Tuple[Optional[ScoredDataItem], List[Item]]` | Single rollout. Or override collect_trajectories instead. |
|
|
||||||
|
|
||||||
## Overridable Methods
|
|
||||||
|
|
||||||
| Method | Default Behavior | Override When |
|
|
||||||
|--------|-----------------|---------------|
|
|
||||||
| `collect_trajectories()` | Runs collect_trajectory group_size times in parallel | Batch generation, MCTS, coupled rollouts |
|
|
||||||
| `wandb_log()` | Logs completion lengths, rollout table, perf stats | Add custom metrics (always call super) |
|
|
||||||
| `config_init()` | Returns (env_config_cls(), ServerBaseline()) | Custom defaults + server configs |
|
|
||||||
| `postprocess_histories()` | Passthrough | Final processing before sending to trainer |
|
|
||||||
| `save_checkpoint()` | Saves JSON to checkpoint_dir | Custom serialization |
|
|
||||||
| `cleanup()` | No-op | Release resources after each rollout |
|
|
||||||
|
|
||||||
## ScoredDataGroup Structure
|
|
||||||
|
|
||||||
```python
|
|
||||||
ScoredDataGroup = TypedDict with:
|
|
||||||
tokens: List[List[int]] # Token IDs per rollout
|
|
||||||
masks: List[List[int]] # -100=prompt, token_id=completion
|
|
||||||
scores: List[float] # Score per rollout
|
|
||||||
advantages: Optional[...] # Per-token advantages
|
|
||||||
ref_logprobs: Optional[...] # Reference model logprobs
|
|
||||||
messages: Optional[...] # OpenAI-format messages
|
|
||||||
inference_logprobs: Optional[...] # Inference logprobs
|
|
||||||
```
|
|
||||||
|
|
||||||
## BaseEnvConfig Key Fields
|
|
||||||
|
|
||||||
| Field | Default | Description |
|
|
||||||
|-------|---------|-------------|
|
|
||||||
| `group_size` | 4 | Responses grouped for scoring |
|
|
||||||
| `steps_per_eval` | 100 | Steps between evaluations |
|
|
||||||
| `max_token_length` | 2048 | Max token length for generations |
|
|
||||||
| `total_steps` | 1000 | Total training steps |
|
|
||||||
| `use_wandb` | True | Enable wandb logging |
|
|
||||||
| `tokenizer_name` | DeepHermes-3 | Tokenizer for token encoding |
|
|
||||||
| `ensure_scores_are_not_same` | True | Skip groups with identical scores |
|
|
||||||
| `worker_timeout` | 600 | Task timeout seconds |
|
|
||||||
|
|
||||||
## Data Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
env_manager() → add_train_workers() → handle_env()
|
|
||||||
→ collect_trajectories() → postprocess_histories()
|
|
||||||
→ handle_send_to_api() → training server
|
|
||||||
```
|
|
||||||
|
|
||||||
## Atropos Environment Statistics (82 environments analyzed)
|
|
||||||
|
|
||||||
- 95% implement setup, collect_trajectories, evaluate, get_next_item
|
|
||||||
- 76% override wandb_log
|
|
||||||
- 54% have custom config class
|
|
||||||
- Most use collect_trajectories (plural), not collect_trajectory (singular)
|
|
||||||
- Common reward patterns: LLM-judge (~40), regex-extract (~35), code-exec (~12)
|
|
||||||
|
|
@ -1,199 +0,0 @@
|
||||||
# Usage Patterns — Testing Environments and Evaluating Models
|
|
||||||
|
|
||||||
## Pattern 1: Test Your Environment Works (process mode)
|
|
||||||
|
|
||||||
Use `process` mode to verify your environment runs end-to-end before
|
|
||||||
committing. This generates trajectories without needing an Atropos
|
|
||||||
training server.
|
|
||||||
|
|
||||||
**Before running:** Ask the user for their inference setup (see SKILL.md "Inference Setup" section). Replace `<BASE_URL>`, `<MODEL>`, and `<SERVER_TYPE>` below with their chosen values.
|
|
||||||
|
|
||||||
### Step 1: Run 1 trajectory
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd ~/.hermes/hermes-agent
|
|
||||||
source venv/bin/activate
|
|
||||||
|
|
||||||
python environments/your_env.py process \
|
|
||||||
--env.total_steps 1 \
|
|
||||||
--env.group_size 1 \
|
|
||||||
--env.use_wandb false \
|
|
||||||
--env.data_path_to_save_groups /tmp/test_output.jsonl \
|
|
||||||
--openai.base_url "<BASE_URL>" \
|
|
||||||
--openai.model_name "<MODEL>" \
|
|
||||||
--openai.server_type <SERVER_TYPE> \
|
|
||||||
--openai.health_check false
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Verify the output
|
|
||||||
|
|
||||||
```python
|
|
||||||
import json
|
|
||||||
for line in open("/tmp/test_output.jsonl"):
|
|
||||||
data = json.loads(line)
|
|
||||||
print(f"Scores: {data.get('scores', [])}")
|
|
||||||
print(f"Token sequences: {len(data.get('tokens', []))}")
|
|
||||||
# Check messages include tool calls
|
|
||||||
for msg_list in data.get("messages", []):
|
|
||||||
roles = [m.get("role") for m in msg_list]
|
|
||||||
print(f"Roles: {roles}")
|
|
||||||
for m in reversed(msg_list):
|
|
||||||
if m.get("role") == "assistant" and m.get("content"):
|
|
||||||
print(f"Response: {m['content'][:200]}...")
|
|
||||||
break
|
|
||||||
```
|
|
||||||
|
|
||||||
### What to check:
|
|
||||||
- **Scores are not all 0.0** — if so, compute_reward is broken
|
|
||||||
- **Scores are in [0, 1]** — not negative, not >1
|
|
||||||
- **Messages include "tool" role entries** — agent used tools
|
|
||||||
- **Token sequences are non-empty**
|
|
||||||
- **An HTML visualization is generated** next to the .jsonl
|
|
||||||
|
|
||||||
### Common failures:
|
|
||||||
- `'AgentResult' object has no attribute 'X'` — accessing a field that doesn't exist. See agentresult-fields.md.
|
|
||||||
- Score always 0.0 — reward function erroring silently
|
|
||||||
- Score always 1.0 — verification too lenient or not running
|
|
||||||
|
|
||||||
|
|
||||||
## Pattern 2: Evaluate a Model (evaluate mode)
|
|
||||||
|
|
||||||
Use `evaluate` mode to benchmark a model on your environment's eval
|
|
||||||
split. This runs the full agent loop with tools for each eval item.
|
|
||||||
|
|
||||||
### Step 1: Run evaluation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/your_env.py evaluate \
|
|
||||||
--env.eval_size 20 \
|
|
||||||
--env.use_wandb false \
|
|
||||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
|
||||||
--openai.base_url "<BASE_URL>" \
|
|
||||||
--openai.model_name "<MODEL>" \
|
|
||||||
--openai.server_type <SERVER_TYPE> \
|
|
||||||
--openai.health_check false
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Read results
|
|
||||||
|
|
||||||
Stdout shows a lighteval-compatible table:
|
|
||||||
|
|
||||||
```
|
|
||||||
Evaluation Results: your-env_eval
|
|
||||||
|Metric | Value|
|
|
||||||
|mean correctness| 0.850 |
|
|
||||||
|mean reward | 0.920 |
|
|
||||||
|mean tool calls | 4.300 |
|
|
||||||
|n items | 20 |
|
|
||||||
Evaluation completed in 367 seconds
|
|
||||||
```
|
|
||||||
|
|
||||||
JSON results saved to the eval directory:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import json
|
|
||||||
data = json.load(open("/tmp/eval_results/metrics.json"))
|
|
||||||
for metric, value in data["results"]["all"].items():
|
|
||||||
print(f"{metric}: {value}")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Compare models
|
|
||||||
|
|
||||||
Run evaluate with different models and compare the metrics.json files.
|
|
||||||
|
|
||||||
### What to check:
|
|
||||||
- **"data_dir_to_save_evals is not set"** — you forgot the flag, results won't be saved
|
|
||||||
- **Tool usage rate = 0** — evaluate() is using chat_completion instead of HermesAgentLoop
|
|
||||||
- **All scores identical** — judge failing, falling back to heuristic
|
|
||||||
- **Very slow** — each item runs a full agent loop (~30-90s). Use `--env.eval_size 5` for quick checks.
|
|
||||||
|
|
||||||
|
|
||||||
## Pattern 3: Generate Training Data (process mode, larger scale)
|
|
||||||
|
|
||||||
Generate trajectory data for offline training or analysis:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/your_env.py process \
|
|
||||||
--env.total_steps 50 \
|
|
||||||
--env.group_size 4 \
|
|
||||||
--env.use_wandb false \
|
|
||||||
--env.data_path_to_save_groups data/trajectories.jsonl \
|
|
||||||
--openai.base_url "<BASE_URL>" \
|
|
||||||
--openai.model_name "<MODEL>" \
|
|
||||||
--openai.server_type <SERVER_TYPE> \
|
|
||||||
--openai.health_check false
|
|
||||||
```
|
|
||||||
|
|
||||||
### Analyze the distribution:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import json
|
|
||||||
scores = []
|
|
||||||
for line in open("data/trajectories.jsonl"):
|
|
||||||
data = json.loads(line)
|
|
||||||
scores.extend(data.get("scores", []))
|
|
||||||
|
|
||||||
print(f"Total: {len(scores)}, Mean: {sum(scores)/len(scores):.3f}")
|
|
||||||
for bucket in [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]:
|
|
||||||
count = sum(1 for s in scores if abs(s - bucket) < 0.1)
|
|
||||||
print(f" {bucket:.1f}: {'█' * count} ({count})")
|
|
||||||
```
|
|
||||||
|
|
||||||
### What to check:
|
|
||||||
- **Score distribution has variance** — RL needs score variance. All-same scores are useless.
|
|
||||||
|
|
||||||
|
|
||||||
## Pattern 4: Full RL Training (serve mode)
|
|
||||||
|
|
||||||
For actual RL training with Atropos:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Terminal 1: Start Atropos API server
|
|
||||||
run-api
|
|
||||||
|
|
||||||
# Terminal 2: Start your environment
|
|
||||||
python environments/your_env.py serve \
|
|
||||||
--config environments/your_env/default.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
For Phase 2 with VLLM:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Terminal 1: VLLM server
|
|
||||||
python -m vllm.entrypoints.openai.api_server --model your-model --port 8000
|
|
||||||
|
|
||||||
# Terminal 2: Atropos API
|
|
||||||
run-api
|
|
||||||
|
|
||||||
# Terminal 3: Environment
|
|
||||||
python environments/your_env.py serve \
|
|
||||||
--openai.base_url http://localhost:8000/v1 \
|
|
||||||
--openai.model_name your-model \
|
|
||||||
--openai.server_type vllm
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Pattern 5: Quick Smoke Test
|
|
||||||
|
|
||||||
Verify imports and config before spending money on API calls:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from environments.your_env import YourEnv
|
|
||||||
print(f"Name: {YourEnv.name}")
|
|
||||||
cfg, servers = YourEnv.config_init()
|
|
||||||
print(f"Toolsets: {cfg.enabled_toolsets}")
|
|
||||||
print(f"Server: {servers[0].model_name}")
|
|
||||||
print("All imports OK")
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
## Timing Expectations
|
|
||||||
|
|
||||||
| Mode | Items | Time per item | Total |
|
|
||||||
|------|-------|--------------|-------|
|
|
||||||
| process (1 item) | 1 | 30-90s | ~1 min |
|
|
||||||
| evaluate (5 items) | 5 | 30-90s | ~5 min |
|
|
||||||
| evaluate (20 items) | 20 | 30-90s | ~15-30 min |
|
|
||||||
| process (50 items) | 50 | 30-90s | ~30-75 min |
|
|
||||||
|
|
||||||
Times are for cloud APIs with Claude Sonnet-class models. Local models may be faster or slower depending on hardware.
|
|
||||||
|
|
@ -166,14 +166,6 @@ youtube = [
|
||||||
]
|
]
|
||||||
# `hermes dashboard` (localhost SPA + API). Not in core to keep the default install lean.
|
# `hermes dashboard` (localhost SPA + API). Not in core to keep the default install lean.
|
||||||
web = ["fastapi==0.133.1", "uvicorn[standard]==0.41.0"]
|
web = ["fastapi==0.133.1", "uvicorn[standard]==0.41.0"]
|
||||||
rl = [
|
|
||||||
"atroposlib @ git+https://github.com/NousResearch/atropos.git@c20c85256e5a45ad31edf8b7276e9c5ee1995a30",
|
|
||||||
"tinker @ git+https://github.com/thinking-machines-lab/tinker.git@30517b667f18a3dfb7ef33fb56cf686d5820ba2b",
|
|
||||||
"fastapi==0.133.1",
|
|
||||||
"uvicorn[standard]==0.41.0",
|
|
||||||
"wandb==0.25.1",
|
|
||||||
]
|
|
||||||
yc-bench = ["yc-bench @ git+https://github.com/collinear-ai/yc-bench.git@bfb0c88062450f46341bd9a5298903fc2e952a5c ; python_version >= '3.12'"]
|
|
||||||
all = [
|
all = [
|
||||||
# Policy (2026-05-12): `[all]` includes only extras that genuinely
|
# Policy (2026-05-12): `[all]` includes only extras that genuinely
|
||||||
# CAN'T be lazy-installed via `tools/lazy_deps.py` — i.e. things every
|
# CAN'T be lazy-installed via `tools/lazy_deps.py` — i.e. things every
|
||||||
|
|
@ -215,7 +207,7 @@ hermes-agent = "run_agent:main"
|
||||||
hermes-acp = "acp_adapter.entry:main"
|
hermes-acp = "acp_adapter.entry:main"
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli", "hermes_bootstrap", "hermes_constants", "hermes_state", "hermes_time", "hermes_logging", "rl_cli", "utils"]
|
py-modules = ["run_agent", "model_tools", "toolsets", "batch_runner", "trajectory_compressor", "toolset_distributions", "cli", "hermes_bootstrap", "hermes_constants", "hermes_state", "hermes_time", "hermes_logging", "utils"]
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
[tool.setuptools.package-data]
|
||||||
hermes_cli = ["web_dist/**/*"]
|
hermes_cli = ["web_dist/**/*"]
|
||||||
|
|
@ -238,11 +230,7 @@ python-version = "3.13"
|
||||||
unknown-argument = "warn"
|
unknown-argument = "warn"
|
||||||
redundant-cast = "ignore"
|
redundant-cast = "ignore"
|
||||||
|
|
||||||
[tool.ty.src]
|
|
||||||
exclude = ["tinker-atropos"]
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
exclude = ["tinker-atropos"]
|
|
||||||
preview = true # required for PLW1514 (unspecified-encoding) — preview rule
|
preview = true # required for PLW1514 (unspecified-encoding) — preview rule
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
|
|
|
||||||
446
rl_cli.py
446
rl_cli.py
|
|
@ -1,446 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
RL Training CLI Runner
|
|
||||||
|
|
||||||
Dedicated CLI runner for RL training workflows with:
|
|
||||||
- Extended timeouts for long-running training
|
|
||||||
- RL-focused system prompts
|
|
||||||
- Full toolset including RL training tools
|
|
||||||
- Special handling for 30-minute check intervals
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python rl_cli.py "Train a model on GSM8k for math reasoning"
|
|
||||||
python rl_cli.py --interactive
|
|
||||||
python rl_cli.py --list-environments
|
|
||||||
|
|
||||||
Environment Variables:
|
|
||||||
TINKER_API_KEY: API key for Tinker service (required)
|
|
||||||
WANDB_API_KEY: API key for WandB metrics (required)
|
|
||||||
OPENROUTER_API_KEY: API key for OpenRouter (required for agent)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import fire
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from hermes_constants import OPENROUTER_BASE_URL, get_hermes_home
|
|
||||||
|
|
||||||
# Load .env from ~/.hermes/.env first, then project root as dev fallback.
|
|
||||||
# User-managed env files should override stale shell exports on restart.
|
|
||||||
_hermes_home = get_hermes_home()
|
|
||||||
_project_env = Path(__file__).parent / '.env'
|
|
||||||
|
|
||||||
from hermes_cli.env_loader import load_hermes_dotenv
|
|
||||||
|
|
||||||
_loaded_env_paths = load_hermes_dotenv(hermes_home=_hermes_home, project_env=_project_env)
|
|
||||||
for _env_path in _loaded_env_paths:
|
|
||||||
print(f"✅ Loaded environment variables from {_env_path}")
|
|
||||||
|
|
||||||
# Set terminal working directory to tinker-atropos submodule
|
|
||||||
# This ensures terminal commands run in the right context for RL work
|
|
||||||
tinker_atropos_dir = Path(__file__).parent / 'tinker-atropos'
|
|
||||||
if tinker_atropos_dir.exists():
|
|
||||||
os.environ['TERMINAL_CWD'] = str(tinker_atropos_dir)
|
|
||||||
os.environ['HERMES_QUIET'] = '1' # Disable temp subdirectory creation
|
|
||||||
print(f"📂 Terminal working directory: {tinker_atropos_dir}")
|
|
||||||
else:
|
|
||||||
# Fall back to hermes-agent directory if submodule not found
|
|
||||||
os.environ['TERMINAL_CWD'] = str(Path(__file__).parent)
|
|
||||||
os.environ['HERMES_QUIET'] = '1'
|
|
||||||
print(f"⚠️ tinker-atropos submodule not found, using: {Path(__file__).parent}")
|
|
||||||
|
|
||||||
# Import agent and tools
|
|
||||||
from run_agent import AIAgent
|
|
||||||
from tools.rl_training_tool import get_missing_keys
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Config Loading
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
DEFAULT_MODEL = "anthropic/claude-opus-4.5"
|
|
||||||
DEFAULT_BASE_URL = OPENROUTER_BASE_URL
|
|
||||||
|
|
||||||
|
|
||||||
def load_hermes_config() -> dict:
|
|
||||||
"""
|
|
||||||
Load configuration from ~/.hermes/config.yaml.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Configuration with model, base_url, etc.
|
|
||||||
"""
|
|
||||||
config_path = _hermes_home / 'config.yaml'
|
|
||||||
|
|
||||||
config = {
|
|
||||||
"model": DEFAULT_MODEL,
|
|
||||||
"base_url": DEFAULT_BASE_URL,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config_path.exists():
|
|
||||||
try:
|
|
||||||
with open(config_path, "r", encoding='utf-8') as f:
|
|
||||||
file_config = yaml.safe_load(f) or {}
|
|
||||||
|
|
||||||
# Get model from config
|
|
||||||
if "model" in file_config:
|
|
||||||
if isinstance(file_config["model"], str):
|
|
||||||
config["model"] = file_config["model"]
|
|
||||||
elif isinstance(file_config["model"], dict):
|
|
||||||
config["model"] = file_config["model"].get("default", DEFAULT_MODEL)
|
|
||||||
|
|
||||||
# Get base_url if specified
|
|
||||||
if "base_url" in file_config:
|
|
||||||
config["base_url"] = file_config["base_url"]
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ Warning: Failed to load config.yaml: {e}")
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# RL-Specific Configuration
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
# Extended timeouts for long-running RL operations
|
|
||||||
RL_MAX_ITERATIONS = 200 # Allow many more iterations for long workflows
|
|
||||||
|
|
||||||
# RL-focused system prompt
|
|
||||||
RL_SYSTEM_PROMPT = """You are an automated post-training engineer specializing in reinforcement learning for language models.
|
|
||||||
|
|
||||||
## Your Capabilities
|
|
||||||
|
|
||||||
You have access to RL training tools for running reinforcement learning on models through Tinker-Atropos:
|
|
||||||
|
|
||||||
1. **DISCOVER**: Use `rl_list_environments` to see available RL environments
|
|
||||||
2. **INSPECT**: Read environment files to understand how they work (verifiers, data loading, rewards)
|
|
||||||
3. **INSPECT DATA**: Use terminal to explore HuggingFace datasets and understand their format
|
|
||||||
4. **CREATE**: Copy existing environments as templates, modify for your needs
|
|
||||||
5. **CONFIGURE**: Use `rl_select_environment` and `rl_edit_config` to set up training
|
|
||||||
6. **TEST**: Always use `rl_test_inference` before full training to validate your setup
|
|
||||||
7. **TRAIN**: Use `rl_start_training` to begin, `rl_check_status` to monitor
|
|
||||||
8. **EVALUATE**: Use `rl_get_results` and analyze WandB metrics to assess performance
|
|
||||||
|
|
||||||
## Environment Files
|
|
||||||
|
|
||||||
Environment files are located in: `tinker-atropos/tinker_atropos/environments/`
|
|
||||||
|
|
||||||
Study existing environments to learn patterns. Look for:
|
|
||||||
- `load_dataset()` calls - how data is loaded
|
|
||||||
- `score_answer()` / `score()` - verification logic
|
|
||||||
- `get_next_item()` - prompt formatting
|
|
||||||
- `system_prompt` - instruction format
|
|
||||||
- `config_init()` - default configuration
|
|
||||||
|
|
||||||
## Creating New Environments
|
|
||||||
|
|
||||||
To create a new environment:
|
|
||||||
1. Read an existing environment file (e.g., gsm8k_tinker.py)
|
|
||||||
2. Use terminal to explore the target dataset format
|
|
||||||
3. Copy the environment file as a template
|
|
||||||
4. Modify the dataset loading, prompt formatting, and verifier logic
|
|
||||||
5. Test with `rl_test_inference` before training
|
|
||||||
|
|
||||||
## Important Guidelines
|
|
||||||
|
|
||||||
- **Always test before training**: Training runs take hours - verify everything works first
|
|
||||||
- **Monitor metrics**: Check WandB for reward/mean and percent_correct
|
|
||||||
- **Status check intervals**: Wait at least 30 minutes between status checks
|
|
||||||
- **Early stopping**: Stop training early if metrics look bad or stagnant
|
|
||||||
- **Iterate quickly**: Start with small total_steps to validate, then scale up
|
|
||||||
|
|
||||||
## Available Toolsets
|
|
||||||
|
|
||||||
You have access to:
|
|
||||||
- **RL tools**: Environment discovery, config management, training, testing
|
|
||||||
- **Terminal**: Run commands, inspect files, explore datasets
|
|
||||||
- **Web**: Search for information, documentation, papers
|
|
||||||
- **File tools**: Read and modify code files
|
|
||||||
|
|
||||||
When asked to train a model, follow this workflow:
|
|
||||||
1. List available environments
|
|
||||||
2. Select and configure the appropriate environment
|
|
||||||
3. Test with sample prompts
|
|
||||||
4. Start training with conservative settings
|
|
||||||
5. Monitor progress and adjust as needed
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Toolsets to enable for RL workflows
|
|
||||||
RL_TOOLSETS = ["terminal", "web", "rl"]
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Helper Functions
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
def check_requirements():
|
|
||||||
"""Check that all required environment variables and services are available."""
|
|
||||||
errors = []
|
|
||||||
|
|
||||||
# Check API keys
|
|
||||||
if not os.getenv("OPENROUTER_API_KEY"):
|
|
||||||
errors.append("OPENROUTER_API_KEY not set - required for agent")
|
|
||||||
|
|
||||||
missing_rl_keys = get_missing_keys()
|
|
||||||
if missing_rl_keys:
|
|
||||||
errors.append(f"Missing RL API keys: {', '.join(missing_rl_keys)}")
|
|
||||||
|
|
||||||
if errors:
|
|
||||||
print("❌ Missing requirements:")
|
|
||||||
for error in errors:
|
|
||||||
print(f" - {error}")
|
|
||||||
print("\nPlease set these environment variables in your .env file or shell.")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def check_tinker_atropos():
|
|
||||||
"""Check if tinker-atropos submodule is properly set up."""
|
|
||||||
tinker_path = Path(__file__).parent / "tinker-atropos"
|
|
||||||
|
|
||||||
if not tinker_path.exists():
|
|
||||||
return False, "tinker-atropos submodule not found. Run: git submodule update --init"
|
|
||||||
|
|
||||||
envs_path = tinker_path / "tinker_atropos" / "environments"
|
|
||||||
if not envs_path.exists():
|
|
||||||
return False, f"environments directory not found at {envs_path}"
|
|
||||||
|
|
||||||
env_files = list(envs_path.glob("*.py"))
|
|
||||||
env_files = [f for f in env_files if not f.name.startswith("_")]
|
|
||||||
|
|
||||||
return True, {"path": str(tinker_path), "environments_count": len(env_files)}
|
|
||||||
|
|
||||||
|
|
||||||
def list_environments_sync():
|
|
||||||
"""List available environments (synchronous wrapper)."""
|
|
||||||
from tools.rl_training_tool import rl_list_environments
|
|
||||||
import json
|
|
||||||
|
|
||||||
async def _list():
|
|
||||||
result = await rl_list_environments()
|
|
||||||
return json.loads(result)
|
|
||||||
|
|
||||||
return asyncio.run(_list())
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
|
||||||
# Main CLI
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
def main(
|
|
||||||
task: str = None,
|
|
||||||
model: str = None,
|
|
||||||
api_key: str = None,
|
|
||||||
base_url: str = None,
|
|
||||||
max_iterations: int = RL_MAX_ITERATIONS,
|
|
||||||
interactive: bool = False,
|
|
||||||
list_environments: bool = False,
|
|
||||||
check_server: bool = False,
|
|
||||||
verbose: bool = False,
|
|
||||||
save_trajectories: bool = True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
RL Training CLI - Dedicated runner for RL training workflows.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task: The training task/goal (e.g., "Train a model on GSM8k for math")
|
|
||||||
model: Model to use for the agent (reads from ~/.hermes/config.yaml if not provided)
|
|
||||||
api_key: OpenRouter API key (uses OPENROUTER_API_KEY env var if not provided)
|
|
||||||
base_url: API base URL (reads from config or defaults to OpenRouter)
|
|
||||||
max_iterations: Maximum agent iterations (default: 200 for long workflows)
|
|
||||||
interactive: Run in interactive mode (multiple conversations)
|
|
||||||
list_environments: Just list available RL environments and exit
|
|
||||||
check_server: Check if RL API server is running and exit
|
|
||||||
verbose: Enable verbose logging
|
|
||||||
save_trajectories: Save conversation trajectories (default: True for RL)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
# Train on a specific environment
|
|
||||||
python rl_cli.py "Train a model on GSM8k math problems"
|
|
||||||
|
|
||||||
# Interactive mode
|
|
||||||
python rl_cli.py --interactive
|
|
||||||
|
|
||||||
# List available environments
|
|
||||||
python rl_cli.py --list-environments
|
|
||||||
|
|
||||||
# Check server status
|
|
||||||
python rl_cli.py --check-server
|
|
||||||
"""
|
|
||||||
# Load config from ~/.hermes/config.yaml
|
|
||||||
config = load_hermes_config()
|
|
||||||
|
|
||||||
# Use config values if not explicitly provided
|
|
||||||
if model is None:
|
|
||||||
model = config["model"]
|
|
||||||
if base_url is None:
|
|
||||||
base_url = config["base_url"]
|
|
||||||
|
|
||||||
print("🎯 RL Training Agent")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Handle setup check
|
|
||||||
if check_server:
|
|
||||||
print("\n🔍 Checking tinker-atropos setup...")
|
|
||||||
ok, result = check_tinker_atropos()
|
|
||||||
if ok:
|
|
||||||
print("✅ tinker-atropos submodule found")
|
|
||||||
print(f" Path: {result.get('path')}")
|
|
||||||
print(f" Environments found: {result.get('environments_count', 0)}")
|
|
||||||
|
|
||||||
# Also check API keys
|
|
||||||
missing = get_missing_keys()
|
|
||||||
if missing:
|
|
||||||
print(f"\n⚠️ Missing API keys: {', '.join(missing)}")
|
|
||||||
print(" Add them to ~/.hermes/.env")
|
|
||||||
else:
|
|
||||||
print("✅ API keys configured")
|
|
||||||
else:
|
|
||||||
print(f"❌ tinker-atropos not set up: {result}")
|
|
||||||
print("\nTo set up:")
|
|
||||||
print(" git submodule update --init")
|
|
||||||
print(" pip install -e ./tinker-atropos")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Handle environment listing
|
|
||||||
if list_environments:
|
|
||||||
print("\n📋 Available RL Environments:")
|
|
||||||
print("-" * 40)
|
|
||||||
try:
|
|
||||||
data = list_environments_sync()
|
|
||||||
if "error" in data:
|
|
||||||
print(f"❌ Error: {data['error']}")
|
|
||||||
return
|
|
||||||
|
|
||||||
envs = data.get("environments", [])
|
|
||||||
if not envs:
|
|
||||||
print("No environments found.")
|
|
||||||
print("\nMake sure tinker-atropos is set up:")
|
|
||||||
print(" git submodule update --init")
|
|
||||||
return
|
|
||||||
|
|
||||||
for env in envs:
|
|
||||||
print(f"\n 📦 {env['name']}")
|
|
||||||
print(f" Class: {env['class_name']}")
|
|
||||||
print(f" Path: {env['file_path']}")
|
|
||||||
if env.get('description'):
|
|
||||||
desc = env['description'][:100] + "..." if len(env.get('description', '')) > 100 else env.get('description', '')
|
|
||||||
print(f" Description: {desc}")
|
|
||||||
|
|
||||||
print(f"\n📊 Total: {len(envs)} environments")
|
|
||||||
print("\nUse `rl_select_environment(name)` to select an environment for training.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Error listing environments: {e}")
|
|
||||||
print("\nMake sure tinker-atropos is set up:")
|
|
||||||
print(" git submodule update --init")
|
|
||||||
print(" pip install -e ./tinker-atropos")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Check requirements
|
|
||||||
if not check_requirements():
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Set default task if none provided
|
|
||||||
if not task and not interactive:
|
|
||||||
print("\n⚠️ No task provided. Use --interactive for interactive mode or provide a task.")
|
|
||||||
print("\nExamples:")
|
|
||||||
print(' python rl_cli.py "Train a model on GSM8k math problems"')
|
|
||||||
print(' python rl_cli.py "Create an RL environment for code generation"')
|
|
||||||
print(' python rl_cli.py --interactive')
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get API key
|
|
||||||
api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
|
||||||
if not api_key:
|
|
||||||
print("❌ No API key provided. Set OPENROUTER_API_KEY or pass --api-key")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"\n🤖 Model: {model}")
|
|
||||||
print(f"🔧 Max iterations: {max_iterations}")
|
|
||||||
print(f"📁 Toolsets: {', '.join(RL_TOOLSETS)}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Create agent with RL configuration
|
|
||||||
agent = AIAgent(
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=api_key,
|
|
||||||
model=model,
|
|
||||||
max_iterations=max_iterations,
|
|
||||||
enabled_toolsets=RL_TOOLSETS,
|
|
||||||
save_trajectories=save_trajectories,
|
|
||||||
verbose_logging=verbose,
|
|
||||||
quiet_mode=False,
|
|
||||||
ephemeral_system_prompt=RL_SYSTEM_PROMPT,
|
|
||||||
)
|
|
||||||
|
|
||||||
if interactive:
|
|
||||||
# Interactive mode - multiple conversations
|
|
||||||
print("\n🔄 Interactive RL Training Mode")
|
|
||||||
print("Type 'quit' or 'exit' to end the session.")
|
|
||||||
print("Type 'status' to check active training runs.")
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
user_input = input("\n🎯 RL Task> ").strip()
|
|
||||||
|
|
||||||
if not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() in {'quit', 'exit', 'q'}:
|
|
||||||
print("\n👋 Goodbye!")
|
|
||||||
break
|
|
||||||
|
|
||||||
if user_input.lower() == 'status':
|
|
||||||
# Quick status check
|
|
||||||
from tools.rl_training_tool import rl_list_runs
|
|
||||||
import json
|
|
||||||
result = asyncio.run(rl_list_runs())
|
|
||||||
runs = json.loads(result)
|
|
||||||
if isinstance(runs, list) and runs:
|
|
||||||
print("\n📊 Active Runs:")
|
|
||||||
for run in runs:
|
|
||||||
print(f" - {run['run_id']}: {run['environment']} ({run['status']})")
|
|
||||||
else:
|
|
||||||
print("\nNo active runs.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Run the agent
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
agent.run_conversation(user_input)
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\n👋 Interrupted. Goodbye!")
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Error: {e}")
|
|
||||||
if verbose:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
else:
|
|
||||||
# Single task mode
|
|
||||||
print(f"\n📝 Task: {task}")
|
|
||||||
print("-" * 40)
|
|
||||||
|
|
||||||
try:
|
|
||||||
agent.run_conversation(task)
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("✅ Task completed")
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print("\n\n⚠️ Interrupted by user")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"\n❌ Error: {e}")
|
|
||||||
if verbose:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fire.Fire(main)
|
|
||||||
|
|
@ -958,20 +958,6 @@ except Exception:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# tinker-atropos (RL training) is optional and OFF by default. Matches the
|
|
||||||
# Linux/macOS install.sh behavior. Reasons not to auto-install:
|
|
||||||
# - tinker-atropos/pyproject.toml pulls atroposlib + tinker from git+https
|
|
||||||
# (NousResearch/atropos + thinking-machines-lab/tinker) which can fail on
|
|
||||||
# locked-down networks, flaky DNS, or rate-limited github.com and would
|
|
||||||
# previously kill the whole install mid-flight on Windows.
|
|
||||||
# - It's an RL training submodule, not part of the default agent surface.
|
|
||||||
# Users who don't do RL training never need it.
|
|
||||||
# Users who do want it can run the one-liner we print below.
|
|
||||||
if (Test-Path "tinker-atropos\pyproject.toml") {
|
|
||||||
Write-Info "tinker-atropos submodule found — skipping install (optional, for RL training)"
|
|
||||||
Write-Info " To install later: $UvCmd pip install -e `".\tinker-atropos`""
|
|
||||||
}
|
|
||||||
|
|
||||||
Pop-Location
|
Pop-Location
|
||||||
|
|
||||||
Write-Success "All dependencies installed"
|
Write-Success "All dependencies installed"
|
||||||
|
|
|
||||||
|
|
@ -1051,11 +1051,6 @@ install_deps() {
|
||||||
log_info "Termux note: matrix e2ee and local faster-whisper extras are excluded from .[termux-all] due to upstream Android wheel/toolchain blockers."
|
log_info "Termux note: matrix e2ee and local faster-whisper extras are excluded from .[termux-all] due to upstream Android wheel/toolchain blockers."
|
||||||
log_info "Termux note: browser/WhatsApp tooling is not installed by default; see the Termux guide for optional follow-up steps."
|
log_info "Termux note: browser/WhatsApp tooling is not installed by default; see the Termux guide for optional follow-up steps."
|
||||||
|
|
||||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
|
||||||
log_info "tinker-atropos submodule found — skipping install (optional, for RL training)"
|
|
||||||
log_info " To install later: $PIP_PYTHON -m pip install -e \"./tinker-atropos\""
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_success "All dependencies installed"
|
log_success "All dependencies installed"
|
||||||
return 0
|
return 0
|
||||||
fi
|
fi
|
||||||
|
|
@ -1243,13 +1238,6 @@ PY
|
||||||
|
|
||||||
log_success "Main package installed"
|
log_success "Main package installed"
|
||||||
|
|
||||||
# tinker-atropos (RL training) is optional — skip by default.
|
|
||||||
# To enable RL tools: git submodule update --init tinker-atropos && uv pip install -e "./tinker-atropos"
|
|
||||||
if [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
|
||||||
log_info "tinker-atropos submodule found — skipping install (optional, for RL training)"
|
|
||||||
log_info " To install: $UV_CMD pip install -e \"./tinker-atropos\""
|
|
||||||
fi
|
|
||||||
|
|
||||||
log_success "All dependencies installed"
|
log_success "All dependencies installed"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -267,22 +267,6 @@ else
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Submodules (terminal backend + RL training)
|
|
||||||
# ============================================================================
|
|
||||||
|
|
||||||
echo -e "${CYAN}→${NC} Installing optional submodules..."
|
|
||||||
|
|
||||||
# tinker-atropos (RL training backend)
|
|
||||||
if is_termux; then
|
|
||||||
echo -e "${CYAN}→${NC} Skipping tinker-atropos on Termux (not part of the tested Android path)"
|
|
||||||
elif [ -d "tinker-atropos" ] && [ -f "tinker-atropos/pyproject.toml" ]; then
|
|
||||||
$UV_CMD pip install -e "./tinker-atropos" && \
|
|
||||||
echo -e "${GREEN}✓${NC} tinker-atropos installed" || \
|
|
||||||
echo -e "${YELLOW}⚠${NC} tinker-atropos install failed (RL tools may not work)"
|
|
||||||
else
|
|
||||||
echo -e "${YELLOW}⚠${NC} tinker-atropos not found (run: git submodule update --init --recursive)"
|
|
||||||
fi
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Optional: ripgrep (for faster file search)
|
# Optional: ripgrep (for faster file search)
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,6 @@ _CREDENTIAL_NAMES = frozenset({
|
||||||
"RETAINDB_API_KEY",
|
"RETAINDB_API_KEY",
|
||||||
"HINDSIGHT_API_KEY",
|
"HINDSIGHT_API_KEY",
|
||||||
"HINDSIGHT_LLM_API_KEY",
|
"HINDSIGHT_LLM_API_KEY",
|
||||||
"TINKER_API_KEY",
|
|
||||||
"DAYTONA_API_KEY",
|
"DAYTONA_API_KEY",
|
||||||
"TWILIO_AUTH_TOKEN",
|
"TWILIO_AUTH_TOKEN",
|
||||||
"TELEGRAM_BOT_TOKEN",
|
"TELEGRAM_BOT_TOKEN",
|
||||||
|
|
|
||||||
|
|
@ -1,164 +0,0 @@
|
||||||
"""Security tests for Terminal-Bench 2 archive extraction."""
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import importlib
|
|
||||||
import io
|
|
||||||
import sys
|
|
||||||
import tarfile
|
|
||||||
import types
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
def _stub_module(name: str, **attrs):
|
|
||||||
module = types.ModuleType(name)
|
|
||||||
for key, value in attrs.items():
|
|
||||||
setattr(module, key, value)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def _load_terminalbench_module(monkeypatch):
|
|
||||||
class _EvalHandlingEnum:
|
|
||||||
STOP_TRAIN = "stop_train"
|
|
||||||
|
|
||||||
class _APIServerConfig:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
self.args = args
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
class _AgentResult:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _HermesAgentLoop:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _HermesAgentBaseEnv:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _HermesAgentEnvConfig:
|
|
||||||
pass
|
|
||||||
|
|
||||||
class _ToolContext:
|
|
||||||
pass
|
|
||||||
|
|
||||||
stub_modules = {
|
|
||||||
"atroposlib": _stub_module("atroposlib"),
|
|
||||||
"atroposlib.envs": _stub_module("atroposlib.envs"),
|
|
||||||
"atroposlib.envs.base": _stub_module(
|
|
||||||
"atroposlib.envs.base",
|
|
||||||
EvalHandlingEnum=_EvalHandlingEnum,
|
|
||||||
),
|
|
||||||
"atroposlib.envs.server_handling": _stub_module("atroposlib.envs.server_handling"),
|
|
||||||
"atroposlib.envs.server_handling.server_manager": _stub_module(
|
|
||||||
"atroposlib.envs.server_handling.server_manager",
|
|
||||||
APIServerConfig=_APIServerConfig,
|
|
||||||
),
|
|
||||||
"environments.agent_loop": _stub_module(
|
|
||||||
"environments.agent_loop",
|
|
||||||
AgentResult=_AgentResult,
|
|
||||||
HermesAgentLoop=_HermesAgentLoop,
|
|
||||||
),
|
|
||||||
"environments.hermes_base_env": _stub_module(
|
|
||||||
"environments.hermes_base_env",
|
|
||||||
HermesAgentBaseEnv=_HermesAgentBaseEnv,
|
|
||||||
HermesAgentEnvConfig=_HermesAgentEnvConfig,
|
|
||||||
),
|
|
||||||
"environments.tool_context": _stub_module(
|
|
||||||
"environments.tool_context",
|
|
||||||
ToolContext=_ToolContext,
|
|
||||||
),
|
|
||||||
"tools.terminal_tool": _stub_module(
|
|
||||||
"tools.terminal_tool",
|
|
||||||
register_task_env_overrides=lambda *args, **kwargs: None,
|
|
||||||
clear_task_env_overrides=lambda *args, **kwargs: None,
|
|
||||||
cleanup_vm=lambda *args, **kwargs: None,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
stub_modules["atroposlib"].envs = stub_modules["atroposlib.envs"]
|
|
||||||
stub_modules["atroposlib.envs"].base = stub_modules["atroposlib.envs.base"]
|
|
||||||
stub_modules["atroposlib.envs"].server_handling = stub_modules["atroposlib.envs.server_handling"]
|
|
||||||
stub_modules["atroposlib.envs.server_handling"].server_manager = stub_modules[
|
|
||||||
"atroposlib.envs.server_handling.server_manager"
|
|
||||||
]
|
|
||||||
|
|
||||||
for name, module in stub_modules.items():
|
|
||||||
monkeypatch.setitem(sys.modules, name, module)
|
|
||||||
|
|
||||||
module_name = "environments.benchmarks.terminalbench_2.terminalbench2_env"
|
|
||||||
sys.modules.pop(module_name, None)
|
|
||||||
return importlib.import_module(module_name)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_tar_b64(entries):
|
|
||||||
buf = io.BytesIO()
|
|
||||||
with tarfile.open(fileobj=buf, mode="w:gz") as tar:
|
|
||||||
for entry in entries:
|
|
||||||
kind = entry["kind"]
|
|
||||||
info = tarfile.TarInfo(entry["name"])
|
|
||||||
|
|
||||||
if kind == "dir":
|
|
||||||
info.type = tarfile.DIRTYPE
|
|
||||||
tar.addfile(info)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if kind == "file":
|
|
||||||
data = entry["data"].encode("utf-8")
|
|
||||||
info.size = len(data)
|
|
||||||
tar.addfile(info, io.BytesIO(data))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if kind == "symlink":
|
|
||||||
info.type = tarfile.SYMTYPE
|
|
||||||
info.linkname = entry["target"]
|
|
||||||
tar.addfile(info)
|
|
||||||
continue
|
|
||||||
|
|
||||||
raise ValueError(f"Unknown tar entry kind: {kind}")
|
|
||||||
|
|
||||||
return base64.b64encode(buf.getvalue()).decode("ascii")
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_base64_tar_allows_safe_files(tmp_path, monkeypatch):
|
|
||||||
module = _load_terminalbench_module(monkeypatch)
|
|
||||||
archive = _build_tar_b64(
|
|
||||||
[
|
|
||||||
{"kind": "dir", "name": "nested"},
|
|
||||||
{"kind": "file", "name": "nested/hello.txt", "data": "hello"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
target = tmp_path / "extract"
|
|
||||||
module._extract_base64_tar(archive, target)
|
|
||||||
|
|
||||||
assert (target / "nested" / "hello.txt").read_text(encoding="utf-8") == "hello"
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_base64_tar_rejects_path_traversal(tmp_path, monkeypatch):
|
|
||||||
module = _load_terminalbench_module(monkeypatch)
|
|
||||||
archive = _build_tar_b64(
|
|
||||||
[
|
|
||||||
{"kind": "file", "name": "../escape.txt", "data": "owned"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
target = tmp_path / "extract"
|
|
||||||
with pytest.raises(ValueError, match="Unsafe archive member path"):
|
|
||||||
module._extract_base64_tar(archive, target)
|
|
||||||
|
|
||||||
assert not (tmp_path / "escape.txt").exists()
|
|
||||||
|
|
||||||
|
|
||||||
def test_extract_base64_tar_rejects_symlinks(tmp_path, monkeypatch):
|
|
||||||
module = _load_terminalbench_module(monkeypatch)
|
|
||||||
archive = _build_tar_b64(
|
|
||||||
[
|
|
||||||
{"kind": "symlink", "name": "link", "target": "../../escape.txt"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
target = tmp_path / "extract"
|
|
||||||
with pytest.raises(ValueError, match="Unsupported archive member type"):
|
|
||||||
module._extract_base64_tar(archive, target)
|
|
||||||
|
|
||||||
assert not (target / "link").exists()
|
|
||||||
|
|
@ -39,8 +39,6 @@ class TestExplicitAllowlist:
|
||||||
"OPENROUTER_API_KEY",
|
"OPENROUTER_API_KEY",
|
||||||
"OPENAI_API_KEY",
|
"OPENAI_API_KEY",
|
||||||
"ANTHROPIC_API_KEY",
|
"ANTHROPIC_API_KEY",
|
||||||
"WANDB_API_KEY",
|
|
||||||
"TINKER_API_KEY",
|
|
||||||
"HONCHO_API_KEY",
|
"HONCHO_API_KEY",
|
||||||
"FIRECRAWL_API_KEY",
|
"FIRECRAWL_API_KEY",
|
||||||
"BROWSERBASE_API_KEY",
|
"BROWSERBASE_API_KEY",
|
||||||
|
|
|
||||||
|
|
@ -18,4 +18,3 @@ def test_setup_hermes_script_has_termux_path():
|
||||||
assert ".[termux]" in content
|
assert ".[termux]" in content
|
||||||
assert "constraints-termux.txt" in content
|
assert "constraints-termux.txt" in content
|
||||||
assert "$PREFIX/bin" in content
|
assert "$PREFIX/bin" in content
|
||||||
assert "Skipping tinker-atropos on Termux" in content
|
|
||||||
|
|
|
||||||
|
|
@ -1,505 +0,0 @@
|
||||||
"""
|
|
||||||
Tests for environments/agent_loop.py — HermesAgentLoop.
|
|
||||||
|
|
||||||
Tests the multi-turn agent engine using mocked servers, without needing
|
|
||||||
real API keys or running servers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Ensure repo root is importable
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent.parent))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.agent_loop import (
|
|
||||||
AgentResult,
|
|
||||||
HermesAgentLoop,
|
|
||||||
ToolError,
|
|
||||||
_extract_reasoning_from_message,
|
|
||||||
resize_tool_pool,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Mock server infrastructure ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockFunction:
|
|
||||||
name: str
|
|
||||||
arguments: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockToolCall:
|
|
||||||
id: str
|
|
||||||
function: MockFunction
|
|
||||||
type: str = "function"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockMessage:
|
|
||||||
content: Optional[str]
|
|
||||||
role: str = "assistant"
|
|
||||||
tool_calls: Optional[List[MockToolCall]] = None
|
|
||||||
reasoning_content: Optional[str] = None
|
|
||||||
reasoning: Optional[str] = None
|
|
||||||
reasoning_details: Optional[list] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockChoice:
|
|
||||||
message: MockMessage
|
|
||||||
finish_reason: str = "stop"
|
|
||||||
index: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockChatCompletion:
|
|
||||||
choices: List[MockChoice]
|
|
||||||
id: str = "chatcmpl-mock"
|
|
||||||
model: str = "mock-model"
|
|
||||||
|
|
||||||
|
|
||||||
class MockServer:
|
|
||||||
"""
|
|
||||||
Mock server that returns pre-configured responses in sequence.
|
|
||||||
Mimics the chat_completion() interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, responses: List[MockChatCompletion]):
|
|
||||||
self.responses = responses
|
|
||||||
self.call_count = 0
|
|
||||||
self.call_history: List[Dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def chat_completion(self, **kwargs) -> MockChatCompletion:
|
|
||||||
self.call_history.append(kwargs)
|
|
||||||
if self.call_count >= len(self.responses):
|
|
||||||
# Return a simple text response if we run out
|
|
||||||
return MockChatCompletion(
|
|
||||||
choices=[MockChoice(message=MockMessage(content="Done."))]
|
|
||||||
)
|
|
||||||
resp = self.responses[self.call_count]
|
|
||||||
self.call_count += 1
|
|
||||||
return resp
|
|
||||||
|
|
||||||
|
|
||||||
def make_text_response(content: str) -> MockChatCompletion:
|
|
||||||
"""Create a simple text-only response (no tool calls)."""
|
|
||||||
return MockChatCompletion(
|
|
||||||
choices=[MockChoice(message=MockMessage(content=content))]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_tool_response(
|
|
||||||
tool_name: str,
|
|
||||||
arguments: dict,
|
|
||||||
content: str = "",
|
|
||||||
tool_call_id: str = "call_001",
|
|
||||||
) -> MockChatCompletion:
|
|
||||||
"""Create a response with a single tool call."""
|
|
||||||
return MockChatCompletion(
|
|
||||||
choices=[
|
|
||||||
MockChoice(
|
|
||||||
message=MockMessage(
|
|
||||||
content=content,
|
|
||||||
tool_calls=[
|
|
||||||
MockToolCall(
|
|
||||||
id=tool_call_id,
|
|
||||||
function=MockFunction(
|
|
||||||
name=tool_name,
|
|
||||||
arguments=json.dumps(arguments),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
],
|
|
||||||
),
|
|
||||||
finish_reason="tool_calls",
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Tests ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentResult:
|
|
||||||
def test_defaults(self):
|
|
||||||
result = AgentResult(messages=[])
|
|
||||||
assert result.messages == []
|
|
||||||
assert result.managed_state is None
|
|
||||||
assert result.turns_used == 0
|
|
||||||
assert result.finished_naturally is False
|
|
||||||
assert result.reasoning_per_turn == []
|
|
||||||
assert result.tool_errors == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestExtractReasoning:
|
|
||||||
def test_reasoning_content_field(self):
|
|
||||||
msg = MockMessage(content="hello", reasoning_content="I think...")
|
|
||||||
assert _extract_reasoning_from_message(msg) == "I think..."
|
|
||||||
|
|
||||||
def test_reasoning_field(self):
|
|
||||||
msg = MockMessage(content="hello", reasoning="Let me consider...")
|
|
||||||
assert _extract_reasoning_from_message(msg) == "Let me consider..."
|
|
||||||
|
|
||||||
def test_reasoning_details(self):
|
|
||||||
detail = MagicMock()
|
|
||||||
detail.text = "Detail reasoning"
|
|
||||||
msg = MockMessage(content="hello", reasoning_details=[detail])
|
|
||||||
assert _extract_reasoning_from_message(msg) == "Detail reasoning"
|
|
||||||
|
|
||||||
def test_reasoning_details_dict_format(self):
|
|
||||||
msg = MockMessage(
|
|
||||||
content="hello",
|
|
||||||
reasoning_details=[{"text": "Dict reasoning"}],
|
|
||||||
)
|
|
||||||
assert _extract_reasoning_from_message(msg) == "Dict reasoning"
|
|
||||||
|
|
||||||
def test_no_reasoning(self):
|
|
||||||
msg = MockMessage(content="hello")
|
|
||||||
assert _extract_reasoning_from_message(msg) is None
|
|
||||||
|
|
||||||
def test_reasoning_content_takes_priority(self):
|
|
||||||
msg = MockMessage(
|
|
||||||
content="hello",
|
|
||||||
reasoning_content="First",
|
|
||||||
reasoning="Second",
|
|
||||||
)
|
|
||||||
assert _extract_reasoning_from_message(msg) == "First"
|
|
||||||
|
|
||||||
|
|
||||||
class TestHermesAgentLoop:
|
|
||||||
"""Test the agent loop with mock servers."""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def basic_tools(self):
|
|
||||||
"""Minimal tool schema for testing."""
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "terminal",
|
|
||||||
"description": "Run a command",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"command": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Command to run",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["command"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "read_file",
|
|
||||||
"description": "Read a file",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"path": {"type": "string"},
|
|
||||||
},
|
|
||||||
"required": ["path"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def valid_names(self):
|
|
||||||
return {"terminal", "read_file", "todo"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_simple_text_response(self, basic_tools, valid_names):
|
|
||||||
"""Model responds with text only, no tool calls."""
|
|
||||||
server = MockServer([make_text_response("Hello! How can I help?")])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally is True
|
|
||||||
assert result.turns_used == 1
|
|
||||||
assert len(result.messages) >= 2 # user + assistant
|
|
||||||
assert result.messages[-1]["role"] == "assistant"
|
|
||||||
assert result.messages[-1]["content"] == "Hello! How can I help?"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_then_text(self, basic_tools, valid_names):
|
|
||||||
"""Model calls a tool, then responds with text."""
|
|
||||||
server = MockServer([
|
|
||||||
make_tool_response("todo", {"todos": [{"id": "1", "content": "test", "status": "pending"}]}),
|
|
||||||
make_text_response("I created a todo for you."),
|
|
||||||
])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Create a todo"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally is True
|
|
||||||
assert result.turns_used == 2
|
|
||||||
# Should have: user, assistant (tool_call), tool (result), assistant (text)
|
|
||||||
roles = [m["role"] for m in result.messages]
|
|
||||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_turns_reached(self, basic_tools, valid_names):
|
|
||||||
"""Model keeps calling tools until max_turns is hit."""
|
|
||||||
# Create responses that always call a tool
|
|
||||||
responses = [
|
|
||||||
make_tool_response("todo", {"todos": [{"id": str(i), "content": f"task {i}", "status": "pending"}]}, tool_call_id=f"call_{i}")
|
|
||||||
for i in range(10)
|
|
||||||
]
|
|
||||||
server = MockServer(responses)
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=3,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Keep going"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally is False
|
|
||||||
assert result.turns_used == 3
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_name(self, basic_tools, valid_names):
|
|
||||||
"""Model calls a tool not in valid_tool_names."""
|
|
||||||
server = MockServer([
|
|
||||||
make_tool_response("nonexistent_tool", {"arg": "val"}),
|
|
||||||
make_text_response("OK, that didn't work."),
|
|
||||||
])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Call something weird"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Should record a tool error
|
|
||||||
assert len(result.tool_errors) >= 1
|
|
||||||
assert result.tool_errors[0].tool_name == "nonexistent_tool"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_response(self, basic_tools, valid_names):
|
|
||||||
"""Server returns empty response."""
|
|
||||||
server = MockServer([MockChatCompletion(choices=[])])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally is False
|
|
||||||
assert result.turns_used == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_api_error_handling(self, basic_tools, valid_names):
|
|
||||||
"""Server raises an exception."""
|
|
||||||
|
|
||||||
class FailingServer:
|
|
||||||
async def chat_completion(self, **kwargs):
|
|
||||||
raise ConnectionError("Server unreachable")
|
|
||||||
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=FailingServer(),
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally is False
|
|
||||||
assert result.turns_used == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tools_passed_to_server(self, basic_tools, valid_names):
|
|
||||||
"""Verify tools are passed in the chat_completion kwargs."""
|
|
||||||
server = MockServer([make_text_response("OK")])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
await agent.run(messages)
|
|
||||||
|
|
||||||
assert len(server.call_history) == 1
|
|
||||||
assert "tools" in server.call_history[0]
|
|
||||||
assert server.call_history[0]["tools"] == basic_tools
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_extra_body_forwarded(self, basic_tools, valid_names):
|
|
||||||
"""extra_body should be forwarded to server."""
|
|
||||||
extra = {"provider": {"ignore": ["DeepInfra"]}}
|
|
||||||
server = MockServer([make_text_response("OK")])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
extra_body=extra,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
await agent.run(messages)
|
|
||||||
|
|
||||||
assert server.call_history[0].get("extra_body") == extra
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_managed_state_returned(self, basic_tools, valid_names):
|
|
||||||
"""If server has get_state(), result should include managed_state."""
|
|
||||||
server = MockServer([make_text_response("OK")])
|
|
||||||
server.get_state = lambda: {"nodes": [{"test": True}]}
|
|
||||||
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.managed_state is not None
|
|
||||||
assert "nodes" in result.managed_state
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_managed_state_without_get_state(self, basic_tools, valid_names):
|
|
||||||
"""Regular server without get_state() should return None managed_state."""
|
|
||||||
server = MockServer([make_text_response("OK")])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Hi"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.managed_state is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_memory_tool_blocked(self, basic_tools):
|
|
||||||
"""Memory tool should return error in RL environments."""
|
|
||||||
valid = {"terminal", "read_file", "todo", "memory"}
|
|
||||||
server = MockServer([
|
|
||||||
make_tool_response("memory", {"action": "add", "target": "user", "content": "test"}),
|
|
||||||
make_text_response("Done"),
|
|
||||||
])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Remember this"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Find the tool response
|
|
||||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
|
||||||
assert len(tool_msgs) >= 1
|
|
||||||
tool_result = json.loads(tool_msgs[0]["content"])
|
|
||||||
assert "error" in tool_result
|
|
||||||
assert "not available" in tool_result["error"].lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_session_search_blocked(self, basic_tools):
|
|
||||||
"""session_search should return error in RL environments."""
|
|
||||||
valid = {"terminal", "read_file", "todo", "session_search"}
|
|
||||||
server = MockServer([
|
|
||||||
make_tool_response("session_search", {"query": "test"}),
|
|
||||||
make_text_response("Done"),
|
|
||||||
])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "Search sessions"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
|
||||||
assert len(tool_msgs) >= 1
|
|
||||||
tool_result = json.loads(tool_msgs[0]["content"])
|
|
||||||
assert "error" in tool_result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_reasoning_content_preserved(self, basic_tools, valid_names):
|
|
||||||
"""Reasoning content should be extracted and preserved."""
|
|
||||||
resp = MockChatCompletion(
|
|
||||||
choices=[
|
|
||||||
MockChoice(
|
|
||||||
message=MockMessage(
|
|
||||||
content="The answer is 42.",
|
|
||||||
reasoning_content="Let me think about this step by step...",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
server = MockServer([resp])
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=basic_tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=10,
|
|
||||||
)
|
|
||||||
messages = [{"role": "user", "content": "What is the meaning of life?"}]
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert len(result.reasoning_per_turn) == 1
|
|
||||||
assert result.reasoning_per_turn[0] == "Let me think about this step by step..."
|
|
||||||
|
|
||||||
|
|
||||||
class TestResizeToolPool:
|
|
||||||
def test_resize_works(self):
|
|
||||||
"""resize_tool_pool should not raise."""
|
|
||||||
resize_tool_pool(16) # Small pool for testing
|
|
||||||
resize_tool_pool(128) # Restore default
|
|
||||||
|
|
||||||
def test_resize_shuts_down_previous_executor(self, monkeypatch):
|
|
||||||
"""Replacing the global tool executor should shut down the old pool."""
|
|
||||||
import environments.agent_loop as agent_loop_module
|
|
||||||
|
|
||||||
old_executor = MagicMock()
|
|
||||||
new_executor = MagicMock()
|
|
||||||
|
|
||||||
monkeypatch.setattr(agent_loop_module, "_tool_executor", old_executor)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
agent_loop_module.concurrent.futures,
|
|
||||||
"ThreadPoolExecutor",
|
|
||||||
MagicMock(return_value=new_executor),
|
|
||||||
)
|
|
||||||
|
|
||||||
resize_tool_pool(16)
|
|
||||||
|
|
||||||
old_executor.shutdown.assert_called_once_with(wait=False)
|
|
||||||
assert agent_loop_module._tool_executor is new_executor
|
|
||||||
|
|
@ -1,552 +0,0 @@
|
||||||
"""Integration tests for HermesAgentLoop tool calling.
|
|
||||||
|
|
||||||
Tests the full agent loop with real LLM calls via OpenRouter.
|
|
||||||
Uses stepfun/step-3.5-flash:free by default (zero cost), falls back
|
|
||||||
to anthropic/claude-sonnet-4 if the free model is unavailable.
|
|
||||||
|
|
||||||
These tests verify:
|
|
||||||
1. Single tool call: model calls a tool, gets result, responds
|
|
||||||
2. Multi-tool call: model calls multiple tools in one turn
|
|
||||||
3. Multi-turn: model calls tools across multiple turns
|
|
||||||
4. Unknown tool rejection: model calling a non-existent tool gets an error
|
|
||||||
5. Max turns: loop stops when max_turns is reached
|
|
||||||
6. No tools: model responds without calling any tools
|
|
||||||
7. Tool error handling: tool execution errors are captured
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_agent_loop_tool_calling.py -v
|
|
||||||
pytest tests/test_agent_loop_tool_calling.py -v -k "single" # run one test
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Set
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# pytestmark removed — tests skip gracefully via OPENROUTER_API_KEY check on line 59
|
|
||||||
|
|
||||||
# Ensure repo root is importable
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
||||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Test infrastructure
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
# Models to try, in order of preference (free first)
|
|
||||||
_MODELS = [
|
|
||||||
"stepfun/step-3.5-flash:free",
|
|
||||||
"google/gemini-2.0-flash-001",
|
|
||||||
"anthropic/claude-sonnet-4",
|
|
||||||
]
|
|
||||||
|
|
||||||
def _get_api_key():
|
|
||||||
key = os.getenv("OPENROUTER_API_KEY", "")
|
|
||||||
if not key:
|
|
||||||
pytest.skip("OPENROUTER_API_KEY not set")
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
def _make_server(model: str = None):
|
|
||||||
"""Create an OpenAI server for testing."""
|
|
||||||
from atroposlib.envs.server_handling.openai_server import OpenAIServer
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
|
|
||||||
config = APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name=model or _MODELS[0],
|
|
||||||
server_type="openai",
|
|
||||||
api_key=_get_api_key(),
|
|
||||||
health_check=False,
|
|
||||||
)
|
|
||||||
return OpenAIServer(config)
|
|
||||||
|
|
||||||
|
|
||||||
async def _try_models(test_fn):
|
|
||||||
"""Try running a test with each model until one works."""
|
|
||||||
last_error = None
|
|
||||||
for model in _MODELS:
|
|
||||||
try:
|
|
||||||
server = _make_server(model)
|
|
||||||
return await test_fn(server, model)
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
if "rate" in str(e).lower() or "limit" in str(e).lower():
|
|
||||||
continue # Rate limited, try next model
|
|
||||||
raise # Real error
|
|
||||||
pytest.skip(f"All models failed. Last error: {last_error}")
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Fake tools for testing
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
# Simple calculator tool
|
|
||||||
CALC_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "calculate",
|
|
||||||
"description": "Calculate a math expression. Returns the numeric result.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"expression": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Math expression to evaluate, e.g. '2 + 3'"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["expression"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Weather lookup tool
|
|
||||||
WEATHER_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get the current weather for a city. Returns temperature and conditions.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name, e.g. 'Tokyo'"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Lookup tool (always succeeds)
|
|
||||||
LOOKUP_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "lookup",
|
|
||||||
"description": "Look up a fact. Returns a short answer string.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"query": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "What to look up"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["query"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
# Error tool (always fails)
|
|
||||||
ERROR_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "failing_tool",
|
|
||||||
"description": "A tool that always fails with an error.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"input": {"type": "string"}
|
|
||||||
},
|
|
||||||
"required": ["input"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
|
||||||
"""Handle fake tool calls for testing."""
|
|
||||||
if tool_name == "calculate":
|
|
||||||
expr = args.get("expression", "0")
|
|
||||||
try:
|
|
||||||
# Safe eval for simple math
|
|
||||||
result = eval(expr, {"__builtins__": {}}, {})
|
|
||||||
return json.dumps({"result": result})
|
|
||||||
except Exception as e:
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
|
|
||||||
elif tool_name == "get_weather":
|
|
||||||
city = args.get("city", "Unknown")
|
|
||||||
# Return canned weather
|
|
||||||
return json.dumps({
|
|
||||||
"city": city,
|
|
||||||
"temperature": 22,
|
|
||||||
"conditions": "sunny",
|
|
||||||
"humidity": 45,
|
|
||||||
})
|
|
||||||
|
|
||||||
elif tool_name == "lookup":
|
|
||||||
query = args.get("query", "")
|
|
||||||
return json.dumps({"answer": f"The answer to '{query}' is 42."})
|
|
||||||
|
|
||||||
elif tool_name == "failing_tool":
|
|
||||||
raise RuntimeError("This tool always fails!")
|
|
||||||
|
|
||||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Tests
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_tool_call():
|
|
||||||
"""Model should call a single tool, get the result, and respond."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentResult)
|
|
||||||
assert result.turns_used >= 2, f"Expected at least 2 turns (tool call + response), got {result.turns_used}"
|
|
||||||
|
|
||||||
# Verify a tool call happened
|
|
||||||
tool_calls_found = False
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
if tc["function"]["name"] == "get_weather":
|
|
||||||
tool_calls_found = True
|
|
||||||
args = json.loads(tc["function"]["arguments"])
|
|
||||||
assert "city" in args
|
|
||||||
assert tool_calls_found, "Model should have called get_weather"
|
|
||||||
|
|
||||||
# Verify tool result is in conversation
|
|
||||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
|
||||||
assert len(tool_results) >= 1, "Should have at least one tool result"
|
|
||||||
|
|
||||||
# Verify the final response references the weather
|
|
||||||
final_msg = result.messages[-1]
|
|
||||||
assert final_msg["role"] == "assistant"
|
|
||||||
assert final_msg["content"], "Final response should have content"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multi_tool_single_turn():
|
|
||||||
"""Model should call multiple tools in a single turn."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[WEATHER_TOOL, CALC_TOOL],
|
|
||||||
valid_tool_names={"get_weather", "calculate"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": (
|
|
||||||
"I need two things at once: "
|
|
||||||
"1) What's the weather in Paris? Use get_weather. "
|
|
||||||
"2) What is 15 * 7? Use calculate. "
|
|
||||||
"Call BOTH tools in a single response."
|
|
||||||
)},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Count distinct tools called
|
|
||||||
tools_called = set()
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tools_called.add(tc["function"]["name"])
|
|
||||||
|
|
||||||
# At minimum, both tools should have been called (maybe in different turns)
|
|
||||||
assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}"
|
|
||||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_multi_turn_conversation():
|
|
||||||
"""Agent should handle multiple turns of tool calls."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[LOOKUP_TOOL, CALC_TOOL],
|
|
||||||
valid_tool_names={"lookup", "calculate"},
|
|
||||||
max_turns=10,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": (
|
|
||||||
"First, use the lookup tool to look up 'meaning of life'. "
|
|
||||||
"Then use calculate to compute 6 * 7. "
|
|
||||||
"Do these in separate tool calls, one at a time."
|
|
||||||
)},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Should have used both tools
|
|
||||||
tools_called = set()
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tools_called.add(tc["function"]["name"])
|
|
||||||
|
|
||||||
assert "lookup" in tools_called, f"lookup not called. Called: {tools_called}"
|
|
||||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
|
||||||
|
|
||||||
# Should finish naturally
|
|
||||||
assert result.finished_naturally, "Should finish naturally after answering"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_rejected():
|
|
||||||
"""If the model calls a tool not in valid_tool_names, it gets an error."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
# Only allow "calculate" but give schema for both
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[CALC_TOOL, WEATHER_TOOL],
|
|
||||||
valid_tool_names={"calculate"}, # weather NOT allowed
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What's the weather in London? Use get_weather."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Check if get_weather was called and rejected
|
|
||||||
if result.tool_errors:
|
|
||||||
weather_errors = [e for e in result.tool_errors if e.tool_name == "get_weather"]
|
|
||||||
assert len(weather_errors) > 0, "get_weather should have been rejected"
|
|
||||||
assert "Unknown tool" in weather_errors[0].error
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_turns_limit():
|
|
||||||
"""Agent should stop after max_turns even if model keeps calling tools."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[LOOKUP_TOOL],
|
|
||||||
valid_tool_names={"lookup"},
|
|
||||||
max_turns=2, # Very low limit
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": (
|
|
||||||
"Keep looking up facts. Look up 'fact 1', then 'fact 2', "
|
|
||||||
"then 'fact 3', then 'fact 4'. Do them one at a time."
|
|
||||||
)},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.turns_used <= 2, f"Should stop at max_turns=2, used {result.turns_used}"
|
|
||||||
assert not result.finished_naturally, "Should NOT finish naturally (hit max_turns)"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_tools_direct_response():
|
|
||||||
"""When no tools are useful, model should respond directly."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=200,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What is 2 + 2? Just answer directly, no tools needed."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally, "Should finish naturally with a direct response"
|
|
||||||
assert result.turns_used == 1, f"Should take exactly 1 turn for a direct answer, took {result.turns_used}"
|
|
||||||
|
|
||||||
final = result.messages[-1]
|
|
||||||
assert final["role"] == "assistant"
|
|
||||||
assert final["content"], "Should have text content"
|
|
||||||
assert "4" in final["content"], "Should contain the answer '4'"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_error_handling():
|
|
||||||
"""Tool execution errors should be captured and reported to the model."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[ERROR_TOOL],
|
|
||||||
valid_tool_names={"failing_tool"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "Please call the failing_tool with input 'test'."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# The tool error should be recorded
|
|
||||||
assert len(result.tool_errors) >= 1, "Should have at least one tool error"
|
|
||||||
assert "RuntimeError" in result.tool_errors[0].error or "always fails" in result.tool_errors[0].error
|
|
||||||
|
|
||||||
# The error should be in the conversation as a tool result
|
|
||||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
|
||||||
assert len(tool_results) >= 1
|
|
||||||
error_result = json.loads(tool_results[0]["content"])
|
|
||||||
assert "error" in error_result
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_agent_result_structure():
|
|
||||||
"""Verify the AgentResult has all expected fields populated."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[CALC_TOOL],
|
|
||||||
valid_tool_names={"calculate"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=300,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What is 3 + 4? Use the calculate tool."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Structural checks
|
|
||||||
assert isinstance(result, AgentResult)
|
|
||||||
assert isinstance(result.messages, list)
|
|
||||||
assert len(result.messages) >= 3, "Should have user + assistant(tool) + tool_result + assistant(final)"
|
|
||||||
assert isinstance(result.turns_used, int)
|
|
||||||
assert result.turns_used > 0
|
|
||||||
assert isinstance(result.finished_naturally, bool)
|
|
||||||
assert isinstance(result.tool_errors, list)
|
|
||||||
assert isinstance(result.reasoning_per_turn, list)
|
|
||||||
|
|
||||||
# Messages should follow OpenAI format
|
|
||||||
for msg in result.messages:
|
|
||||||
assert "role" in msg, f"Message missing 'role': {msg}"
|
|
||||||
assert msg["role"] in ("system", "user", "assistant", "tool"), f"Invalid role: {msg['role']}"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_conversation_history_preserved():
|
|
||||||
"""The full conversation history should be in result.messages."""
|
|
||||||
|
|
||||||
async def _run(server, model):
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=server,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "You are a helpful weather assistant."},
|
|
||||||
{"role": "user", "content": "What's the weather in Berlin? Use get_weather."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# System message should be preserved
|
|
||||||
assert result.messages[0]["role"] == "system"
|
|
||||||
assert "weather assistant" in result.messages[0]["content"]
|
|
||||||
|
|
||||||
# User message should be preserved
|
|
||||||
assert result.messages[1]["role"] == "user"
|
|
||||||
assert "Berlin" in result.messages[1]["content"]
|
|
||||||
|
|
||||||
# Should have assistant + tool + assistant sequence
|
|
||||||
roles = [m["role"] for m in result.messages]
|
|
||||||
assert "tool" in roles, "Should have tool results in conversation"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
await _try_models(_run)
|
|
||||||
|
|
@ -1,359 +0,0 @@
|
||||||
"""Integration tests for HermesAgentLoop with a local vLLM server.
|
|
||||||
|
|
||||||
Tests the full Phase 2 flow: ManagedServer + tool calling with a real
|
|
||||||
vLLM backend, producing actual token IDs and logprobs for RL training.
|
|
||||||
|
|
||||||
Requires a running vLLM server. Start one from the atropos directory:
|
|
||||||
|
|
||||||
python -m example_trainer.vllm_api_server \
|
|
||||||
--model Qwen/Qwen3-4B-Thinking-2507 \
|
|
||||||
--port 9001 \
|
|
||||||
--gpu-memory-utilization 0.8 \
|
|
||||||
--max-model-len=32000
|
|
||||||
|
|
||||||
Tests are automatically skipped if the server is not reachable.
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_agent_loop_vllm.py -v
|
|
||||||
pytest tests/test_agent_loop_vllm.py -v -k "single"
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import requests
|
|
||||||
|
|
||||||
# Ensure repo root is importable
|
|
||||||
_repo_root = Path(__file__).resolve().parent.parent.parent
|
|
||||||
if str(_repo_root) not in sys.path:
|
|
||||||
sys.path.insert(0, str(_repo_root))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.agent_loop import AgentResult, HermesAgentLoop
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Configuration
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
VLLM_HOST = "localhost"
|
|
||||||
VLLM_PORT = 9001
|
|
||||||
VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}"
|
|
||||||
VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507"
|
|
||||||
|
|
||||||
|
|
||||||
def _vllm_is_running() -> bool:
|
|
||||||
"""Check if the vLLM server is reachable."""
|
|
||||||
try:
|
|
||||||
r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3)
|
|
||||||
return r.status_code == 200
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# Skip all tests in this module if vLLM is not running
|
|
||||||
pytestmark = pytest.mark.skipif(
|
|
||||||
not _vllm_is_running(),
|
|
||||||
reason=(
|
|
||||||
f"vLLM server not reachable at {VLLM_BASE_URL}. "
|
|
||||||
"Start it with: python -m example_trainer.vllm_api_server "
|
|
||||||
f"--model {VLLM_MODEL} --port {VLLM_PORT} "
|
|
||||||
"--gpu-memory-utilization 0.8 --max-model-len=32000"
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Server setup
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def _make_server_manager():
|
|
||||||
"""Create a ServerManager pointing to the local vLLM server."""
|
|
||||||
from atroposlib.envs.server_handling.server_manager import (
|
|
||||||
ServerManager,
|
|
||||||
APIServerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = APIServerConfig(
|
|
||||||
base_url=VLLM_BASE_URL,
|
|
||||||
model_name=VLLM_MODEL,
|
|
||||||
server_type="vllm",
|
|
||||||
health_check=False,
|
|
||||||
)
|
|
||||||
sm = ServerManager([config], tool_parser="hermes")
|
|
||||||
sm.servers[0].server_healthy = True
|
|
||||||
return sm
|
|
||||||
|
|
||||||
|
|
||||||
def _get_tokenizer():
|
|
||||||
"""Load the tokenizer for the model."""
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
return AutoTokenizer.from_pretrained(VLLM_MODEL)
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Fake tools
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
WEATHER_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get the current weather for a city. Returns temperature and conditions.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "City name, e.g. 'Tokyo'",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["city"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
CALC_TOOL = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": "calculate",
|
|
||||||
"description": "Calculate a math expression. Returns the numeric result.",
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"expression": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "Math expression, e.g. '2 + 3'",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"required": ["expression"],
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str:
|
|
||||||
"""Handle fake tool calls for testing."""
|
|
||||||
if tool_name == "get_weather":
|
|
||||||
city = args.get("city", "Unknown")
|
|
||||||
return json.dumps({
|
|
||||||
"city": city,
|
|
||||||
"temperature": 22,
|
|
||||||
"conditions": "sunny",
|
|
||||||
"humidity": 45,
|
|
||||||
})
|
|
||||||
elif tool_name == "calculate":
|
|
||||||
expr = args.get("expression", "0")
|
|
||||||
try:
|
|
||||||
result = eval(expr, {"__builtins__": {}}, {})
|
|
||||||
return json.dumps({"result": result})
|
|
||||||
except Exception as e:
|
|
||||||
return json.dumps({"error": str(e)})
|
|
||||||
return json.dumps({"error": f"Unknown tool: {tool_name}"})
|
|
||||||
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# Tests
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vllm_single_tool_call():
|
|
||||||
"""vLLM model calls a tool, gets result, responds — full Phase 2 flow."""
|
|
||||||
sm = _make_server_manager()
|
|
||||||
tokenizer = _get_tokenizer()
|
|
||||||
|
|
||||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert isinstance(result, AgentResult)
|
|
||||||
assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}"
|
|
||||||
|
|
||||||
# Verify tool call happened
|
|
||||||
tool_calls_found = False
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
if tc["function"]["name"] == "get_weather":
|
|
||||||
tool_calls_found = True
|
|
||||||
args = json.loads(tc["function"]["arguments"])
|
|
||||||
assert "city" in args
|
|
||||||
assert tool_calls_found, "Model should have called get_weather"
|
|
||||||
|
|
||||||
# Verify tool results in conversation
|
|
||||||
tool_results = [m for m in result.messages if m.get("role") == "tool"]
|
|
||||||
assert len(tool_results) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vllm_multi_tool_calls():
|
|
||||||
"""vLLM model calls multiple tools across turns."""
|
|
||||||
sm = _make_server_manager()
|
|
||||||
tokenizer = _get_tokenizer()
|
|
||||||
|
|
||||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=[WEATHER_TOOL, CALC_TOOL],
|
|
||||||
valid_tool_names={"get_weather", "calculate"},
|
|
||||||
max_turns=10,
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": (
|
|
||||||
"I need two things: "
|
|
||||||
"1) What's the weather in Paris? Use get_weather. "
|
|
||||||
"2) What is 15 * 7? Use calculate."
|
|
||||||
)},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Both tools should be called
|
|
||||||
tools_called = set()
|
|
||||||
for msg in result.messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tools_called.add(tc["function"]["name"])
|
|
||||||
|
|
||||||
assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}"
|
|
||||||
assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vllm_managed_server_produces_nodes():
|
|
||||||
"""ManagedServer should produce SequenceNodes with tokens and logprobs."""
|
|
||||||
sm = _make_server_manager()
|
|
||||||
tokenizer = _get_tokenizer()
|
|
||||||
|
|
||||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What's the weather in Berlin? Use get_weather."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Get the managed state — should have SequenceNodes
|
|
||||||
state = managed.get_state()
|
|
||||||
|
|
||||||
assert state is not None, "ManagedServer should return state"
|
|
||||||
nodes = state.get("nodes", [])
|
|
||||||
assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}"
|
|
||||||
|
|
||||||
node = nodes[0]
|
|
||||||
assert hasattr(node, "tokens"), "Node should have tokens"
|
|
||||||
assert hasattr(node, "logprobs"), "Node should have logprobs"
|
|
||||||
assert len(node.tokens) > 0, "Tokens should not be empty"
|
|
||||||
assert len(node.logprobs) > 0, "Logprobs should not be empty"
|
|
||||||
assert len(node.tokens) == len(node.logprobs), (
|
|
||||||
f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vllm_no_tools_direct_response():
|
|
||||||
"""vLLM model should respond directly when no tools are needed."""
|
|
||||||
sm = _make_server_manager()
|
|
||||||
tokenizer = _get_tokenizer()
|
|
||||||
|
|
||||||
async with sm.managed_server(tokenizer=tokenizer) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=[WEATHER_TOOL],
|
|
||||||
valid_tool_names={"get_weather"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What is 2 + 2? Answer directly, no tools."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
assert result.finished_naturally, "Should finish naturally"
|
|
||||||
assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}"
|
|
||||||
|
|
||||||
final = result.messages[-1]
|
|
||||||
assert final["role"] == "assistant"
|
|
||||||
assert final["content"], "Should have content"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_vllm_thinking_content_extracted():
|
|
||||||
"""Qwen3-Thinking model should produce reasoning content."""
|
|
||||||
sm = _make_server_manager()
|
|
||||||
tokenizer = _get_tokenizer()
|
|
||||||
|
|
||||||
async with sm.managed_server(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
preserve_think_blocks=True,
|
|
||||||
) as managed:
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=managed,
|
|
||||||
tool_schemas=[CALC_TOOL],
|
|
||||||
valid_tool_names={"calculate"},
|
|
||||||
max_turns=5,
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": "What is 123 * 456? Use the calculate tool."},
|
|
||||||
]
|
|
||||||
|
|
||||||
with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler):
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
# Qwen3-Thinking should generate <think> blocks
|
|
||||||
# Check if any content contains thinking markers
|
|
||||||
has_thinking = False
|
|
||||||
for msg in result.messages:
|
|
||||||
content = msg.get("content", "") or ""
|
|
||||||
if "<think>" in content or "</think>" in content:
|
|
||||||
has_thinking = True
|
|
||||||
break
|
|
||||||
|
|
||||||
# Also check reasoning_per_turn
|
|
||||||
has_reasoning = any(r for r in result.reasoning_per_turn if r)
|
|
||||||
|
|
||||||
# At least one of these should be true for a thinking model
|
|
||||||
assert has_thinking or has_reasoning, (
|
|
||||||
"Qwen3-Thinking should produce <think> blocks or reasoning content"
|
|
||||||
)
|
|
||||||
|
|
@ -23,7 +23,7 @@ class TestStreamingAssemblyRepair:
|
||||||
|
|
||||||
These tests verify the REPAIR FUNCTION itself works correctly for the
|
These tests verify the REPAIR FUNCTION itself works correctly for the
|
||||||
cases that arise during streaming assembly. Integration tests that
|
cases that arise during streaming assembly. Integration tests that
|
||||||
exercise the full streaming path are in test_agent_loop_tool_calling.py.
|
exercise the full streaming path are in run_agent.py's streaming tests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# -- Truncation cases (most common streaming failure) --
|
# -- Truncation cases (most common streaming failure) --
|
||||||
|
|
|
||||||
|
|
@ -278,7 +278,7 @@ class TestLegacyToolsetMap:
|
||||||
expected = [
|
expected = [
|
||||||
"web_tools", "terminal_tools", "vision_tools", "moa_tools",
|
"web_tools", "terminal_tools", "vision_tools", "moa_tools",
|
||||||
"image_tools", "skills_tools", "browser_tools", "cronjob_tools",
|
"image_tools", "skills_tools", "browser_tools", "cronjob_tools",
|
||||||
"rl_tools", "file_tools", "tts_tools",
|
"file_tools", "tts_tools",
|
||||||
]
|
]
|
||||||
for name in expected:
|
for name in expected:
|
||||||
assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}"
|
assert name in _LEGACY_TOOLSET_MAP, f"Missing legacy toolset: {name}"
|
||||||
|
|
|
||||||
|
|
@ -1,178 +0,0 @@
|
||||||
"""
|
|
||||||
Tests for ManagedServer / tool-parser integration.
|
|
||||||
|
|
||||||
Validates that:
|
|
||||||
1. The installed atroposlib API still matches Hermes's expectations
|
|
||||||
2. Hermes's parser registry remains compatible with ManagedServer parsing
|
|
||||||
3. HermesAgentBaseEnv wires the selected parser into ServerManager correctly
|
|
||||||
|
|
||||||
These tests verify the contract between hermes-agent's environments/ code
|
|
||||||
and atroposlib's ManagedServer. They detect API incompatibilities early.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
||||||
|
|
||||||
try:
|
|
||||||
import atroposlib # noqa: F401
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestManagedServerAPI:
|
|
||||||
"""Test that ManagedServer's API matches what hermes-agent expects."""
|
|
||||||
|
|
||||||
def test_managed_server_init_signature(self):
|
|
||||||
"""ManagedServer should accept tool_call_parser parameter."""
|
|
||||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
|
||||||
|
|
||||||
sig = inspect.signature(ManagedServer.__init__)
|
|
||||||
params = list(sig.parameters.keys())
|
|
||||||
|
|
||||||
# Core params that must exist
|
|
||||||
assert "self" in params
|
|
||||||
assert "server" in params
|
|
||||||
assert "tokenizer" in params
|
|
||||||
assert "track_tree" in params
|
|
||||||
|
|
||||||
# tool_call_parser — required for tool_call_support branch
|
|
||||||
# If this fails, atroposlib hasn't been updated to tool_call_support
|
|
||||||
has_tool_parser = "tool_call_parser" in params
|
|
||||||
if not has_tool_parser:
|
|
||||||
pytest.skip(
|
|
||||||
"ManagedServer does not have tool_call_parser param — "
|
|
||||||
"baseline atroposlib (pre tool_call_support branch)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_server_manager_managed_server_signature(self):
|
|
||||||
"""ServerManager.managed_server() should accept tool_call_parser."""
|
|
||||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
|
||||||
|
|
||||||
sig = inspect.signature(ServerManager.managed_server)
|
|
||||||
params = list(sig.parameters.keys())
|
|
||||||
|
|
||||||
assert "self" in params
|
|
||||||
assert "tokenizer" in params
|
|
||||||
|
|
||||||
has_tool_parser = "tool_call_parser" in params
|
|
||||||
if not has_tool_parser:
|
|
||||||
pytest.skip(
|
|
||||||
"ServerManager.managed_server() does not have tool_call_parser param — "
|
|
||||||
"baseline atroposlib (pre tool_call_support branch)"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_managed_server_chat_template_kwargs(self):
|
|
||||||
"""ManagedServer should have CHAT_TEMPLATE_KWARGS for forwarding tools/thinking."""
|
|
||||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
|
||||||
|
|
||||||
if not hasattr(ManagedServer, "CHAT_TEMPLATE_KWARGS"):
|
|
||||||
pytest.skip(
|
|
||||||
"ManagedServer does not have CHAT_TEMPLATE_KWARGS — "
|
|
||||||
"baseline atroposlib (pre tool_call_support branch)"
|
|
||||||
)
|
|
||||||
|
|
||||||
kwargs = ManagedServer.CHAT_TEMPLATE_KWARGS
|
|
||||||
assert "tools" in kwargs, "tools must be in CHAT_TEMPLATE_KWARGS"
|
|
||||||
|
|
||||||
def test_no_get_logprobs_method(self):
|
|
||||||
"""get_logprobs should be removed in tool_call_support branch."""
|
|
||||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
|
||||||
|
|
||||||
# In baseline, get_logprobs exists. In tool_call_support, it's removed.
|
|
||||||
# We just note the state — not a hard fail either way.
|
|
||||||
has_get_logprobs = hasattr(ManagedServer, "get_logprobs")
|
|
||||||
if has_get_logprobs:
|
|
||||||
pytest.skip(
|
|
||||||
"ManagedServer still has get_logprobs — baseline atroposlib"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestParserCompatibility:
|
|
||||||
"""Test that hermes-agent's parsers match ManagedServer's expectations."""
|
|
||||||
|
|
||||||
def test_parser_parse_returns_correct_format(self):
|
|
||||||
"""
|
|
||||||
ManagedServer expects parser.parse(text) -> (content, tool_calls)
|
|
||||||
where tool_calls is a list of objects with .id, .function.name, .function.arguments
|
|
||||||
"""
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
|
|
||||||
tc = tool_calls[0]
|
|
||||||
# ManagedServer accesses these attrs directly
|
|
||||||
assert hasattr(tc, "id")
|
|
||||||
assert hasattr(tc, "function")
|
|
||||||
assert hasattr(tc.function, "name")
|
|
||||||
assert hasattr(tc.function, "arguments")
|
|
||||||
|
|
||||||
def test_parser_no_tools_returns_none(self):
|
|
||||||
"""ManagedServer checks `if parsed_tool_calls:` — None should be falsy."""
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
content, tool_calls = parser.parse("Just text, no tools")
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
||||||
def test_parser_content_is_string_or_none(self):
|
|
||||||
"""ManagedServer uses `parsed_content or ""` — must be str or None."""
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
|
|
||||||
# With tool calls
|
|
||||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
|
||||||
content, _ = parser.parse(text)
|
|
||||||
assert content is None or isinstance(content, str)
|
|
||||||
|
|
||||||
# Without tool calls
|
|
||||||
content2, _ = parser.parse("Just text")
|
|
||||||
assert isinstance(content2, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestBaseEnvCompatibility:
|
|
||||||
"""Test that hermes_base_env.py's tool-parser wiring matches the current API."""
|
|
||||||
|
|
||||||
def test_hermes_base_env_sets_server_manager_tool_parser(self):
|
|
||||||
"""Hermes wires parser selection through ServerManager.tool_parser."""
|
|
||||||
import ast
|
|
||||||
|
|
||||||
base_env_path = Path(__file__).parent.parent.parent / "environments" / "hermes_base_env.py"
|
|
||||||
source = base_env_path.read_text()
|
|
||||||
tree = ast.parse(source)
|
|
||||||
|
|
||||||
found_assignment = False
|
|
||||||
for node in ast.walk(tree):
|
|
||||||
if isinstance(node, ast.Assign):
|
|
||||||
for target in node.targets:
|
|
||||||
if isinstance(target, ast.Attribute) and target.attr == "tool_parser":
|
|
||||||
parent = target.value
|
|
||||||
if (
|
|
||||||
isinstance(parent, ast.Attribute)
|
|
||||||
and parent.attr == "server"
|
|
||||||
and isinstance(parent.value, ast.Name)
|
|
||||||
and parent.value.id == "self"
|
|
||||||
):
|
|
||||||
found_assignment = True
|
|
||||||
|
|
||||||
assert found_assignment, (
|
|
||||||
"hermes_base_env.py should set self.server.tool_parser from config.tool_call_parser"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_hermes_base_env_uses_config_tool_call_parser(self):
|
|
||||||
"""Verify hermes_base_env uses the config field rather than a local parser instance."""
|
|
||||||
base_env_path = Path(__file__).parent.parent.parent / "environments" / "hermes_base_env.py"
|
|
||||||
source = base_env_path.read_text()
|
|
||||||
|
|
||||||
assert 'tool_call_parser: str = Field(' in source
|
|
||||||
assert 'self.server.tool_parser = config.tool_call_parser' in source
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
"""Tests for rl_training_tool.py — file handle lifecycle and cleanup.
|
|
||||||
|
|
||||||
Verifies that _stop_training_run properly closes log file handles,
|
|
||||||
terminates processes, and handles edge cases on failure paths.
|
|
||||||
Inspired by PR #715 (0xbyt4).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from tools.rl_training_tool import RunState, _stop_training_run
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_state(**overrides) -> RunState:
|
|
||||||
"""Create a minimal RunState for testing."""
|
|
||||||
defaults = {
|
|
||||||
"run_id": "test-run-001",
|
|
||||||
"environment": "test_env",
|
|
||||||
"config": {},
|
|
||||||
}
|
|
||||||
defaults.update(overrides)
|
|
||||||
return RunState(**defaults)
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopTrainingRunFileHandles:
|
|
||||||
"""Verify that _stop_training_run closes log file handles stored as attributes."""
|
|
||||||
|
|
||||||
def test_closes_all_log_file_handles(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
files = {}
|
|
||||||
for attr in ("api_log_file", "trainer_log_file", "env_log_file"):
|
|
||||||
fh = MagicMock()
|
|
||||||
setattr(state, attr, fh)
|
|
||||||
files[attr] = fh
|
|
||||||
|
|
||||||
_stop_training_run(state)
|
|
||||||
|
|
||||||
for attr, fh in files.items():
|
|
||||||
fh.close.assert_called_once()
|
|
||||||
assert getattr(state, attr) is None
|
|
||||||
|
|
||||||
def test_clears_file_attrs_to_none(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
state.api_log_file = MagicMock()
|
|
||||||
|
|
||||||
_stop_training_run(state)
|
|
||||||
|
|
||||||
assert state.api_log_file is None
|
|
||||||
|
|
||||||
def test_close_exception_does_not_propagate(self):
|
|
||||||
"""If a file handle .close() raises, it must not crash."""
|
|
||||||
state = _make_run_state()
|
|
||||||
bad_fh = MagicMock()
|
|
||||||
bad_fh.close.side_effect = OSError("already closed")
|
|
||||||
good_fh = MagicMock()
|
|
||||||
state.api_log_file = bad_fh
|
|
||||||
state.trainer_log_file = good_fh
|
|
||||||
|
|
||||||
_stop_training_run(state) # should not raise
|
|
||||||
|
|
||||||
bad_fh.close.assert_called_once()
|
|
||||||
good_fh.close.assert_called_once()
|
|
||||||
|
|
||||||
def test_handles_missing_file_attrs(self):
|
|
||||||
"""RunState without log file attrs should not crash."""
|
|
||||||
state = _make_run_state()
|
|
||||||
# No log file attrs set at all — getattr(..., None) should handle it
|
|
||||||
_stop_training_run(state) # should not raise
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopTrainingRunProcesses:
|
|
||||||
"""Verify that _stop_training_run terminates processes correctly."""
|
|
||||||
|
|
||||||
def test_terminates_running_processes(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
for attr in ("api_process", "trainer_process", "env_process"):
|
|
||||||
proc = MagicMock()
|
|
||||||
proc.poll.return_value = None # still running
|
|
||||||
setattr(state, attr, proc)
|
|
||||||
|
|
||||||
_stop_training_run(state)
|
|
||||||
|
|
||||||
for attr in ("api_process", "trainer_process", "env_process"):
|
|
||||||
getattr(state, attr).terminate.assert_called_once()
|
|
||||||
|
|
||||||
def test_does_not_terminate_exited_processes(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
proc = MagicMock()
|
|
||||||
proc.poll.return_value = 0 # already exited
|
|
||||||
state.api_process = proc
|
|
||||||
|
|
||||||
_stop_training_run(state)
|
|
||||||
|
|
||||||
proc.terminate.assert_not_called()
|
|
||||||
|
|
||||||
def test_handles_none_processes(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
# All process attrs are None by default
|
|
||||||
_stop_training_run(state) # should not raise
|
|
||||||
|
|
||||||
def test_handles_mixed_running_and_exited_processes(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
# api still running
|
|
||||||
api = MagicMock()
|
|
||||||
api.poll.return_value = None
|
|
||||||
state.api_process = api
|
|
||||||
# trainer already exited
|
|
||||||
trainer = MagicMock()
|
|
||||||
trainer.poll.return_value = 0
|
|
||||||
state.trainer_process = trainer
|
|
||||||
# env is None
|
|
||||||
state.env_process = None
|
|
||||||
|
|
||||||
_stop_training_run(state)
|
|
||||||
|
|
||||||
api.terminate.assert_called_once()
|
|
||||||
trainer.terminate.assert_not_called()
|
|
||||||
|
|
||||||
|
|
||||||
class TestStopTrainingRunStatus:
|
|
||||||
"""Verify status transitions in _stop_training_run."""
|
|
||||||
|
|
||||||
def test_sets_status_to_stopped_when_running(self):
|
|
||||||
state = _make_run_state(status="running")
|
|
||||||
_stop_training_run(state)
|
|
||||||
assert state.status == "stopped"
|
|
||||||
|
|
||||||
def test_does_not_change_status_when_failed(self):
|
|
||||||
state = _make_run_state(status="failed")
|
|
||||||
_stop_training_run(state)
|
|
||||||
assert state.status == "failed"
|
|
||||||
|
|
||||||
def test_does_not_change_status_when_pending(self):
|
|
||||||
state = _make_run_state(status="pending")
|
|
||||||
_stop_training_run(state)
|
|
||||||
assert state.status == "pending"
|
|
||||||
|
|
||||||
def test_no_crash_with_no_processes_and_no_files(self):
|
|
||||||
state = _make_run_state()
|
|
||||||
_stop_training_run(state) # should not raise
|
|
||||||
assert state.status == "pending"
|
|
||||||
|
|
@ -1,274 +0,0 @@
|
||||||
"""
|
|
||||||
Tests for environments/tool_call_parsers/ — client-side tool call parsers.
|
|
||||||
|
|
||||||
These parsers extract structured tool_calls from raw model output text.
|
|
||||||
Used in Phase 2 (VLLM/generate) where the server returns raw tokens.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Ensure repo root is importable
|
|
||||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
|
||||||
|
|
||||||
try:
|
|
||||||
from environments.tool_call_parsers import (
|
|
||||||
ParseResult,
|
|
||||||
ToolCallParser,
|
|
||||||
get_parser,
|
|
||||||
list_parsers,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
pytest.skip("atroposlib not installed", allow_module_level=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Registry tests ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestParserRegistry:
|
|
||||||
def test_list_parsers_returns_nonempty(self):
|
|
||||||
parsers = list_parsers()
|
|
||||||
assert len(parsers) > 0
|
|
||||||
|
|
||||||
def test_hermes_parser_registered(self):
|
|
||||||
parsers = list_parsers()
|
|
||||||
assert "hermes" in parsers
|
|
||||||
|
|
||||||
def test_get_parser_returns_instance(self):
|
|
||||||
parser = get_parser("hermes")
|
|
||||||
assert isinstance(parser, ToolCallParser)
|
|
||||||
|
|
||||||
def test_get_parser_unknown_raises(self):
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
get_parser("nonexistent_parser_xyz")
|
|
||||||
|
|
||||||
def test_all_registered_parsers_instantiate(self):
|
|
||||||
"""Every registered parser should be importable and instantiable."""
|
|
||||||
for name in list_parsers():
|
|
||||||
parser = get_parser(name)
|
|
||||||
assert isinstance(parser, ToolCallParser)
|
|
||||||
assert hasattr(parser, "parse")
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Hermes parser tests ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestHermesParser:
|
|
||||||
@pytest.fixture
|
|
||||||
def parser(self):
|
|
||||||
return get_parser("hermes")
|
|
||||||
|
|
||||||
def test_no_tool_call(self, parser):
|
|
||||||
text = "Hello, I can help you with that."
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert content == text
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
||||||
def test_single_tool_call(self, parser):
|
|
||||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}}</tool_call>'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "terminal"
|
|
||||||
args = json.loads(tool_calls[0].function.arguments)
|
|
||||||
assert args["command"] == "ls -la"
|
|
||||||
|
|
||||||
def test_tool_call_with_surrounding_text(self, parser):
|
|
||||||
text = 'Let me check that for you.\n<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "terminal"
|
|
||||||
# Content should have the surrounding text
|
|
||||||
if content is not None:
|
|
||||||
assert "check that" in content or content.strip() != ""
|
|
||||||
|
|
||||||
def test_multiple_tool_calls(self, parser):
|
|
||||||
text = (
|
|
||||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
|
||||||
'<tool_call>{"name": "read_file", "arguments": {"path": "test.py"}}</tool_call>'
|
|
||||||
)
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 2
|
|
||||||
names = {tc.function.name for tc in tool_calls}
|
|
||||||
assert "terminal" in names
|
|
||||||
assert "read_file" in names
|
|
||||||
|
|
||||||
def test_tool_call_ids_are_unique(self, parser):
|
|
||||||
text = (
|
|
||||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
|
||||||
'<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
|
||||||
)
|
|
||||||
_, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
ids = [tc.id for tc in tool_calls]
|
|
||||||
assert len(ids) == len(set(ids)), "Tool call IDs must be unique"
|
|
||||||
|
|
||||||
def test_empty_string(self, parser):
|
|
||||||
content, tool_calls = parser.parse("")
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
||||||
def test_malformed_json_in_tool_call(self, parser):
|
|
||||||
text = '<tool_call>not valid json</tool_call>'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
# Should either return None tool_calls or handle gracefully
|
|
||||||
# (implementation may vary — some parsers return error tool calls)
|
|
||||||
|
|
||||||
def test_truncated_tool_call(self, parser):
|
|
||||||
"""Test handling of unclosed tool_call tag (model truncated mid-generation)."""
|
|
||||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
# Parser should handle truncated output gracefully
|
|
||||||
# Either parse it successfully or return None
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Parse result contract tests (applies to ALL parsers) ───────────────
|
|
||||||
|
|
||||||
class TestParseResultContract:
|
|
||||||
"""Ensure all parsers conform to the ParseResult contract."""
|
|
||||||
|
|
||||||
@pytest.fixture(params=["hermes"]) # Add more as needed
|
|
||||||
def parser(self, request):
|
|
||||||
return get_parser(request.param)
|
|
||||||
|
|
||||||
def test_returns_tuple_of_two(self, parser):
|
|
||||||
result = parser.parse("hello world")
|
|
||||||
assert isinstance(result, tuple)
|
|
||||||
assert len(result) == 2
|
|
||||||
|
|
||||||
def test_no_tools_returns_none_tool_calls(self, parser):
|
|
||||||
content, tool_calls = parser.parse("Just plain text, no tools.")
|
|
||||||
assert tool_calls is None
|
|
||||||
assert content is not None
|
|
||||||
|
|
||||||
def test_tool_calls_are_proper_objects(self, parser):
|
|
||||||
"""When tool calls are found, they should be ChatCompletionMessageToolCall objects."""
|
|
||||||
# Use hermes format since that's universal
|
|
||||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "echo hi"}}</tool_call>'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
if tool_calls is not None:
|
|
||||||
for tc in tool_calls:
|
|
||||||
assert hasattr(tc, "id")
|
|
||||||
assert hasattr(tc, "function")
|
|
||||||
assert hasattr(tc.function, "name")
|
|
||||||
assert hasattr(tc.function, "arguments")
|
|
||||||
assert tc.id is not None
|
|
||||||
assert isinstance(tc.function.name, str)
|
|
||||||
assert isinstance(tc.function.arguments, str)
|
|
||||||
|
|
||||||
|
|
||||||
# ─── DeepSeek V3 parser tests ───────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestDeepSeekV3Parser:
|
|
||||||
@pytest.fixture
|
|
||||||
def parser(self):
|
|
||||||
return get_parser("deepseek_v3")
|
|
||||||
|
|
||||||
def test_no_tool_call(self, parser):
|
|
||||||
text = "Hello, how can I help you?"
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert content == text
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
||||||
def test_single_tool_call(self, parser):
|
|
||||||
text = (
|
|
||||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather\n'
|
|
||||||
'```json\n{"city": "London"}\n```<|tool▁call▁end|><|tool▁calls▁end|>'
|
|
||||||
)
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "get_weather"
|
|
||||||
args = json.loads(tool_calls[0].function.arguments)
|
|
||||||
assert args["city"] == "London"
|
|
||||||
|
|
||||||
def test_multiple_tool_calls(self, parser):
|
|
||||||
text = (
|
|
||||||
'<|tool▁calls▁begin|>'
|
|
||||||
'<|tool▁call▁begin|>function<|tool▁sep|>get_weather\n'
|
|
||||||
'```json\n{"city": "London"}\n```<|tool▁call▁end|>'
|
|
||||||
'<|tool▁call▁begin|>function<|tool▁sep|>get_time\n'
|
|
||||||
'```json\n{"timezone": "UTC"}\n```<|tool▁call▁end|>'
|
|
||||||
'<|tool▁calls▁end|>'
|
|
||||||
)
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 2, f"Expected 2 tool calls, got {len(tool_calls)}"
|
|
||||||
names = [tc.function.name for tc in tool_calls]
|
|
||||||
assert "get_weather" in names
|
|
||||||
assert "get_time" in names
|
|
||||||
|
|
||||||
def test_tool_call_with_preceding_text(self, parser):
|
|
||||||
text = (
|
|
||||||
'Let me check that for you.\n'
|
|
||||||
'<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>terminal\n'
|
|
||||||
'```json\n{"command": "ls"}\n```<|tool▁call▁end|><|tool▁calls▁end|>'
|
|
||||||
)
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
|
|
||||||
|
|
||||||
# ─── Mistral parser tests ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestMistralParser:
|
|
||||||
@pytest.fixture
|
|
||||||
def parser(self):
|
|
||||||
return get_parser("mistral")
|
|
||||||
|
|
||||||
def test_no_tool_call(self, parser):
|
|
||||||
text = "Hello, how can I help you?"
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert content == text
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
||||||
def test_pre_v11_single_tool_call(self, parser):
|
|
||||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "func"
|
|
||||||
args = json.loads(tool_calls[0].function.arguments)
|
|
||||||
assert args["key"] == "val"
|
|
||||||
|
|
||||||
def test_pre_v11_nested_json(self, parser):
|
|
||||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "func"
|
|
||||||
args = json.loads(tool_calls[0].function.arguments)
|
|
||||||
assert args["nested"]["deep"] is True
|
|
||||||
|
|
||||||
def test_v11_single_tool_call(self, parser):
|
|
||||||
text = '[TOOL_CALLS]get_weather{"city": "London"}'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "get_weather"
|
|
||||||
args = json.loads(tool_calls[0].function.arguments)
|
|
||||||
assert args["city"] == "London"
|
|
||||||
|
|
||||||
def test_v11_multiple_tool_calls(self, parser):
|
|
||||||
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 2
|
|
||||||
names = [tc.function.name for tc in tool_calls]
|
|
||||||
assert "func1" in names
|
|
||||||
assert "func2" in names
|
|
||||||
|
|
||||||
def test_preceding_text_preserved(self, parser):
|
|
||||||
text = 'Hello[TOOL_CALLS]func{"a": 1}'
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert content == "Hello"
|
|
||||||
assert tool_calls is not None
|
|
||||||
assert len(tool_calls) == 1
|
|
||||||
assert tool_calls[0].function.name == "func"
|
|
||||||
|
|
||||||
def test_malformed_json_fallback(self, parser):
|
|
||||||
text = "[TOOL_CALLS] not valid json"
|
|
||||||
content, tool_calls = parser.parse(text)
|
|
||||||
assert tool_calls is None
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 65f084ee8054a5d02aeac76e24ed60388511c82b
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
"""Configurable budget constants for tool result persistence.
|
"""Configurable budget constants for tool result persistence.
|
||||||
|
|
||||||
Overridable at the RL environment level via HermesAgentEnvConfig fields.
|
|
||||||
Per-tool resolution: pinned > config overrides > registry > default.
|
Per-tool resolution: pinned > config overrides > registry > default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
13
toolsets.py
13
toolsets.py
|
|
@ -170,17 +170,6 @@ TOOLSETS = {
|
||||||
"includes": []
|
"includes": []
|
||||||
},
|
},
|
||||||
|
|
||||||
"rl": {
|
|
||||||
"description": "RL training tools for running reinforcement learning on Tinker-Atropos",
|
|
||||||
"tools": [
|
|
||||||
"rl_list_environments", "rl_select_environment",
|
|
||||||
"rl_get_current_config", "rl_edit_config",
|
|
||||||
"rl_start_training", "rl_check_status",
|
|
||||||
"rl_stop_training", "rl_get_results",
|
|
||||||
"rl_list_runs", "rl_test_inference"
|
|
||||||
],
|
|
||||||
"includes": []
|
|
||||||
},
|
|
||||||
|
|
||||||
"file": {
|
"file": {
|
||||||
"description": "File manipulation tools: read, write, patch (with fuzzy matching), and search (content + files)",
|
"description": "File manipulation tools: read, write, patch (with fuzzy matching), and search (content + files)",
|
||||||
|
|
@ -390,7 +379,7 @@ TOOLSETS = {
|
||||||
# Mirrors hermes-cli so cron's "default" toolset is the same set of
|
# Mirrors hermes-cli so cron's "default" toolset is the same set of
|
||||||
# core tools users see interactively — then `hermes tools` filters
|
# core tools users see interactively — then `hermes tools` filters
|
||||||
# them down per the platform config. _DEFAULT_OFF_TOOLSETS (moa,
|
# them down per the platform config. _DEFAULT_OFF_TOOLSETS (moa,
|
||||||
# homeassistant, rl) are excluded by _get_platform_tools() unless
|
# homeassistant) are excluded by _get_platform_tools() unless
|
||||||
# the user explicitly enables them.
|
# the user explicitly enables them.
|
||||||
"description": "Default cron toolset - same core tools as hermes-cli; gated by `hermes tools`",
|
"description": "Default cron toolset - same core tools as hermes-cli; gated by `hermes tools`",
|
||||||
"tools": _HERMES_CORE_TOOLS,
|
"tools": _HERMES_CORE_TOOLS,
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,6 @@ hermes-agent/
|
||||||
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
├── cron/ # Scheduler (jobs.py, scheduler.py)
|
||||||
├── plugins/memory/ # Memory provider plugins
|
├── plugins/memory/ # Memory provider plugins
|
||||||
├── plugins/context_engine/ # Context engine plugins
|
├── plugins/context_engine/ # Context engine plugins
|
||||||
├── environments/ # RL training environments (Atropos)
|
|
||||||
├── skills/ # Bundled skills (always available)
|
├── skills/ # Bundled skills (always available)
|
||||||
├── optional-skills/ # Official optional skills (install explicitly)
|
├── optional-skills/ # Official optional skills (install explicitly)
|
||||||
├── website/ # Docusaurus documentation site
|
├── website/ # Docusaurus documentation site
|
||||||
|
|
@ -185,7 +184,6 @@ If you are new to the codebase:
|
||||||
8. **[Gateway Internals](./gateway-internals.md)** — messaging platform gateway
|
8. **[Gateway Internals](./gateway-internals.md)** — messaging platform gateway
|
||||||
9. **[Context Compression & Prompt Caching](./context-compression-and-caching.md)** — compression and caching
|
9. **[Context Compression & Prompt Caching](./context-compression-and-caching.md)** — compression and caching
|
||||||
10. **[ACP Internals](./acp-internals.md)** — IDE integration
|
10. **[ACP Internals](./acp-internals.md)** — IDE integration
|
||||||
11. **[Environments, Benchmarks & Data Generation](./environments.md)** — RL training
|
|
||||||
|
|
||||||
## Major Subsystems
|
## Major Subsystems
|
||||||
|
|
||||||
|
|
@ -247,11 +245,11 @@ Exposes Hermes as an editor-native agent over stdio/JSON-RPC for VS Code, Zed, a
|
||||||
|
|
||||||
→ [ACP Internals](./acp-internals.md)
|
→ [ACP Internals](./acp-internals.md)
|
||||||
|
|
||||||
### RL / Environments / Trajectories
|
### Trajectories
|
||||||
|
|
||||||
Full environment framework for evaluation and RL training. Integrates with Atropos, supports multiple tool-call parsers, and generates ShareGPT-format trajectories.
|
Generates ShareGPT-format trajectories from agent sessions for training data generation.
|
||||||
|
|
||||||
→ [Environments, Benchmarks & Data Generation](./environments.md), [Trajectories & Training Format](./trajectory-format.md)
|
→ [Trajectories & Training Format](./trajectory-format.md)
|
||||||
|
|
||||||
## Design Principles
|
## Design Principles
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,9 +50,6 @@ export VIRTUAL_ENV="$(pwd)/venv"
|
||||||
|
|
||||||
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
# Install with all extras (messaging, cron, CLI menus, dev tools)
|
||||||
uv pip install -e ".[all,dev]"
|
uv pip install -e ".[all,dev]"
|
||||||
# tinker-atropos is a git submodule — needs `git submodule update --init` first
|
|
||||||
# if you didn't clone with `--recurse-submodules`
|
|
||||||
uv pip install -e "./tinker-atropos"
|
|
||||||
|
|
||||||
# Optional: browser tools
|
# Optional: browser tools
|
||||||
npm install
|
npm install
|
||||||
|
|
|
||||||
|
|
@ -1,520 +0,0 @@
|
||||||
---
|
|
||||||
sidebar_position: 5
|
|
||||||
title: "Environments, Benchmarks & Data Generation"
|
|
||||||
description: "Building RL training environments, running evaluation benchmarks, and generating SFT data with the Hermes-Agent Atropos integration"
|
|
||||||
---
|
|
||||||
|
|
||||||
# Environments, Benchmarks & Data Generation
|
|
||||||
|
|
||||||
Hermes Agent includes a full environment framework that connects its tool-calling capabilities to the [Atropos](https://github.com/NousResearch/atropos) RL training framework. This enables three workflows:
|
|
||||||
|
|
||||||
1. **RL Training** — Train language models on multi-turn agentic tasks with GRPO
|
|
||||||
2. **Benchmarks** — Evaluate models on standardised agentic benchmarks
|
|
||||||
3. **Data Generation** — Generate SFT training data from agent rollouts
|
|
||||||
|
|
||||||
All three share the same core: an **environment** class that defines tasks, runs an agent loop, and scores the output.
|
|
||||||
|
|
||||||
:::info Repo environments vs RL training tools
|
|
||||||
The Python environment framework documented here lives under the repo's `environments/` directory and is the implementation-level API for Hermes/Atropos integration. This is separate from the user-facing `rl_*` tools, which operate as an orchestration surface for remote RL training workflows.
|
|
||||||
:::
|
|
||||||
|
|
||||||
:::tip Quick Links
|
|
||||||
- **Want to run benchmarks?** Jump to [Available Benchmarks](#available-benchmarks)
|
|
||||||
- **Want to train with RL?** See [RL Training Tools](/user-guide/features/rl-training) for the agent-driven interface, or [Running Environments](#running-environments) for manual execution
|
|
||||||
- **Want to create a new environment?** See [Creating Environments](#creating-environments)
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
The environment system is built on a three-layer inheritance chain:
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
classDiagram
|
|
||||||
class BaseEnv {
|
|
||||||
Server management
|
|
||||||
Worker scheduling
|
|
||||||
Wandb logging
|
|
||||||
CLI: serve / process / evaluate
|
|
||||||
}
|
|
||||||
|
|
||||||
class HermesAgentBaseEnv {
|
|
||||||
Terminal backend configuration
|
|
||||||
Tool resolution
|
|
||||||
Agent loop engine
|
|
||||||
ToolContext access
|
|
||||||
}
|
|
||||||
|
|
||||||
class TerminalTestEnv {
|
|
||||||
Stack testing
|
|
||||||
}
|
|
||||||
|
|
||||||
class HermesSweEnv {
|
|
||||||
SWE training
|
|
||||||
}
|
|
||||||
|
|
||||||
class TerminalBench2EvalEnv {
|
|
||||||
Benchmark evaluation
|
|
||||||
}
|
|
||||||
|
|
||||||
class TBLiteEvalEnv {
|
|
||||||
Fast benchmark
|
|
||||||
}
|
|
||||||
|
|
||||||
class YCBenchEvalEnv {
|
|
||||||
Long-horizon benchmark
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseEnv <|-- HermesAgentBaseEnv
|
|
||||||
HermesAgentBaseEnv <|-- TerminalTestEnv
|
|
||||||
HermesAgentBaseEnv <|-- HermesSweEnv
|
|
||||||
HermesAgentBaseEnv <|-- TerminalBench2EvalEnv
|
|
||||||
TerminalBench2EvalEnv <|-- TBLiteEvalEnv
|
|
||||||
TerminalBench2EvalEnv <|-- YCBenchEvalEnv
|
|
||||||
```
|
|
||||||
|
|
||||||
### BaseEnv (Atropos)
|
|
||||||
|
|
||||||
The foundation from `atroposlib`. Provides:
|
|
||||||
- **Server management** — connects to OpenAI-compatible APIs (VLLM, SGLang, OpenRouter)
|
|
||||||
- **Worker scheduling** — parallel rollout coordination
|
|
||||||
- **Wandb integration** — metrics logging and rollout visualisation
|
|
||||||
- **CLI interface** — three subcommands: `serve`, `process`, `evaluate`
|
|
||||||
- **Eval logging** — `evaluate_log()` saves results to JSON + JSONL
|
|
||||||
|
|
||||||
### HermesAgentBaseEnv
|
|
||||||
|
|
||||||
The hermes-agent layer (`environments/hermes_base_env.py`). Adds:
|
|
||||||
- **Terminal backend configuration** — sets `TERMINAL_ENV` for sandboxed execution (local, Docker, Modal, Daytona, SSH, Singularity)
|
|
||||||
- **Tool resolution** — `_resolve_tools_for_group()` calls hermes-agent's `get_tool_definitions()` to get the right tool schemas based on enabled/disabled toolsets
|
|
||||||
- **Agent loop integration** — `collect_trajectory()` runs `HermesAgentLoop` and scores the result
|
|
||||||
- **Two-phase operation** — Phase 1 (OpenAI server) for eval/SFT, Phase 2 (VLLM ManagedServer) for full RL with logprobs
|
|
||||||
- **Async safety patches** — monkey-patches Modal backend to work inside Atropos's event loop
|
|
||||||
|
|
||||||
### Concrete Environments
|
|
||||||
|
|
||||||
Your environment inherits from `HermesAgentBaseEnv` and implements five methods:
|
|
||||||
|
|
||||||
| Method | Purpose |
|
|
||||||
|--------|---------|
|
|
||||||
| `setup()` | Load dataset, initialise state |
|
|
||||||
| `get_next_item()` | Return the next item for rollout |
|
|
||||||
| `format_prompt(item)` | Convert an item into the user message |
|
|
||||||
| `compute_reward(item, result, ctx)` | Score the rollout (0.0–1.0) |
|
|
||||||
| `evaluate()` | Periodic evaluation logic |
|
|
||||||
|
|
||||||
## Core Components
|
|
||||||
|
|
||||||
### Agent Loop
|
|
||||||
|
|
||||||
`HermesAgentLoop` (`environments/agent_loop.py`) is the reusable multi-turn agent engine. It runs the same tool-calling pattern as hermes-agent's main loop:
|
|
||||||
|
|
||||||
1. Send messages + tool schemas to the API via `server.chat_completion()`
|
|
||||||
2. If the response contains `tool_calls`, dispatch each via `handle_function_call()`
|
|
||||||
3. Append tool results to the conversation, go back to step 1
|
|
||||||
4. If no `tool_calls`, the agent is done
|
|
||||||
|
|
||||||
Tool calls execute in a thread pool (`ThreadPoolExecutor(128)`) so that async backends (Modal, Docker) don't deadlock inside Atropos's event loop.
|
|
||||||
|
|
||||||
Returns an `AgentResult`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@dataclass
|
|
||||||
class AgentResult:
|
|
||||||
messages: List[Dict[str, Any]] # Full conversation history
|
|
||||||
turns_used: int # Number of LLM calls made
|
|
||||||
finished_naturally: bool # True if model stopped on its own
|
|
||||||
reasoning_per_turn: List[Optional[str]] # Extracted reasoning content
|
|
||||||
tool_errors: List[ToolError] # Errors encountered during tool dispatch
|
|
||||||
managed_state: Optional[Dict] # VLLM ManagedServer state (Phase 2)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Tool Context
|
|
||||||
|
|
||||||
`ToolContext` (`environments/tool_context.py`) gives reward functions direct access to the **same sandbox** the model used during its rollout. The `task_id` scoping means all state (files, processes, browser tabs) is preserved.
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def compute_reward(self, item, result, ctx: ToolContext):
|
|
||||||
# Run tests in the model's terminal sandbox
|
|
||||||
test = ctx.terminal("pytest -v")
|
|
||||||
if test["exit_code"] == 0:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
# Check if a file was created
|
|
||||||
content = ctx.read_file("/workspace/solution.py")
|
|
||||||
if content.get("content"):
|
|
||||||
return 0.5
|
|
||||||
|
|
||||||
# Download files for local verification
|
|
||||||
ctx.download_file("/remote/output.bin", "/local/output.bin")
|
|
||||||
return 0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
Available methods:
|
|
||||||
|
|
||||||
| Category | Methods |
|
|
||||||
|----------|---------|
|
|
||||||
| **Terminal** | `terminal(command, timeout)` |
|
|
||||||
| **Files** | `read_file(path)`, `write_file(path, content)`, `search(query, path)` |
|
|
||||||
| **Transfers** | `upload_file()`, `upload_dir()`, `download_file()`, `download_dir()` |
|
|
||||||
| **Web** | `web_search(query)`, `web_extract(urls)` |
|
|
||||||
| **Browser** | `browser_navigate(url)`, `browser_snapshot()` |
|
|
||||||
| **Generic** | `call_tool(name, args)` — escape hatch for any hermes-agent tool |
|
|
||||||
| **Cleanup** | `cleanup()` — release all resources |
|
|
||||||
|
|
||||||
### Tool Call Parsers
|
|
||||||
|
|
||||||
For **Phase 2** (VLLM ManagedServer), the server returns raw text without structured tool calls. Client-side parsers in `environments/tool_call_parsers/` extract `tool_calls` from raw output:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from environments.tool_call_parsers import get_parser
|
|
||||||
|
|
||||||
parser = get_parser("hermes") # or "mistral", "llama3_json", "qwen", "deepseek_v3", etc.
|
|
||||||
content, tool_calls = parser.parse(raw_model_output)
|
|
||||||
```
|
|
||||||
|
|
||||||
Available parsers: `hermes`, `mistral`, `llama3_json`, `llama4_json`, `qwen`, `qwen3_coder`, `deepseek_v3`, `deepseek_v3_1` (alias `deepseek_v31`), `kimi_k2`, `longcat`, `glm45`, `glm47`.
|
|
||||||
|
|
||||||
In Phase 1 (OpenAI server type), parsers are not needed — the server handles tool call parsing natively.
|
|
||||||
|
|
||||||
## Available Benchmarks
|
|
||||||
|
|
||||||
### TerminalBench2
|
|
||||||
|
|
||||||
**89 challenging terminal tasks** with per-task Docker sandbox environments.
|
|
||||||
|
|
||||||
| | |
|
|
||||||
|---|---|
|
|
||||||
| **What it tests** | Single-task coding/sysadmin ability |
|
|
||||||
| **Scoring** | Binary pass/fail (test suite verification) |
|
|
||||||
| **Sandbox** | Modal cloud sandboxes (per-task Docker images) |
|
|
||||||
| **Tools** | `terminal` + `file` |
|
|
||||||
| **Tasks** | 89 tasks across multiple categories |
|
|
||||||
| **Cost** | ~$50–200 for full eval (parallel execution) |
|
|
||||||
| **Time** | ~2–4 hours |
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
--config environments/benchmarks/terminalbench_2/default.yaml
|
|
||||||
|
|
||||||
# Run specific tasks
|
|
||||||
python environments/benchmarks/terminalbench_2/terminalbench2_env.py evaluate \
|
|
||||||
--config environments/benchmarks/terminalbench_2/default.yaml \
|
|
||||||
--env.task_filter fix-git,git-multibranch
|
|
||||||
```
|
|
||||||
|
|
||||||
Dataset: [NousResearch/terminal-bench-2](https://huggingface.co/datasets/NousResearch/terminal-bench-2) on HuggingFace.
|
|
||||||
|
|
||||||
### TBLite (OpenThoughts Terminal Bench Lite)
|
|
||||||
|
|
||||||
**100 difficulty-calibrated tasks** — a faster proxy for TerminalBench2.
|
|
||||||
|
|
||||||
| | |
|
|
||||||
|---|---|
|
|
||||||
| **What it tests** | Same as TB2 (coding/sysadmin), calibrated difficulty tiers |
|
|
||||||
| **Scoring** | Binary pass/fail |
|
|
||||||
| **Sandbox** | Modal cloud sandboxes |
|
|
||||||
| **Tools** | `terminal` + `file` |
|
|
||||||
| **Tasks** | 100 tasks: Easy (40), Medium (26), Hard (26), Extreme (8) |
|
|
||||||
| **Correlation** | r=0.911 with full TB2 |
|
|
||||||
| **Speed** | 2.6–8× faster than TB2 |
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
--config environments/benchmarks/tblite/default.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
TBLite is a thin subclass of TerminalBench2 — only the dataset and timeouts differ. Created by the OpenThoughts Agent team (Snorkel AI + Bespoke Labs). Dataset: [NousResearch/openthoughts-tblite](https://huggingface.co/datasets/NousResearch/openthoughts-tblite).
|
|
||||||
|
|
||||||
### YC-Bench
|
|
||||||
|
|
||||||
**Long-horizon strategic benchmark** — the agent plays CEO of an AI startup.
|
|
||||||
|
|
||||||
| | |
|
|
||||||
|---|---|
|
|
||||||
| **What it tests** | Multi-turn strategic coherence over hundreds of turns |
|
|
||||||
| **Scoring** | Composite: `0.5 × survival + 0.5 × normalised_funds` |
|
|
||||||
| **Sandbox** | Local terminal (no Modal needed) |
|
|
||||||
| **Tools** | `terminal` only |
|
|
||||||
| **Runs** | 9 default (3 presets × 3 seeds), sequential |
|
|
||||||
| **Cost** | ~$50–200 for full eval |
|
|
||||||
| **Time** | ~3–6 hours |
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Install yc-bench (optional dependency)
|
|
||||||
pip install "hermes-agent[yc-bench]"
|
|
||||||
|
|
||||||
# Run evaluation
|
|
||||||
bash environments/benchmarks/yc_bench/run_eval.sh
|
|
||||||
|
|
||||||
# Or directly
|
|
||||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
--config environments/benchmarks/yc_bench/default.yaml
|
|
||||||
|
|
||||||
# Quick single-preset test
|
|
||||||
python environments/benchmarks/yc_bench/yc_bench_env.py evaluate \
|
|
||||||
--config environments/benchmarks/yc_bench/default.yaml \
|
|
||||||
--env.presets '["fast_test"]' --env.seeds '[1]'
|
|
||||||
```
|
|
||||||
|
|
||||||
YC-Bench uses [collinear-ai/yc-bench](https://github.com/collinear-ai/yc-bench) — a deterministic simulation with 4 skill domains (research, inference, data_environment, training), prestige system, employee management, and financial pressure. Unlike TB2's per-task binary scoring, YC-Bench measures whether an agent can maintain coherent strategy over hundreds of compounding decisions.
|
|
||||||
|
|
||||||
## Training Environments
|
|
||||||
|
|
||||||
### TerminalTestEnv
|
|
||||||
|
|
||||||
A minimal self-contained environment with inline tasks (no external dataset). Used for **validating the full stack** end-to-end. Each task asks the model to create a file at a known path; the verifier checks the content.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Process mode (saves rollouts to JSONL, no training server needed)
|
|
||||||
python environments/terminal_test_env/terminal_test_env.py process \
|
|
||||||
--env.data_path_to_save_groups terminal_test_output.jsonl
|
|
||||||
|
|
||||||
# Serve mode (connects to Atropos API for RL training)
|
|
||||||
python environments/terminal_test_env/terminal_test_env.py serve
|
|
||||||
```
|
|
||||||
|
|
||||||
### HermesSweEnv
|
|
||||||
|
|
||||||
SWE-bench style training environment. The model gets a coding task, uses terminal + file + web tools to solve it, and the reward function runs tests in the same Modal sandbox.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
|
||||||
--openai.model_name YourModel \
|
|
||||||
--env.dataset_name bigcode/humanevalpack \
|
|
||||||
--env.terminal_backend modal
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running Environments
|
|
||||||
|
|
||||||
Every environment is a standalone Python script with three CLI subcommands:
|
|
||||||
|
|
||||||
### `evaluate` — Run a benchmark
|
|
||||||
|
|
||||||
For eval-only environments (benchmarks). Runs all items, computes metrics, logs to wandb.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/benchmarks/tblite/tblite_env.py evaluate \
|
|
||||||
--config environments/benchmarks/tblite/default.yaml \
|
|
||||||
--openai.model_name anthropic/claude-sonnet-4.6
|
|
||||||
```
|
|
||||||
|
|
||||||
No training server or `run-api` needed. The environment handles everything.
|
|
||||||
|
|
||||||
### `process` — Generate SFT data
|
|
||||||
|
|
||||||
Runs rollouts and saves scored trajectories to JSONL. Useful for generating training data without a full RL loop.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python environments/terminal_test_env/terminal_test_env.py process \
|
|
||||||
--env.data_path_to_save_groups output.jsonl \
|
|
||||||
--openai.model_name anthropic/claude-sonnet-4.6
|
|
||||||
```
|
|
||||||
|
|
||||||
Output format: each line is a scored trajectory with the full conversation history, reward, and metadata.
|
|
||||||
|
|
||||||
### `serve` — Connect to Atropos for RL training
|
|
||||||
|
|
||||||
Connects the environment to a running Atropos API server (`run-api`). Used during live RL training.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Terminal 1: Start the Atropos API
|
|
||||||
run-api
|
|
||||||
|
|
||||||
# Terminal 2: Start the environment
|
|
||||||
python environments/hermes_swe_env/hermes_swe_env.py serve \
|
|
||||||
--openai.model_name YourModel
|
|
||||||
```
|
|
||||||
|
|
||||||
The environment receives items from Atropos, runs agent rollouts, computes rewards, and sends scored trajectories back for training.
|
|
||||||
|
|
||||||
## Two-Phase Operation
|
|
||||||
|
|
||||||
### Phase 1: OpenAI Server (Eval / SFT)
|
|
||||||
|
|
||||||
Uses `server.chat_completion()` with `tools=` parameter. The server (VLLM, SGLang, OpenRouter, OpenAI) handles tool call parsing natively. Returns `ChatCompletion` objects with structured `tool_calls`.
|
|
||||||
|
|
||||||
- **Use for**: evaluation, SFT data generation, benchmarks, testing
|
|
||||||
- **Placeholder tokens** are created for the Atropos pipeline (since real token IDs aren't available from the OpenAI API)
|
|
||||||
|
|
||||||
### Phase 2: VLLM ManagedServer (Full RL)
|
|
||||||
|
|
||||||
Uses ManagedServer for exact token IDs + logprobs via `/generate`. A client-side [tool call parser](#tool-call-parsers) reconstructs structured `tool_calls` from raw output.
|
|
||||||
|
|
||||||
- **Use for**: full RL training with GRPO/PPO
|
|
||||||
- **Real tokens**, masks, and logprobs flow through the pipeline
|
|
||||||
- Set `tool_call_parser` in config to match your model's format (e.g., `"hermes"`, `"qwen"`, `"mistral"`)
|
|
||||||
|
|
||||||
## Creating Environments
|
|
||||||
|
|
||||||
### Training Environment
|
|
||||||
|
|
||||||
```python
|
|
||||||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
||||||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
||||||
|
|
||||||
class MyEnvConfig(HermesAgentEnvConfig):
|
|
||||||
my_custom_field: str = "default_value"
|
|
||||||
|
|
||||||
class MyEnv(HermesAgentBaseEnv):
|
|
||||||
name = "my-env"
|
|
||||||
env_config_cls = MyEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls):
|
|
||||||
env_config = MyEnvConfig(
|
|
||||||
enabled_toolsets=["terminal", "file"],
|
|
||||||
terminal_backend="modal",
|
|
||||||
max_agent_turns=30,
|
|
||||||
)
|
|
||||||
server_configs = [APIServerConfig(
|
|
||||||
base_url="https://openrouter.ai/api/v1",
|
|
||||||
model_name="anthropic/claude-sonnet-4.6",
|
|
||||||
server_type="openai",
|
|
||||||
)]
|
|
||||||
return env_config, server_configs
|
|
||||||
|
|
||||||
async def setup(self):
|
|
||||||
from datasets import load_dataset
|
|
||||||
self.dataset = list(load_dataset("my-dataset", split="train"))
|
|
||||||
self.iter = 0
|
|
||||||
|
|
||||||
async def get_next_item(self):
|
|
||||||
item = self.dataset[self.iter % len(self.dataset)]
|
|
||||||
self.iter += 1
|
|
||||||
return item
|
|
||||||
|
|
||||||
def format_prompt(self, item):
|
|
||||||
return item["instruction"]
|
|
||||||
|
|
||||||
async def compute_reward(self, item, result, ctx):
|
|
||||||
# ctx gives full tool access to the rollout's sandbox
|
|
||||||
test = ctx.terminal("pytest -v")
|
|
||||||
return 1.0 if test["exit_code"] == 0 else 0.0
|
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs):
|
|
||||||
# Periodic evaluation during training
|
|
||||||
pass
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyEnv.cli()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Eval-Only Benchmark
|
|
||||||
|
|
||||||
For benchmarks, follow the pattern used by TerminalBench2, TBLite, and YC-Bench:
|
|
||||||
|
|
||||||
1. **Create under** `environments/benchmarks/your-benchmark/`
|
|
||||||
2. **Set eval-only config**: `eval_handling=STOP_TRAIN`, `steps_per_eval=1`, `total_steps=1`
|
|
||||||
3. **Stub training methods**: `collect_trajectories()` returns `(None, [])`, `score()` returns `None`
|
|
||||||
4. **Implement** `rollout_and_score_eval(eval_item)` — the per-item agent loop + scoring
|
|
||||||
5. **Implement** `evaluate()` — orchestrates all runs, computes aggregate metrics
|
|
||||||
6. **Add streaming JSONL** for crash-safe result persistence
|
|
||||||
7. **Add cleanup**: `KeyboardInterrupt` handling, `cleanup_all_environments()`, `_tool_executor.shutdown()`
|
|
||||||
8. **Run with** `evaluate` subcommand
|
|
||||||
|
|
||||||
See `environments/benchmarks/yc_bench/yc_bench_env.py` for a clean, well-documented reference implementation.
|
|
||||||
|
|
||||||
## Configuration Reference
|
|
||||||
|
|
||||||
### HermesAgentEnvConfig Fields
|
|
||||||
|
|
||||||
| Field | Type | Default | Description |
|
|
||||||
|-------|------|---------|-------------|
|
|
||||||
| `enabled_toolsets` | `List[str]` | `None` (all) | Which hermes toolsets to enable |
|
|
||||||
| `disabled_toolsets` | `List[str]` | `None` | Toolsets to filter out |
|
|
||||||
| `distribution` | `str` | `None` | Probabilistic toolset distribution name |
|
|
||||||
| `max_agent_turns` | `int` | `30` | Max LLM calls per rollout |
|
|
||||||
| `agent_temperature` | `float` | `1.0` | Sampling temperature |
|
|
||||||
| `system_prompt` | `str` | `None` | System message for the agent |
|
|
||||||
| `terminal_backend` | `str` | `"local"` | `local`, `docker`, `modal`, `daytona`, `ssh`, `singularity` |
|
|
||||||
| `terminal_timeout` | `int` | `120` | Seconds per terminal command |
|
|
||||||
| `terminal_lifetime` | `int` | `3600` | Max sandbox lifetime |
|
|
||||||
| `dataset_name` | `str` | `None` | HuggingFace dataset identifier |
|
|
||||||
| `tool_pool_size` | `int` | `128` | Thread pool size for tool execution |
|
|
||||||
| `tool_call_parser` | `str` | `"hermes"` | Parser for Phase 2 raw output |
|
|
||||||
| `extra_body` | `Dict` | `None` | Extra params for OpenAI API (e.g., OpenRouter provider prefs) |
|
|
||||||
| `eval_handling` | `Enum` | `STOP_TRAIN` | `STOP_TRAIN`, `LIMIT_TRAIN`, `NONE` |
|
|
||||||
|
|
||||||
### YAML Configuration
|
|
||||||
|
|
||||||
Environments can be configured via YAML files passed with `--config`:
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
env:
|
|
||||||
enabled_toolsets: ["terminal", "file"]
|
|
||||||
max_agent_turns: 60
|
|
||||||
max_token_length: 32000
|
|
||||||
agent_temperature: 0.8
|
|
||||||
terminal_backend: "modal"
|
|
||||||
terminal_timeout: 300
|
|
||||||
dataset_name: "NousResearch/terminal-bench-2"
|
|
||||||
tokenizer_name: "NousResearch/Hermes-3-Llama-3.1-8B"
|
|
||||||
use_wandb: true
|
|
||||||
wandb_name: "my-benchmark"
|
|
||||||
|
|
||||||
openai:
|
|
||||||
base_url: "https://openrouter.ai/api/v1"
|
|
||||||
model_name: "anthropic/claude-sonnet-4.6"
|
|
||||||
server_type: "openai"
|
|
||||||
health_check: false
|
|
||||||
```
|
|
||||||
|
|
||||||
YAML values override `config_init()` defaults. CLI arguments override YAML values:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python my_env.py evaluate \
|
|
||||||
--config my_config.yaml \
|
|
||||||
--openai.model_name anthropic/claude-opus-4.6 # overrides YAML
|
|
||||||
```
|
|
||||||
|
|
||||||
## Prerequisites
|
|
||||||
|
|
||||||
### For all environments
|
|
||||||
|
|
||||||
- Python >= 3.11
|
|
||||||
- `atroposlib`: `pip install git+https://github.com/NousResearch/atropos.git`
|
|
||||||
- An LLM API key (OpenRouter, OpenAI, or self-hosted VLLM/SGLang)
|
|
||||||
|
|
||||||
### For Modal-sandboxed benchmarks (TB2, TBLite)
|
|
||||||
|
|
||||||
- [Modal](https://modal.com) account and CLI: `pip install "hermes-agent[modal]"`
|
|
||||||
- `MODAL_TOKEN_ID` and `MODAL_TOKEN_SECRET` environment variables
|
|
||||||
|
|
||||||
### For YC-Bench
|
|
||||||
|
|
||||||
- `pip install "hermes-agent[yc-bench]"` (installs the yc-bench CLI + SQLAlchemy)
|
|
||||||
- No Modal needed — runs with local terminal backend
|
|
||||||
|
|
||||||
### For RL training
|
|
||||||
|
|
||||||
- `TINKER_API_KEY` — API key for the [Tinker](https://tinker.computer) training service
|
|
||||||
- `WANDB_API_KEY` — for Weights & Biases metrics tracking
|
|
||||||
- The `tinker-atropos` submodule (at `tinker-atropos/` in the repo)
|
|
||||||
|
|
||||||
See [RL Training](/user-guide/features/rl-training) for the agent-driven RL workflow.
|
|
||||||
|
|
||||||
## Directory Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
environments/
|
|
||||||
├── hermes_base_env.py # Abstract base class (HermesAgentBaseEnv)
|
|
||||||
├── agent_loop.py # Multi-turn agent engine (HermesAgentLoop)
|
|
||||||
├── tool_context.py # Per-rollout tool access for reward functions
|
|
||||||
├── patches.py # Async-safety patches for Modal backend
|
|
||||||
│
|
|
||||||
├── tool_call_parsers/ # Phase 2 client-side parsers
|
|
||||||
│ ├── hermes_parser.py # Hermes/ChatML <tool_call> format
|
|
||||||
│ ├── mistral_parser.py # Mistral [TOOL_CALLS] format
|
|
||||||
│ ├── llama_parser.py # Llama 3 JSON tool calling
|
|
||||||
│ ├── qwen_parser.py # Qwen format
|
|
||||||
│ ├── deepseek_v3_parser.py # DeepSeek V3 format
|
|
||||||
│ └── ... # + kimi_k2, longcat, glm45/47, etc.
|
|
||||||
│
|
|
||||||
├── terminal_test_env/ # Stack validation (inline tasks)
|
|
||||||
├── hermes_swe_env/ # SWE-bench training environment
|
|
||||||
│
|
|
||||||
└── benchmarks/ # Evaluation benchmarks
|
|
||||||
├── terminalbench_2/ # 89 terminal tasks, Modal sandboxes
|
|
||||||
├── tblite/ # 100 calibrated tasks (fast TB2 proxy)
|
|
||||||
└── yc_bench/ # Long-horizon strategic benchmark
|
|
||||||
```
|
|
||||||
|
|
@ -123,13 +123,11 @@ If you installed manually (not via the quick installer):
|
||||||
cd /path/to/hermes-agent
|
cd /path/to/hermes-agent
|
||||||
export VIRTUAL_ENV="$(pwd)/venv"
|
export VIRTUAL_ENV="$(pwd)/venv"
|
||||||
|
|
||||||
# Pull latest code and submodules
|
# Pull latest code
|
||||||
git pull origin main
|
git pull origin main
|
||||||
git submodule update --init --recursive
|
|
||||||
|
|
||||||
# Reinstall (picks up new dependencies)
|
# Reinstall (picks up new dependencies)
|
||||||
uv pip install -e ".[all]"
|
uv pip install -e ".[all]"
|
||||||
uv pip install -e "./tinker-atropos"
|
|
||||||
|
|
||||||
# Check for new config options
|
# Check for new config options
|
||||||
hermes config check
|
hermes config check
|
||||||
|
|
|
||||||
|
|
@ -97,5 +97,4 @@ See the [Messaging Gateway overview](/docs/user-guide/messaging) for the platfor
|
||||||
|
|
||||||
## Training & Evaluation
|
## Training & Evaluation
|
||||||
|
|
||||||
- **[RL Training](/docs/user-guide/features/rl-training)** — Generate trajectory data from agent sessions for reinforcement learning and model fine-tuning. Supports Atropos environments with customizable reward functions.
|
|
||||||
- **[Batch Processing](/docs/user-guide/features/batch-processing)** — Run the agent across hundreds of prompts in parallel, generating structured ShareGPT-format trajectory data for training data generation or evaluation.
|
- **[Batch Processing](/docs/user-guide/features/batch-processing)** — Run the agent across hundreds of prompts in parallel, generating structured ShareGPT-format trajectory data for training data generation or evaluation.
|
||||||
|
|
|
||||||
|
|
@ -1355,7 +1355,6 @@ You can switch between providers at any time with `hermes model` — no restart
|
||||||
| Premium TTS voices | [ElevenLabs](https://elevenlabs.io/) | `ELEVENLABS_API_KEY` |
|
| Premium TTS voices | [ElevenLabs](https://elevenlabs.io/) | `ELEVENLABS_API_KEY` |
|
||||||
| OpenAI TTS + voice transcription | [OpenAI](https://platform.openai.com/api-keys) | `VOICE_TOOLS_OPENAI_KEY` |
|
| OpenAI TTS + voice transcription | [OpenAI](https://platform.openai.com/api-keys) | `VOICE_TOOLS_OPENAI_KEY` |
|
||||||
| Mistral TTS + voice transcription | [Mistral](https://console.mistral.ai/) | `MISTRAL_API_KEY` |
|
| Mistral TTS + voice transcription | [Mistral](https://console.mistral.ai/) | `MISTRAL_API_KEY` |
|
||||||
| RL Training | [Tinker](https://tinker-console.thinkingmachines.ai/) + [WandB](https://wandb.ai/) | `TINKER_API_KEY`, `WANDB_API_KEY` |
|
|
||||||
| Cross-session user modeling | [Honcho](https://honcho.dev/) | `HONCHO_API_KEY` |
|
| Cross-session user modeling | [Honcho](https://honcho.dev/) | `HONCHO_API_KEY` |
|
||||||
| Semantic long-term memory | [Supermemory](https://supermemory.ai) | `SUPERMEMORY_API_KEY` |
|
| Semantic long-term memory | [Supermemory](https://supermemory.ai) | `SUPERMEMORY_API_KEY` |
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -148,8 +148,6 @@ For native Anthropic auth, Hermes prefers Claude Code's own credential files whe
|
||||||
| `HONCHO_BASE_URL` | Base URL for self-hosted Honcho instances (default: Honcho cloud). No API key required for local instances |
|
| `HONCHO_BASE_URL` | Base URL for self-hosted Honcho instances (default: Honcho cloud). No API key required for local instances |
|
||||||
| `HINDSIGHT_TIMEOUT` | Timeout in seconds for Hindsight memory-provider API calls (default: `60`). Bump this if your Hindsight instance is slow to respond during `/sync` or `on_session_switch` and you're seeing timeouts in `errors.log`. |
|
| `HINDSIGHT_TIMEOUT` | Timeout in seconds for Hindsight memory-provider API calls (default: `60`). Bump this if your Hindsight instance is slow to respond during `/sync` or `on_session_switch` and you're seeing timeouts in `errors.log`. |
|
||||||
| `SUPERMEMORY_API_KEY` | Semantic long-term memory with profile recall and session ingest ([supermemory.ai](https://supermemory.ai)) |
|
| `SUPERMEMORY_API_KEY` | Semantic long-term memory with profile recall and session ingest ([supermemory.ai](https://supermemory.ai)) |
|
||||||
| `TINKER_API_KEY` | RL training ([tinker-console.thinkingmachines.ai](https://tinker-console.thinkingmachines.ai/)) |
|
|
||||||
| `WANDB_API_KEY` | RL training metrics ([wandb.ai](https://wandb.ai/)) |
|
|
||||||
| `DAYTONA_API_KEY` | Daytona cloud sandboxes ([daytona.io](https://daytona.io/)) |
|
| `DAYTONA_API_KEY` | Daytona cloud sandboxes ([daytona.io](https://daytona.io/)) |
|
||||||
| `VERCEL_TOKEN` | Vercel Sandbox access token ([vercel.com](https://vercel.com/)) |
|
| `VERCEL_TOKEN` | Vercel Sandbox access token ([vercel.com](https://vercel.com/)) |
|
||||||
| `VERCEL_PROJECT_ID` | Vercel project ID (required with `VERCEL_TOKEN`) |
|
| `VERCEL_PROJECT_ID` | Vercel project ID (required with `VERCEL_TOKEN`) |
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,6 @@ hermes skills uninstall <skill-name>
|
||||||
| [**faiss**](/docs/user-guide/skills/optional/mlops/mlops-faiss) | Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or whe... |
|
| [**faiss**](/docs/user-guide/skills/optional/mlops/mlops-faiss) | Facebook's library for efficient similarity search and clustering of dense vectors. Supports billions of vectors, GPU acceleration, and various index types (Flat, IVF, HNSW). Use for fast k-NN search, large-scale vector retrieval, or whe... |
|
||||||
| [**optimizing-attention-flash**](/docs/user-guide/skills/optional/mlops/mlops-flash-attention) | Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster in... |
|
| [**optimizing-attention-flash**](/docs/user-guide/skills/optional/mlops/mlops-flash-attention) | Optimizes transformer attention with Flash Attention for 2-4x speedup and 10-20x memory reduction. Use when training/running transformers with long sequences (>512 tokens), encountering GPU memory issues with attention, or need faster in... |
|
||||||
| [**guidance**](/docs/user-guide/skills/optional/mlops/mlops-guidance) | Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework |
|
| [**guidance**](/docs/user-guide/skills/optional/mlops/mlops-guidance) | Control LLM output with regex and grammars, guarantee valid JSON/XML/code generation, enforce structured formats, and build multi-step workflows with Guidance - Microsoft Research's constrained generation framework |
|
||||||
| [**hermes-atropos-environments**](/docs/user-guide/skills/optional/mlops/mlops-hermes-atropos-environments) | Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/eva... |
|
|
||||||
| [**huggingface-tokenizers**](/docs/user-guide/skills/optional/mlops/mlops-huggingface-tokenizers) | Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integ... |
|
| [**huggingface-tokenizers**](/docs/user-guide/skills/optional/mlops/mlops-huggingface-tokenizers) | Fast tokenizers optimized for research and production. Rust-based implementation tokenizes 1GB in <20 seconds. Supports BPE, WordPiece, and Unigram algorithms. Train custom vocabularies, track alignments, handle padding/truncation. Integ... |
|
||||||
| [**instructor**](/docs/user-guide/skills/optional/mlops/mlops-instructor) | Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library |
|
| [**instructor**](/docs/user-guide/skills/optional/mlops/mlops-instructor) | Extract structured data from LLM responses with Pydantic validation, retry failed extractions automatically, parse complex JSON with type safety, and stream partial results with Instructor - battle-tested structured output library |
|
||||||
| [**lambda-labs-gpu-cloud**](/docs/user-guide/skills/optional/mlops/mlops-lambda-labs) | Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training. |
|
| [**lambda-labs-gpu-cloud**](/docs/user-guide/skills/optional/mlops/mlops-lambda-labs) | Reserved and on-demand GPU cloud instances for ML training and inference. Use when you need dedicated GPU instances with simple SSH access, persistent filesystems, or high-performance multi-node clusters for large-scale training. |
|
||||||
|
|
|
||||||
|
|
@ -148,21 +148,6 @@ Registered only when the agent is spawned by the kanban dispatcher (`HERMES_KANB
|
||||||
|------|-------------|----------------------|
|
|------|-------------|----------------------|
|
||||||
| `mixture_of_agents` | Route a hard problem through multiple frontier LLMs collaboratively. Makes 5 API calls (4 reference models + 1 aggregator) with maximum reasoning effort — use sparingly for genuinely difficult problems. Best for: complex math, advanced alg… | OPENROUTER_API_KEY |
|
| `mixture_of_agents` | Route a hard problem through multiple frontier LLMs collaboratively. Makes 5 API calls (4 reference models + 1 aggregator) with maximum reasoning effort — use sparingly for genuinely difficult problems. Best for: complex math, advanced alg… | OPENROUTER_API_KEY |
|
||||||
|
|
||||||
## `rl` toolset
|
|
||||||
|
|
||||||
| Tool | Description | Requires environment |
|
|
||||||
|------|-------------|----------------------|
|
|
||||||
| `rl_check_status` | Get status and metrics for a training run. RATE LIMITED: enforces 30-minute minimum between checks for the same run. Returns WandB metrics: step, state, reward_mean, loss, percent_correct. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_edit_config` | Update a configuration field. Use rl_get_current_config() first to see all available fields for the selected environment. Each environment has different configurable options. Infrastructure settings (tokenizer, URLs, lora_rank, learning_ra… | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_get_current_config` | Get the current environment configuration. Returns only fields that can be modified: group_size, max_token_length, total_steps, steps_per_eval, use_wandb, wandb_name, max_num_workers. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_get_results` | Get final results and metrics for a completed training run. Returns final metrics and path to trained weights. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_list_environments` | List all available RL environments. Returns environment names, paths, and descriptions. TIP: Read the file_path with file tools to understand how each environment works (verifiers, data loading, rewards). | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_list_runs` | List all training runs (active and completed) with their status. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_select_environment` | Select an RL environment for training. Loads the environment's default configuration. After selecting, use rl_get_current_config() to see settings and rl_edit_config() to modify them. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_start_training` | Start a new RL training run with the current environment and config. Most training parameters (lora_rank, learning_rate, etc.) are fixed. Use rl_edit_config() to set group_size, batch_size, wandb_project before starting. WARNING: Training… | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_stop_training` | Stop a running training job. Use if metrics look bad, training is stagnant, or you want to try different settings. | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
| `rl_test_inference` | Quick inference test for any environment. Runs a few steps of inference + scoring using OpenRouter. Default: 3 steps x 16 completions = 48 rollouts per model, testing 3 models = 144 total. Tests environment loading, prompt construction, in… | TINKER_API_KEY, WANDB_API_KEY |
|
|
||||||
|
|
||||||
## `session_search` toolset
|
## `session_search` toolset
|
||||||
|
|
||||||
| Tool | Description | Requires environment |
|
| Tool | Description | Requires environment |
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ Or in-session:
|
||||||
```
|
```
|
||||||
/tools list
|
/tools list
|
||||||
/tools disable browser
|
/tools disable browser
|
||||||
/tools enable rl
|
/tools enable homeassistant
|
||||||
```
|
```
|
||||||
|
|
||||||
## Core Toolsets
|
## Core Toolsets
|
||||||
|
|
@ -71,7 +71,6 @@ Or in-session:
|
||||||
| `memory` | `memory` | Persistent cross-session memory management. |
|
| `memory` | `memory` | Persistent cross-session memory management. |
|
||||||
| `messaging` | `send_message` | Send messages to other platforms (Telegram, Discord, etc.) from within a session. |
|
| `messaging` | `send_message` | Send messages to other platforms (Telegram, Discord, etc.) from within a session. |
|
||||||
| `moa` | `mixture_of_agents` | Multi-model consensus via Mixture of Agents. |
|
| `moa` | `mixture_of_agents` | Multi-model consensus via Mixture of Agents. |
|
||||||
| `rl` | `rl_check_status`, `rl_edit_config`, `rl_get_current_config`, `rl_get_results`, `rl_list_environments`, `rl_list_runs`, `rl_select_environment`, `rl_start_training`, `rl_stop_training`, `rl_test_inference` | RL training environment management (Atropos). |
|
|
||||||
| `safe` | `image_generate`, `vision_analyze`, `web_extract`, `web_search` (via `includes`) | Read-only research + media generation. No file writes, no terminal, no code execution. |
|
| `safe` | `image_generate`, `vision_analyze`, `web_extract`, `web_search` (via `includes`) | Read-only research + media generation. No file writes, no terminal, no code execution. |
|
||||||
| `search` | `web_search` | Web search only (without extract). |
|
| `search` | `web_search` | Web search only (without extract). |
|
||||||
| `session_search` | `session_search` | Search past conversation sessions. |
|
| `session_search` | `session_search` | Search past conversation sessions. |
|
||||||
|
|
|
||||||
|
|
@ -1,234 +0,0 @@
|
||||||
---
|
|
||||||
sidebar_position: 13
|
|
||||||
title: "RL Training"
|
|
||||||
description: "Reinforcement learning on agent behaviors with Tinker-Atropos — environment discovery, training, and evaluation"
|
|
||||||
---
|
|
||||||
|
|
||||||
# RL Training
|
|
||||||
|
|
||||||
Hermes Agent includes an integrated RL (Reinforcement Learning) training pipeline built on **Tinker-Atropos**. This enables training language models on environment-specific tasks using GRPO (Group Relative Policy Optimization) with LoRA adapters, orchestrated entirely through the agent's tool interface.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
The RL training system consists of three components:
|
|
||||||
|
|
||||||
1. **[Atropos](https://github.com/NousResearch/atropos)** — A trajectory API server that coordinates environment interactions, manages rollout groups, and computes advantages
|
|
||||||
2. **[Tinker](https://thinkingmachines.ai/tinker/)** — A training service that handles model weights, LoRA training, sampling/inference, and optimizer steps
|
|
||||||
3. **Environments** — Python classes that define tasks, scoring, and reward functions (e.g., GSM8K math problems)
|
|
||||||
|
|
||||||
The agent can discover environments, configure training parameters, launch training runs, and monitor metrics — all through a set of `rl_*` tools.
|
|
||||||
|
|
||||||
## Requirements
|
|
||||||
|
|
||||||
RL training requires:
|
|
||||||
|
|
||||||
- **Python >= 3.11** (Tinker package requirement)
|
|
||||||
- **TINKER_API_KEY** — API key for the Tinker training service
|
|
||||||
- **WANDB_API_KEY** — API key for [Weights & Biases](https://wandb.ai/) metrics tracking
|
|
||||||
- The `tinker-atropos` submodule (at `tinker-atropos/` relative to the Hermes root)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Set up API keys
|
|
||||||
hermes config set TINKER_API_KEY your-tinker-key
|
|
||||||
hermes config set WANDB_API_KEY your-wandb-key
|
|
||||||
```
|
|
||||||
|
|
||||||
When both keys are present and Python >= 3.11 is available, the `rl` toolset is automatically enabled.
|
|
||||||
|
|
||||||
## Available Tools
|
|
||||||
|
|
||||||
| Tool | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `rl_list_environments` | Discover available RL environments |
|
|
||||||
| `rl_select_environment` | Select an environment and load its config |
|
|
||||||
| `rl_get_current_config` | View configurable and locked fields |
|
|
||||||
| `rl_edit_config` | Modify configurable training parameters |
|
|
||||||
| `rl_start_training` | Launch a training run (spawns 3 processes) |
|
|
||||||
| `rl_check_status` | Monitor training progress and WandB metrics |
|
|
||||||
| `rl_stop_training` | Stop a running training job |
|
|
||||||
| `rl_get_results` | Get final metrics and model weights path |
|
|
||||||
| `rl_list_runs` | List all active and completed runs |
|
|
||||||
| `rl_test_inference` | Quick inference test using OpenRouter |
|
|
||||||
|
|
||||||
## Workflow
|
|
||||||
|
|
||||||
### 1. Discover Environments
|
|
||||||
|
|
||||||
```
|
|
||||||
List the available RL environments
|
|
||||||
```
|
|
||||||
|
|
||||||
The agent calls `rl_list_environments()` which scans `tinker-atropos/tinker_atropos/environments/` using AST parsing to find Python classes inheriting from `BaseEnv`. Each environment defines:
|
|
||||||
|
|
||||||
- **Dataset loading** — where training data comes from (e.g., HuggingFace datasets)
|
|
||||||
- **Prompt construction** — how to format items for the model
|
|
||||||
- **Scoring/verification** — how to evaluate model outputs and assign rewards
|
|
||||||
|
|
||||||
### 2. Select and Configure
|
|
||||||
|
|
||||||
```
|
|
||||||
Select the GSM8K environment and show me the configuration
|
|
||||||
```
|
|
||||||
|
|
||||||
The agent calls `rl_select_environment("gsm8k_tinker")`, then `rl_get_current_config()` to see all parameters.
|
|
||||||
|
|
||||||
Configuration fields are divided into two categories:
|
|
||||||
|
|
||||||
**Configurable fields** (can be modified):
|
|
||||||
- `group_size` — Number of completions per item (default: 16)
|
|
||||||
- `batch_size` — Training batch size (default: 128)
|
|
||||||
- `wandb_name` — WandB run name (auto-set to `{env}-{timestamp}`)
|
|
||||||
- Other environment-specific parameters
|
|
||||||
|
|
||||||
**Locked fields** (infrastructure settings, cannot be changed):
|
|
||||||
- `tokenizer_name` — Model tokenizer (e.g., `Qwen/Qwen3-8B`)
|
|
||||||
- `rollout_server_url` — Atropos API URL (`http://localhost:8000`)
|
|
||||||
- `max_token_length` — Maximum token length (8192)
|
|
||||||
- `max_num_workers` — Maximum parallel workers (2048)
|
|
||||||
- `total_steps` — Total training steps (2500)
|
|
||||||
- `lora_rank` — LoRA adapter rank (32)
|
|
||||||
- `learning_rate` — Learning rate (4e-5)
|
|
||||||
- `max_token_trainer_length` — Max tokens for trainer (9000)
|
|
||||||
|
|
||||||
### 3. Start Training
|
|
||||||
|
|
||||||
```
|
|
||||||
Start the training run
|
|
||||||
```
|
|
||||||
|
|
||||||
The agent calls `rl_start_training()` which:
|
|
||||||
|
|
||||||
1. Generates a YAML config file merging locked settings with configurable overrides
|
|
||||||
2. Creates a unique run ID
|
|
||||||
3. Spawns three processes:
|
|
||||||
- **Atropos API server** (`run-api`) — trajectory coordination
|
|
||||||
- **Tinker trainer** (`launch_training.py`) — LoRA training + FastAPI inference server on port 8001
|
|
||||||
- **Environment** (`environment.py serve`) — the selected environment connecting to Atropos
|
|
||||||
|
|
||||||
The processes start with staggered delays (5s for API, 30s for trainer, 90s more for environment) to ensure proper initialization order.
|
|
||||||
|
|
||||||
### 4. Monitor Progress
|
|
||||||
|
|
||||||
```
|
|
||||||
Check the status of training run abc12345
|
|
||||||
```
|
|
||||||
|
|
||||||
The agent calls `rl_check_status(run_id)` which reports:
|
|
||||||
|
|
||||||
- Process status (running/exited for each of the 3 processes)
|
|
||||||
- Running time
|
|
||||||
- WandB metrics (step, reward mean, percent correct, eval accuracy)
|
|
||||||
- Log file locations for debugging
|
|
||||||
|
|
||||||
:::note Rate Limiting
|
|
||||||
Status checks are rate-limited to once every **30 minutes** per run ID. This prevents excessive polling during long-running training jobs that take hours.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### 5. Stop or Get Results
|
|
||||||
|
|
||||||
```
|
|
||||||
Stop the training run
|
|
||||||
# or
|
|
||||||
Get the final results for run abc12345
|
|
||||||
```
|
|
||||||
|
|
||||||
`rl_stop_training()` terminates all three processes in reverse order (environment → trainer → API). `rl_get_results()` retrieves final WandB metrics and training history.
|
|
||||||
|
|
||||||
## Inference Testing
|
|
||||||
|
|
||||||
Before committing to a full training run, you can test if an environment works correctly using `rl_test_inference`. This runs a few steps of inference and scoring using OpenRouter — no Tinker API needed, just an `OPENROUTER_API_KEY`.
|
|
||||||
|
|
||||||
```
|
|
||||||
Test the selected environment with inference
|
|
||||||
```
|
|
||||||
|
|
||||||
Default configuration:
|
|
||||||
- **3 steps × 16 completions = 48 rollouts per model**
|
|
||||||
- Tests 3 models at different scales for robustness:
|
|
||||||
- `qwen/qwen3-8b` (small)
|
|
||||||
- `z-ai/glm-4.7-flash` (medium)
|
|
||||||
- `minimax/minimax-m2.7` (large)
|
|
||||||
- Total: ~144 rollouts
|
|
||||||
|
|
||||||
This validates:
|
|
||||||
- Environment loads correctly
|
|
||||||
- Prompt construction works
|
|
||||||
- Inference response parsing is robust across model scales
|
|
||||||
- Verifier/scoring logic produces valid rewards
|
|
||||||
|
|
||||||
## Tinker API Integration
|
|
||||||
|
|
||||||
The trainer uses the [Tinker](https://tinker.computer) API for model training operations:
|
|
||||||
|
|
||||||
- **ServiceClient** — Creates training and sampling clients
|
|
||||||
- **Training client** — Handles forward-backward passes with importance sampling loss, optimizer steps (Adam), and weight checkpointing
|
|
||||||
- **Sampling client** — Provides inference using the latest trained weights
|
|
||||||
|
|
||||||
The training loop:
|
|
||||||
1. Fetches a batch of rollouts from Atropos (prompt + completions + scores)
|
|
||||||
2. Converts to Tinker Datum objects with padded logprobs and advantages
|
|
||||||
3. Runs forward-backward pass with importance sampling loss
|
|
||||||
4. Takes an optimizer step (Adam: lr=4e-5, β1=0.9, β2=0.95)
|
|
||||||
5. Saves weights and creates a new sampling client for next-step inference
|
|
||||||
6. Logs metrics to WandB
|
|
||||||
|
|
||||||
## Architecture Diagram
|
|
||||||
|
|
||||||
```mermaid
|
|
||||||
flowchart LR
|
|
||||||
api["Atropos API<br/>run-api<br/>port 8000"]
|
|
||||||
env["Environment<br/>BaseEnv implementation"]
|
|
||||||
infer["OpenAI / sglang<br/>inference API<br/>port 8001"]
|
|
||||||
trainer["Tinker Trainer<br/>LoRA training + FastAPI"]
|
|
||||||
|
|
||||||
env <--> api
|
|
||||||
env --> infer
|
|
||||||
api -->|"batches: tokens, scores, logprobs"| trainer
|
|
||||||
trainer -->|"serves inference"| infer
|
|
||||||
```
|
|
||||||
|
|
||||||
## Creating Custom Environments
|
|
||||||
|
|
||||||
To create a new RL environment:
|
|
||||||
|
|
||||||
1. Create a Python file in `tinker-atropos/tinker_atropos/environments/`
|
|
||||||
2. Define a class that inherits from `BaseEnv`
|
|
||||||
3. Implement the required methods:
|
|
||||||
- `load_dataset()` — Load your training data
|
|
||||||
- `get_next_item()` — Provide the next item to the model
|
|
||||||
- `score_answer()` — Score model outputs and assign rewards
|
|
||||||
- `collect_trajectories()` — Collect and return trajectories
|
|
||||||
4. Optionally define a custom config class inheriting from `BaseEnvConfig`
|
|
||||||
|
|
||||||
Study the existing `gsm8k_tinker.py` as a template. The agent can help you create new environments — it can read existing environment files, inspect HuggingFace datasets, and write new environment code.
|
|
||||||
|
|
||||||
## WandB Metrics
|
|
||||||
|
|
||||||
Training runs log to Weights & Biases with these key metrics:
|
|
||||||
|
|
||||||
| Metric | Description |
|
|
||||||
|--------|-------------|
|
|
||||||
| `train/loss` | Training loss (importance sampling) |
|
|
||||||
| `train/learning_rate` | Current learning rate |
|
|
||||||
| `reward/mean` | Mean reward across groups |
|
|
||||||
| `logprobs/mean` | Mean reference logprobs |
|
|
||||||
| `logprobs/mean_training` | Mean training logprobs |
|
|
||||||
| `logprobs/diff` | Logprob drift (reference - training) |
|
|
||||||
| `advantages/mean` | Mean advantage values |
|
|
||||||
| `advantages/std` | Advantage standard deviation |
|
|
||||||
|
|
||||||
## Log Files
|
|
||||||
|
|
||||||
Each training run generates log files in `~/.hermes/logs/rl_training/`:
|
|
||||||
|
|
||||||
```
|
|
||||||
logs/
|
|
||||||
├── api_{run_id}.log # Atropos API server logs
|
|
||||||
├── trainer_{run_id}.log # Tinker trainer logs
|
|
||||||
├── env_{run_id}.log # Environment process logs
|
|
||||||
└── inference_tests/ # Inference test results
|
|
||||||
├── test_{env}_{model}.jsonl
|
|
||||||
└── test_{env}_{model}.log
|
|
||||||
```
|
|
||||||
|
|
||||||
These are invaluable for debugging when training fails or produces unexpected results.
|
|
||||||
|
|
@ -1,323 +0,0 @@
|
||||||
---
|
|
||||||
title: "Hermes Atropos Environments — Build, test, and debug Hermes Agent RL environments for Atropos training"
|
|
||||||
sidebar_label: "Hermes Atropos Environments"
|
|
||||||
description: "Build, test, and debug Hermes Agent RL environments for Atropos training"
|
|
||||||
---
|
|
||||||
|
|
||||||
{/* This page is auto-generated from the skill's SKILL.md by website/scripts/generate-skill-docs.py. Edit the source SKILL.md, not this page. */}
|
|
||||||
|
|
||||||
# Hermes Atropos Environments
|
|
||||||
|
|
||||||
Build, test, and debug Hermes Agent RL environments for Atropos training. Covers the HermesAgentBaseEnv interface, reward functions, agent loop integration, evaluation with tools, wandb logging, and the three CLI modes (serve/process/evaluate). Use when creating, reviewing, or fixing RL environments in the hermes-agent repo.
|
|
||||||
|
|
||||||
## Skill metadata
|
|
||||||
|
|
||||||
| | |
|
|
||||||
|---|---|
|
|
||||||
| Source | Optional — install with `hermes skills install official/mlops/hermes-atropos-environments` |
|
|
||||||
| Path | `optional-skills/mlops/hermes-atropos-environments` |
|
|
||||||
| Version | `1.1.0` |
|
|
||||||
| Author | Hermes Agent |
|
|
||||||
| License | MIT |
|
|
||||||
| Platforms | linux, macos, windows |
|
|
||||||
| Tags | `atropos`, `rl`, `environments`, `training`, `reinforcement-learning`, `reward-functions` |
|
|
||||||
| Related skills | [`axolotl`](/docs/user-guide/skills/optional/mlops/mlops-training-axolotl), [`fine-tuning-with-trl`](/docs/user-guide/skills/optional/mlops/mlops-training-trl-fine-tuning), `lm-evaluation-harness` |
|
|
||||||
|
|
||||||
## Reference: full SKILL.md
|
|
||||||
|
|
||||||
:::info
|
|
||||||
The following is the complete skill definition that Hermes loads when this skill is triggered. This is what the agent sees as instructions when the skill is active.
|
|
||||||
:::
|
|
||||||
|
|
||||||
# Hermes Agent Atropos Environments
|
|
||||||
|
|
||||||
Guide for building RL environments in the hermes-agent repo that integrate with the Atropos training framework.
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
<!-- ascii-guard-ignore -->
|
|
||||||
```
|
|
||||||
Atropos BaseEnv (atroposlib/envs/base.py)
|
|
||||||
└── HermesAgentBaseEnv (environments/hermes_base_env.py)
|
|
||||||
├── Handles agent loop orchestration
|
|
||||||
├── Handles tool resolution per group
|
|
||||||
├── Handles ToolContext for reward verification
|
|
||||||
└── YOUR ENVIRONMENT (environments/your_env.py)
|
|
||||||
Only implements: setup, get_next_item, format_prompt,
|
|
||||||
compute_reward, evaluate, wandb_log
|
|
||||||
```
|
|
||||||
<!-- ascii-guard-ignore-end -->
|
|
||||||
|
|
||||||
Hermes environments are special because they run a **multi-turn agent loop with tool calling** — not just single-turn completions. The base env handles the loop; you implement the task and scoring.
|
|
||||||
|
|
||||||
## File Locations
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|------|---------|
|
|
||||||
| `environments/hermes_base_env.py` | Base class with agent loop + tool resolution |
|
|
||||||
| `environments/agent_loop.py` | `HermesAgentLoop` + `AgentResult` dataclass |
|
|
||||||
| `environments/tool_context.py` | `ToolContext` for reward verification |
|
|
||||||
| `environments/tool_call_parsers.py` | Phase 2 tool call parsers (hermes, mistral, etc.) |
|
|
||||||
| `environments/your_env.py` | Your environment implementation |
|
|
||||||
|
|
||||||
## Inference Setup — Ask the User First
|
|
||||||
|
|
||||||
**IMPORTANT:** Before running any test, evaluation, or data generation command, always ask the user how they want to handle inference. Do NOT assume OpenRouter or any specific endpoint. Present these options:
|
|
||||||
|
|
||||||
1. **OpenRouter** — Ask which model they want to use (e.g., `anthropic/claude-sonnet-4.5`, `google/gemini-2.5-pro`, `meta-llama/llama-3.3-70b-instruct`, etc.). Requires `OPENROUTER_API_KEY` in environment.
|
|
||||||
2. **Self-hosted VLLM endpoint** — Ask for their base URL (e.g., `http://localhost:8000/v1`) and model name. Set `--openai.server_type vllm`.
|
|
||||||
3. **Other OpenAI-compatible API** — Ask for the base URL, model name, and any required API key. Set `--openai.server_type openai` and `--openai.health_check false`.
|
|
||||||
4. **Local Atropos training server** — For `serve` mode with a live training loop. Default `http://localhost:8000/v1`.
|
|
||||||
|
|
||||||
Once the user tells you their setup, use those values in all CLI commands for that session. Example prompts:
|
|
||||||
|
|
||||||
> "Before I run this, how would you like to handle inference?
|
|
||||||
> 1. OpenRouter (I'll need your preferred model, e.g. claude-sonnet-4.5)
|
|
||||||
> 2. A self-hosted VLLM endpoint (give me the URL and model name)
|
|
||||||
> 3. Another OpenAI-compatible API (give me the URL, model, and any auth details)
|
|
||||||
> 4. Local Atropos training server (serve mode)"
|
|
||||||
|
|
||||||
### Key flags by provider:
|
|
||||||
|
|
||||||
| Provider | `--openai.server_type` | `--openai.health_check` | `--openai.api_key` |
|
|
||||||
|----------|----------------------|------------------------|-------------------|
|
|
||||||
| OpenRouter | `openai` | `false` | `$OPENROUTER_API_KEY` |
|
|
||||||
| VLLM (self-hosted) | `vllm` | (default) | (not needed) |
|
|
||||||
| Other OpenAI-compatible | `openai` | `false` | As needed |
|
|
||||||
| Local Atropos | (default) | (default) | (not needed) |
|
|
||||||
|
|
||||||
## Required Methods
|
|
||||||
|
|
||||||
### 1. `setup()` — Load dataset and initialize state
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def setup(self) -> None:
|
|
||||||
"""Called once at startup. Load datasets, initialize state."""
|
|
||||||
# Try HuggingFace first, fallback to built-in samples
|
|
||||||
try:
|
|
||||||
from datasets import load_dataset
|
|
||||||
ds = load_dataset("your/dataset", split="test")
|
|
||||||
self._items = [...]
|
|
||||||
except Exception:
|
|
||||||
self._items = BUILTIN_SAMPLES
|
|
||||||
|
|
||||||
# Always split into train/eval
|
|
||||||
random.shuffle(self._items)
|
|
||||||
eval_size = max(20, int(len(self._items) * 0.1))
|
|
||||||
self._eval_items = self._items[:eval_size]
|
|
||||||
self._items = self._items[eval_size:]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. `get_next_item()` — Return next training item
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def get_next_item(self) -> dict:
|
|
||||||
"""Return next item, cycling through dataset."""
|
|
||||||
item = self._items[self._index % len(self._items)]
|
|
||||||
self._index += 1
|
|
||||||
return item
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. `format_prompt(item)` — Convert item to user message
|
|
||||||
|
|
||||||
```python
|
|
||||||
def format_prompt(self, item: dict) -> str:
|
|
||||||
"""Convert a dataset item into the user-facing prompt."""
|
|
||||||
return f"Research this question: {item['question']}"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. `compute_reward(item, result, ctx)` — Score the rollout
|
|
||||||
|
|
||||||
**CRITICAL**: `result` is an `AgentResult`, NOT a dict. It has these attributes:
|
|
||||||
- `result.messages` — List of message dicts (OpenAI format)
|
|
||||||
- `result.turns_used` — Number of LLM calls made
|
|
||||||
- `result.finished_naturally` — True if model stopped voluntarily
|
|
||||||
- `result.tool_errors` — List of ToolError objects
|
|
||||||
|
|
||||||
**AgentResult does NOT have**: `final_response`, `tool_calls`, `tools_used`.
|
|
||||||
You must extract these from `result.messages`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def compute_reward(self, item, result: AgentResult, ctx: ToolContext) -> float:
|
|
||||||
# Extract final response (last assistant message with content)
|
|
||||||
final_response = ""
|
|
||||||
tools_used = []
|
|
||||||
for msg in reversed(result.messages):
|
|
||||||
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
|
||||||
final_response = msg["content"]
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
fn = tc.get("function", {}) if isinstance(tc, dict) else {}
|
|
||||||
name = fn.get("name", "")
|
|
||||||
if name:
|
|
||||||
tools_used.append(name)
|
|
||||||
|
|
||||||
# Score using LLM judge, heuristic, or ToolContext verification
|
|
||||||
correctness = await self._llm_judge(item, final_response)
|
|
||||||
return correctness
|
|
||||||
```
|
|
||||||
|
|
||||||
`ctx` (ToolContext) gives you terminal/file access to the agent's sandbox for verification:
|
|
||||||
```python
|
|
||||||
# Run tests in the agent's sandbox
|
|
||||||
result = ctx.terminal("pytest /workspace/test.py")
|
|
||||||
return 1.0 if result["exit_code"] == 0 else 0.0
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. `evaluate()` — Periodic evaluation with full agent loop
|
|
||||||
|
|
||||||
**MUST use the full agent loop with tools**, not single-turn chat_completion.
|
|
||||||
The whole point of hermes-agent environments is agentic evaluation:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def evaluate(self, *args, **kwargs) -> None:
|
|
||||||
import time, uuid
|
|
||||||
from environments.agent_loop import HermesAgentLoop
|
|
||||||
from environments.tool_context import ToolContext
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
tools, valid_names = self._resolve_tools_for_group()
|
|
||||||
samples = []
|
|
||||||
|
|
||||||
for item in self._eval_items[:self.config.eval_size]:
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
messages = []
|
|
||||||
if self.config.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
|
||||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
|
||||||
|
|
||||||
agent = HermesAgentLoop(
|
|
||||||
server=self.server,
|
|
||||||
tool_schemas=tools,
|
|
||||||
valid_tool_names=valid_names,
|
|
||||||
max_turns=self.config.max_agent_turns,
|
|
||||||
task_id=task_id,
|
|
||||||
temperature=0.0, # Deterministic for eval
|
|
||||||
max_tokens=self.config.max_token_length,
|
|
||||||
extra_body=self.config.extra_body,
|
|
||||||
)
|
|
||||||
result = await agent.run(messages)
|
|
||||||
|
|
||||||
ctx = ToolContext(task_id)
|
|
||||||
try:
|
|
||||||
reward = await self.compute_reward(item, result, ctx)
|
|
||||||
finally:
|
|
||||||
ctx.cleanup()
|
|
||||||
|
|
||||||
samples.append({"prompt": ..., "response": ..., "reward": reward})
|
|
||||||
|
|
||||||
eval_metrics = {"eval/mean_reward": ...}
|
|
||||||
await self.evaluate_log(metrics=eval_metrics, samples=samples,
|
|
||||||
start_time=start_time, end_time=time.time())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6. `wandb_log()` — Custom metrics logging
|
|
||||||
|
|
||||||
Always call `super().wandb_log()` at the end:
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def wandb_log(self, wandb_metrics=None):
|
|
||||||
if wandb_metrics is None:
|
|
||||||
wandb_metrics = {}
|
|
||||||
if self._reward_buffer:
|
|
||||||
n = len(self._reward_buffer)
|
|
||||||
wandb_metrics["train/mean_reward"] = sum(self._reward_buffer) / n
|
|
||||||
self._reward_buffer.clear()
|
|
||||||
await super().wandb_log(wandb_metrics) # MUST call super
|
|
||||||
```
|
|
||||||
|
|
||||||
**Pitfall**: `compute_reward` appends to metric buffers. During eval, this pollutes training metrics. Roll back buffer entries added during eval.
|
|
||||||
|
|
||||||
## Config Class
|
|
||||||
|
|
||||||
Always create a custom config subclass with Pydantic Field descriptors. Key inherited fields you can tune: `enabled_toolsets`, `max_agent_turns`, `agent_temperature`, `system_prompt`, `terminal_backend`, `group_size`, `steps_per_eval`, `total_steps`.
|
|
||||||
|
|
||||||
## config_init() — Default Configuration
|
|
||||||
|
|
||||||
Classmethod returning `(YourEnvConfig, [APIServerConfig(...)])`. Set server_type to "openai" for OpenRouter/external APIs. Load API key from environment variable.
|
|
||||||
|
|
||||||
## Three CLI Modes
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# SERVE — Full training loop (connects to Atropos API server)
|
|
||||||
python environments/my_env.py serve --openai.base_url http://localhost:8000/v1
|
|
||||||
|
|
||||||
# PROCESS — Offline data generation (saves JSONL)
|
|
||||||
python environments/my_env.py process --env.total_steps 10 --env.group_size 1 \
|
|
||||||
--env.use_wandb false --env.data_path_to_save_groups output.jsonl \
|
|
||||||
--openai.base_url "<USER_BASE_URL>" \
|
|
||||||
--openai.model_name "<USER_MODEL>" \
|
|
||||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
|
||||||
|
|
||||||
# EVALUATE — Standalone eval (runs setup + evaluate only)
|
|
||||||
python environments/my_env.py evaluate --env.eval_size 20 \
|
|
||||||
--env.data_dir_to_save_evals /tmp/eval_results \
|
|
||||||
--openai.base_url "<USER_BASE_URL>" \
|
|
||||||
--openai.model_name "<USER_MODEL>" \
|
|
||||||
--openai.server_type <USER_SERVER_TYPE> --openai.health_check false
|
|
||||||
```
|
|
||||||
|
|
||||||
Config priority: CLI args > YAML file > config_init() defaults.
|
|
||||||
|
|
||||||
## Common Pitfalls
|
|
||||||
|
|
||||||
1. **AgentResult has .messages, not .final_response** — Extract the final response by iterating reversed(result.messages) looking for the last assistant message with content.
|
|
||||||
|
|
||||||
2. **evaluate() must use HermesAgentLoop, not chat_completion** — Single-turn chat_completion has no tools. The whole point of hermes-agent benchmarks is agentic evaluation with tool use.
|
|
||||||
|
|
||||||
3. **Don't call _llm_judge twice** — If compute_reward already calls it, extract the score from the buffer instead of calling judge separately in evaluate().
|
|
||||||
|
|
||||||
4. **Eval pollutes training buffers** — compute_reward appends to metric buffers. During eval, roll back buffer entries to keep training metrics clean.
|
|
||||||
|
|
||||||
5. **Always set health_check=false for OpenRouter** — OpenRouter has no /health endpoint.
|
|
||||||
|
|
||||||
6. **Set data_dir_to_save_evals in evaluate mode** — Without it, results aren't saved.
|
|
||||||
|
|
||||||
7. **default_toolsets class variable vs enabled_toolsets config** — The class variable is a hint; the config field is what actually controls tool resolution.
|
|
||||||
|
|
||||||
8. **Tool call parsing in messages** — Tool calls are dicts with `{"function": {"name": ..., "arguments": ...}}`. Always check `isinstance(tc, dict)`.
|
|
||||||
|
|
||||||
9. **ToolContext.cleanup()** — Always call in a finally block to release sandbox resources.
|
|
||||||
|
|
||||||
10. **server_type must be "openai" for external APIs** — Without it, Atropos assumes a local VLLM server.
|
|
||||||
|
|
||||||
11. **Always ask the user for their inference setup** — Never hardcode or assume a specific provider/model. See the "Inference Setup" section above.
|
|
||||||
|
|
||||||
## Reward Function Patterns
|
|
||||||
|
|
||||||
### LLM Judge (for open-ended tasks)
|
|
||||||
Use `self.server.chat_completion()` with a scoring prompt. Parse JSON response for score float. Always include a heuristic fallback (keyword overlap) for when the judge call fails.
|
|
||||||
|
|
||||||
### Binary Verification (for code/terminal tasks)
|
|
||||||
Use `ctx.terminal("pytest test.py -q")` to run tests in the agent's sandbox. Return 1.0 for pass, 0.0 for fail.
|
|
||||||
|
|
||||||
### Multi-Signal (combine multiple indicators)
|
|
||||||
Weight correctness (0.6) + tool usage (0.2) + efficiency (0.2) + optional bonuses. Clamp to [0, 1].
|
|
||||||
|
|
||||||
## Testing Your Environment
|
|
||||||
|
|
||||||
1. **Import test**: `python -c "from environments.my_env import MyEnv; print('OK')"`
|
|
||||||
2. **Ask the user for inference setup** (see "Inference Setup" section above)
|
|
||||||
3. **Process mode** (1 item): Verify JSONL output has valid tokens, masks, scores
|
|
||||||
4. **Evaluate mode**: Verify full agent loop runs with tools, metrics logged correctly
|
|
||||||
5. **Check reward range**: Scores should be in [0, 1], not all identical
|
|
||||||
|
|
||||||
## Minimum Implementation Checklist
|
|
||||||
|
|
||||||
```python
|
|
||||||
class MyEnv(HermesAgentBaseEnv):
|
|
||||||
name = "my-env"
|
|
||||||
env_config_cls = MyEnvConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def config_init(cls): ... # Default server + env config
|
|
||||||
async def setup(self): ... # Load dataset + train/eval split
|
|
||||||
async def get_next_item(self): ... # Cycle through training items
|
|
||||||
def format_prompt(self, item): ... # Item → user message string
|
|
||||||
async def compute_reward(self, item, result, ctx): ... # Score rollout
|
|
||||||
async def evaluate(self, *args, **kwargs): ... # Full agent loop eval
|
|
||||||
async def wandb_log(self, metrics=None): ... # Custom metrics + super()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
MyEnv.cli()
|
|
||||||
```
|
|
||||||
|
|
@ -103,7 +103,6 @@ const sidebars: SidebarsConfig = {
|
||||||
type: 'category',
|
type: 'category',
|
||||||
label: 'Advanced',
|
label: 'Advanced',
|
||||||
items: [
|
items: [
|
||||||
'user-guide/features/rl-training',
|
|
||||||
'user-guide/features/spotify',
|
'user-guide/features/spotify',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
@ -238,7 +237,6 @@ const sidebars: SidebarsConfig = {
|
||||||
'developer-guide/tools-runtime',
|
'developer-guide/tools-runtime',
|
||||||
'developer-guide/acp-internals',
|
'developer-guide/acp-internals',
|
||||||
'developer-guide/cron-internals',
|
'developer-guide/cron-internals',
|
||||||
'developer-guide/environments',
|
|
||||||
'developer-guide/trajectory-format',
|
'developer-guide/trajectory-format',
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue