added metadata capture

This commit is contained in:
Shannon Sands 2026-02-05 12:00:31 +10:00
parent 661d8f4d6c
commit 87464821d8
2 changed files with 82 additions and 1 deletions

View file

@ -375,6 +375,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
}
if result.trajectory_data is not None:
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if getattr(result.trajectory_data, "metadata", None):
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
if self.config.include_messages:
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
import json
@ -428,6 +430,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if getattr(result.trajectory_data, "metadata", None):
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
if self.config.include_messages:
scored["messages"] = messages
@ -476,6 +480,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
if lp is not None:
group["inference_logprobs"].append(lp)
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
if group.get("messages") is not None and it.get("messages") is not None:
group["messages"].append(it["messages"])