mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Eval splits for holdout sets
This commit is contained in:
parent
fe17b5ff08
commit
dff5481e58
2 changed files with 62 additions and 22 deletions
|
|
@ -44,7 +44,9 @@ env:
|
|||
group_size: 8
|
||||
total_steps: 10000
|
||||
steps_per_eval: 500
|
||||
|
||||
num_eval_tasks: 10
|
||||
eval_split_ratio: 0.1
|
||||
|
||||
# Logging
|
||||
use_wandb: true
|
||||
|
|
|
|||
|
|
@ -95,6 +95,10 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig):
|
|||
default=10,
|
||||
description="Number of tasks to run during periodic evaluation"
|
||||
)
|
||||
eval_split_ratio: float = Field(
|
||||
default=0.1,
|
||||
description="Fraction of dataset to hold out for evaluation (0.0-1.0)"
|
||||
)
|
||||
|
||||
|
||||
class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||
|
|
@ -144,6 +148,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._dataset = None
|
||||
self._train_dataset = None
|
||||
self._eval_dataset = None
|
||||
self._dataset_indices = []
|
||||
self._current_index = 0
|
||||
|
||||
|
|
@ -197,12 +203,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
for task_dir in task_dirs
|
||||
]
|
||||
|
||||
# Create shuffled indices
|
||||
self._dataset_indices = list(range(len(self._dataset)))
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
|
||||
logger.info(f"Loaded {len(self._dataset)} tasks from local directory")
|
||||
|
||||
self._split_dataset()
|
||||
return
|
||||
|
||||
# Otherwise, load from HuggingFace
|
||||
|
|
@ -220,25 +223,54 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
)
|
||||
)
|
||||
|
||||
# Create shuffled indices
|
||||
self._dataset_indices = list(range(len(self._dataset)))
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
|
||||
logger.info(f"Loaded {len(self._dataset)} tasks from HuggingFace")
|
||||
|
||||
self._split_dataset()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ERROR loading dataset: {e}")
|
||||
raise
|
||||
|
||||
def _split_dataset(self):
|
||||
"""Split dataset into train and eval sets based on eval_split_ratio."""
|
||||
if self._dataset is None or len(self._dataset) == 0:
|
||||
raise RuntimeError("Cannot split empty dataset")
|
||||
|
||||
total_size = len(self._dataset)
|
||||
eval_size = int(total_size * self.config.eval_split_ratio)
|
||||
train_size = total_size - eval_size
|
||||
|
||||
all_indices = list(range(total_size))
|
||||
random.shuffle(all_indices)
|
||||
|
||||
train_indices = all_indices[:train_size]
|
||||
eval_indices = all_indices[train_size:]
|
||||
|
||||
if isinstance(self._dataset, list):
|
||||
self._train_dataset = [self._dataset[i] for i in train_indices]
|
||||
self._eval_dataset = [self._dataset[i] for i in eval_indices]
|
||||
else:
|
||||
self._train_dataset = self._dataset.select(train_indices)
|
||||
self._eval_dataset = self._dataset.select(eval_indices)
|
||||
|
||||
self._dataset_indices = list(range(len(self._train_dataset)))
|
||||
random.shuffle(self._dataset_indices)
|
||||
self._current_index = 0
|
||||
|
||||
logger.info(
|
||||
f"Split dataset: {len(self._train_dataset)} train, "
|
||||
f"{len(self._eval_dataset)} eval "
|
||||
f"(ratio={self.config.eval_split_ratio:.1%})"
|
||||
)
|
||||
|
||||
async def get_next_item(self) -> Item:
|
||||
"""Sample next task from dataset."""
|
||||
if self._dataset is None:
|
||||
"""Sample next task from training dataset."""
|
||||
if self._train_dataset is None:
|
||||
raise RuntimeError("Dataset not loaded. Call setup() first.")
|
||||
|
||||
# Get next task (with wraparound)
|
||||
idx = self._dataset_indices[self._current_index]
|
||||
task = self._dataset[idx]
|
||||
task = self._train_dataset[idx]
|
||||
|
||||
# Advance to next task
|
||||
self._current_index += 1
|
||||
|
|
@ -677,16 +709,22 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
|
||||
async def evaluate(self):
|
||||
"""
|
||||
Periodic evaluation on a fixed set of tasks.
|
||||
Periodic evaluation on holdout eval set.
|
||||
|
||||
Runs the agent on num_eval_tasks tasks and measures performance
|
||||
without affecting training. Returns metrics for wandb logging.
|
||||
Runs the agent on num_eval_tasks from the held-out eval set
|
||||
(never seen during training). Returns metrics for wandb logging.
|
||||
"""
|
||||
if self._dataset is None:
|
||||
logger.warning("Cannot evaluate: dataset not loaded")
|
||||
if self._eval_dataset is None:
|
||||
logger.warning("Cannot evaluate: eval dataset not loaded")
|
||||
return {}
|
||||
|
||||
logger.info(f"Starting evaluation on {self.config.num_eval_tasks} tasks...")
|
||||
if len(self._eval_dataset) == 0:
|
||||
logger.warning("Eval dataset is empty")
|
||||
return {}
|
||||
|
||||
# Use min of num_eval_tasks and actual eval set size
|
||||
num_tasks = min(self.config.num_eval_tasks, len(self._eval_dataset))
|
||||
logger.info(f"Starting evaluation on {num_tasks} held-out tasks...")
|
||||
|
||||
eval_metrics = {
|
||||
"rewards": [],
|
||||
|
|
@ -695,12 +733,12 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
"natural_finishes": [],
|
||||
}
|
||||
|
||||
# Sample eval tasks randomly
|
||||
# Sample from eval set (holdout)
|
||||
import random
|
||||
eval_indices = random.sample(range(len(self._dataset)), min(self.config.num_eval_tasks, len(self._dataset)))
|
||||
eval_indices = random.sample(range(len(self._eval_dataset)), num_tasks)
|
||||
|
||||
for idx in eval_indices:
|
||||
task = self._dataset[idx]
|
||||
task = self._eval_dataset[idx]
|
||||
|
||||
# Build item using same logic as get_next_item
|
||||
task_dir = task.get("extra_info", {}).get("task_dir")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue