From 661d8f4d6cf73fadf849b2b0c409e8555b0d6751 Mon Sep 17 00:00:00 2001 From: Shannon Sands Date: Thu, 5 Feb 2026 11:42:58 +1000 Subject: [PATCH] logprobs --- atropos/envs/agent_env.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/atropos/envs/agent_env.py b/atropos/envs/agent_env.py index 5596266eb32..6a69d114466 100644 --- a/atropos/envs/agent_env.py +++ b/atropos/envs/agent_env.py @@ -373,6 +373,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): "masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []), "scores": 0.0, } + if result.trajectory_data is not None: + scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] if self.config.include_messages: # Record a final failure marker as a user-side tool_response-like block so it survives templates. import json @@ -423,6 +425,9 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): "masks": result.trajectory_data.masked_tokens, "scores": score, } + # Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`. + # We stash per-item values here and lift them into the group in `collect_trajectories()`. + scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] if self.config.include_messages: scored["messages"] = messages @@ -448,8 +453,6 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): if len(items) != self.config.group_size: return None, backlog - - # TODO: Mack sure logprobs included group: ScoredDataGroup = ScoredDataGroup( tokens=[], @@ -458,15 +461,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): advantages=[], ref_logprobs=[], messages=[] if self.config.include_messages else None, + inference_logprobs=[], group_overrides={}, overrides=[], images=[], + generation_params=None, ) for it in items: group["tokens"].append(it["tokens"]) group["masks"].append(it["masks"]) group["scores"].append(it["scores"]) + # policy logprobs (for PPO/GRPO training) if present + lp = it.get("inference_logprobs") # type: ignore[typeddict-item] + if lp is not None: + group["inference_logprobs"].append(lp) if group.get("messages") is not None and it.get("messages") is not None: group["messages"].append(it["messages"])