mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
217 lines
6.9 KiB
Python
217 lines
6.9 KiB
Python
"""
|
|
Simple test environment for validating the atropos-agent setup.
|
|
|
|
This environment uses a local OpenAI-compatible server for LLM testing to verify:
|
|
- BaseEnv extension works correctly
|
|
- API communication via OpenAI-compatible endpoint
|
|
- Basic trajectory collection
|
|
|
|
This is a minimal environment for testing, not production use.
|
|
"""
|
|
|
|
import os
|
|
from typing import Dict, List, Optional, Tuple
|
|
|
|
from dotenv import load_dotenv
|
|
from pydantic import Field
|
|
|
|
from atroposlib.envs.base import (
|
|
APIServerConfig,
|
|
Item,
|
|
)
|
|
|
|
from ..agent import AgentConfig
|
|
from .agent_env import AgentEnv, AgentEnvConfig
|
|
|
|
# Load environment variables from .env file
|
|
load_dotenv()
|
|
|
|
|
|
# Simple test prompts for validation
|
|
TEST_PROMPTS = [
|
|
{
|
|
"prompt": "What is 2 + 2? Answer with just the number.",
|
|
"expected": "4",
|
|
},
|
|
{
|
|
"prompt": "What is the capital of France? Answer with just the city name.",
|
|
"expected": "Paris",
|
|
},
|
|
{
|
|
"prompt": "What color is the sky on a clear day? Answer with just the color.",
|
|
"expected": "Blue",
|
|
},
|
|
{
|
|
"prompt": "How many days are in a week? Answer with just the number.",
|
|
"expected": "7",
|
|
},
|
|
{
|
|
"prompt": "What is 10 * 5? Answer with just the number.",
|
|
"expected": "50",
|
|
},
|
|
]
|
|
|
|
SYSTEM_PROMPT = (
|
|
"You are a helpful assistant. Answer questions concisely and directly. "
|
|
"When asked for a simple answer, provide just that answer without explanation."
|
|
)
|
|
|
|
|
|
class SimpleTestEnvConfig(AgentEnvConfig):
|
|
"""Configuration for the simple test environment."""
|
|
|
|
server_base_url: str = Field(
|
|
default="http://127.0.0.1:8080",
|
|
description="Base URL for an OpenAI-compatible server (without /v1)",
|
|
)
|
|
server_model: str = Field(
|
|
default="hermes-4-36b",
|
|
description="Model name",
|
|
)
|
|
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
|
|
|
|
|
class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]):
|
|
"""
|
|
A simple test environment to validate the atropos-agent setup.
|
|
|
|
Uses a local OpenAI-compatible LLM endpoint with basic question-answering tasks.
|
|
Scoring is based on whether the response contains the expected answer.
|
|
"""
|
|
|
|
name = "simple_test_env"
|
|
env_config_cls = SimpleTestEnvConfig
|
|
|
|
def __init__(
|
|
self,
|
|
config: SimpleTestEnvConfig,
|
|
server_configs: List[APIServerConfig],
|
|
slurm: bool = False,
|
|
testing: bool = False,
|
|
):
|
|
super().__init__(config, server_configs, slurm, testing)
|
|
self.iter = 0
|
|
self.test_prompts = TEST_PROMPTS
|
|
self.percent_correct_buffer: List[float] = []
|
|
|
|
@classmethod
|
|
def config_init(cls) -> Tuple[SimpleTestEnvConfig, List[APIServerConfig]]:
|
|
"""
|
|
Initialize configuration with local server settings from environment variables.
|
|
"""
|
|
base_url = (
|
|
os.getenv("ATROPOS_SERVER_BASE_URL")
|
|
or os.getenv("OPENAI_BASE_URL")
|
|
or os.getenv("LLM_BASE_URL")
|
|
or "http://127.0.0.1:8080"
|
|
)
|
|
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
|
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
|
|
|
env_config = SimpleTestEnvConfig(
|
|
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
|
group_size=4,
|
|
use_wandb=False, # Disable wandb for simple testing
|
|
rollout_server_url="http://localhost:8000",
|
|
total_steps=10,
|
|
batch_size=16,
|
|
steps_per_eval=5,
|
|
max_token_length=2048,
|
|
inference_weight=1.0,
|
|
wandb_name="simple_test",
|
|
server_base_url=base_url,
|
|
server_model=model,
|
|
)
|
|
|
|
# OpenAI-compatible servers typically expose chat completions at /v1.
|
|
server_configs = [
|
|
APIServerConfig(
|
|
model_name=model,
|
|
base_url=f"{base_url}/v1",
|
|
api_key=api_key,
|
|
num_max_requests_at_once=4,
|
|
num_requests_for_eval=8,
|
|
timeout=120, # Local models may be slower
|
|
),
|
|
]
|
|
|
|
return env_config, server_configs
|
|
|
|
async def setup_agent_env(self):
|
|
"""Setup the environment - load test data."""
|
|
print(f"SimpleTestEnv setup complete. {len(self.test_prompts)} test prompts loaded.")
|
|
print(f"Using server at: {self.config.server_base_url}")
|
|
print(f"Model: {self.config.server_model}")
|
|
|
|
async def get_next_item(self) -> Item:
|
|
"""Get the next test prompt."""
|
|
item = self.test_prompts[self.iter % len(self.test_prompts)]
|
|
self.iter += 1
|
|
return item
|
|
|
|
def build_task(self, item: Item) -> str:
|
|
return item["prompt"]
|
|
|
|
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
|
|
return AgentConfig(
|
|
max_steps=5,
|
|
temperature=0.7,
|
|
max_tokens=256,
|
|
system_prompt=SYSTEM_PROMPT,
|
|
)
|
|
|
|
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
|
expected = item["expected"].lower()
|
|
response_lower = (final_response or "").lower()
|
|
score = 1.0 if expected in response_lower else 0.0
|
|
self.percent_correct_buffer.append(score)
|
|
return score
|
|
|
|
async def evaluate(self, *args, **kwargs):
|
|
"""
|
|
Simple evaluation - run through all test prompts once.
|
|
"""
|
|
correct = 0
|
|
total = len(self.test_prompts)
|
|
|
|
for item in self.test_prompts:
|
|
messages = [
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": item["prompt"]},
|
|
]
|
|
|
|
response = await self.server.chat_completion(
|
|
messages=messages,
|
|
n=1,
|
|
max_tokens=256,
|
|
temperature=0.0, # Greedy for eval
|
|
split="eval",
|
|
)
|
|
|
|
response_text = response.choices[0].message.content or ""
|
|
expected = item["expected"].lower()
|
|
|
|
if expected in response_text.lower():
|
|
correct += 1
|
|
|
|
accuracy = correct / total
|
|
print(f"Evaluation: {correct}/{total} = {accuracy:.2%} accuracy")
|
|
return {"eval_accuracy": accuracy}
|
|
|
|
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
|
"""Log metrics (simplified for testing)."""
|
|
if wandb_metrics is None:
|
|
wandb_metrics = {}
|
|
|
|
if self.percent_correct_buffer:
|
|
avg_correct = sum(self.percent_correct_buffer) / len(self.percent_correct_buffer)
|
|
wandb_metrics["train/percent_correct"] = avg_correct
|
|
print(f"Train accuracy: {avg_correct:.2%}")
|
|
self.percent_correct_buffer = []
|
|
|
|
await super().wandb_log(wandb_metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Allow running as CLI
|
|
SimpleTestEnv.cli()
|