This commit is contained in:
Shannon Sands 2026-02-05 11:42:58 +10:00
parent bf13a848ef
commit 661d8f4d6c

View file

@ -373,6 +373,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
"scores": 0.0,
}
if result.trajectory_data is not None:
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if self.config.include_messages:
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
import json
@ -423,6 +425,9 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
"masks": result.trajectory_data.masked_tokens,
"scores": score,
}
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if self.config.include_messages:
scored["messages"] = messages
@ -448,8 +453,6 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
if len(items) != self.config.group_size:
return None, backlog
# TODO: Mack sure logprobs included
group: ScoredDataGroup = ScoredDataGroup(
tokens=[],
@ -458,15 +461,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
advantages=[],
ref_logprobs=[],
messages=[] if self.config.include_messages else None,
inference_logprobs=[],
group_overrides={},
overrides=[],
images=[],
generation_params=None,
)
for it in items:
group["tokens"].append(it["tokens"])
group["masks"].append(it["masks"])
group["scores"].append(it["scores"])
# policy logprobs (for PPO/GRPO training) if present
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
if lp is not None:
group["inference_logprobs"].append(lp)
if group.get("messages") is not None and it.get("messages") is not None:
group["messages"].append(it["messages"])