hermes-agent/atropos/envs/test_env.py

217 lines
6.9 KiB
Python

"""
Simple test environment for validating the atropos-agent setup.
This environment uses a local OpenAI-compatible server for LLM testing to verify:
- BaseEnv extension works correctly
- API communication via OpenAI-compatible endpoint
- Basic trajectory collection
This is a minimal environment for testing, not production use.
"""
import os
from typing import Dict, List, Optional, Tuple
from dotenv import load_dotenv
from pydantic import Field
from atroposlib.envs.base import (
APIServerConfig,
Item,
)
from ..agent import AgentConfig
from .agent_env import AgentEnv, AgentEnvConfig
# Load environment variables from .env file
load_dotenv()
# Simple test prompts for validation
TEST_PROMPTS = [
{
"prompt": "What is 2 + 2? Answer with just the number.",
"expected": "4",
},
{
"prompt": "What is the capital of France? Answer with just the city name.",
"expected": "Paris",
},
{
"prompt": "What color is the sky on a clear day? Answer with just the color.",
"expected": "Blue",
},
{
"prompt": "How many days are in a week? Answer with just the number.",
"expected": "7",
},
{
"prompt": "What is 10 * 5? Answer with just the number.",
"expected": "50",
},
]
SYSTEM_PROMPT = (
"You are a helpful assistant. Answer questions concisely and directly. "
"When asked for a simple answer, provide just that answer without explanation."
)
class SimpleTestEnvConfig(AgentEnvConfig):
"""Configuration for the simple test environment."""
server_base_url: str = Field(
default="http://127.0.0.1:8080",
description="Base URL for an OpenAI-compatible server (without /v1)",
)
server_model: str = Field(
default="hermes-4-36b",
description="Model name",
)
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
"""
A simple test environment to validate the atropos-agent setup.
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
Scoring is based on whether the response contains the expected answer.
"""
name = "simple_test_env"
env_config_cls = SimpleTestEnvConfig
def __init__(
self,
config: SimpleTestEnvConfig,
server_configs: List[APIServerConfig],
slurm: bool = False,
testing: bool = False,
):
super().__init__(config, server_configs, slurm, testing)
self.iter = 0
self.test_prompts = TEST_PROMPTS
self.percent_correct_buffer: List[float] = []
@classmethod
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
"""
Initialize configuration with local server settings from environment variables.
"""
base_url = (
os.getenv("ATROPOS_SERVER_BASE_URL")
or os.getenv("OPENAI_BASE_URL")
or os.getenv("LLM_BASE_URL")
or "http://127.0.0.1:8080"
)
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
env_config = SimpleTestEnvConfig(
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
group_size=4,
use_wandb=False, # Disable wandb for simple testing
rollout_server_url="http://localhost:8000",
total_steps=10,
batch_size=16,
steps_per_eval=5,
max_token_length=2048,
inference_weight=1.0,
wandb_name="simple_test",
server_base_url=base_url,
server_model=model,
)
# OpenAI-compatible servers typically expose chat completions at /v1.
server_configs = [
APIServerConfig(
model_name=model,
base_url=f"{base_url}/v1",
api_key=api_key,
num_max_requests_at_once=4,
num_requests_for_eval=8,
timeout=120, # Local models may be slower
),
]
return env_config, server_configs
async def setup_agent_env(self):
"""Setup the environment - load test data."""
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
print(f"Using server at: {self.config.server_base_url}")
print(f"Model: {self.config.server_model}")
async def get_next_item(self) -> Item:
"""Get the next test prompt."""
item = self.test_prompts[self.iter % len(self.test_prompts)]
self.iter += 1
return item
def build_task(self, item: Item) -> str:
return item["prompt"]
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
return AgentConfig(
max_steps=5,
temperature=0.7,
max_tokens=256,
system_prompt=SYSTEM_PROMPT,
)
async def score_trajectory(self, item: Item, final_response: str) -> float:
expected = item["expected"].lower()
response_lower = (final_response or "").lower()
score = 1.0 if expected in response_lower else 0.0
self.percent_correct_buffer.append(score)
return score
async def evaluate(self, *args, **kwargs):
"""
Simple evaluation - run through all test prompts once.
"""
correct = 0
total = len(self.test_prompts)
for item in self.test_prompts:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": item["prompt"]},
]
response = await self.server.chat_completion(
messages=messages,
n=1,
max_tokens=256,
temperature=0.0, # Greedy for eval
split="eval",
)
response_text = response.choices[0].message.content or ""
expected = item["expected"].lower()
if expected in response_text.lower():
correct += 1
accuracy = correct / total
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
return {"eval_accuracy": accuracy}
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log metrics (simplified for testing)."""
if wandb_metrics is None:
wandb_metrics = {}
if self.percent_correct_buffer:
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
wandb_metrics["train/percent_correct"] = avg_correct
print(f"Train accuracy: {avg_correct:.2%}")
self.percent_correct_buffer = []
await super().wandb_log(wandb_metrics)
if __name__ == "__main__":
# Allow running as CLI
SimpleTestEnv.cli()