Eval splits for holdout sets

This commit is contained in:
Sam Herring 2026-03-03 14:42:45 -05:00
parent fe17b5ff08
commit dff5481e58
2 changed files with 62 additions and 22 deletions

View file

@ -44,7 +44,9 @@ env:
group_size: 8
total_steps: 10000
steps_per_eval: 500
num_eval_tasks: 10
eval_split_ratio: 0.1
# Logging
use_wandb: true

View file

@ -95,6 +95,10 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig):
default=10,
description="Number of tasks to run during periodic evaluation"
)
eval_split_ratio: float = Field(
default=0.1,
description="Fraction of dataset to hold out for evaluation (0.0-1.0)"
)
class EndlessTerminalsEnv(HermesAgentBaseEnv):
@ -144,6 +148,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._dataset = None
self._train_dataset = None
self._eval_dataset = None
self._dataset_indices = []
self._current_index = 0
@ -197,12 +203,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
for task_dir in task_dirs
]
# Create shuffled indices
self._dataset_indices = list(range(len(self._dataset)))
random.shuffle(self._dataset_indices)
self._current_index = 0
logger.info(f"Loaded {len(self._dataset)} tasks from local directory")
self._split_dataset()
return
# Otherwise, load from HuggingFace
@ -220,25 +223,54 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
)
)
# Create shuffled indices
self._dataset_indices = list(range(len(self._dataset)))
random.shuffle(self._dataset_indices)
self._current_index = 0
logger.info(f"Loaded {len(self._dataset)} tasks from HuggingFace")
self._split_dataset()
except Exception as e:
logger.error(f"ERROR loading dataset: {e}")
raise
def _split_dataset(self):
"""Split dataset into train and eval sets based on eval_split_ratio."""
if self._dataset is None or len(self._dataset) == 0:
raise RuntimeError("Cannot split empty dataset")
total_size = len(self._dataset)
eval_size = int(total_size * self.config.eval_split_ratio)
train_size = total_size - eval_size
all_indices = list(range(total_size))
random.shuffle(all_indices)
train_indices = all_indices[:train_size]
eval_indices = all_indices[train_size:]
if isinstance(self._dataset, list):
self._train_dataset = [self._dataset[i] for i in train_indices]
self._eval_dataset = [self._dataset[i] for i in eval_indices]
else:
self._train_dataset = self._dataset.select(train_indices)
self._eval_dataset = self._dataset.select(eval_indices)
self._dataset_indices = list(range(len(self._train_dataset)))
random.shuffle(self._dataset_indices)
self._current_index = 0
logger.info(
f"Split dataset: {len(self._train_dataset)} train, "
f"{len(self._eval_dataset)} eval "
f"(ratio={self.config.eval_split_ratio:.1%})"
)
async def get_next_item(self) -> Item:
"""Sample next task from dataset."""
if self._dataset is None:
"""Sample next task from training dataset."""
if self._train_dataset is None:
raise RuntimeError("Dataset not loaded. Call setup() first.")
# Get next task (with wraparound)
idx = self._dataset_indices[self._current_index]
task = self._dataset[idx]
task = self._train_dataset[idx]
# Advance to next task
self._current_index += 1
@ -677,16 +709,22 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
async def evaluate(self):
"""
Periodic evaluation on a fixed set of tasks.
Periodic evaluation on holdout eval set.
Runs the agent on num_eval_tasks tasks and measures performance
without affecting training. Returns metrics for wandb logging.
Runs the agent on num_eval_tasks from the held-out eval set
(never seen during training). Returns metrics for wandb logging.
"""
if self._dataset is None:
logger.warning("Cannot evaluate: dataset not loaded")
if self._eval_dataset is None:
logger.warning("Cannot evaluate: eval dataset not loaded")
return {}
logger.info(f"Starting evaluation on {self.config.num_eval_tasks} tasks...")
if len(self._eval_dataset) == 0:
logger.warning("Eval dataset is empty")
return {}
# Use min of num_eval_tasks and actual eval set size
num_tasks = min(self.config.num_eval_tasks, len(self._eval_dataset))
logger.info(f"Starting evaluation on {num_tasks} held-out tasks...")
eval_metrics = {
"rewards": [],
@ -695,12 +733,12 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
"natural_finishes": [],
}
# Sample eval tasks randomly
# Sample from eval set (holdout)
import random
eval_indices = random.sample(range(len(self._dataset)), min(self.config.num_eval_tasks, len(self._dataset)))
eval_indices = random.sample(range(len(self._eval_dataset)), num_tasks)
for idx in eval_indices:
task = self._dataset[idx]
task = self._eval_dataset[idx]
# Build item using same logic as get_next_item
task_dir = task.get("extra_info", {}).get("task_dir")