Updating to use hermes-agent backend and parse container definition out of provided .sif files

This commit is contained in:
Sam Herring 2026-02-24 16:35:18 -08:00
parent 9139eeaa60
commit f1c2f8a414
2 changed files with 311 additions and 61 deletions

View file

@ -2,7 +2,11 @@
Endless Terminals Environment for Hermes-Agent + Atropos RL.
Loads pre-generated terminal tasks from HuggingFace dataset and scores
agent performance using test execution in Apptainer containers.
agent performance using test execution in the agent's sandbox.
Uses hermes-agent backends (modal, docker, local) with per-task Docker images
extracted from container.def files. Tests run in the same sandbox the agent
used, following the Terminal Bench 2 pattern.
Dataset: https://huggingface.co/datasets/obiwan96/endless-terminals-train
@ -12,12 +16,13 @@ Run:
"""
import asyncio
import logging
import os
import random
import subprocess
import re
import sys
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
@ -26,12 +31,19 @@ _repo_root = Path(__file__).resolve().parent.parent.parent
if str(_repo_root) not in sys.path:
sys.path.insert(0, str(_repo_root))
from atroposlib.envs.base import ScoredDataItem
from atroposlib.envs.base import ScoredDataGroup, ScoredDataItem
from atroposlib.type_definitions import Item
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
from environments.agent_loop import AgentResult
from environments.tool_context import ToolContext
from tools.terminal_tool import (
register_task_env_overrides,
clear_task_env_overrides,
cleanup_vm,
)
logger = logging.getLogger(__name__)
# Add endless-terminals to path for imports
ENDLESS_TERMINALS_PATH = os.getenv(
@ -69,6 +81,12 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig):
# Test execution
test_timeout_s: int = Field(default=60, description="Test execution timeout (seconds)")
# Docker image fallback
default_docker_image: str = Field(
default="ubuntu:22.04",
description="Default Docker image if container.def parsing fails"
)
# Agent defaults
max_agent_turns: int = Field(default=32, description="Max turns for agent (increased for long traces)")
@ -160,17 +178,22 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv[EndlessTerminalsEnvConfig]):
container_sif = task_dir_path / "container.sif"
final_test = task_dir_path / "test_final_state.py"
# Verify files exist
if not container_sif.exists() or not final_test.exists():
print(f"[EndlessTerminalsEnv] WARNING: Missing files in {task_dir_path}", flush=True)
# Verify test file exists
if not final_test.exists():
print(f"[EndlessTerminalsEnv] WARNING: Missing test file in {task_dir_path}", flush=True)
return await self.get_next_item()
# Parse container.def to extract Docker image
container_def = task_dir_path / "container.def"
docker_image = self._parse_docker_image_from_def(container_def)
return {
"task_id": f"{task_dir_path.name}",
"task_name": task_dir_path.name,
"description": task.get("description", ""),
"task_dir": str(task_dir_path),
"container_sif": str(container_sif),
"final_test": str(final_test),
"docker_image": docker_image,
"dataset_index": idx,
}
@ -178,6 +201,208 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv[EndlessTerminalsEnvConfig]):
"""Return the task description for the agent."""
return str(item.get("description", ""))
def _parse_docker_image_from_def(self, container_def_path: Path) -> str:
"""
Parse container.def file to extract the Docker base image.
Apptainer definition files typically look like:
Bootstrap: docker
From: ubuntu:22.04
Returns the image from the "From:" line, or falls back to default.
"""
if not container_def_path.exists():
logger.warning(f"container.def not found at {container_def_path}, using default image")
return self.config.default_docker_image
try:
content = container_def_path.read_text()
# Look for "From: <image>" line (case-insensitive)
match = re.search(r'^From:\s*(.+)$', content, re.MULTILINE | re.IGNORECASE)
if match:
image = match.group(1).strip()
logger.info(f"Extracted Docker image from container.def: {image}")
return image
except Exception as e:
logger.warning(f"Failed to parse {container_def_path}: {e}")
logger.warning(f"Could not extract image from {container_def_path}, using default")
return self.config.default_docker_image
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[ScoredDataItem], List[Item]]:
"""
Override to register per-task Docker image before running the agent.
Follows Terminal Bench 2 pattern: register_task_env_overrides() tells
the hermes-agent terminal backend to use a specific Docker image for
this task_id.
This is a copy of HermesAgentBaseEnv.collect_trajectory with Docker
image registration added after task_id generation.
"""
import uuid
from environments.agent_loop import HermesAgentLoop
task_id = str(uuid.uuid4())
task_name = item.get("task_name", "unknown")
docker_image = item.get("docker_image", self.config.default_docker_image)
# Register Docker image override for this task_id
register_task_env_overrides(task_id, {"modal_image": docker_image})
logger.info(
f"Task {task_name}: registered Docker image {docker_image} for task_id {task_id[:8]}"
)
try:
# Get group-level tools (resolved once in collect_trajectories)
if self._current_group_tools is None:
tools, valid_names = self._resolve_tools_for_group()
else:
tools, valid_names = self._current_group_tools
# Build initial messages
messages: List[Dict[str, Any]] = []
if self.config.system_prompt:
messages.append({"role": "system", "content": self.config.system_prompt})
messages.append({"role": "user", "content": self.format_prompt(item)})
# Run the agent loop
result: AgentResult
if self._use_managed_server():
# Phase 2: ManagedServer with parser
from environments.tool_call_parsers import get_parser
try:
tc_parser = get_parser(self.config.tool_call_parser)
except KeyError:
logger.warning(
"Tool call parser '%s' not found, falling back to 'hermes'",
self.config.tool_call_parser,
)
tc_parser = get_parser("hermes")
try:
async with self.server.managed_server(
tokenizer=self.tokenizer,
tool_call_parser=tc_parser,
) as managed:
agent = HermesAgentLoop(
server=managed,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
except NotImplementedError:
# DummyManagedServer not allowed
logger.warning("ManagedServer not available. Falling back to direct server mode.")
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
else:
# Phase 1: OpenAI server
agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=self.config.agent_temperature,
max_tokens=self.config.max_token_length,
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# Skip reward computation if agent produced no output
only_system_and_user = all(
msg.get("role") in ("system", "user") for msg in result.messages
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d). Skipping reward.",
result.turns_used,
)
reward = 0.0
else:
# Compute reward using ToolContext
ctx = ToolContext(task_id)
try:
reward = await self.compute_reward(item, result, ctx)
except Exception as e:
logger.error("compute_reward failed: %s", e)
reward = 0.0
finally:
ctx.cleanup()
# Track tool errors for wandb logging
if result.tool_errors:
for err in result.tool_errors:
self._tool_error_buffer.append({
"turn": err.turn,
"tool": err.tool_name,
"args": err.arguments[:150],
"error": err.error[:300],
"result": err.tool_result[:300],
})
# Build ScoredDataItem from ManagedServer state
# Phase 2: real tokens/masks/logprobs from SequenceNodes
# Phase 1: placeholder tokens
nodes = (result.managed_state or {}).get("nodes", [])
if nodes:
# Phase 2: use actual node data
node = nodes[-1]
scored_item: Dict[str, Any] = {
"tokens": node.tokens,
"masks": node.masked_tokens,
"scores": reward,
}
if hasattr(node, "logprobs") and node.logprobs:
scored_item["advantages"] = None
scored_item["ref_logprobs"] = None
else:
# Phase 1: create placeholder tokens
full_text = "\n".join(
msg.get("content", "") for msg in result.messages if msg.get("content")
)
if self.tokenizer:
tokens = self.tokenizer.encode(full_text, add_special_tokens=True)
else:
tokens = list(range(min(len(full_text) // 4, 128)))
scored_item = {
"tokens": tokens,
"masks": [-100] + tokens[1:],
"scores": reward,
}
# Include messages for wandb rollout display
scored_item["messages"] = result.messages
return scored_item, []
finally:
# Clean up task overrides and sandbox
clear_task_env_overrides(task_id)
try:
cleanup_vm(task_id)
except Exception as e:
logger.debug(f"VM cleanup for {task_id[:8]}: {e}")
async def compute_reward(
self,
item: Item,
@ -185,68 +410,91 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv[EndlessTerminalsEnvConfig]):
ctx: ToolContext
) -> float:
"""
Run final tests in container and return binary reward.
Run final tests in the agent's sandbox and return binary reward.
Uses ToolContext to execute pytest in the SAME sandbox the agent used,
following the Terminal Bench 2 verification pattern. No separate
Apptainer execution needed.
Returns 1.0 if tests pass, 0.0 otherwise.
"""
task_id = item.get("task_id", "unknown")
container_sif = Path(item.get("container_sif", ""))
final_test = Path(item.get("final_test", ""))
task_name = item.get("task_name", "unknown")
final_test_path = Path(item.get("final_test", ""))
if not container_sif.exists() or not final_test.exists():
print(f"[EndlessTerminalsEnv] ERROR: Missing test files for {task_id}", flush=True)
if not final_test_path.exists():
logger.error(f"Task {task_name}: test file not found at {final_test_path}")
return 0.0
print(f"[EndlessTerminalsEnv] Running tests for {task_id}...", flush=True)
logger.info(f"Task {task_name}: running tests in sandbox...")
try:
# Run final tests in container
success = await self._run_tests_in_container(container_sif, final_test)
score = 1.0 if success else 0.0
print(f"[EndlessTerminalsEnv] Task {task_id} score: {score}", flush=True)
return score
except Exception as e:
print(f"[EndlessTerminalsEnv] ERROR scoring {task_id}: {e}", flush=True)
return 0.0
async def _run_tests_in_container(
self,
container_sif: Path,
final_test_path: Path
) -> bool:
"""Run pytest in Apptainer container."""
loop = asyncio.get_event_loop()
try:
result = await loop.run_in_executor(
# Run tests in a thread to avoid blocking the event loop
loop = asyncio.get_event_loop()
reward = await loop.run_in_executor(
None,
lambda: subprocess.run(
[
"apptainer", "exec",
"--fakeroot",
"--userns",
"--writable-tmpfs",
"--cleanenv",
str(container_sif),
"pytest", "-q",
str(final_test_path.name),
],
capture_output=True,
text=True,
timeout=self.config.test_timeout_s,
cwd=str(final_test_path.parent),
)
self._run_tests_in_sandbox,
final_test_path,
ctx,
task_name,
)
return result.returncode == 0
except subprocess.TimeoutExpired:
print(f"[EndlessTerminalsEnv] Test timeout for {final_test_path}", flush=True)
return False
status = "PASS" if reward == 1.0 else "FAIL"
logger.info(f"Task {task_name}: {status} (reward={reward})")
return reward
except Exception as e:
print(f"[EndlessTerminalsEnv] Test execution error: {e}", flush=True)
return False
logger.error(f"Task {task_name}: test execution failed: {e}", exc_info=True)
return 0.0
def _run_tests_in_sandbox(
self,
test_file_path: Path,
ctx: ToolContext,
task_name: str,
) -> float:
"""
Upload test file to sandbox and execute pytest.
Runs in thread pool (via run_in_executor) to avoid blocking the event loop
with synchronous ToolContext calls.
Args:
test_file_path: Local path to test_final_state.py
ctx: ToolContext scoped to the agent's sandbox
task_name: For logging
Returns:
1.0 if tests pass, 0.0 otherwise
"""
try:
# Upload test file to sandbox
test_content = test_file_path.read_text()
ctx.write_file("/workspace/test_final_state.py", test_content)
logger.debug(f"Task {task_name}: uploaded test file to /workspace/test_final_state.py")
# Run pytest in the sandbox
result = ctx.terminal(
"cd /workspace && python -m pytest -q test_final_state.py",
timeout=self.config.test_timeout_s,
)
exit_code = result.get("exit_code", -1)
output = result.get("output", "")
if exit_code == 0:
logger.debug(f"Task {task_name}: tests passed")
return 1.0
else:
# Log failure output (last 500 chars for debugging)
output_preview = output[-500:] if output else "(no output)"
logger.info(
f"Task {task_name}: tests failed (exit_code={exit_code})\n{output_preview}"
)
return 0.0
except Exception as e:
logger.error(f"Task {task_name}: error running tests: {e}")
return 0.0
async def evaluate(self):
"""Periodic evaluation (optional)."""