Added task sppecific metris and evals

This commit is contained in:
Sam Herring 2026-02-27 11:20:18 -08:00
parent b7e713b101
commit 6fdb38ed29
2 changed files with 249 additions and 33 deletions

View file

@ -90,6 +90,12 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig):
# Agent defaults
max_agent_turns: int = Field(default=32, description="Max turns for agent (increased for long traces)")
# Evaluation settings
num_eval_tasks: int = Field(
default=10,
description="Number of tasks to run during periodic evaluation"
)
class EndlessTerminalsEnv(HermesAgentBaseEnv):
"""
@ -141,15 +147,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
self._dataset_indices = []
self._current_index = 0
# Metrics tracking for wandb - single buffer with dicts
self._metrics_buffer = []
# Debug: check server config
if hasattr(self, 'server') and hasattr(self.server, 'servers'):
for i, srv in enumerate(self.server.servers):
print(f"[DEBUG] Server {i}: model_name={getattr(srv.config, 'model_name', 'NONE')}", flush=True)
logger.debug(f"Server {i}: model_name={getattr(srv.config, 'model_name', 'NONE')}")
async def setup(self):
"""Load dataset from HuggingFace or local directory."""
if not self.config.use_dataset:
print("[EndlessTerminalsEnv] Using procedural task generation (not implemented yet)", flush=True)
logger.info("Using procedural task generation (not implemented yet)")
return
# If tasks_base_dir is set, load from local directory instead of HuggingFace
@ -165,18 +174,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
if not tasks_base.exists():
raise RuntimeError(f"tasks_base_dir not found: {tasks_base}")
print(f"[EndlessTerminalsEnv] Loading tasks from local directory: {tasks_base}", flush=True)
logger.info(f"Loading tasks from local directory: {tasks_base}")
# Find all task_* directories
task_dirs = sorted(tasks_base.glob("task_*"))
print(f"[EndlessTerminalsEnv] Found {len(task_dirs)} task directories", flush=True)
logger.info(f"Found {len(task_dirs)} task directories")
if not task_dirs:
# Debug: show what's actually in the directory
all_items = list(tasks_base.iterdir())
print(f"[EndlessTerminalsEnv] Directory contains {len(all_items)} items:", flush=True)
logger.warning(f"Directory contains {len(all_items)} items:")
for item in all_items[:10]:
print(f" - {item.name} ({'dir' if item.is_dir() else 'file'})", flush=True)
logger.warning(f" - {item.name} ({'dir' if item.is_dir() else 'file'})")
raise RuntimeError(f"No task_* directories found in {tasks_base}")
# Create fake dataset items (just the directory paths)
@ -193,11 +202,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
random.shuffle(self._dataset_indices)
self._current_index = 0
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from local directory", flush=True)
logger.info(f"Loaded {len(self._dataset)} tasks from local directory")
return
# Otherwise, load from HuggingFace
print(f"[EndlessTerminalsEnv] Loading dataset from HuggingFace: {self.config.dataset_name}", flush=True)
logger.info(f"Loading dataset from HuggingFace: {self.config.dataset_name}")
try:
from datasets import load_dataset
@ -216,10 +225,10 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
random.shuffle(self._dataset_indices)
self._current_index = 0
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from HuggingFace", flush=True)
logger.info(f"Loaded {len(self._dataset)} tasks from HuggingFace")
except Exception as e:
print(f"[EndlessTerminalsEnv] ERROR loading dataset: {e}", flush=True)
logger.error(f"ERROR loading dataset: {e}")
raise
async def get_next_item(self) -> Item:
@ -237,7 +246,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
# Reshuffle for next epoch
random.shuffle(self._dataset_indices)
self._current_index = 0
print("[EndlessTerminalsEnv] Reshuffled dataset (completed one epoch)", flush=True)
logger.info("Reshuffled dataset (completed one epoch)")
# Extract task directory path
task_dir = task.get("extra_info", {}).get("task_dir")
@ -258,8 +267,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
# Verify directory exists
if not task_dir_path.exists():
print(f"[EndlessTerminalsEnv] WARNING: Task dir not found: {task_dir_path}", flush=True)
print(f"[EndlessTerminalsEnv] Hint: Set tasks_base_dir to directory containing task_* folders", flush=True)
logger.warning(f"Task dir not found: {task_dir_path}")
logger.warning("Hint: Set tasks_base_dir to directory containing task_* folders")
return await self.get_next_item() # Try next task
# Look for test file in tests/ subdirectory first, then at root
@ -366,32 +375,32 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
task_name = item.get("task_name", "unknown")
docker_image = item.get("docker_image", self.config.default_docker_image)
print(f"[DEBUG] collect_trajectory START for {task_name}", flush=True)
logger.debug(f"collect_trajectory START for {task_name}")
# Register Docker image override for this task_id
print(f"[DEBUG] Registering Docker image: {docker_image}", flush=True)
logger.debug(f"Registering Docker image: {docker_image}")
register_task_env_overrides(task_id, {"modal_image": docker_image})
logger.info(
f"Task {task_name}: registered Docker image {docker_image} for task_id {task_id[:8]}"
)
print(f"[DEBUG] Docker image registered", flush=True)
logger.debug("Docker image registered")
try:
# Get group-level tools (resolved once in collect_trajectories)
print(f"[DEBUG] Resolving tools...", flush=True)
logger.debug("Resolving tools...")
if self._current_group_tools is None:
tools, valid_names = self._resolve_tools_for_group()
else:
tools, valid_names = self._current_group_tools
print(f"[DEBUG] Tools resolved: {len(tools)} tools", flush=True)
logger.debug(f"Tools resolved: {len(tools)} tools")
# Build initial messages
print(f"[DEBUG] Building initial messages...", flush=True)
logger.debug("Building initial messages...")
messages: List[Dict[str, Any]] = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": self.format_prompt(item)})
print(f"[DEBUG] Messages built, starting agent loop...", flush=True)
logger.debug("Messages built, starting agent loop...")
# Run the agent loop
result: AgentResult
@ -472,16 +481,28 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
finally:
ctx.cleanup()
# Track tool errors for wandb logging
# Track metrics for wandb logging
task_metrics = {
"test_passed": 1.0 if reward > 0.5 else 0.0,
"reward": reward,
"turns_used": result.turns_used,
"finished_naturally": result.finished_naturally,
"docker_image": docker_image,
"num_tool_errors": len(result.tool_errors),
}
# Include detailed tool errors if any occurred
if result.tool_errors:
for err in result.tool_errors:
self._tool_error_buffer.append({
task_metrics["tool_errors"] = [
{
"turn": err.turn,
"tool": err.tool_name,
"args": err.arguments[:150],
"error": err.error[:300],
"result": err.tool_result[:300],
})
"error": err.error[:200],
}
for err in result.tool_errors
]
self._metrics_buffer.append(task_metrics)
# Build ScoredDataItem from ManagedServer state
# Phase 2: real tokens/masks/logprobs from SequenceNodes
@ -490,6 +511,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
if nodes:
# Phase 2: use actual node data
# nodes[-1] contains the full accumulated trajectory from all turns
node = nodes[-1]
scored_item: Dict[str, Any] = {
"tokens": node.tokens,
@ -497,6 +519,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
"scores": reward,
}
if hasattr(node, "logprobs") and node.logprobs:
scored_item["logprobs"] = node.logprobs
scored_item["advantages"] = None
scored_item["ref_logprobs"] = None
else:
@ -622,14 +645,206 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
return 0.0
async def evaluate(self):
"""Periodic evaluation (optional)."""
return {}
"""
Periodic evaluation on a fixed set of tasks.
Runs the agent on num_eval_tasks tasks and measures performance
without affecting training. Returns metrics for wandb logging.
"""
if self._dataset is None:
logger.warning("Cannot evaluate: dataset not loaded")
return {}
logger.info(f"Starting evaluation on {self.config.num_eval_tasks} tasks...")
eval_metrics = {
"rewards": [],
"passes": [],
"turns": [],
"natural_finishes": [],
}
# Sample eval tasks randomly
import random
eval_indices = random.sample(range(len(self._dataset)), min(self.config.num_eval_tasks, len(self._dataset)))
for idx in eval_indices:
task = self._dataset[idx]
# Build item using same logic as get_next_item
task_dir = task.get("extra_info", {}).get("task_dir")
if not task_dir:
task_dir = task.get("reward_spec", {}).get("ground_truth")
if not task_dir:
continue
task_dir_path = Path(task_dir)
if self.config.tasks_base_dir and not task_dir_path.exists():
original_path = Path(task_dir)
task_name = original_path.name
task_dir_path = Path(os.path.expanduser(self.config.tasks_base_dir)) / task_name
if not task_dir_path.exists():
continue
# Find test file
final_test = task_dir_path / "tests" / "test_final_state.py"
if not final_test.exists():
final_test = task_dir_path / "test_final_state.py"
if not final_test.exists():
continue
# Parse Docker image
container_def = task_dir_path / "environment" / "container.def"
if not container_def.exists():
container_def = task_dir_path / "container.def"
docker_image = self._parse_docker_image_from_def(container_def)
# Load description
description = task.get("description", "")
instruction_md = task_dir_path / "instruction.md"
if not description and instruction_md.exists():
try:
description = instruction_md.read_text().strip()
except Exception:
pass
item = {
"description": description,
"final_test": str(final_test),
"docker_image": docker_image,
}
# Run agent on this task
try:
import uuid
task_id = str(uuid.uuid4())
# Register task environment
from model_tools import register_task_env_overrides
register_task_env_overrides(task_id, {"modal_image": docker_image})
# Build messages
messages = [
{"role": "system", "content": self.config.system_prompt},
{"role": "user", "content": description or "Complete the task."},
]
# Get tools
from model_tools import get_tool_definitions
tools = get_tool_definitions(self.config.enabled_toolsets)
valid_names = {t["function"]["name"] for t in tools}
# Run agent
from environments.agent_loop import HermesAgentLoop
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# Compute reward
from environments.tool_context import ToolContext
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.warning(f"Eval reward computation failed: {e}")
reward = 0.0
finally:
ctx.cleanup()
# Track metrics
eval_metrics["rewards"].append(reward)
eval_metrics["passes"].append(1.0 if reward > 0.5 else 0.0)
eval_metrics["turns"].append(result.turns_used)
eval_metrics["natural_finishes"].append(1.0 if result.finished_naturally else 0.0)
except Exception as e:
logger.error(f"Eval task failed: {e}")
continue
finally:
# Cleanup
from model_tools import clear_task_env_overrides, cleanup_vm
clear_task_env_overrides(task_id)
cleanup_vm(task_id)
# Aggregate metrics
if not eval_metrics["rewards"]:
logger.warning("No eval tasks completed successfully")
return {}
aggregated = {
"eval/pass_rate": sum(eval_metrics["passes"]) / len(eval_metrics["passes"]),
"eval/avg_reward": sum(eval_metrics["rewards"]) / len(eval_metrics["rewards"]),
"eval/avg_turns": sum(eval_metrics["turns"]) / len(eval_metrics["turns"]),
"eval/natural_finish_rate": sum(eval_metrics["natural_finishes"]) / len(eval_metrics["natural_finishes"]),
"eval/num_tasks": len(eval_metrics["rewards"]),
}
logger.info(f"Evaluation complete: pass_rate={aggregated['eval/pass_rate']:.2%}, avg_turns={aggregated['eval/avg_turns']:.1f}")
return aggregated
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
"""Log Endless Terminals specific metrics to wandb."""
if wandb_metrics is None:
wandb_metrics = {}
# Aggregate metrics from buffer
if self._metrics_buffer:
# Test pass rate
test_passes = [m["test_passed"] for m in self._metrics_buffer]
wandb_metrics["endless_terminals/test_pass_rate"] = sum(test_passes) / len(test_passes)
wandb_metrics["endless_terminals/num_tests_passed"] = sum(test_passes)
wandb_metrics["endless_terminals/num_tests_total"] = len(test_passes)
# Turns used statistics
turns = [m["turns_used"] for m in self._metrics_buffer]
wandb_metrics["endless_terminals/avg_turns_used"] = sum(turns) / len(turns)
wandb_metrics["endless_terminals/max_turns_used"] = max(turns)
wandb_metrics["endless_terminals/min_turns_used"] = min(turns)
# Natural finish rate (did model stop on its own vs hitting max turns)
natural_finishes = [1.0 if m["finished_naturally"] else 0.0 for m in self._metrics_buffer]
wandb_metrics["endless_terminals/natural_finish_rate"] = sum(natural_finishes) / len(natural_finishes)
# Tool error statistics
total_tool_errors = sum(m["num_tool_errors"] for m in self._metrics_buffer)
wandb_metrics["endless_terminals/total_tool_errors"] = total_tool_errors
wandb_metrics["endless_terminals/avg_tool_errors_per_task"] = total_tool_errors / len(self._metrics_buffer)
# Docker image distribution (count unique images used)
docker_images = [m["docker_image"] for m in self._metrics_buffer]
unique_images = set(docker_images)
wandb_metrics["endless_terminals/num_unique_docker_images"] = len(unique_images)
# Log most common errors if any
all_errors = []
for m in self._metrics_buffer:
if "tool_errors" in m:
all_errors.extend(m["tool_errors"])
if all_errors:
# Count error types
error_tools = {}
for err in all_errors:
tool = err["tool"]
error_tools[tool] = error_tools.get(tool, 0) + 1
# Log top 3 error-prone tools
for i, (tool, count) in enumerate(sorted(error_tools.items(), key=lambda x: x[1], reverse=True)[:3]):
wandb_metrics[f"endless_terminals/errors_by_tool/{tool}"] = count
# Clear buffer after logging
self._metrics_buffer = []
await super().wandb_log(wandb_metrics)