mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-13 03:52:00 +00:00
Changing return type to be ScoredDataGroup to account for multiple trajectories
This commit is contained in:
parent
6fdb38ed29
commit
fe17b5ff08
1 changed files with 58 additions and 27 deletions
|
|
@ -404,6 +404,8 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||||
|
|
||||||
# Run the agent loop
|
# Run the agent loop
|
||||||
result: AgentResult
|
result: AgentResult
|
||||||
|
managed_state: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
if self._use_managed_server():
|
if self._use_managed_server():
|
||||||
# Phase 2: ManagedServer with parser
|
# Phase 2: ManagedServer with parser
|
||||||
from environments.tool_call_parsers import get_parser
|
from environments.tool_call_parsers import get_parser
|
||||||
|
|
@ -432,6 +434,9 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||||
extra_body=self.config.extra_body,
|
extra_body=self.config.extra_body,
|
||||||
)
|
)
|
||||||
result = await agent.run(messages)
|
result = await agent.run(messages)
|
||||||
|
|
||||||
|
# Get state directly from managed server while still in context
|
||||||
|
managed_state = managed.get_state()
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
# DummyManagedServer not allowed
|
# DummyManagedServer not allowed
|
||||||
logger.warning("ManagedServer not available. Falling back to direct server mode.")
|
logger.warning("ManagedServer not available. Falling back to direct server mode.")
|
||||||
|
|
@ -466,10 +471,11 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||||
)
|
)
|
||||||
if result.turns_used == 0 or only_system_and_user:
|
if result.turns_used == 0 or only_system_and_user:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Agent loop produced no output (turns=%d). Skipping reward.",
|
"Agent loop produced no output (turns=%d). Skipping trajectory.",
|
||||||
result.turns_used,
|
result.turns_used,
|
||||||
)
|
)
|
||||||
reward = 0.0
|
# Return None to skip this trajectory (likely an API failure)
|
||||||
|
return None, []
|
||||||
else:
|
else:
|
||||||
# Compute reward using ToolContext
|
# Compute reward using ToolContext
|
||||||
ctx = ToolContext(task_id)
|
ctx = ToolContext(task_id)
|
||||||
|
|
@ -504,26 +510,49 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||||
|
|
||||||
self._metrics_buffer.append(task_metrics)
|
self._metrics_buffer.append(task_metrics)
|
||||||
|
|
||||||
# Build ScoredDataItem from ManagedServer state
|
# ============================================================================
|
||||||
# Phase 2: real tokens/masks/logprobs from SequenceNodes
|
# Build ScoredDataGroup from ManagedServer state
|
||||||
# Phase 1: placeholder tokens
|
# ============================================================================
|
||||||
nodes = (result.managed_state or {}).get("nodes", [])
|
# Phase 2: Extract pre-computed data from SequenceNodes
|
||||||
|
# We may have multiple trajectories in the nodes due to how interesting
|
||||||
|
# agents can be, so iterate through all nodes and return multiple sequences.
|
||||||
|
#
|
||||||
|
# Each SequenceNode contains:
|
||||||
|
# - tokens: Full unmasked token sequence [1, 2, 3, ..., N]
|
||||||
|
# - masked_tokens: Training format [-100, -100, ..., -100, actual, actual, ...]
|
||||||
|
# - logprobs: Training format [1.0, 1.0, ..., 1.0, -0.5, -0.3, ...]
|
||||||
|
# - full_text: Complete text (prompt + all completions)
|
||||||
|
#
|
||||||
|
# Phase 1: Create placeholder tokens for OpenAI-style servers
|
||||||
|
# ============================================================================
|
||||||
|
nodes = (managed_state or {}).get("nodes", []) if managed_state else []
|
||||||
|
|
||||||
|
# Create ScoredDataGroup with lists for multiple trajectories
|
||||||
|
scored_group = ScoredDataGroup()
|
||||||
|
scored_group["tokens"] = []
|
||||||
|
scored_group["masks"] = []
|
||||||
|
scored_group["scores"] = []
|
||||||
|
scored_group["messages"] = []
|
||||||
|
scored_group["inference_logprobs"] = []
|
||||||
|
|
||||||
if nodes:
|
if nodes:
|
||||||
# Phase 2: use actual node data
|
# Phase 2: iterate through all nodes (may have multiple trajectories)
|
||||||
# nodes[-1] contains the full accumulated trajectory from all turns
|
for i, node in enumerate(nodes):
|
||||||
node = nodes[-1]
|
scored_group["tokens"].append(node.tokens)
|
||||||
scored_item: Dict[str, Any] = {
|
scored_group["masks"].append(node.masked_tokens)
|
||||||
"tokens": node.tokens,
|
scored_group["scores"].append(reward)
|
||||||
"masks": node.masked_tokens,
|
scored_group["messages"].append(result.messages)
|
||||||
"scores": reward,
|
|
||||||
}
|
if hasattr(node, "logprobs") and node.logprobs:
|
||||||
if hasattr(node, "logprobs") and node.logprobs:
|
scored_group["inference_logprobs"].append(node.logprobs)
|
||||||
scored_item["logprobs"] = node.logprobs
|
else:
|
||||||
scored_item["advantages"] = None
|
# Placeholder logprobs if not available
|
||||||
scored_item["ref_logprobs"] = None
|
scored_group["inference_logprobs"].append([1.0] * len(node.tokens))
|
||||||
|
|
||||||
|
logger.debug(f"Added trajectory {i+1}/{len(nodes)} with {len(node.tokens)} tokens")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Phase 1: create placeholder tokens
|
# Phase 1: create placeholder tokens for OpenAI-style servers
|
||||||
full_text = "\n".join(
|
full_text = "\n".join(
|
||||||
msg.get("content", "") for msg in result.messages if msg.get("content")
|
msg.get("content", "") for msg in result.messages if msg.get("content")
|
||||||
)
|
)
|
||||||
|
|
@ -532,16 +561,18 @@ class EndlessTerminalsEnv(HermesAgentBaseEnv):
|
||||||
else:
|
else:
|
||||||
tokens = list(range(min(len(full_text) // 4, 128)))
|
tokens = list(range(min(len(full_text) // 4, 128)))
|
||||||
|
|
||||||
scored_item = {
|
scored_group["tokens"].append(tokens)
|
||||||
"tokens": tokens,
|
scored_group["masks"].append([-100] + tokens[1:])
|
||||||
"masks": [-100] + tokens[1:],
|
scored_group["scores"].append(reward)
|
||||||
"scores": reward,
|
scored_group["messages"].append(result.messages)
|
||||||
}
|
scored_group["inference_logprobs"].append([1.0] * len(tokens))
|
||||||
|
|
||||||
# Include messages for wandb rollout display
|
# Return None if no trajectories collected
|
||||||
scored_item["messages"] = result.messages
|
if len(scored_group["tokens"]) == 0:
|
||||||
|
return None, []
|
||||||
|
|
||||||
return scored_item, []
|
logger.debug(f"Returning ScoredDataGroup with {len(scored_group['tokens'])} trajectories")
|
||||||
|
return scored_group, []
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# Clean up task overrides and sandbox
|
# Clean up task overrides and sandbox
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue