mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-15 04:12:25 +00:00
adjusted prompt again to make things more reliable, having api issues
This commit is contained in:
parent
87464821d8
commit
487487406d
2 changed files with 115 additions and 156 deletions
|
|
@ -32,16 +32,13 @@ load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
# Default system prompt with tool calling instructions.
|
# Default system prompt with tool calling instructions.
|
||||||
#
|
|
||||||
# IMPORTANT: In training-mode environments we want "raw text in -> raw text out" and we
|
|
||||||
# parse tool calls from completion text. Do not rely on server-specific `tool_calls` fields.
|
|
||||||
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
|
||||||
|
|
||||||
You are a function calling AI model.
|
You are a function calling AI model.
|
||||||
|
|
||||||
You are provided with function signatures within <tools></tools> XML tags.
|
You are provided with function signatures within <tools></tools> XML tags.
|
||||||
You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
|
||||||
|
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
|
||||||
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
|
||||||
|
|
||||||
Here are the available tools:
|
Here are the available tools:
|
||||||
|
|
@ -314,6 +311,24 @@ class AtroposAgent:
|
||||||
return model
|
return model
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _infer_server_base_url_for_debug(self) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Best-effort inference of the configured base_url for debug logging.
|
||||||
|
|
||||||
|
This is helpful when diagnosing hangs / retries at the transport layer.
|
||||||
|
"""
|
||||||
|
servers = getattr(self.server, "servers", None)
|
||||||
|
if isinstance(servers, list) and servers:
|
||||||
|
s0 = servers[0]
|
||||||
|
cfg = getattr(s0, "config", None)
|
||||||
|
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
|
||||||
|
if isinstance(base_url, str) and base_url:
|
||||||
|
return base_url
|
||||||
|
base_url = getattr(self.server, "base_url", None)
|
||||||
|
if isinstance(base_url, str) and base_url:
|
||||||
|
return base_url
|
||||||
|
return None
|
||||||
|
|
||||||
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
|
||||||
|
|
@ -359,6 +374,8 @@ class AtroposAgent:
|
||||||
# Avoid dumping megabytes by default; messages can be huge.
|
# Avoid dumping megabytes by default; messages can be huge.
|
||||||
meta = {
|
meta = {
|
||||||
"step": step_num,
|
"step": step_num,
|
||||||
|
"base_url": self._infer_server_base_url_for_debug(),
|
||||||
|
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
|
||||||
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
|
||||||
"n": chat_kwargs.get("n"),
|
"n": chat_kwargs.get("n"),
|
||||||
"max_tokens": chat_kwargs.get("max_tokens"),
|
"max_tokens": chat_kwargs.get("max_tokens"),
|
||||||
|
|
@ -421,8 +438,12 @@ class AtroposAgent:
|
||||||
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
|
||||||
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
|
||||||
"""
|
"""
|
||||||
|
# Hard guardrail: never allow a single chat completion to block for more than 2 minutes.
|
||||||
|
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
|
||||||
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
|
||||||
timeout_s = float(timeout_s_raw) if timeout_s_raw else None
|
timeout_s_default = 120.0
|
||||||
|
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
|
||||||
|
timeout_s = min(timeout_s, 120.0)
|
||||||
|
|
||||||
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
|
||||||
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
wait_every_s = float(wait_every_raw) if wait_every_raw else None
|
||||||
|
|
@ -452,9 +473,15 @@ class AtroposAgent:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if timeout_s and timeout_s > 0:
|
|
||||||
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
|
||||||
return await _await_call()
|
except asyncio.TimeoutError as e:
|
||||||
|
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
|
||||||
|
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
|
||||||
|
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Treat cancellation as a hard failure rather than crashing the whole env run.
|
||||||
|
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
|
||||||
|
raise RuntimeError("chat_completion cancelled") from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
detail: Dict[str, Any] = {
|
detail: Dict[str, Any] = {
|
||||||
"step": step_num,
|
"step": step_num,
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from __future__ import annotations
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, List, Literal, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
|
|
@ -41,13 +41,14 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
|
||||||
description="If true, score tests on PASS_TO_PASS ∪ FAIL_TO_PASS; else PASS_TO_PASS only.",
|
description="If true, score tests on PASS_TO_PASS ∪ FAIL_TO_PASS; else PASS_TO_PASS only.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
prompt_mode: str = Field(
|
||||||
|
default="problem_statement",
|
||||||
|
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
|
||||||
|
)
|
||||||
|
|
||||||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||||||
install_timeout_s: float = Field(default=600.0)
|
install_timeout_s: float = Field(default=600.0)
|
||||||
test_timeout_s: float = Field(default=600.0)
|
test_timeout_s: float = Field(default=600.0)
|
||||||
verification_mode: Literal["pytest", "install"] = Field(
|
|
||||||
default="install",
|
|
||||||
description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.",
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
|
||||||
|
|
||||||
|
|
@ -78,12 +79,15 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||||||
# Defaults for running the env via CLI in offline `process` mode.
|
# Defaults for running the env via CLI in offline `process` mode.
|
||||||
# Override via env vars or `--env.*` flags as needed.
|
# Override via env vars or `--env.*` flags as needed.
|
||||||
base_url = (
|
base_url_raw = (
|
||||||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||||||
or os.getenv("OPENAI_BASE_URL")
|
or os.getenv("OPENAI_BASE_URL")
|
||||||
or os.getenv("LLM_BASE_URL")
|
or os.getenv("LLM_BASE_URL")
|
||||||
or "http://127.0.0.1:8080"
|
or "http://127.0.0.1:8080"
|
||||||
)
|
)
|
||||||
|
base_url = base_url_raw.rstrip("/")
|
||||||
|
if not base_url.endswith("/v1"):
|
||||||
|
base_url = f"{base_url}/v1"
|
||||||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
|
||||||
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
|
||||||
|
|
||||||
|
|
@ -108,7 +112,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
server_configs = [
|
server_configs = [
|
||||||
APIServerConfig(
|
APIServerConfig(
|
||||||
model_name=model,
|
model_name=model,
|
||||||
base_url=f"{base_url.rstrip('/')}/v1",
|
base_url=base_url,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
num_max_requests_at_once=1,
|
num_max_requests_at_once=1,
|
||||||
num_requests_for_eval=1,
|
num_requests_for_eval=1,
|
||||||
|
|
@ -184,69 +188,43 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
problem = str(item.get("problem_statement") or "")
|
problem = str(item.get("problem_statement") or "")
|
||||||
context = str(item.get("text") or "")
|
context = str(item.get("text") or "")
|
||||||
|
|
||||||
# The dataset "text" field can be extremely large (e.g. includes large code blobs
|
|
||||||
# and long test lists). In local dev and bring-up runs this can make the first LLM
|
|
||||||
# call appear "hung" while the model chews through a massive prompt. Keep a cap.
|
|
||||||
|
|
||||||
# TODO: Remove, smoke test only
|
|
||||||
def _cap(s: str, n: int) -> tuple[str, bool]:
|
|
||||||
if len(s) <= n:
|
|
||||||
return s, False
|
|
||||||
return s[:n], True
|
|
||||||
|
|
||||||
problem, problem_trunc = _cap(problem, 8_000)
|
|
||||||
context, context_trunc = _cap(context, 12_000)
|
|
||||||
|
|
||||||
nodeids = self._tests_for_item(item)
|
nodeids = self._tests_for_item(item)
|
||||||
tests_preview = "\n".join(f"- {t}" for t in nodeids[:50])
|
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||||||
if len(nodeids) > 50:
|
|
||||||
tests_preview += f"\n- ... ({len(nodeids) - 50} more)"
|
|
||||||
|
|
||||||
repo_dir = self._repo_name(item)
|
repo_dir = self._repo_name(item)
|
||||||
verify_note = ""
|
|
||||||
# TODO: Remove, smoke testing only
|
|
||||||
if self.config.verification_mode == "install":
|
|
||||||
verify_note = (
|
|
||||||
"\nVerification for this run is INSTALL-ONLY:\n"
|
|
||||||
"- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n"
|
|
||||||
"- You may skip running pytest to save time.\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
trunc_note = ""
|
|
||||||
if problem_trunc or context_trunc:
|
|
||||||
trunc_note = (
|
|
||||||
"\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n"
|
|
||||||
f"- problem_statement_truncated={problem_trunc}\n"
|
|
||||||
f"- text_truncated={context_trunc}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
tests_block = (
|
tests_block = (
|
||||||
"Run these tests to verify:\n"
|
"Run these tests to verify:\n"
|
||||||
f"{tests_preview}\n\n"
|
f"{tests_list}\n\n"
|
||||||
"When done, briefly describe what you changed and confirm tests pass."
|
"When done, briefly describe what you changed and confirm tests pass."
|
||||||
)
|
)
|
||||||
if self.config.verification_mode == "install":
|
|
||||||
# Keep install-only prompts short and avoid huge test lists.
|
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||||||
tests_block = (
|
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
|
||||||
"When done, briefly describe what you changed and confirm that "
|
raise ValueError(
|
||||||
"`python -m pip install -e .` succeeds."
|
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
|
||||||
|
"Expected 'problem_statement' or 'problem_statement+text'."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
context_block = ""
|
||||||
|
if prompt_mode == "problem_statement+text" and context:
|
||||||
|
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
|
||||||
|
context_block = f"\nAdditional context:\n{context}\n"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
|
||||||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||||||
f"Workspace path: ./{repo_dir}\n\n"
|
f"Workspace path: ./{repo_dir}\n\n"
|
||||||
"Constraints:\n"
|
"Constraints:\n"
|
||||||
|
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
|
||||||
|
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||||||
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
|
||||||
"- Use non-interactive commands only.\n\n"
|
"- Use non-interactive commands only.\n\n"
|
||||||
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
|
||||||
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
|
||||||
f"{verify_note}\n"
|
|
||||||
f"{trunc_note}\n"
|
|
||||||
"Problem statement:\n"
|
"Problem statement:\n"
|
||||||
f"{problem}\n\n"
|
f"{problem}\n\n"
|
||||||
"Additional context:\n"
|
f"{context_block}\n"
|
||||||
f"{context}\n\n"
|
|
||||||
f"{tests_block}"
|
f"{tests_block}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -275,74 +253,61 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prefer a lightweight "fetch by sha" to avoid pulling full history.
|
# Repo setup strategy:
|
||||||
# If it fails (some servers disallow fetching unadvertised objects, or we hit
|
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
|
||||||
# shallow-object edge cases), fall back to a full clone
|
# - For each trajectory, create an isolated git worktree under the slot workspace
|
||||||
|
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
|
||||||
|
|
||||||
# TODO: tbh, should just do this before setting up worktree & after sandbox build
|
def _repo_cache_slug(repo_name: str) -> str:
|
||||||
clone_attempts: list[tuple[str, str]] = []
|
return repo_name.replace("/", "__")
|
||||||
clone_attempts.append(
|
|
||||||
(
|
repo_slug = _repo_cache_slug(repo)
|
||||||
"shallow_fetch_sha",
|
cache_root = "/data/repo_cache"
|
||||||
(
|
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||||||
f"rm -rf {repo_dir} && "
|
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||||||
f"git init {repo_dir} && "
|
|
||||||
f"cd {repo_dir} && "
|
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
|
||||||
"export GIT_TERMINAL_PROMPT=0 && "
|
# util-linux (flock) is included in the sandbox image.
|
||||||
"export GIT_LFS_SKIP_SMUDGE=1 && "
|
worktree_cmd = (
|
||||||
"git config advice.detachedHead false && "
|
"set -e; "
|
||||||
f"git remote add origin {clone_url} && "
|
f"rm -rf {repo_dir}; "
|
||||||
f"git fetch --depth 1 origin {base_commit} && "
|
f"mkdir -p {cache_root}/.locks; "
|
||||||
"git checkout -q FETCH_HEAD"
|
f": > {lock_file}; "
|
||||||
),
|
f"flock -x {lock_file} sh -lc '"
|
||||||
)
|
f"set -e; "
|
||||||
)
|
"export GIT_TERMINAL_PROMPT=0; "
|
||||||
clone_attempts.append(
|
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||||||
(
|
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||||||
"full_clone_checkout",
|
f" git init --bare \"{bare_repo}\"; "
|
||||||
(
|
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||||||
f"rm -rf {repo_dir} && "
|
"fi; "
|
||||||
f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && "
|
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||||||
f"cd {repo_dir} && "
|
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||||||
"git config advice.detachedHead false && "
|
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||||
f"git checkout -q {base_commit}"
|
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||||||
),
|
"fi; "
|
||||||
)
|
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||||||
|
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||||||
|
"fi; "
|
||||||
|
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||||||
|
"'"
|
||||||
)
|
)
|
||||||
|
|
||||||
clone_res = None
|
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||||||
for label, cmd in clone_attempts:
|
|
||||||
t_attempt = time.perf_counter()
|
|
||||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True)
|
|
||||||
res = await exec_tool(
|
res = await exec_tool(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
name="terminal",
|
name="terminal",
|
||||||
arguments={"command": cmd, "timeout": self.config.install_timeout_s},
|
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
clone_res = res
|
if not res.success:
|
||||||
if res.success:
|
|
||||||
print(
|
|
||||||
f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
break
|
|
||||||
print(
|
|
||||||
f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: "
|
|
||||||
f"{res.error}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if clone_res is None or not clone_res.success:
|
|
||||||
err = clone_res.error if clone_res is not None else "unknown"
|
|
||||||
out = clone_res.output if clone_res is not None else ""
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"git clone/checkout failed "
|
"git worktree setup failed "
|
||||||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}"
|
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s",
|
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||||||
|
|
@ -380,43 +345,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
_ = trajectory_id
|
_ = trajectory_id
|
||||||
repo_dir = self._repo_name(item)
|
repo_dir = self._repo_name(item)
|
||||||
|
|
||||||
if self.config.verification_mode == "install":
|
|
||||||
# Training correctness: do not reward trajectories that never actually used tools.
|
# Training correctness: do not reward trajectories that never actually used tools.
|
||||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
||||||
return 0.0, {
|
return 0.0, {
|
||||||
"verification_mode": "install",
|
"verification_mode": "dataset_tests",
|
||||||
"error": "No tool calls were made by the agent",
|
|
||||||
}
|
|
||||||
|
|
||||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True)
|
|
||||||
t0 = time.perf_counter()
|
|
||||||
install_cmd = (
|
|
||||||
f"cd {repo_dir} && "
|
|
||||||
"python -m venv .venv && "
|
|
||||||
". .venv/bin/activate && "
|
|
||||||
"python -m pip install -U pip setuptools wheel && "
|
|
||||||
"python -m pip install -e ."
|
|
||||||
)
|
|
||||||
res = await exec_tool(
|
|
||||||
ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s})
|
|
||||||
)
|
|
||||||
ok = bool(res.success)
|
|
||||||
print(
|
|
||||||
f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} "
|
|
||||||
f"in {time.perf_counter() - t0:.2f}s",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
return (1.0 if ok else 0.0), {
|
|
||||||
"verification_mode": "install",
|
|
||||||
"install_success": ok,
|
|
||||||
"error": res.error,
|
|
||||||
"verification_messages": [{"role": "user", "content": res.to_xml()}],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Training correctness: do not reward trajectories that never actually used tools.
|
|
||||||
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
|
|
||||||
return 0.0, {
|
|
||||||
"verification_mode": "pytest",
|
|
||||||
"error": "No tool calls were made by the agent",
|
"error": "No tool calls were made by the agent",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -424,7 +356,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
if not nodeids:
|
if not nodeids:
|
||||||
return 0.0, {"error": "No tests provided"}
|
return 0.0, {"error": "No tests provided"}
|
||||||
|
|
||||||
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True)
|
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
|
||||||
setup_cmd = (
|
setup_cmd = (
|
||||||
f"cd {repo_dir} && "
|
f"cd {repo_dir} && "
|
||||||
"python -m venv .venv && "
|
"python -m venv .venv && "
|
||||||
|
|
@ -439,7 +371,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
|
||||||
if not setup_res.success:
|
if not setup_res.success:
|
||||||
return 0.0, {
|
return 0.0, {
|
||||||
"verification_mode": "pytest",
|
"verification_mode": "dataset_tests",
|
||||||
"phase": "install",
|
"phase": "install",
|
||||||
"error": setup_res.error,
|
"error": setup_res.error,
|
||||||
"output": setup_res.output,
|
"output": setup_res.output,
|
||||||
|
|
@ -459,7 +391,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
verification_messages.append({"role": "user", "content": res.to_xml()})
|
verification_messages.append({"role": "user", "content": res.to_xml()})
|
||||||
if not res.success:
|
if not res.success:
|
||||||
return 0.0, {
|
return 0.0, {
|
||||||
"verification_mode": "pytest",
|
"verification_mode": "dataset_tests",
|
||||||
"phase": "pytest",
|
"phase": "pytest",
|
||||||
"failed_chunk": chunk_idx,
|
"failed_chunk": chunk_idx,
|
||||||
"error": res.error,
|
"error": res.error,
|
||||||
|
|
@ -467,7 +399,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
|
||||||
"verification_messages": verification_messages,
|
"verification_messages": verification_messages,
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1.0, {"verification_mode": "pytest", "passed": True, "verification_messages": verification_messages}
|
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
|
||||||
|
|
||||||
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
async def score_trajectory(self, item: Item, final_response: str) -> float:
|
||||||
# Not used; scoring happens in verify_and_score_trajectory.
|
# Not used; scoring happens in verify_and_score_trajectory.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue