mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-15 04:12:25 +00:00
logprobs
This commit is contained in:
parent
bf13a848ef
commit
661d8f4d6c
1 changed files with 11 additions and 2 deletions
|
|
@ -373,6 +373,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
|||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||
"scores": 0.0,
|
||||
}
|
||||
if result.trajectory_data is not None:
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if self.config.include_messages:
|
||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||
import json
|
||||
|
|
@ -423,6 +425,9 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
|||
"masks": result.trajectory_data.masked_tokens,
|
||||
"scores": score,
|
||||
}
|
||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||
if self.config.include_messages:
|
||||
scored["messages"] = messages
|
||||
|
||||
|
|
@ -448,8 +453,6 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
|||
|
||||
if len(items) != self.config.group_size:
|
||||
return None, backlog
|
||||
|
||||
# TODO: Mack sure logprobs included
|
||||
|
||||
group: ScoredDataGroup = ScoredDataGroup(
|
||||
tokens=[],
|
||||
|
|
@ -458,15 +461,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
|||
advantages=[],
|
||||
ref_logprobs=[],
|
||||
messages=[] if self.config.include_messages else None,
|
||||
inference_logprobs=[],
|
||||
group_overrides={},
|
||||
overrides=[],
|
||||
images=[],
|
||||
generation_params=None,
|
||||
)
|
||||
|
||||
for it in items:
|
||||
group["tokens"].append(it["tokens"])
|
||||
group["masks"].append(it["masks"])
|
||||
group["scores"].append(it["scores"])
|
||||
# policy logprobs (for PPO/GRPO training) if present
|
||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||
if lp is not None:
|
||||
group["inference_logprobs"].append(lp)
|
||||
if group.get("messages") is not None and it.get("messages") is not None:
|
||||
group["messages"].append(it["messages"])
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue