diff --git a/environments/endless_terminals/default.yaml b/environments/endless_terminals/default.yaml index 1932e212fe..4ae8da550c 100644 --- a/environments/endless_terminals/default.yaml +++ b/environments/endless_terminals/default.yaml @@ -44,7 +44,9 @@ env: group_size: 8 total_steps: 10000 steps_per_eval: 500 + num_eval_tasks: 10 + eval_split_ratio: 0.1 # Logging use_wandb: true diff --git a/environments/endless_terminals/endless_terminals_env.py b/environments/endless_terminals/endless_terminals_env.py index 02c6b44e05..ac8bfbc801 100644 --- a/environments/endless_terminals/endless_terminals_env.py +++ b/environments/endless_terminals/endless_terminals_env.py @@ -95,6 +95,10 @@ class EndlessTerminalsEnvConfig(HermesAgentEnvConfig): default=10, description="Number of tasks to run during periodic evaluation" ) + eval_split_ratio: float = Field( + default=0.1, + description="Fraction of dataset to hold out for evaluation (0.0-1.0)" + ) class EndlessTerminalsEnv(HermesAgentBaseEnv): @@ -144,6 +148,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._dataset = None + self._train_dataset = None + self._eval_dataset = None self._dataset_indices = [] self._current_index = 0 @@ -197,12 +203,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): for task_dir in task_dirs ] - # Create shuffled indices - self._dataset_indices = list(range(len(self._dataset))) - random.shuffle(self._dataset_indices) - self._current_index = 0 - logger.info(f"Loaded {len(self._dataset)} tasks from local directory") + + self._split_dataset() return # Otherwise, load from HuggingFace @@ -220,25 +223,54 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): ) ) - # Create shuffled indices - self._dataset_indices = list(range(len(self._dataset))) - random.shuffle(self._dataset_indices) - self._current_index = 0 - logger.info(f"Loaded {len(self._dataset)} tasks from HuggingFace") + self._split_dataset() + except Exception as e: logger.error(f"ERROR loading dataset: {e}") raise + def _split_dataset(self): + """Split dataset into train and eval sets based on eval_split_ratio.""" + if self._dataset is None or len(self._dataset) == 0: + raise RuntimeError("Cannot split empty dataset") + + total_size = len(self._dataset) + eval_size = int(total_size * self.config.eval_split_ratio) + train_size = total_size - eval_size + + all_indices = list(range(total_size)) + random.shuffle(all_indices) + + train_indices = all_indices[:train_size] + eval_indices = all_indices[train_size:] + + if isinstance(self._dataset, list): + self._train_dataset = [self._dataset[i] for i in train_indices] + self._eval_dataset = [self._dataset[i] for i in eval_indices] + else: + self._train_dataset = self._dataset.select(train_indices) + self._eval_dataset = self._dataset.select(eval_indices) + + self._dataset_indices = list(range(len(self._train_dataset))) + random.shuffle(self._dataset_indices) + self._current_index = 0 + + logger.info( + f"Split dataset: {len(self._train_dataset)} train, " + f"{len(self._eval_dataset)} eval " + f"(ratio={self.config.eval_split_ratio:.1%})" + ) + async def get_next_item(self) -> Item: - """Sample next task from dataset.""" - if self._dataset is None: + """Sample next task from training dataset.""" + if self._train_dataset is None: raise RuntimeError("Dataset not loaded. Call setup() first.") # Get next task (with wraparound) idx = self._dataset_indices[self._current_index] - task = self._dataset[idx] + task = self._train_dataset[idx] # Advance to next task self._current_index += 1 @@ -677,16 +709,22 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): async def evaluate(self): """ - Periodic evaluation on a fixed set of tasks. + Periodic evaluation on holdout eval set. - Runs the agent on num_eval_tasks tasks and measures performance - without affecting training. Returns metrics for wandb logging. + Runs the agent on num_eval_tasks from the held-out eval set + (never seen during training). Returns metrics for wandb logging. """ - if self._dataset is None: - logger.warning("Cannot evaluate: dataset not loaded") + if self._eval_dataset is None: + logger.warning("Cannot evaluate: eval dataset not loaded") return {} - logger.info(f"Starting evaluation on {self.config.num_eval_tasks} tasks...") + if len(self._eval_dataset) == 0: + logger.warning("Eval dataset is empty") + return {} + + # Use min of num_eval_tasks and actual eval set size + num_tasks = min(self.config.num_eval_tasks, len(self._eval_dataset)) + logger.info(f"Starting evaluation on {num_tasks} held-out tasks...") eval_metrics = { "rewards": [], @@ -695,12 +733,12 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): "natural_finishes": [], } - # Sample eval tasks randomly + # Sample from eval set (holdout) import random - eval_indices = random.sample(range(len(self._dataset)), min(self.config.num_eval_tasks, len(self._dataset))) + eval_indices = random.sample(range(len(self._eval_dataset)), num_tasks) for idx in eval_indices: - task = self._dataset[idx] + task = self._eval_dataset[idx] # Build item using same logic as get_next_item task_dir = task.get("extra_info", {}).get("task_dir")