mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-08 03:01:47 +00:00
added metadata capture
This commit is contained in:
parent
661d8f4d6c
commit
87464821d8
2 changed files with 82 additions and 1 deletions
|
|
@ -144,6 +144,7 @@ class SequenceData:
|
||||||
tokens: List[int]
|
tokens: List[int]
|
||||||
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
masked_tokens: List[int] # -100 for prompt, actual IDs for completion
|
||||||
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
logprobs: List[float] # 1.0 for prompt, actual values for completion
|
||||||
|
metadata: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_sequence_node(cls, node) -> "SequenceData":
|
def from_sequence_node(cls, node) -> "SequenceData":
|
||||||
|
|
@ -153,6 +154,7 @@ class SequenceData:
|
||||||
tokens=node.tokens,
|
tokens=node.tokens,
|
||||||
masked_tokens=node.masked_tokens,
|
masked_tokens=node.masked_tokens,
|
||||||
logprobs=node.logprobs,
|
logprobs=node.logprobs,
|
||||||
|
metadata=getattr(node, "metadata", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -312,6 +314,44 @@ class AtroposAgent:
|
||||||
return model
|
return model
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||||
|
|
||||||
|
This is useful for debugging training runs, especially when ManagedServer state
|
||||||
|
tracking is unavailable (e.g. OpenAI-compatible chat endpoints).
|
||||||
|
"""
|
||||||
|
meta: Dict[str, Any] = {}
|
||||||
|
try:
|
||||||
|
rid = getattr(response, "id", None)
|
||||||
|
if isinstance(rid, str) and rid:
|
||||||
|
meta["id"] = rid
|
||||||
|
model = getattr(response, "model", None)
|
||||||
|
if isinstance(model, str) and model:
|
||||||
|
meta["model"] = model
|
||||||
|
created = getattr(response, "created", None)
|
||||||
|
if isinstance(created, int):
|
||||||
|
meta["created"] = created
|
||||||
|
system_fingerprint = getattr(response, "system_fingerprint", None)
|
||||||
|
if isinstance(system_fingerprint, str) and system_fingerprint:
|
||||||
|
meta["system_fingerprint"] = system_fingerprint
|
||||||
|
|
||||||
|
choices = getattr(response, "choices", None)
|
||||||
|
if isinstance(choices, list) and choices:
|
||||||
|
fr = getattr(choices[0], "finish_reason", None)
|
||||||
|
if isinstance(fr, str) and fr:
|
||||||
|
meta["finish_reason"] = fr
|
||||||
|
|
||||||
|
usage = getattr(response, "usage", None)
|
||||||
|
if usage is not None:
|
||||||
|
if hasattr(usage, "model_dump"):
|
||||||
|
meta["usage"] = usage.model_dump()
|
||||||
|
elif isinstance(usage, dict):
|
||||||
|
meta["usage"] = usage
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return meta
|
||||||
|
|
||||||
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
def _debug_dump_request(self, *, step_num: int, chat_kwargs: Dict[str, Any]) -> None:
|
||||||
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
if os.getenv("ATROPOS_DEBUG_AGENT_REQUEST") != "1":
|
||||||
return
|
return
|
||||||
|
|
@ -491,6 +531,7 @@ class AtroposAgent:
|
||||||
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
managed=managed, step_num=step_num + 1, chat_kwargs=chat_kwargs
|
||||||
)
|
)
|
||||||
self._debug_dump_response(step_num=step_num + 1, response=response)
|
self._debug_dump_response(step_num=step_num + 1, response=response)
|
||||||
|
response_meta = self._extract_response_metadata(response)
|
||||||
print(
|
print(
|
||||||
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s",
|
||||||
flush=True,
|
flush=True,
|
||||||
|
|
@ -517,12 +558,30 @@ class AtroposAgent:
|
||||||
last_node = current_node
|
last_node = current_node
|
||||||
last_prompt_messages = prompt_messages
|
last_prompt_messages = prompt_messages
|
||||||
last_response_text = response_text
|
last_response_text = response_text
|
||||||
|
|
||||||
|
step_sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None
|
||||||
|
if step_sequence_data is None:
|
||||||
|
if response_meta:
|
||||||
|
# We still want metadata for debugging even if token/logprob state tracking is unavailable.
|
||||||
|
step_sequence_data = SequenceData(
|
||||||
|
full_text=response_text,
|
||||||
|
tokens=[],
|
||||||
|
masked_tokens=[],
|
||||||
|
logprobs=[],
|
||||||
|
metadata=response_meta,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
merged = dict(response_meta)
|
||||||
|
node_meta = step_sequence_data.metadata
|
||||||
|
if isinstance(node_meta, dict):
|
||||||
|
merged.update(node_meta)
|
||||||
|
step_sequence_data.metadata = merged or step_sequence_data.metadata
|
||||||
|
|
||||||
step = AgentStep(
|
step = AgentStep(
|
||||||
step_number=step_num + 1,
|
step_number=step_num + 1,
|
||||||
assistant_message=response_text,
|
assistant_message=response_text,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
sequence_data=SequenceData.from_sequence_node(current_node) if current_node else None,
|
sequence_data=step_sequence_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
|
|
@ -579,6 +638,14 @@ class AtroposAgent:
|
||||||
masked_tokens=masked_tokens,
|
masked_tokens=masked_tokens,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
# Preserve response metadata (if any) even on failure trajectories.
|
||||||
|
try:
|
||||||
|
if trajectory_data is not None and steps:
|
||||||
|
last_step = steps[-1]
|
||||||
|
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||||
|
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
return AgentResult(
|
return AgentResult(
|
||||||
success=False,
|
success=False,
|
||||||
final_response=final_response,
|
final_response=final_response,
|
||||||
|
|
@ -610,6 +677,15 @@ class AtroposAgent:
|
||||||
masked_tokens=masked_tokens,
|
masked_tokens=masked_tokens,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure trajectory_data carries the most recent metadata we observed (if any).
|
||||||
|
try:
|
||||||
|
if trajectory_data is not None and steps:
|
||||||
|
last_step = steps[-1]
|
||||||
|
if last_step.sequence_data and isinstance(last_step.sequence_data.metadata, dict):
|
||||||
|
trajectory_data.metadata = dict(last_step.sequence_data.metadata)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
return AgentResult(
|
return AgentResult(
|
||||||
success=True,
|
success=True,
|
||||||
|
|
|
||||||
|
|
@ -375,6 +375,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
}
|
}
|
||||||
if result.trajectory_data is not None:
|
if result.trajectory_data is not None:
|
||||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||||
|
if getattr(result.trajectory_data, "metadata", None):
|
||||||
|
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||||
if self.config.include_messages:
|
if self.config.include_messages:
|
||||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||||
import json
|
import json
|
||||||
|
|
@ -428,6 +430,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||||
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||||
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||||
|
if getattr(result.trajectory_data, "metadata", None):
|
||||||
|
scored["overrides"] = {"managed_metadata": result.trajectory_data.metadata}
|
||||||
if self.config.include_messages:
|
if self.config.include_messages:
|
||||||
scored["messages"] = messages
|
scored["messages"] = messages
|
||||||
|
|
||||||
|
|
@ -476,6 +480,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||||
if lp is not None:
|
if lp is not None:
|
||||||
group["inference_logprobs"].append(lp)
|
group["inference_logprobs"].append(lp)
|
||||||
|
group["overrides"].append(it.get("overrides") or {}) # type: ignore[typeddict-item]
|
||||||
if group.get("messages") is not None and it.get("messages") is not None:
|
if group.get("messages") is not None and it.get("messages") is not None:
|
||||||
group["messages"].append(it["messages"])
|
group["messages"].append(it["messages"])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue