mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
- environments/gsm8k_agent_env.py: Math reasoning with Python REPL tool - Subclasses HermesAgentBaseEnv (proper tools= parameter, not ICL) - Uses ATROPOS_SERVER_* env vars from .env - Hermes tool call parser, configurable per model - Math verification via math_verify with string fallback - Tested: process mode works, both trajectories scored 1.0 - Updated memory bank with consolidation plan: - environments/ is the canonical env system (proper tool calling) - atropos/backends/ kept as sandbox infrastructure - atropos/agent/ and atropos/envs/agent_env.py marked for removal
350 lines
12 KiB
Python
350 lines
12 KiB
Python
"""
|
|
GSM8kAgentEnv -- Math Reasoning with Tool Use (Python REPL)
|
|
|
|
An agentic RL environment where models solve GSM8k math problems using
|
|
a Python interpreter tool. Uses proper OpenAI-spec tool calling via
|
|
HermesAgentBaseEnv (not ICL).
|
|
|
|
The model:
|
|
1. Receives a math problem
|
|
2. Can call the `terminal` tool to run Python code (`python3 -c "..."`)
|
|
3. Provides a final answer in \\boxed{} format
|
|
4. Gets reward: 1.0 if correct, 0.0 if wrong
|
|
|
|
Usage:
|
|
# Phase 1 (OpenRouter, no training):
|
|
python environments/gsm8k_agent_env.py process \\
|
|
--env.data_path_to_save_groups gsm8k_agent_output.jsonl
|
|
|
|
# Phase 2 (VLLM + Tinker training):
|
|
run-api
|
|
python launch_training.py --config configs/gsm8k_agent.yaml
|
|
python environments/gsm8k_agent_env.py serve
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
# Ensure repo root is on sys.path
|
|
_repo_root = Path(__file__).resolve().parent.parent
|
|
if str(_repo_root) not in sys.path:
|
|
sys.path.insert(0, str(_repo_root))
|
|
|
|
from atroposlib.envs.base import ScoredDataGroup
|
|
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
|
from atroposlib.type_definitions import Item
|
|
|
|
from environments.agent_loop import AgentResult
|
|
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
|
from environments.tool_context import ToolContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# Math verification helpers
|
|
# =============================================================================
|
|
|
|
def _verify_math_answer(model_response: str, gold_answer: str) -> bool:
|
|
"""
|
|
Verify if the model's response contains the correct answer.
|
|
Uses math_verify for robust LaTeX comparison, falls back to string matching.
|
|
"""
|
|
try:
|
|
from latex2sympy2_extended import NormalizationConfig
|
|
from math_verify import LatexExtractionConfig, parse, verify
|
|
|
|
gold_parsed = parse(
|
|
f"\\boxed{{{gold_answer}}}",
|
|
extraction_mode="first_match",
|
|
extraction_config=[LatexExtractionConfig()],
|
|
)
|
|
|
|
# Strip <think> blocks if present
|
|
answer_text = model_response
|
|
if "</think>" in answer_text:
|
|
answer_text = answer_text.split("</think>")[-1]
|
|
|
|
answer_parsed = parse(
|
|
answer_text,
|
|
extraction_config=[
|
|
LatexExtractionConfig(
|
|
normalization_config=NormalizationConfig(
|
|
nits=False,
|
|
malformed_operators=False,
|
|
basic_latex=True,
|
|
boxed="all",
|
|
units=True,
|
|
),
|
|
boxed_match_priority=0,
|
|
try_extract_without_anchor=False,
|
|
)
|
|
],
|
|
extraction_mode="first_match",
|
|
)
|
|
|
|
return bool(verify(answer_parsed, gold_parsed))
|
|
|
|
except ImportError:
|
|
# Fallback: simple string matching for \\boxed{answer}
|
|
import re
|
|
pattern = r'\\boxed\{([^}]+)\}'
|
|
matches = re.findall(pattern, model_response)
|
|
if matches:
|
|
model_answer = matches[-1].strip().replace(",", "")
|
|
gold_clean = gold_answer.strip().replace(",", "")
|
|
return model_answer == gold_clean
|
|
return False
|
|
|
|
|
|
# =============================================================================
|
|
# Environment Config
|
|
# =============================================================================
|
|
|
|
class GSM8kAgentEnvConfig(HermesAgentEnvConfig):
|
|
"""Config with defaults for GSM8k agent environment."""
|
|
pass
|
|
|
|
|
|
# =============================================================================
|
|
# Environment
|
|
# =============================================================================
|
|
|
|
class GSM8kAgentEnv(HermesAgentBaseEnv):
|
|
"""
|
|
GSM8k math environment with Python REPL tool calling.
|
|
|
|
Models solve grade-school math problems by reasoning step by step
|
|
and using Python (via the terminal tool) for calculations.
|
|
|
|
Exercises the full agentic RL training loop:
|
|
- Model receives math problem
|
|
- Makes tool calls to compute (python3 -c "...")
|
|
- Provides final answer in \\boxed{}
|
|
- Reward: binary (1.0 correct, 0.0 wrong)
|
|
"""
|
|
|
|
name = "gsm8k-agent"
|
|
env_config_cls = GSM8kAgentEnvConfig
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[GSM8kAgentEnvConfig, List[APIServerConfig]]:
|
|
"""
|
|
Default config using terminal tool.
|
|
|
|
Reads from environment variables (set in .env):
|
|
ATROPOS_SERVER_BASE_URL - Inference server URL
|
|
ATROPOS_SERVER_MODEL - Model name on the server
|
|
ATROPOS_TOKENIZER_NAME - HuggingFace tokenizer name
|
|
ATROPOS_SERVER_API_KEY - API key for the server
|
|
"""
|
|
# Resolve inference server settings from env
|
|
base_url = (
|
|
os.getenv("ATROPOS_SERVER_BASE_URL")
|
|
or os.getenv("OPENAI_BASE_URL")
|
|
or os.getenv("LLM_BASE_URL")
|
|
or "https://openrouter.ai/api/v1"
|
|
)
|
|
if not base_url.rstrip("/").endswith("/v1"):
|
|
base_url = base_url.rstrip("/") + "/v1"
|
|
|
|
model = (
|
|
os.getenv("ATROPOS_SERVER_MODEL")
|
|
or os.getenv("LLM_MODEL")
|
|
or "Hermes-4.3-36B"
|
|
)
|
|
|
|
api_key = (
|
|
os.getenv("ATROPOS_SERVER_API_KEY")
|
|
or os.getenv("NOUS_API_KEY")
|
|
or os.getenv("OPENROUTER_API_KEY")
|
|
or os.getenv("OPENAI_API_KEY")
|
|
or ""
|
|
)
|
|
|
|
tokenizer = (
|
|
os.getenv("ATROPOS_TOKENIZER_NAME")
|
|
or os.getenv("ATROPOS_TOKENIZER")
|
|
or "NousResearch/Hermes-4.3-36B"
|
|
)
|
|
|
|
env_config = GSM8kAgentEnvConfig(
|
|
# Terminal tool only -- model uses `python3 -c "..."` for math
|
|
enabled_toolsets=["terminal"],
|
|
disabled_toolsets=None,
|
|
distribution=None,
|
|
# Agent settings
|
|
max_agent_turns=5, # Math problems don't need many turns
|
|
max_token_length=2048, # Room for reasoning + code
|
|
agent_temperature=1.0,
|
|
system_prompt=(
|
|
"You are a helpful math assistant. You have access to a terminal "
|
|
"where you can run Python code to help solve problems.\n\n"
|
|
"When you need to calculate something, use the terminal tool with "
|
|
"a command like: python3 -c \"print(2 + 2)\"\n\n"
|
|
"When you have the final answer, write it inside \\boxed{} like: \\boxed{42}\n\n"
|
|
"Work step by step. Use Python to verify your reasoning."
|
|
),
|
|
# Terminal backend (local for testing, modal for production)
|
|
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
|
# Parser -- hermes format for Hermes models
|
|
tool_call_parser="hermes",
|
|
# Atropos settings
|
|
group_size=4,
|
|
tokenizer_name=tokenizer,
|
|
steps_per_eval=5,
|
|
total_steps=10,
|
|
use_wandb=bool(os.getenv("WANDB_API_KEY")),
|
|
wandb_name="gsm8k-agent",
|
|
ensure_scores_are_not_same=False,
|
|
# No external dataset (we load GSM8k ourselves)
|
|
dataset_name=None,
|
|
)
|
|
|
|
server_configs = [
|
|
APIServerConfig(
|
|
base_url=base_url,
|
|
model_name=model,
|
|
server_type="openai",
|
|
api_key=api_key,
|
|
health_check=False,
|
|
)
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def setup(self):
|
|
"""Load GSM8k dataset."""
|
|
from datasets import load_dataset
|
|
|
|
self.train = load_dataset("gsm8k", "main", split="train").shuffle(seed=42)
|
|
test_data = load_dataset("gsm8k", "main", split="test").shuffle(seed=42)
|
|
self.test = [
|
|
{
|
|
"question": item["question"],
|
|
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
|
}
|
|
for item in test_data
|
|
]
|
|
self.iter = 0
|
|
self.reward_buffer: List[float] = []
|
|
self.tool_use_buffer: List[int] = []
|
|
print(f"[GSM8kAgentEnv] Loaded {len(self.train)} train, {len(self.test)} test examples")
|
|
|
|
async def get_next_item(self) -> Dict[str, str]:
|
|
"""Cycle through training problems."""
|
|
item = self.train[self.iter % len(self.train)]
|
|
self.iter += 1
|
|
return {
|
|
"question": item["question"],
|
|
"gold_answer": item["answer"].split("#")[-1].strip().replace(",", ""),
|
|
}
|
|
|
|
def format_prompt(self, item: Dict[str, str]) -> str:
|
|
"""Format the math problem as a user message."""
|
|
return item["question"]
|
|
|
|
async def compute_reward(
|
|
self, item: Dict[str, str], result: AgentResult, ctx: ToolContext
|
|
) -> float:
|
|
"""
|
|
Score: verify the model's \\boxed{} answer against the gold answer.
|
|
|
|
The agent has full access to terminal via ctx, but for GSM8k we just
|
|
check the final answer from the conversation.
|
|
"""
|
|
# Get the last assistant message content
|
|
final_text = ""
|
|
for msg in reversed(result.messages):
|
|
if msg.get("role") == "assistant" and msg.get("content"):
|
|
final_text = msg["content"]
|
|
break
|
|
|
|
correct = _verify_math_answer(final_text, item["gold_answer"])
|
|
reward = 1.0 if correct else 0.0
|
|
|
|
self.reward_buffer.append(reward)
|
|
# Count tool calls in this trajectory
|
|
tool_call_count = sum(
|
|
len(msg.get("tool_calls", []))
|
|
for msg in result.messages
|
|
if msg.get("role") == "assistant"
|
|
)
|
|
self.tool_use_buffer.append(tool_call_count)
|
|
|
|
return reward
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""Evaluate on a subset of the test set (greedy, no tools for speed)."""
|
|
start_time = time.time()
|
|
correct = 0
|
|
total = 0
|
|
samples = []
|
|
|
|
eval_subset = self.test[:30] # Small subset for quick eval
|
|
|
|
for item in eval_subset:
|
|
try:
|
|
completion = await self.server.chat_completion(
|
|
messages=[
|
|
{"role": "system", "content": self.config.system_prompt or ""},
|
|
{"role": "user", "content": item["question"]},
|
|
],
|
|
n=1,
|
|
max_tokens=self.config.max_token_length,
|
|
temperature=0.0,
|
|
split="eval",
|
|
)
|
|
|
|
response = completion.choices[0].message.content or ""
|
|
is_correct = _verify_math_answer(response, item["gold_answer"])
|
|
|
|
if is_correct:
|
|
correct += 1
|
|
total += 1
|
|
|
|
samples.append({
|
|
"question": item["question"],
|
|
"gold_answer": item["gold_answer"],
|
|
"response": response[:500],
|
|
"correct": is_correct,
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error("Eval failed: %s", e)
|
|
total += 1
|
|
|
|
percent_correct = correct / total if total > 0 else 0
|
|
end_time = time.time()
|
|
|
|
await self.evaluate_log(
|
|
metrics={"eval/percent_correct": percent_correct, "eval/total": total},
|
|
samples=samples,
|
|
start_time=start_time,
|
|
end_time=end_time,
|
|
)
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log training metrics."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
if self.reward_buffer:
|
|
wandb_metrics["train/percent_correct"] = sum(self.reward_buffer) / len(self.reward_buffer)
|
|
wandb_metrics["train/total_rollouts"] = len(self.reward_buffer)
|
|
self.reward_buffer = []
|
|
|
|
if self.tool_use_buffer:
|
|
wandb_metrics["train/avg_tool_calls"] = sum(self.tool_use_buffer) / len(self.tool_use_buffer)
|
|
wandb_metrics["train/tool_use_rate"] = sum(1 for t in self.tool_use_buffer if t > 0) / len(self.tool_use_buffer)
|
|
self.tool_use_buffer = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
GSM8kAgentEnv.cli()
|