Changing return type to be ScoredDataGroup to account for multiple trajectories

This commit is contained in:
Sam Herring 2026-03-02 11:35:06 -08:00
parent 6fdb38ed29
commit fe17b5ff08

View file

@ -404,6 +404,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
# Run the agent loop
result: AgentResult
managed_state: Optional[Dict[str, Any]] = None
if self._use_managed_server():
# Phase 2: ManagedServer with parser
from environments.tool_call_parsers import get_parser
@ -432,6 +434,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
extra_body=self.config.extra_body,
)
result = await agent.run(messages)
# Get state directly from managed server while still in context
managed_state = managed.get_state()
except NotImplementedError:
# DummyManagedServer not allowed
logger.warning("ManagedServer not available. Falling back to direct server mode.")
@ -466,10 +471,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
)
if result.turns_used == 0 or only_system_and_user:
logger.warning(
"Agent loop produced no output (turns=%d). Skipping reward.",
"Agent loop produced no output (turns=%d). Skipping trajectory.",
result.turns_used,
)
reward = 0.0
# Return None to skip this trajectory (likely an API failure)
return None, []
else:
# Compute reward using ToolContext
ctx = ToolContext(task_id)
@ -504,26 +510,49 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
self._metrics_buffer.append(task_metrics)
# Build ScoredDataItem from ManagedServer state
# Phase 2: real tokens/masks/logprobs from SequenceNodes
# Phase 1: placeholder tokens
nodes = (result.managed_state or {}).get("nodes", [])
# ============================================================================
# Build ScoredDataGroup from ManagedServer state
# ============================================================================
# Phase 2: Extract pre-computed data from SequenceNodes
# We may have multiple trajectories in the nodes due to how interesting
# agents can be, so iterate through all nodes and return multiple sequences.
#
# Each SequenceNode contains:
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
# - full_text: Complete text (prompt + all completions)
#
# Phase 1: Create placeholder tokens for OpenAI-style servers
# ============================================================================
nodes = (managed_state or {}).get("nodes", []) if managed_state else []
# Create ScoredDataGroup with lists for multiple trajectories
scored_group = ScoredDataGroup()
scored_group["tokens"] = []
scored_group["masks"] = []
scored_group["scores"] = []
scored_group["messages"] = []
scored_group["inference_logprobs"] = []
if nodes:
# Phase 2: use actual node data
# nodes[-1] contains the full accumulated trajectory from all turns
node = nodes[-1]
scored_item: Dict[str, Any] = {
"tokens": node.tokens,
"masks": node.masked_tokens,
"scores": reward,
}
if hasattr(node, "logprobs") and node.logprobs:
scored_item["logprobs"] = node.logprobs
scored_item["advantages"] = None
scored_item["ref_logprobs"] = None
# Phase 2: iterate through all nodes (may have multiple trajectories)
for i, node in enumerate(nodes):
scored_group["tokens"].append(node.tokens)
scored_group["masks"].append(node.masked_tokens)
scored_group["scores"].append(reward)
scored_group["messages"].append(result.messages)
if hasattr(node, "logprobs") and node.logprobs:
scored_group["inference_logprobs"].append(node.logprobs)
else:
# Placeholder logprobs if not available
scored_group["inference_logprobs"].append([1.0] * len(node.tokens))
logger.debug(f"Added trajectory {i+1}/{len(nodes)} with {len(node.tokens)} tokens")
else:
# Phase 1: create placeholder tokens
# Phase 1: create placeholder tokens for OpenAI-style servers
full_text = "\n".join(
msg.get("content", "") for msg in result.messages if msg.get("content")
)
@ -532,16 +561,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
else:
tokens = list(range(min(len(full_text) // 4, 128)))
scored_item = {
"tokens": tokens,
"masks": [-100] + tokens[1:],
"scores": reward,
}
scored_group["tokens"].append(tokens)
scored_group["masks"].append([-100] + tokens[1:])
scored_group["scores"].append(reward)
scored_group["messages"].append(result.messages)
scored_group["inference_logprobs"].append([1.0] * len(tokens))
# Include messages for wandb rollout display
scored_item["messages"] = result.messages
# Return None if no trajectories collected
if len(scored_group["tokens"]) == 0:
return None, []
return scored_item, []
logger.debug(f"Returning ScoredDataGroup with {len(scored_group['tokens'])} trajectories")
return scored_group, []
finally:
# Clean up task overrides and sandbox