From fe17b5ff080d26c60e07298a17f0397e4bede6ac Mon Sep 17 00:00:00 2001 From: Sam Herring Date: Mon, 2 Mar 2026 11:35:06 -0800 Subject: [PATCH] Changing return type to be ScoredDataGroup to account for multiple trajectories --- .../endless_terminals_env.py | 85 +++++++++++++------ 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/environments/endless_terminals/endless_terminals_env.py b/environments/endless_terminals/endless_terminals_env.py index d85ef9cd78..02c6b44e05 100644 --- a/environments/endless_terminals/endless_terminals_env.py +++ b/environments/endless_terminals/endless_terminals_env.py @@ -404,6 +404,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): # Run the agent loop result: AgentResult + managed_state: Optional[Dict[str, Any]] = None + if self._use_managed_server(): # Phase 2: ManagedServer with parser from environments.tool_call_parsers import get_parser @@ -432,6 +434,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): extra_body=self.config.extra_body, ) result = await agent.run(messages) + + # Get state directly from managed server while still in context + managed_state = managed.get_state() except NotImplementedError: # DummyManagedServer not allowed logger.warning("ManagedServer not available. Falling back to direct server mode.") @@ -466,10 +471,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): ) if result.turns_used == 0 or only_system_and_user: logger.warning( - "Agent loop produced no output (turns=%d). Skipping reward.", + "Agent loop produced no output (turns=%d). Skipping trajectory.", result.turns_used, ) - reward = 0.0 + # Return None to skip this trajectory (likely an API failure) + return None, [] else: # Compute reward using ToolContext ctx = ToolContext(task_id) @@ -504,26 +510,49 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): self._metrics_buffer.append(task_metrics) - # Build ScoredDataItem from ManagedServer state - # Phase 2: real tokens/masks/logprobs from SequenceNodes - # Phase 1: placeholder tokens - nodes = (result.managed_state or {}).get("nodes", []) + # ============================================================================ + # Build ScoredDataGroup from ManagedServer state + # ============================================================================ + # Phase 2: Extract pre-computed data from SequenceNodes + # We may have multiple trajectories in the nodes due to how interesting + # agents can be, so iterate through all nodes and return multiple sequences. + # + # Each SequenceNode contains: + # - tokens: Full unmasked token sequence [1, 2, 3, ..., N] + # - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...] + # - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...] + # - full_text: Complete text (prompt + all completions) + # + # Phase 1: Create placeholder tokens for OpenAI-style servers + # ============================================================================ + nodes = (managed_state or {}).get("nodes", []) if managed_state else [] + + # Create ScoredDataGroup with lists for multiple trajectories + scored_group = ScoredDataGroup() + scored_group["tokens"] = [] + scored_group["masks"] = [] + scored_group["scores"] = [] + scored_group["messages"] = [] + scored_group["inference_logprobs"] = [] if nodes: - # Phase 2: use actual node data - # nodes[-1] contains the full accumulated trajectory from all turns - node = nodes[-1] - scored_item: Dict[str, Any] = { - "tokens": node.tokens, - "masks": node.masked_tokens, - "scores": reward, - } - if hasattr(node, "logprobs") and node.logprobs: - scored_item["logprobs"] = node.logprobs - scored_item["advantages"] = None - scored_item["ref_logprobs"] = None + # Phase 2: iterate through all nodes (may have multiple trajectories) + for i, node in enumerate(nodes): + scored_group["tokens"].append(node.tokens) + scored_group["masks"].append(node.masked_tokens) + scored_group["scores"].append(reward) + scored_group["messages"].append(result.messages) + + if hasattr(node, "logprobs") and node.logprobs: + scored_group["inference_logprobs"].append(node.logprobs) + else: + # Placeholder logprobs if not available + scored_group["inference_logprobs"].append([1.0] * len(node.tokens)) + + logger.debug(f"Added trajectory {i+1}/{len(nodes)} with {len(node.tokens)} tokens") + else: - # Phase 1: create placeholder tokens + # Phase 1: create placeholder tokens for OpenAI-style servers full_text = "\n".join( msg.get("content", "") for msg in result.messages if msg.get("content") ) @@ -532,16 +561,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv): else: tokens = list(range(min(len(full_text) // 4, 128))) - scored_item = { - "tokens": tokens, - "masks": [-100] + tokens[1:], - "scores": reward, - } + scored_group["tokens"].append(tokens) + scored_group["masks"].append([-100] + tokens[1:]) + scored_group["scores"].append(reward) + scored_group["messages"].append(result.messages) + scored_group["inference_logprobs"].append([1.0] * len(tokens)) - # Include messages for wandb rollout display - scored_item["messages"] = result.messages + # Return None if no trajectories collected + if len(scored_group["tokens"]) == 0: + return None, [] - return scored_item, [] + logger.debug(f"Returning ScoredDataGroup with {len(scored_group['tokens'])} trajectories") + return scored_group, [] finally: # Clean up task overrides and sandbox