This commit is contained in:
Shannon Sands 2026-02-05 11:42:58 +10:00
parent bf13a848ef
commit 661d8f4d6c

View file

@ -373,6 +373,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []), "masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
"scores": 0.0, "scores": 0.0,
} }
if result.trajectory_data is not None:
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if self.config.include_messages: if self.config.include_messages:
# Record a final failure marker as a user-side tool_response-like block so it survives templates. # Record a final failure marker as a user-side tool_response-like block so it survives templates.
import json import json
@ -423,6 +425,9 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
"masks": result.trajectory_data.masked_tokens, "masks": result.trajectory_data.masked_tokens,
"scores": score, "scores": score,
} }
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if self.config.include_messages: if self.config.include_messages:
scored["messages"] = messages scored["messages"] = messages
@ -448,8 +453,6 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
if len(items) != self.config.group_size: if len(items) != self.config.group_size:
return None, backlog return None, backlog
# TODO: Mack sure logprobs included
group: ScoredDataGroup = ScoredDataGroup( group: ScoredDataGroup = ScoredDataGroup(
tokens=[], tokens=[],
@ -458,15 +461,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
advantages=[], advantages=[],
ref_logprobs=[], ref_logprobs=[],
messages=[] if self.config.include_messages else None, messages=[] if self.config.include_messages else None,
inference_logprobs=[],
group_overrides={}, group_overrides={},
overrides=[], overrides=[],
images=[], images=[],
generation_params=None,
) )
for it in items: for it in items:
group["tokens"].append(it["tokens"]) group["tokens"].append(it["tokens"])
group["masks"].append(it["masks"]) group["masks"].append(it["masks"])
group["scores"].append(it["scores"]) group["scores"].append(it["scores"])
# policy logprobs (for PPO/GRPO training) if present
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
if lp is not None:
group["inference_logprobs"].append(lp)
if group.get("messages") is not None and it.get("messages") is not None: if group.get("messages") is not None and it.get("messages") is not None:
group["messages"].append(it["messages"]) group["messages"].append(it["messages"])