Merge: WebResearchEnv evaluate() with full agent loop + tools

This commit is contained in:
teknium1 2026-03-09 19:53:36 -07:00
commit ab7dc22984

View file

@ -425,8 +425,16 @@ class WebResearchEnv(HermesAgentBaseEnv):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
async def evaluate(self, *args, **kwargs) -> None: async def evaluate(self, *args, **kwargs) -> None:
"""Run evaluation on the held-out split using the agent loop.""" """Run evaluation on the held-out split using the full agent loop with tools.
Each eval item runs through the same agent loop as training
the model can use web_search, web_extract, etc. to research answers.
This measures actual agentic research capability, not just knowledge.
"""
import time import time
import uuid
from environments.agent_loop import HermesAgentLoop
from environments.tool_context import ToolContext
items = self._eval_items items = self._eval_items
if not items: if not items:
@ -436,43 +444,75 @@ class WebResearchEnv(HermesAgentBaseEnv):
eval_size = min(self.config.eval_size, len(items)) eval_size = min(self.config.eval_size, len(items))
eval_items = items[:eval_size] eval_items = items[:eval_size]
logger.info(f"Running eval on {len(eval_items)} questions...") logger.info(f"Running eval on {len(eval_items)} questions (with agent loop + tools)...")
start_time = time.time() start_time = time.time()
samples = [] samples = []
for item in eval_items: # Resolve tools once for all eval items
tools, valid_names = self._resolve_tools_for_group()
for i, item in enumerate(eval_items):
task_id = str(uuid.uuid4())
logger.info(f"Eval [{i+1}/{len(eval_items)}]: {item['question'][:80]}...")
try: try:
# Use the base env's agent loop for eval (same as training) # Build messages
prompt = self.format_prompt(item) messages: List[Dict[str, Any]] = []
completion = await self.server.chat_completion( if self.config.system_prompt:
messages=[ messages.append({"role": "system", "content": self.config.system_prompt})
{"role": "system", "content": self.config.system_prompt or ""}, messages.append({"role": "user", "content": self.format_prompt(item)})
{"role": "user", "content": prompt},
], # Run the full agent loop with tools
n=1, agent = HermesAgentLoop(
server=self.server,
tool_schemas=tools,
valid_tool_names=valid_names,
max_turns=self.config.max_agent_turns,
task_id=task_id,
temperature=0.0, # Deterministic for eval
max_tokens=self.config.max_token_length, max_tokens=self.config.max_token_length,
temperature=0.0, extra_body=self.config.extra_body,
split="eval",
) )
result = await agent.run(messages)
response_content = ( # Extract final response and compute reward
completion.choices[0].message.content if completion.choices else "" ctx = ToolContext(task_id)
) try:
reward = await self.compute_reward(item, result, ctx)
finally:
ctx.cleanup()
# Score the response # Extract final response for logging
final_response = ""
tool_call_count = 0
for msg in reversed(result.messages):
if msg.get("role") == "assistant" and msg.get("content") and not final_response:
final_response = msg["content"]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_call_count += len(msg["tool_calls"])
# Score correctness separately for the metric
correctness = await self._llm_judge( correctness = await self._llm_judge(
question=item["question"], question=item["question"],
expected=item["answer"], expected=item["answer"],
model_answer=response_content, model_answer=final_response,
) )
samples.append({ samples.append({
"prompt": item["question"], "prompt": item["question"],
"response": response_content, "response": final_response[:500],
"expected": item["answer"], "expected": item["answer"],
"correctness": correctness, "correctness": correctness,
"reward": reward,
"tool_calls": tool_call_count,
"turns": result.turns_used,
}) })
logger.info(
f" → correctness={correctness:.2f}, reward={reward:.3f}, "
f"tools={tool_call_count}, turns={result.turns_used}"
)
except Exception as e: except Exception as e:
logger.error(f"Eval error on item: {e}") logger.error(f"Eval error on item: {e}")
samples.append({ samples.append({
@ -480,20 +520,33 @@ class WebResearchEnv(HermesAgentBaseEnv):
"response": f"ERROR: {e}", "response": f"ERROR: {e}",
"expected": item["answer"], "expected": item["answer"],
"correctness": 0.0, "correctness": 0.0,
"reward": 0.0,
"tool_calls": 0,
"turns": 0,
}) })
end_time = time.time() end_time = time.time()
# Compute metrics # Compute aggregate metrics
correctness_scores = [s["correctness"] for s in samples] correctness_scores = [s["correctness"] for s in samples]
rewards = [s["reward"] for s in samples]
tool_counts = [s["tool_calls"] for s in samples]
n = len(samples)
eval_metrics = { eval_metrics = {
"eval/mean_correctness": ( "eval/mean_correctness": sum(correctness_scores) / n if n else 0.0,
sum(correctness_scores) / len(correctness_scores) "eval/mean_reward": sum(rewards) / n if n else 0.0,
if correctness_scores else 0.0 "eval/mean_tool_calls": sum(tool_counts) / n if n else 0.0,
), "eval/tool_usage_rate": sum(1 for t in tool_counts if t > 0) / n if n else 0.0,
"eval/n_items": len(samples), "eval/n_items": n,
} }
logger.info(
f"Eval complete — correctness={eval_metrics['eval/mean_correctness']:.3f}, "
f"reward={eval_metrics['eval/mean_reward']:.3f}, "
f"tool_usage={eval_metrics['eval/tool_usage_rate']:.0%}"
)
await self.evaluate_log( await self.evaluate_log(
metrics=eval_metrics, metrics=eval_metrics,
samples=samples, samples=samples,