mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-27 01:11:40 +00:00
Merge: WebResearchEnv evaluate() with full agent loop + tools
This commit is contained in:
commit
ab7dc22984
1 changed files with 78 additions and 25 deletions
|
|
@ -425,8 +425,16 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
async def evaluate(self, *args, **kwargs) -> None:
|
async def evaluate(self, *args, **kwargs) -> None:
|
||||||
"""Run evaluation on the held-out split using the agent loop."""
|
"""Run evaluation on the held-out split using the full agent loop with tools.
|
||||||
|
|
||||||
|
Each eval item runs through the same agent loop as training —
|
||||||
|
the model can use web_search, web_extract, etc. to research answers.
|
||||||
|
This measures actual agentic research capability, not just knowledge.
|
||||||
|
"""
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
|
from environments.agent_loop import HermesAgentLoop
|
||||||
|
from environments.tool_context import ToolContext
|
||||||
|
|
||||||
items = self._eval_items
|
items = self._eval_items
|
||||||
if not items:
|
if not items:
|
||||||
|
|
@ -436,43 +444,75 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
eval_size = min(self.config.eval_size, len(items))
|
eval_size = min(self.config.eval_size, len(items))
|
||||||
eval_items = items[:eval_size]
|
eval_items = items[:eval_size]
|
||||||
|
|
||||||
logger.info(f"Running eval on {len(eval_items)} questions...")
|
logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
samples = []
|
samples = []
|
||||||
|
|
||||||
for item in eval_items:
|
# Resolve tools once for all eval items
|
||||||
|
tools, valid_names = self._resolve_tools_for_group()
|
||||||
|
|
||||||
|
for i, item in enumerate(eval_items):
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use the base env's agent loop for eval (same as training)
|
# Build messages
|
||||||
prompt = self.format_prompt(item)
|
messages: List[Dict[str, Any]] = []
|
||||||
completion = await self.server.chat_completion(
|
if self.config.system_prompt:
|
||||||
messages=[
|
messages.append({"role": "system", "content": self.config.system_prompt})
|
||||||
{"role": "system", "content": self.config.system_prompt or ""},
|
messages.append({"role": "user", "content": self.format_prompt(item)})
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
# Run the full agent loop with tools
|
||||||
n=1,
|
agent = HermesAgentLoop(
|
||||||
|
server=self.server,
|
||||||
|
tool_schemas=tools,
|
||||||
|
valid_tool_names=valid_names,
|
||||||
|
max_turns=self.config.max_agent_turns,
|
||||||
|
task_id=task_id,
|
||||||
|
temperature=0.0, # Deterministic for eval
|
||||||
max_tokens=self.config.max_token_length,
|
max_tokens=self.config.max_token_length,
|
||||||
temperature=0.0,
|
extra_body=self.config.extra_body,
|
||||||
split="eval",
|
|
||||||
)
|
)
|
||||||
|
result = await agent.run(messages)
|
||||||
|
|
||||||
response_content = (
|
# Extract final response and compute reward
|
||||||
completion.choices[0].message.content if completion.choices else ""
|
ctx = ToolContext(task_id)
|
||||||
)
|
try:
|
||||||
|
reward = await self.compute_reward(item, result, ctx)
|
||||||
|
finally:
|
||||||
|
ctx.cleanup()
|
||||||
|
|
||||||
# Score the response
|
# Extract final response for logging
|
||||||
|
final_response = ""
|
||||||
|
tool_call_count = 0
|
||||||
|
for msg in reversed(result.messages):
|
||||||
|
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
|
||||||
|
final_response = msg["content"]
|
||||||
|
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||||
|
tool_call_count += len(msg["tool_calls"])
|
||||||
|
|
||||||
|
# Score correctness separately for the metric
|
||||||
correctness = await self._llm_judge(
|
correctness = await self._llm_judge(
|
||||||
question=item["question"],
|
question=item["question"],
|
||||||
expected=item["answer"],
|
expected=item["answer"],
|
||||||
model_answer=response_content,
|
model_answer=final_response,
|
||||||
)
|
)
|
||||||
|
|
||||||
samples.append({
|
samples.append({
|
||||||
"prompt": item["question"],
|
"prompt": item["question"],
|
||||||
"response": response_content,
|
"response": final_response[:500],
|
||||||
"expected": item["answer"],
|
"expected": item["answer"],
|
||||||
"correctness": correctness,
|
"correctness": correctness,
|
||||||
|
"reward": reward,
|
||||||
|
"tool_calls": tool_call_count,
|
||||||
|
"turns": result.turns_used,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
|
||||||
|
f"tools={tool_call_count}, turns={result.turns_used}"
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Eval error on item: {e}")
|
logger.error(f"Eval error on item: {e}")
|
||||||
samples.append({
|
samples.append({
|
||||||
|
|
@ -480,20 +520,33 @@ class WebResearchEnv(HermesAgentBaseEnv):
|
||||||
"response": f"ERROR: {e}",
|
"response": f"ERROR: {e}",
|
||||||
"expected": item["answer"],
|
"expected": item["answer"],
|
||||||
"correctness": 0.0,
|
"correctness": 0.0,
|
||||||
|
"reward": 0.0,
|
||||||
|
"tool_calls": 0,
|
||||||
|
"turns": 0,
|
||||||
})
|
})
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
# Compute metrics
|
# Compute aggregate metrics
|
||||||
correctness_scores = [s["correctness"] for s in samples]
|
correctness_scores = [s["correctness"] for s in samples]
|
||||||
|
rewards = [s["reward"] for s in samples]
|
||||||
|
tool_counts = [s["tool_calls"] for s in samples]
|
||||||
|
n = len(samples)
|
||||||
|
|
||||||
eval_metrics = {
|
eval_metrics = {
|
||||||
"eval/mean_correctness": (
|
"eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
|
||||||
sum(correctness_scores) / len(correctness_scores)
|
"eval/mean_reward": sum(rewards) / n if n else 0.0,
|
||||||
if correctness_scores else 0.0
|
"eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
|
||||||
),
|
"eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
|
||||||
"eval/n_items": len(samples),
|
"eval/n_items": n,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
|
||||||
|
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
|
||||||
|
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
|
||||||
|
)
|
||||||
|
|
||||||
await self.evaluate_log(
|
await self.evaluate_log(
|
||||||
metrics=eval_metrics,
|
metrics=eval_metrics,
|
||||||
samples=samples,
|
samples=samples,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue