mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-27 01:11:40 +00:00
Added task sppecific metris and evals
This commit is contained in:
parent
b7e713b101
commit
6fdb38ed29
2 changed files with 249 additions and 33 deletions
|
|
@ -90,6 +90,12 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig):
|
|||
# Agent defaults
|
||||
max_agent_turns: int = Field(default=32, description="Max turns for agent (increased for long traces)")
|
||||
|
||||
# Evaluation settings
|
||||
num_eval_tasks: int = Field(
|
||||
default=10,
|
||||
description="Number of tasks to run during periodic evaluation"
|
||||
)
|
||||
|
||||
|
||||
class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||
"""
|
||||
|
|
@ -141,15 +147,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
self._dataset_indices = []
|
||||
self._current_index = 0
|
||||
|
||||
# Metrics tracking for wandb - single buffer with dicts
|
||||
self._metrics_buffer = []
|
||||
|
||||
# Debug: check server config
|
||||
if hasattr(self, 'server') and hasattr(self.server, 'servers'):
|
||||
for i, srv in enumerate(self.server.servers):
|
||||
print(f"[DEBUG] Server {i}: model_name={getattr(srv.config, 'model_name', 'NONE')}", flush=True)
|
||||
logger.debug(f"Server {i}: model_name={getattr(srv.config, 'model_name', 'NONE')}")
|
||||
|
||||
async def setup(self):
|
||||
"""Load dataset from HuggingFace or local directory."""
|
||||
if not self.config.use_dataset:
|
||||
print("[EndlessTerminalsEnv] Using procedural task generation (not implemented yet)", flush=True)
|
||||
logger.info("Using procedural task generation (not implemented yet)")
|
||||
return
|
||||
|
||||
# If tasks_base_dir is set, load from local directory instead of HuggingFace
|
||||
|
|
@ -165,18 +174,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
if not tasks_base.exists():
|
||||
raise RuntimeError(f"tasks_base_dir not found: {tasks_base}")
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Loading tasks from local directory: {tasks_base}", flush=True)
|
||||
logger.info(f"Loading tasks from local directory: {tasks_base}")
|
||||
|
||||
# Find all task_* directories
|
||||
task_dirs = sorted(tasks_base.glob("task_*"))
|
||||
print(f"[EndlessTerminalsEnv] Found {len(task_dirs)} task directories", flush=True)
|
||||
logger.info(f"Found {len(task_dirs)} task directories")
|
||||
|
||||
if not task_dirs:
|
||||
# Debug: show what's actually in the directory
|
||||
all_items = list(tasks_base.iterdir())
|
||||
print(f"[EndlessTerminalsEnv] Directory contains {len(all_items)} items:", flush=True)
|
||||
logger.warning(f"Directory contains {len(all_items)} items:")
|
||||
for item in all_items[:10]:
|
||||
print(f" - {item.name} ({'dir' if item.is_dir() else 'file'})", flush=True)
|
||||
logger.warning(f" - {item.name} ({'dir' if item.is_dir() else 'file'})")
|
||||
raise RuntimeError(f"No task_* directories found in {tasks_base}")
|
||||
|
||||
# Create fake dataset items (just the directory paths)
|
||||
|
|
@ -193,11 +202,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from local directory", flush=True)
|
||||
logger.info(f"Loaded {len(self._dataset)} tasks from local directory")
|
||||
return
|
||||
|
||||
# Otherwise, load from HuggingFace
|
||||
print(f"[EndlessTerminalsEnv] Loading dataset from HuggingFace: {self.config.dataset_name}", flush=True)
|
||||
logger.info(f"Loading dataset from HuggingFace: {self.config.dataset_name}")
|
||||
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
|
|
@ -216,10 +225,10 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
|
||||
print(f"[EndlessTerminalsEnv] Loaded {len(self._dataset)} tasks from HuggingFace", flush=True)
|
||||
logger.info(f"Loaded {len(self._dataset)} tasks from HuggingFace")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EndlessTerminalsEnv] ERROR loading dataset: {e}", flush=True)
|
||||
logger.error(f"ERROR loading dataset: {e}")
|
||||
raise
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
|
|
@ -237,7 +246,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
# Reshuffle for next epoch
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
print("[EndlessTerminalsEnv] Reshuffled dataset (completed one epoch)", flush=True)
|
||||
logger.info("Reshuffled dataset (completed one epoch)")
|
||||
|
||||
# Extract task directory path
|
||||
task_dir = task.get("extra_info", {}).get("task_dir")
|
||||
|
|
@ -258,8 +267,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
|
||||
# Verify directory exists
|
||||
if not task_dir_path.exists():
|
||||
print(f"[EndlessTerminalsEnv] WARNING: Task dir not found: {task_dir_path}", flush=True)
|
||||
print(f"[EndlessTerminalsEnv] Hint: Set tasks_base_dir to directory containing task_* folders", flush=True)
|
||||
logger.warning(f"Task dir not found: {task_dir_path}")
|
||||
logger.warning("Hint: Set tasks_base_dir to directory containing task_* folders")
|
||||
return await self.get_next_item() # Try next task
|
||||
|
||||
# Look for test file in tests/ subdirectory first, then at root
|
||||
|
|
@ -366,32 +375,32 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
task_name = item.get("task_name", "unknown")
|
||||
docker_image = item.get("docker_image", self.config.default_docker_image)
|
||||
|
||||
print(f"[DEBUG] collect_trajectory START for {task_name}", flush=True)
|
||||
logger.debug(f"collect_trajectory START for {task_name}")
|
||||
|
||||
# Register Docker image override for this task_id
|
||||
print(f"[DEBUG] Registering Docker image: {docker_image}", flush=True)
|
||||
logger.debug(f"Registering Docker image: {docker_image}")
|
||||
register_task_env_overrides(task_id, {"modal_image": docker_image})
|
||||
logger.info(
|
||||
f"Task {task_name}: registered Docker image {docker_image} for task_id {task_id[:8]}"
|
||||
)
|
||||
print(f"[DEBUG] Docker image registered", flush=True)
|
||||
logger.debug("Docker image registered")
|
||||
|
||||
try:
|
||||
# Get group-level tools (resolved once in collect_trajectories)
|
||||
print(f"[DEBUG] Resolving tools...", flush=True)
|
||||
logger.debug("Resolving tools...")
|
||||
if self._current_group_tools is None:
|
||||
tools, valid_names = self._resolve_tools_for_group()
|
||||
else:
|
||||
tools, valid_names = self._current_group_tools
|
||||
print(f"[DEBUG] Tools resolved: {len(tools)} tools", flush=True)
|
||||
logger.debug(f"Tools resolved: {len(tools)} tools")
|
||||
|
||||
# Build initial messages
|
||||
print(f"[DEBUG] Building initial messages...", flush=True)
|
||||
logger.debug("Building initial messages...")
|
||||
messages: List[Dict[str, Any]] = []
|
||||
if self.config.system_prompt:
|
||||
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||
print(f"[DEBUG] Messages built, starting agent loop...", flush=True)
|
||||
logger.debug("Messages built, starting agent loop...")
|
||||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
|
|
@ -472,16 +481,28 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track tool errors for wandb logging
|
||||
# Track metrics for wandb logging
|
||||
task_metrics = {
|
||||
"test_passed": 1.0 if reward > 0.5 else 0.0,
|
||||
"reward": reward,
|
||||
"turns_used": result.turns_used,
|
||||
"finished_naturally": result.finished_naturally,
|
||||
"docker_image": docker_image,
|
||||
"num_tool_errors": len(result.tool_errors),
|
||||
}
|
||||
|
||||
# Include detailed tool errors if any occurred
|
||||
if result.tool_errors:
|
||||
for err in result.tool_errors:
|
||||
self._tool_error_buffer.append({
|
||||
task_metrics["tool_errors"] = [
|
||||
{
|
||||
"turn": err.turn,
|
||||
"tool": err.tool_name,
|
||||
"args": err.arguments[:150],
|
||||
"error": err.error[:300],
|
||||
"result": err.tool_result[:300],
|
||||
})
|
||||
"error": err.error[:200],
|
||||
}
|
||||
for err in result.tool_errors
|
||||
]
|
||||
|
||||
self._metrics_buffer.append(task_metrics)
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
|
|
@ -490,6 +511,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
|
||||
if nodes:
|
||||
# Phase 2: use actual node data
|
||||
# nodes[-1] contains the full accumulated trajectory from all turns
|
||||
node = nodes[-1]
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
|
|
@ -497,6 +519,7 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
"scores": reward,
|
||||
}
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["logprobs"] = node.logprobs
|
||||
scored_item["advantages"] = None
|
||||
scored_item["ref_logprobs"] = None
|
||||
else:
|
||||
|
|
@ -622,14 +645,206 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
return 0.0
|
||||
|
||||
async def evaluate(self):
|
||||
"""Periodic evaluation (optional)."""
|
||||
return {}
|
||||
"""
|
||||
Periodic evaluation on a fixed set of tasks.
|
||||
|
||||
Runs the agent on num_eval_tasks tasks and measures performance
|
||||
without affecting training. Returns metrics for wandb logging.
|
||||
"""
|
||||
if self._dataset is None:
|
||||
logger.warning("Cannot evaluate: dataset not loaded")
|
||||
return {}
|
||||
|
||||
logger.info(f"Starting evaluation on {self.config.num_eval_tasks} tasks...")
|
||||
|
||||
eval_metrics = {
|
||||
"rewards": [],
|
||||
"passes": [],
|
||||
"turns": [],
|
||||
"natural_finishes": [],
|
||||
}
|
||||
|
||||
# Sample eval tasks randomly
|
||||
import random
|
||||
eval_indices = random.sample(range(len(self._dataset)), min(self.config.num_eval_tasks, len(self._dataset)))
|
||||
|
||||
for idx in eval_indices:
|
||||
task = self._dataset[idx]
|
||||
|
||||
# Build item using same logic as get_next_item
|
||||
task_dir = task.get("extra_info", {}).get("task_dir")
|
||||
if not task_dir:
|
||||
task_dir = task.get("reward_spec", {}).get("ground_truth")
|
||||
|
||||
if not task_dir:
|
||||
continue
|
||||
|
||||
task_dir_path = Path(task_dir)
|
||||
if self.config.tasks_base_dir and not task_dir_path.exists():
|
||||
original_path = Path(task_dir)
|
||||
task_name = original_path.name
|
||||
task_dir_path = Path(os.path.expanduser(self.config.tasks_base_dir)) / task_name
|
||||
|
||||
if not task_dir_path.exists():
|
||||
continue
|
||||
|
||||
# Find test file
|
||||
final_test = task_dir_path / "tests" / "test_final_state.py"
|
||||
if not final_test.exists():
|
||||
final_test = task_dir_path / "test_final_state.py"
|
||||
if not final_test.exists():
|
||||
continue
|
||||
|
||||
# Parse Docker image
|
||||
container_def = task_dir_path / "environment" / "container.def"
|
||||
if not container_def.exists():
|
||||
container_def = task_dir_path / "container.def"
|
||||
docker_image = self._parse_docker_image_from_def(container_def)
|
||||
|
||||
# Load description
|
||||
description = task.get("description", "")
|
||||
instruction_md = task_dir_path / "instruction.md"
|
||||
if not description and instruction_md.exists():
|
||||
try:
|
||||
description = instruction_md.read_text().strip()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
item = {
|
||||
"description": description,
|
||||
"final_test": str(final_test),
|
||||
"docker_image": docker_image,
|
||||
}
|
||||
|
||||
# Run agent on this task
|
||||
try:
|
||||
import uuid
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
# Register task environment
|
||||
from model_tools import register_task_env_overrides
|
||||
register_task_env_overrides(task_id, {"modal_image": docker_image})
|
||||
|
||||
# Build messages
|
||||
messages = [
|
||||
{"role": "system", "content": self.config.system_prompt},
|
||||
{"role": "user", "content": description or "Complete the task."},
|
||||
]
|
||||
|
||||
# Get tools
|
||||
from model_tools import get_tool_definitions
|
||||
tools = get_tool_definitions(self.config.enabled_toolsets)
|
||||
valid_names = {t["function"]["name"] for t in tools}
|
||||
|
||||
# Run agent
|
||||
from environments.agent_loop import HermesAgentLoop
|
||||
agent = HermesAgentLoop(
|
||||
server=self.server,
|
||||
tool_schemas=tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=self.config.max_agent_turns,
|
||||
task_id=task_id,
|
||||
temperature=self.config.agent_temperature,
|
||||
max_tokens=self.config.max_token_length,
|
||||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Compute reward
|
||||
from environments.tool_context import ToolContext
|
||||
ctx = ToolContext(task_id)
|
||||
try:
|
||||
reward = await self.compute_reward(item, result, ctx)
|
||||
except Exception as e:
|
||||
logger.warning(f"Eval reward computation failed: {e}")
|
||||
reward = 0.0
|
||||
finally:
|
||||
ctx.cleanup()
|
||||
|
||||
# Track metrics
|
||||
eval_metrics["rewards"].append(reward)
|
||||
eval_metrics["passes"].append(1.0 if reward > 0.5 else 0.0)
|
||||
eval_metrics["turns"].append(result.turns_used)
|
||||
eval_metrics["natural_finishes"].append(1.0 if result.finished_naturally else 0.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Eval task failed: {e}")
|
||||
continue
|
||||
finally:
|
||||
# Cleanup
|
||||
from model_tools import clear_task_env_overrides, cleanup_vm
|
||||
clear_task_env_overrides(task_id)
|
||||
cleanup_vm(task_id)
|
||||
|
||||
# Aggregate metrics
|
||||
if not eval_metrics["rewards"]:
|
||||
logger.warning("No eval tasks completed successfully")
|
||||
return {}
|
||||
|
||||
aggregated = {
|
||||
"eval/pass_rate": sum(eval_metrics["passes"]) / len(eval_metrics["passes"]),
|
||||
"eval/avg_reward": sum(eval_metrics["rewards"]) / len(eval_metrics["rewards"]),
|
||||
"eval/avg_turns": sum(eval_metrics["turns"]) / len(eval_metrics["turns"]),
|
||||
"eval/natural_finish_rate": sum(eval_metrics["natural_finishes"]) / len(eval_metrics["natural_finishes"]),
|
||||
"eval/num_tasks": len(eval_metrics["rewards"]),
|
||||
}
|
||||
|
||||
logger.info(f"Evaluation complete: pass_rate={aggregated['eval/pass_rate']:.2%}, avg_turns={aggregated['eval/avg_turns']:.1f}")
|
||||
return aggregated
|
||||
|
||||
async def wandb_log(self, wandb_metrics: Optional[Dict] = None):
|
||||
"""Log Endless Terminals specific metrics to wandb."""
|
||||
if wandb_metrics is None:
|
||||
wandb_metrics = {}
|
||||
|
||||
# Aggregate metrics from buffer
|
||||
if self._metrics_buffer:
|
||||
# Test pass rate
|
||||
test_passes = [m["test_passed"] for m in self._metrics_buffer]
|
||||
wandb_metrics["endless_terminals/test_pass_rate"] = sum(test_passes) / len(test_passes)
|
||||
wandb_metrics["endless_terminals/num_tests_passed"] = sum(test_passes)
|
||||
wandb_metrics["endless_terminals/num_tests_total"] = len(test_passes)
|
||||
|
||||
# Turns used statistics
|
||||
turns = [m["turns_used"] for m in self._metrics_buffer]
|
||||
wandb_metrics["endless_terminals/avg_turns_used"] = sum(turns) / len(turns)
|
||||
wandb_metrics["endless_terminals/max_turns_used"] = max(turns)
|
||||
wandb_metrics["endless_terminals/min_turns_used"] = min(turns)
|
||||
|
||||
# Natural finish rate (did model stop on its own vs hitting max turns)
|
||||
natural_finishes = [1.0 if m["finished_naturally"] else 0.0 for m in self._metrics_buffer]
|
||||
wandb_metrics["endless_terminals/natural_finish_rate"] = sum(natural_finishes) / len(natural_finishes)
|
||||
|
||||
# Tool error statistics
|
||||
total_tool_errors = sum(m["num_tool_errors"] for m in self._metrics_buffer)
|
||||
wandb_metrics["endless_terminals/total_tool_errors"] = total_tool_errors
|
||||
wandb_metrics["endless_terminals/avg_tool_errors_per_task"] = total_tool_errors / len(self._metrics_buffer)
|
||||
|
||||
# Docker image distribution (count unique images used)
|
||||
docker_images = [m["docker_image"] for m in self._metrics_buffer]
|
||||
unique_images = set(docker_images)
|
||||
wandb_metrics["endless_terminals/num_unique_docker_images"] = len(unique_images)
|
||||
|
||||
# Log most common errors if any
|
||||
all_errors = []
|
||||
for m in self._metrics_buffer:
|
||||
if "tool_errors" in m:
|
||||
all_errors.extend(m["tool_errors"])
|
||||
|
||||
if all_errors:
|
||||
# Count error types
|
||||
error_tools = {}
|
||||
for err in all_errors:
|
||||
tool = err["tool"]
|
||||
error_tools[tool] = error_tools.get(tool, 0) + 1
|
||||
|
||||
# Log top 3 error-prone tools
|
||||
for i, (tool, count) in enumerate(sorted(error_tools.items(), key=lambda x: x[1], reverse=True)[:3]):
|
||||
wandb_metrics[f"endless_terminals/errors_by_tool/{tool}"] = count
|
||||
|
||||
# Clear buffer after logging
|
||||
self._metrics_buffer = []
|
||||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue