mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-10 03:22:05 +00:00
Changing return type to be ScoredDataGroup to account for multiple trajectories
This commit is contained in:
parent
6fdb38ed29
commit
fe17b5ff08
1 changed files with 58 additions and 27 deletions
|
|
@ -404,6 +404,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
|
||||
# Run the agent loop
|
||||
result: AgentResult
|
||||
managed_state: Optional[Dict[str, Any]] = None
|
||||
|
||||
if self._use_managed_server():
|
||||
# Phase 2: ManagedServer with parser
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
|
@ -432,6 +434,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
extra_body=self.config.extra_body,
|
||||
)
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Get state directly from managed server while still in context
|
||||
managed_state = managed.get_state()
|
||||
except NotImplementedError:
|
||||
# DummyManagedServer not allowed
|
||||
logger.warning("ManagedServer not available. Falling back to direct server mode.")
|
||||
|
|
@ -466,10 +471,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
)
|
||||
if result.turns_used == 0 or only_system_and_user:
|
||||
logger.warning(
|
||||
"Agent loop produced no output (turns=%d). Skipping reward.",
|
||||
"Agent loop produced no output (turns=%d). Skipping trajectory.",
|
||||
result.turns_used,
|
||||
)
|
||||
reward = 0.0
|
||||
# Return None to skip this trajectory (likely an API failure)
|
||||
return None, []
|
||||
else:
|
||||
# Compute reward using ToolContext
|
||||
ctx = ToolContext(task_id)
|
||||
|
|
@ -504,26 +510,49 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
|
||||
self._metrics_buffer.append(task_metrics)
|
||||
|
||||
# Build ScoredDataItem from ManagedServer state
|
||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
||||
# Phase 1: placeholder tokens
|
||||
nodes = (result.managed_state or {}).get("nodes", [])
|
||||
# ============================================================================
|
||||
# Build ScoredDataGroup from ManagedServer state
|
||||
# ============================================================================
|
||||
# Phase 2: Extract pre-computed data from SequenceNodes
|
||||
# We may have multiple trajectories in the nodes due to how interesting
|
||||
# agents can be, so iterate through all nodes and return multiple sequences.
|
||||
#
|
||||
# Each SequenceNode contains:
|
||||
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
|
||||
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
|
||||
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
|
||||
# - full_text: Complete text (prompt + all completions)
|
||||
#
|
||||
# Phase 1: Create placeholder tokens for OpenAI-style servers
|
||||
# ============================================================================
|
||||
nodes = (managed_state or {}).get("nodes", []) if managed_state else []
|
||||
|
||||
# Create ScoredDataGroup with lists for multiple trajectories
|
||||
scored_group = ScoredDataGroup()
|
||||
scored_group["tokens"] = []
|
||||
scored_group["masks"] = []
|
||||
scored_group["scores"] = []
|
||||
scored_group["messages"] = []
|
||||
scored_group["inference_logprobs"] = []
|
||||
|
||||
if nodes:
|
||||
# Phase 2: use actual node data
|
||||
# nodes[-1] contains the full accumulated trajectory from all turns
|
||||
node = nodes[-1]
|
||||
scored_item: Dict[str, Any] = {
|
||||
"tokens": node.tokens,
|
||||
"masks": node.masked_tokens,
|
||||
"scores": reward,
|
||||
}
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_item["logprobs"] = node.logprobs
|
||||
scored_item["advantages"] = None
|
||||
scored_item["ref_logprobs"] = None
|
||||
# Phase 2: iterate through all nodes (may have multiple trajectories)
|
||||
for i, node in enumerate(nodes):
|
||||
scored_group["tokens"].append(node.tokens)
|
||||
scored_group["masks"].append(node.masked_tokens)
|
||||
scored_group["scores"].append(reward)
|
||||
scored_group["messages"].append(result.messages)
|
||||
|
||||
if hasattr(node, "logprobs") and node.logprobs:
|
||||
scored_group["inference_logprobs"].append(node.logprobs)
|
||||
else:
|
||||
# Placeholder logprobs if not available
|
||||
scored_group["inference_logprobs"].append([1.0] * len(node.tokens))
|
||||
|
||||
logger.debug(f"Added trajectory {i+1}/{len(nodes)} with {len(node.tokens)} tokens")
|
||||
|
||||
else:
|
||||
# Phase 1: create placeholder tokens
|
||||
# Phase 1: create placeholder tokens for OpenAI-style servers
|
||||
full_text = "\n".join(
|
||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||
)
|
||||
|
|
@ -532,16 +561,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
|||
else:
|
||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||
|
||||
scored_item = {
|
||||
"tokens": tokens,
|
||||
"masks": [-100] + tokens[1:],
|
||||
"scores": reward,
|
||||
}
|
||||
scored_group["tokens"].append(tokens)
|
||||
scored_group["masks"].append([-100] + tokens[1:])
|
||||
scored_group["scores"].append(reward)
|
||||
scored_group["messages"].append(result.messages)
|
||||
scored_group["inference_logprobs"].append([1.0] * len(tokens))
|
||||
|
||||
# Include messages for wandb rollout display
|
||||
scored_item["messages"] = result.messages
|
||||
# Return None if no trajectories collected
|
||||
if len(scored_group["tokens"]) == 0:
|
||||
return None, []
|
||||
|
||||
return scored_item, []
|
||||
logger.debug(f"Returning ScoredDataGroup with {len(scored_group['tokens'])} trajectories")
|
||||
return scored_group, []
|
||||
|
||||
finally:
|
||||
# Clean up task overrides and sandbox
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue