added metadata capture

This commit is contained in:
Shannon Sands 2026-02-05 12:00:31 +10:00
parent 661d8f4d6c
commit 87464821d8
2 changed files with 82 additions and 1 deletions

View file

@ -144,6 +144,7 @@ class SequenceData:
tokens: List[int] tokens: List[int]
masked_tokens: List[int] # -100 for prompt, actual IDs for completion masked_tokens: List[int] # -100 for prompt, actual IDs for completion
logprobs: List[float] # 1.0 for prompt, actual values for completion logprobs: List[float] # 1.0 for prompt, actual values for completion
metadata: Optional[Dict[str, Any]] = None
@classmethod @classmethod
def from_sequence_node(cls, node) -> "SequenceData": def from_sequence_node(cls, node) -> "SequenceData":
@ -153,6 +154,7 @@ class SequenceData:
tokens=node.tokens, tokens=node.tokens,
masked_tokens=node.masked_tokens, masked_tokens=node.masked_tokens,
logprobs=node.logprobs, logprobs=node.logprobs,
metadata=getattr(node, "metadata", None),
) )
@ -312,6 +314,44 @@ class AtroposAgent:
return model return model
return None return None
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
"""
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
This is useful for debugging training runs, especially when ManagedServer state
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
"""
meta: Dict[str, Any] = {}
try:
rid = getattr(response, "id", None)
if isinstance(rid, str) and rid:
meta["id"] = rid
model = getattr(response, "model", None)
if isinstance(model, str) and model:
meta["model"] = model
created = getattr(response, "created", None)
if isinstance(created, int):
meta["created"] = created
system_fingerprint = getattr(response, "system_fingerprint", None)
if isinstance(system_fingerprint, str) and system_fingerprint:
meta["system_fingerprint"] = system_fingerprint
choices = getattr(response, "choices", None)
if isinstance(choices, list) and choices:
fr = getattr(choices[0], "finish_reason", None)
if isinstance(fr, str) and fr:
meta["finish_reason"] = fr
usage = getattr(response, "usage", None)
if usage is not None:
if hasattr(usage, "model_dump"):
meta["usage"] = usage.model_dump()
elif isinstance(usage, dict):
meta["usage"] = usage
except Exception:
pass
return meta
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None: def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1": if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
return return
@ -491,6 +531,7 @@ class AtroposAgent:
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
) )
self._debug_dump_response(step_num=step_num + 1, response=response) self._debug_dump_response(step_num=step_num + 1, response=response)
response_meta = self._extract_response_metadata(response)
print( print(
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s", f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
flush=True, flush=True,
@ -518,11 +559,29 @@ class AtroposAgent:
last_prompt_messages = prompt_messages last_prompt_messages = prompt_messages
last_response_text = response_text last_response_text = response_text
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
if step_sequence_data is None:
if response_meta:
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
step_sequence_data = SequenceData(
full_text=response_text,
tokens=[],
masked_tokens=[],
logprobs=[],
metadata=response_meta,
)
else:
merged = dict(response_meta)
node_meta = step_sequence_data.metadata
if isinstance(node_meta, dict):
merged.update(node_meta)
step_sequence_data.metadata = merged or step_sequence_data.metadata
step = AgentStep( step = AgentStep(
step_number=step_num + 1, step_number=step_num + 1,
assistant_message=response_text, assistant_message=response_text,
tool_calls=tool_calls, tool_calls=tool_calls,
sequence_data=SequenceData.from_sequence_node(current_node) if current_node else None, sequence_data=step_sequence_data,
) )
if not tool_calls: if not tool_calls:
@ -579,6 +638,14 @@ class AtroposAgent:
masked_tokens=masked_tokens, masked_tokens=masked_tokens,
logprobs=logprobs, logprobs=logprobs,
) )
# Preserve response metadata (if any) even on failure trajectories.
try:
if trajectory_data is not None and steps:
last_step = steps[-1]
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
except Exception:
pass
return AgentResult( return AgentResult(
success=False, success=False,
final_response=final_response, final_response=final_response,
@ -611,6 +678,15 @@ class AtroposAgent:
logprobs=logprobs, logprobs=logprobs,
) )
# Ensure trajectory_data carries the most recent metadata we observed (if any).
try:
if trajectory_data is not None and steps:
last_step = steps[-1]
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
except Exception:
pass
return AgentResult( return AgentResult(
success=True, success=True,
final_response=final_response, final_response=final_response,

View file

@ -375,6 +375,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
} }
if result.trajectory_data is not None: if result.trajectory_data is not None:
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if getattr(result.trajectory_data, "metadata", None):
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
if self.config.include_messages: if self.config.include_messages:
# Record a final failure marker as a user-side tool_response-like block so it survives templates. # Record a final failure marker as a user-side tool_response-like block so it survives templates.
import json import json
@ -428,6 +430,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`. # Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
# We stash per-item values here and lift them into the group in `collect_trajectories()`. # We stash per-item values here and lift them into the group in `collect_trajectories()`.
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
if getattr(result.trajectory_data, "metadata", None):
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
if self.config.include_messages: if self.config.include_messages:
scored["messages"] = messages scored["messages"] = messages
@ -476,6 +480,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
lp = it.get("inference_logprobs") # type: ignore[typeddict-item] lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
if lp is not None: if lp is not None:
group["inference_logprobs"].append(lp) group["inference_logprobs"].append(lp)
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
if group.get("messages") is not None and it.get("messages") is not None: if group.get("messages") is not None and it.get("messages") is not None:
group["messages"].append(it["messages"]) group["messages"].append(it["messages"])