mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-15 04:12:25 +00:00
logprobs
This commit is contained in:
parent
bf13a848ef
commit
661d8f4d6c
1 changed files with 11 additions and 2 deletions
|
|
@ -373,6 +373,8 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
"masks": (result.trajectory_data.masked_tokens if result.trajectory_data else []),
|
||||||
"scores": 0.0,
|
"scores": 0.0,
|
||||||
}
|
}
|
||||||
|
if result.trajectory_data is not None:
|
||||||
|
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||||
if self.config.include_messages:
|
if self.config.include_messages:
|
||||||
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
# Record a final failure marker as a user-side tool_response-like block so it survives templates.
|
||||||
import json
|
import json
|
||||||
|
|
@ -423,6 +425,9 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
"masks": result.trajectory_data.masked_tokens,
|
"masks": result.trajectory_data.masked_tokens,
|
||||||
"scores": score,
|
"scores": score,
|
||||||
}
|
}
|
||||||
|
# Atroposlib expects policy logprobs at the *group* level under `inference_logprobs`.
|
||||||
|
# We stash per-item values here and lift them into the group in `collect_trajectories()`.
|
||||||
|
scored["inference_logprobs"] = result.trajectory_data.logprobs # type: ignore[typeddict-unknown-key]
|
||||||
if self.config.include_messages:
|
if self.config.include_messages:
|
||||||
scored["messages"] = messages
|
scored["messages"] = messages
|
||||||
|
|
||||||
|
|
@ -448,8 +453,6 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
|
|
||||||
if len(items) != self.config.group_size:
|
if len(items) != self.config.group_size:
|
||||||
return None, backlog
|
return None, backlog
|
||||||
|
|
||||||
# TODO: Mack sure logprobs included
|
|
||||||
|
|
||||||
group: ScoredDataGroup = ScoredDataGroup(
|
group: ScoredDataGroup = ScoredDataGroup(
|
||||||
tokens=[],
|
tokens=[],
|
||||||
|
|
@ -458,15 +461,21 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]):
|
||||||
advantages=[],
|
advantages=[],
|
||||||
ref_logprobs=[],
|
ref_logprobs=[],
|
||||||
messages=[] if self.config.include_messages else None,
|
messages=[] if self.config.include_messages else None,
|
||||||
|
inference_logprobs=[],
|
||||||
group_overrides={},
|
group_overrides={},
|
||||||
overrides=[],
|
overrides=[],
|
||||||
images=[],
|
images=[],
|
||||||
|
generation_params=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
for it in items:
|
for it in items:
|
||||||
group["tokens"].append(it["tokens"])
|
group["tokens"].append(it["tokens"])
|
||||||
group["masks"].append(it["masks"])
|
group["masks"].append(it["masks"])
|
||||||
group["scores"].append(it["scores"])
|
group["scores"].append(it["scores"])
|
||||||
|
# policy logprobs (for PPO/GRPO training) if present
|
||||||
|
lp = it.get("inference_logprobs") # type: ignore[typeddict-item]
|
||||||
|
if lp is not None:
|
||||||
|
group["inference_logprobs"].append(lp)
|
||||||
if group.get("messages") is not None and it.get("messages") is not None:
|
if group.get("messages") is not None and it.get("messages") is not None:
|
||||||
group["messages"].append(it["messages"])
|
group["messages"].append(it["messages"])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue