swe-smith-oracle runs 1 step process. llama server was just breaking again locally idk, works through Hermes endpoint & ManagedServer fine

This commit is contained in:
Shannon Sands 2026-02-04 11:22:45 +10:00
parent 6cb4fe948a
commit 5a9c98a771
8 changed files with 359 additions and 51 deletions

View file

@ -1,10 +1,11 @@
"""
SWE-smith-oracle benchmark environment (Phase 4.7).
SWE-smith-oracle environment.
This environment is intentionally minimal:
- prepares a sandbox workspace by cloning a public GitHub repo at `base_commit`
- runs an AtroposAgent tool loop to apply a fix
- verifies by running pytest nodeids from the dataset (reward = pass/fail)
- Python only (no multi-language support currently, need to properly bauild & add to dropbox)
Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set).
"""
@ -13,7 +14,8 @@ from __future__ import annotations
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import time
from typing import Any, Dict, List, Literal, Optional, Tuple
from pydantic import Field
@ -40,8 +42,11 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
install_timeout_s: float = Field(default=600.0)
test_timeout_s: float = Field(default=600.0)
verification_mode: Literal["pytest", "install"] = Field(
default="install",
description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.",
)
# Tokenization: should match the model used for training.
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
@ -78,7 +83,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
or "http://127.0.0.1:8080"
)
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
env_config = SweSmithOracleEnvConfig(
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
@ -114,6 +119,12 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
async def setup_agent_env(self) -> None:
from datasets import load_dataset
t0 = time.perf_counter()
print(
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
flush=True,
)
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
self._dataset = ds
@ -135,7 +146,9 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
self._cursor = 0
print(
f"SweSmithOracleEnv loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split}"
f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} "
f"in {time.perf_counter() - t0:.2f}s",
flush=True,
)
def _is_python_row(self, row: Dict[str, Any]) -> bool:
@ -148,6 +161,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
return True
async def get_next_item(self) -> Item:
print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True)
if not self._dataset or not self._indices:
raise RuntimeError("Dataset not initialized (did setup() run?)")
if self._cursor >= len(self._indices):
@ -165,8 +179,19 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
def build_task(self, item: Item) -> str:
repo = item.get("repo") or ""
base_commit = item.get("base_commit") or ""
problem = item.get("problem_statement") or ""
context = item.get("text") or ""
problem = str(item.get("problem_statement") or "")
context = str(item.get("text") or "")
# The dataset "text" field can be extremely large (e.g. includes large code blobs
# and long test lists). In local dev and bring-up runs this can make the first LLM
# call appear "hung" while the model chews through a massive prompt. Keep a cap.
def _cap(s: str, n: int) -> tuple[str, bool]:
if len(s) <= n:
return s, False
return s[:n], True
problem, problem_trunc = _cap(problem, 8_000)
context, context_trunc = _cap(context, 12_000)
nodeids = self._tests_for_item(item)
tests_preview = "\n".join(f"- {t}" for t in nodeids[:50])
@ -174,6 +199,34 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
tests_preview += f"\n- ... ({len(nodeids) - 50} more)"
repo_dir = self._repo_name(item)
verify_note = ""
if self.config.verification_mode == "install":
verify_note = (
"\nVerification for this run is INSTALL-ONLY:\n"
"- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n"
"- You may skip running pytest to save time.\n"
)
trunc_note = ""
if problem_trunc or context_trunc:
trunc_note = (
"\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n"
f"- problem_statement_truncated={problem_trunc}\n"
f"- text_truncated={context_trunc}\n"
)
tests_block = (
"Run these tests to verify:\n"
f"{tests_preview}\n\n"
"When done, briefly describe what you changed and confirm tests pass."
)
if self.config.verification_mode == "install":
# Keep install-only prompts short and avoid huge test lists.
tests_block = (
"When done, briefly describe what you changed and confirm that "
"`python -m pip install -e .` succeeds."
)
return (
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
@ -181,13 +234,13 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
"Constraints:\n"
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
"- Use non-interactive commands only.\n\n"
f"{verify_note}\n"
f"{trunc_note}\n"
"Problem statement:\n"
f"{problem}\n\n"
"Additional context:\n"
f"{context}\n\n"
"Run these tests to verify:\n"
f"{tests_preview}\n\n"
"When done, briefly describe what you changed and confirm tests pass."
f"{tests_block}"
)
def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002
@ -200,7 +253,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
)
async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]:
_ = trajectory_id
t0 = time.perf_counter()
repo = item.get("repo")
base_commit = item.get("base_commit")
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
@ -209,43 +262,80 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
repo_dir = self._repo_name(item)
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
# Fetch only the requested base commit (much lighter than a full clone and robust
# even if the commit is not on the default branch ref we cloned).
#
# This also avoids some failure modes where `git clone` fetches only default-branch
# refs and then `git checkout <sha>` fails because the commit isn't present locally.
clone_and_checkout_cmd = (
f"rm -rf {repo_dir} && "
f"git init {repo_dir} && "
f"cd {repo_dir} && "
f"git remote add origin {clone_url} && "
f"git fetch --depth 1 origin {base_commit} && "
"git -c advice.detachedHead=false checkout -q FETCH_HEAD"
print(
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
flush=True,
)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": clone_and_checkout_cmd, "timeout": self.config.install_timeout_s},
# Prefer a lightweight "fetch by sha" to avoid pulling full history.
# If it fails (some servers disallow fetching unadvertised objects, or we hit
# shallow-object edge cases), fall back to a full clone.
clone_attempts: list[tuple[str, str]] = []
clone_attempts.append(
(
"shallow_fetch_sha",
(
f"rm -rf {repo_dir} && "
f"git init {repo_dir} && "
f"cd {repo_dir} && "
"export GIT_TERMINAL_PROMPT=0 && "
"export GIT_LFS_SKIP_SMUDGE=1 && "
"git config advice.detachedHead false && "
f"git remote add origin {clone_url} && "
f"git fetch --depth 1 origin {base_commit} && "
"git checkout -q FETCH_HEAD"
),
)
)
if not res.success:
clone_attempts.append(
(
"full_clone_checkout",
(
f"rm -rf {repo_dir} && "
f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && "
f"cd {repo_dir} && "
"git config advice.detachedHead false && "
f"git checkout -q {base_commit}"
),
)
)
clone_res = None
for label, cmd in clone_attempts:
t_attempt = time.perf_counter()
print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": cmd, "timeout": self.config.install_timeout_s},
)
)
clone_res = res
if res.success:
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s",
flush=True,
)
break
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: "
f"{res.error}",
flush=True,
)
if clone_res is None or not clone_res.success:
err = clone_res.error if clone_res is not None else "unknown"
out = clone_res.output if clone_res is not None else ""
raise RuntimeError(
"git fetch/checkout failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
"git clone/checkout failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}"
)
# Best-effort baseline python env.
setup_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e . && "
"python -m pip install pytest"
print(
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s",
flush=True,
)
await exec_tool(ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s}))
return {"repo_dir": repo_dir, "base_commit": base_commit}
def _tests_for_item(self, item: Item) -> List[str]:
@ -276,13 +366,60 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
trajectory_id: str,
exec_tool,
agent_result=None, # noqa: ARG002
workspace_meta: Optional[Dict[str, Any]] = None,
) -> tuple[float, Dict[str, Any]]:
_ = trajectory_id
repo_dir = self._repo_name(item)
if self.config.verification_mode == "install":
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True)
t0 = time.perf_counter()
install_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e ."
)
res = await exec_tool(
ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s})
)
ok = bool(res.success)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} "
f"in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return (1.0 if ok else 0.0), {
"verification_mode": "install",
"install_success": ok,
"error": res.error,
}
nodeids = self._tests_for_item(item)
if not nodeids:
return 0.0, {"error": "No tests provided"}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True)
setup_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e . && "
"python -m pip install pytest"
)
setup_res = await exec_tool(
ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})
)
if not setup_res.success:
return 0.0, {
"verification_mode": "pytest",
"phase": "install",
"error": setup_res.error,
"output": setup_res.output,
}
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
for chunk_idx, chunk in enumerate(chunks):
joined = " ".join(chunk)
@ -296,7 +433,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
if not res.success:
return 0.0, {"failed_chunk": chunk_idx, "error": res.error, "output": res.output}
return 1.0, {"passed": True}
return 1.0, {"verification_mode": "pytest", "passed": True}
async def score_trajectory(self, item: Item, final_response: str) -> float:
# Not used; scoring happens in verify_and_score_trajectory.