Changing return type to be ScoredDataGroup to account for multiple trajectories

This commit is contained in:
Sam Herring 2026-03-02 11:35:06 -08:00
parent 6fdb38ed29
commit fe17b5ff08

View file

@ -404,6 +404,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
# Run the agent loop # Run the agent loop
result: AgentResult result: AgentResult
managed_state: Optional[Dict[str, Any]] = None
if self._use_managed_server(): if self._use_managed_server():
# Phase 2: ManagedServer with parser # Phase 2: ManagedServer with parser
from environments.tool_call_parsers import get_parser from environments.tool_call_parsers import get_parser
@ -432,6 +434,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
extra_body=self.config.extra_body, extra_body=self.config.extra_body,
) )
result = await agent.run(messages) result = await agent.run(messages)
# Get state directly from managed server while still in context
managed_state = managed.get_state()
except NotImplementedError: except NotImplementedError:
# DummyManagedServer not allowed # DummyManagedServer not allowed
logger.warning("ManagedServer not available. Falling back to direct server mode.") logger.warning("ManagedServer not available. Falling back to direct server mode.")
@ -466,10 +471,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
) )
if result.turns_used == 0 or only_system_and_user: if result.turns_used == 0 or only_system_and_user:
logger.warning( logger.warning(
"Agent loop produced no output (turns=%d). Skipping reward.", "Agent loop produced no output (turns=%d). Skipping trajectory.",
result.turns_used, result.turns_used,
) )
reward = 0.0 # Return None to skip this trajectory (likely an API failure)
return None, []
else: else:
# Compute reward using ToolContext # Compute reward using ToolContext
ctx = ToolContext(task_id) ctx = ToolContext(task_id)
@ -504,26 +510,49 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
self._metrics_buffer.append(task_metrics) self._metrics_buffer.append(task_metrics)
# Build ScoredDataItem from ManagedServer state # ============================================================================
# Phase 2: real tokens/masks/logprobs from SequenceNodes # Build ScoredDataGroup from ManagedServer state
# Phase 1: placeholder tokens # ============================================================================
nodes = (result.managed_state or {}).get("nodes", []) # Phase 2: Extract pre-computed data from SequenceNodes
# We may have multiple trajectories in the nodes due to how interesting
# agents can be, so iterate through all nodes and return multiple sequences.
#
# Each SequenceNode contains:
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
# - full_text: Complete text (prompt + all completions)
#
# Phase 1: Create placeholder tokens for OpenAI-style servers
# ============================================================================
nodes = (managed_state or {}).get("nodes", []) if managed_state else []
# Create ScoredDataGroup with lists for multiple trajectories
scored_group = ScoredDataGroup()
scored_group["tokens"] = []
scored_group["masks"] = []
scored_group["scores"] = []
scored_group["messages"] = []
scored_group["inference_logprobs"] = []
if nodes: if nodes:
# Phase 2: use actual node data # Phase 2: iterate through all nodes (may have multiple trajectories)
# nodes[-1] contains the full accumulated trajectory from all turns for i, node in enumerate(nodes):
node = nodes[-1] scored_group["tokens"].append(node.tokens)
scored_item: Dict[str, Any] = { scored_group["masks"].append(node.masked_tokens)
"tokens": node.tokens, scored_group["scores"].append(reward)
"masks": node.masked_tokens, scored_group["messages"].append(result.messages)
"scores": reward,
} if hasattr(node, "logprobs") and node.logprobs:
if hasattr(node, "logprobs") and node.logprobs: scored_group["inference_logprobs"].append(node.logprobs)
scored_item["logprobs"] = node.logprobs else:
scored_item["advantages"] = None # Placeholder logprobs if not available
scored_item["ref_logprobs"] = None scored_group["inference_logprobs"].append([1.0] * len(node.tokens))
logger.debug(f"Added trajectory {i+1}/{len(nodes)} with {len(node.tokens)} tokens")
else: else:
# Phase 1: create placeholder tokens # Phase 1: create placeholder tokens for OpenAI-style servers
full_text = "\n".join( full_text = "\n".join(
msg.get("content", "") for msg in result.messages if msg.get("content") msg.get("content", "") for msg in result.messages if msg.get("content")
) )
@ -532,16 +561,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
else: else:
tokens = list(range(min(len(full_text) // 4, 128))) tokens = list(range(min(len(full_text) // 4, 128)))
scored_item = { scored_group["tokens"].append(tokens)
"tokens": tokens, scored_group["masks"].append([-100] + tokens[1:])
"masks": [-100] + tokens[1:], scored_group["scores"].append(reward)
"scores": reward, scored_group["messages"].append(result.messages)
} scored_group["inference_logprobs"].append([1.0] * len(tokens))
# Include messages for wandb rollout display # Return None if no trajectories collected
scored_item["messages"] = result.messages if len(scored_group["tokens"]) == 0:
return None, []
return scored_item, [] logger.debug(f"Returning ScoredDataGroup with {len(scored_group['tokens'])} trajectories")
return scored_group, []
finally: finally:
# Clean up task overrides and sandbox # Clean up task overrides and sandbox