From 87464821d836850d09534ea4ea40046a5844a84b Mon Sep 17 00:00:00 2001 From: Shannon Sands Date: Thu, 5 Feb 2026 12:00:31 +1000 Subject: [PATCH] added metadata capture --- atropos/agent/atropos_agent.py | 78 +++++++++++++++++++++++++++++++++- atropos/envs/agent_env.py | 5 +++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/atropos/agent/atropos_agent.py b/atropos/agent/atropos_agent.py index 18e24d4d8d..0bc2637c4e 100644 --- a/atropos/agent/atropos_agent.py +++ b/atropos/agent/atropos_agent.py @@ -144,6 +144,7 @@ class SequenceData: tokens: List[int] masked_tokens: List[int] # -100 for prompt, actual IDs for completion logprobs: List[float] # 1.0 for prompt, actual values for completion + metadata: Optional[Dict[str, Any]] = None @classmethod def from_sequence_node(cls, node) -> "SequenceData": @@ -153,6 +154,7 @@ class SequenceData: tokens=node.tokens, masked_tokens=node.masked_tokens, logprobs=node.logprobs, + metadata=getattr(node, "metadata", None), ) @@ -312,6 +314,44 @@ class AtroposAgent: return model return None + def _extract_response_metadata(self, response: Any) -> Dict[str, Any]: + """ + Extract lightweight, JSON-serializable metadata from an OpenAI-style response. + + This is useful for debugging training runs, especially when ManagedServer state + tracking is unavailable (e.g. OpenAI-compatible chat endpoints). + """ + meta: Dict[str, Any] = {} + try: + rid = getattr(response, "id", None) + if isinstance(rid, str) and rid: + meta["id"] = rid + model = getattr(response, "model", None) + if isinstance(model, str) and model: + meta["model"] = model + created = getattr(response, "created", None) + if isinstance(created, int): + meta["created"] = created + system_fingerprint = getattr(response, "system_fingerprint", None) + if isinstance(system_fingerprint, str) and system_fingerprint: + meta["system_fingerprint"] = system_fingerprint + + choices = getattr(response, "choices", None) + if isinstance(choices, list) and choices: + fr = getattr(choices[0], "finish_reason", None) + if isinstance(fr, str) and fr: + meta["finish_reason"] = fr + + usage = getattr(response, "usage", None) + if usage is not None: + if hasattr(usage, "model_dump"): + meta["usage"] = usage.model_dump() + elif isinstance(usage, dict): + meta["usage"] = usage + except Exception: + pass + return meta + def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None: if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1": return @@ -491,6 +531,7 @@ class AtroposAgent: managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs ) self._debug_dump_response(step_num=step_num + 1, response=response) + response_meta = self._extract_response_metadata(response) print( f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s", flush=True, @@ -517,12 +558,30 @@ class AtroposAgent: last_node = current_node last_prompt_messages = prompt_messages last_response_text = response_text + + step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None + if step_sequence_data is None: + if response_meta: + # We still want metadata for debugging even if token/logprob state tracking is unavailable. + step_sequence_data = SequenceData( + full_text=response_text, + tokens=[], + masked_tokens=[], + logprobs=[], + metadata=response_meta, + ) + else: + merged = dict(response_meta) + node_meta = step_sequence_data.metadata + if isinstance(node_meta, dict): + merged.update(node_meta) + step_sequence_data.metadata = merged or step_sequence_data.metadata step = AgentStep( step_number=step_num + 1, assistant_message=response_text, tool_calls=tool_calls, - sequence_data=SequenceData.from_sequence_node(current_node) if current_node else None, + sequence_data=step_sequence_data, ) if not tool_calls: @@ -579,6 +638,14 @@ class AtroposAgent: masked_tokens=masked_tokens, logprobs=logprobs, ) + # Preserve response metadata (if any) even on failure trajectories. + try: + if trajectory_data is not None and steps: + last_step = steps[-1] + if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict): + trajectory_data.metadata = dict(last_step.sequence_data.metadata) + except Exception: + pass return AgentResult( success=False, final_response=final_response, @@ -610,6 +677,15 @@ class AtroposAgent: masked_tokens=masked_tokens, logprobs=logprobs, ) + + # Ensure trajectory_data carries the most recent metadata we observed (if any). + try: + if trajectory_data is not None and steps: + last_step = steps[-1] + if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict): + trajectory_data.metadata = dict(last_step.sequence_data.metadata) + except Exception: + pass return AgentResult( success=True, diff --git a/atropos/envs/agent_env.py b/atropos/envs/agent_env.py index 6a69d11446..7d7d14b916 100644 --- a/atropos/envs/agent_env.py +++ b/atropos/envs/agent_env.py @@ -375,6 +375,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): } if result.trajectory_data is not None: scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] + if getattr(result.trajectory_data, "metadata", None): + scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata} if self.config.include_messages: # Record a final failure marker as a user-side tool_response-like block so it survives templates. import json @@ -428,6 +430,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): # Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`. # We stash per-item values here and lift them into the group in `collect_trajectories()`. scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key] + if getattr(result.trajectory_data, "metadata", None): + scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata} if self.config.include_messages: scored["messages"] = messages @@ -476,6 +480,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): lp = it.get("inference_logprobs") # type: ignore[typeddict-item] if lp is not None: group["inference_logprobs"].append(lp) + group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item] if group.get("messages") is not None and it.get("messages") is not None: group["messages"].append(it["messages"])