adjusted prompt again to make things more reliable, having api issues

This commit is contained in:
Shannon Sands 2026-02-05 14:42:10 +10:00
parent 87464821d8
commit 487487406d
2 changed files with 115 additions and 156 deletions

View file

@ -17,7 +17,7 @@ from __future__ import annotations
import os
import random
import time
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
@ -41,13 +41,14 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
description="If true, score tests on PASS_TO_PASS FAIL_TO_PASS; else PASS_TO_PASS only.",
)
prompt_mode: str = Field(
default="problem_statement",
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
)
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
install_timeout_s: float = Field(default=600.0)
test_timeout_s: float = Field(default=600.0)
verification_mode: Literal["pytest", "install"] = Field(
default="install",
description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.",
)
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
@ -78,12 +79,15 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
# Defaults for running the env via CLI in offline `process` mode.
# Override via env vars or `--env.*` flags as needed.
base_url = (
base_url_raw = (
os.getenv("ATROPOS_SERVER_BASE_URL")
or os.getenv("OPENAI_BASE_URL")
or os.getenv("LLM_BASE_URL")
or "http://127.0.0.1:8080"
)
base_url = base_url_raw.rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
@ -108,7 +112,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
server_configs = [
APIServerConfig(
model_name=model,
base_url=f"{base_url.rstrip('/')}/v1",
base_url=base_url,
api_key=api_key,
num_max_requests_at_once=1,
num_requests_for_eval=1,
@ -184,69 +188,43 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
problem = str(item.get("problem_statement") or "")
context = str(item.get("text") or "")
# The dataset "text" field can be extremely large (e.g. includes large code blobs
# and long test lists). In local dev and bring-up runs this can make the first LLM
# call appear "hung" while the model chews through a massive prompt. Keep a cap.
# TODO: Remove, smoke test only
def _cap(s: str, n: int) -> tuple[str, bool]:
if len(s) <= n:
return s, False
return s[:n], True
problem, problem_trunc = _cap(problem, 8_000)
context, context_trunc = _cap(context, 12_000)
nodeids = self._tests_for_item(item)
tests_preview = "\n".join(f"- {t}" for t in nodeids[:50])
if len(nodeids) > 50:
tests_preview += f"\n- ... ({len(nodeids) - 50} more)"
tests_list = "\n".join(f"- {t}" for t in nodeids)
repo_dir = self._repo_name(item)
verify_note = ""
# TODO: Remove, smoke testing only
if self.config.verification_mode == "install":
verify_note = (
"\nVerification for this run is INSTALL-ONLY:\n"
"- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n"
"- You may skip running pytest to save time.\n"
)
trunc_note = ""
if problem_trunc or context_trunc:
trunc_note = (
"\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n"
f"- problem_statement_truncated={problem_trunc}\n"
f"- text_truncated={context_trunc}\n"
)
tests_block = (
"Run these tests to verify:\n"
f"{tests_preview}\n\n"
f"{tests_list}\n\n"
"When done, briefly describe what you changed and confirm tests pass."
)
if self.config.verification_mode == "install":
# Keep install-only prompts short and avoid huge test lists.
tests_block = (
"When done, briefly describe what you changed and confirm that "
"`python -m pip install -e .` succeeds."
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
raise ValueError(
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
"Expected 'problem_statement' or 'problem_statement+text'."
)
context_block = ""
if prompt_mode == "problem_statement+text" and context:
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
context_block = f"\nAdditional context:\n{context}\n"
return (
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
f"Workspace path: ./{repo_dir}\n\n"
"Constraints:\n"
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
"- Use non-interactive commands only.\n\n"
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
f"{verify_note}\n"
f"{trunc_note}\n"
"Problem statement:\n"
f"{problem}\n\n"
"Additional context:\n"
f"{context}\n\n"
f"{context_block}\n"
f"{tests_block}"
)
@ -275,74 +253,61 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
flush=True,
)
# Prefer a lightweight "fetch by sha" to avoid pulling full history.
# If it fails (some servers disallow fetching unadvertised objects, or we hit
# shallow-object edge cases), fall back to a full clone
# Repo setup strategy:
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
# - For each trajectory, create an isolated git worktree under the slot workspace
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
# TODO: tbh, should just do this before setting up worktree & after sandbox build
clone_attempts: list[tuple[str, str]] = []
clone_attempts.append(
(
"shallow_fetch_sha",
(
f"rm -rf {repo_dir} && "
f"git init {repo_dir} && "
f"cd {repo_dir} && "
"export GIT_TERMINAL_PROMPT=0 && "
"export GIT_LFS_SKIP_SMUDGE=1 && "
"git config advice.detachedHead false && "
f"git remote add origin {clone_url} && "
f"git fetch --depth 1 origin {base_commit} && "
"git checkout -q FETCH_HEAD"
),
)
)
clone_attempts.append(
(
"full_clone_checkout",
(
f"rm -rf {repo_dir} && "
f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && "
f"cd {repo_dir} && "
"git config advice.detachedHead false && "
f"git checkout -q {base_commit}"
),
)
def _repo_cache_slug(repo_name: str) -> str:
return repo_name.replace("/", "__")
repo_slug = _repo_cache_slug(repo)
cache_root = "/data/repo_cache"
bare_repo = f"{cache_root}/{repo_slug}.git"
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
# util-linux (flock) is included in the sandbox image.
worktree_cmd = (
"set -e; "
f"rm -rf {repo_dir}; "
f"mkdir -p {cache_root}/.locks; "
f": > {lock_file}; "
f"flock -x {lock_file} sh -lc '"
f"set -e; "
"export GIT_TERMINAL_PROMPT=0; "
"export GIT_LFS_SKIP_SMUDGE=1; "
f"if [ ! -d \"{bare_repo}\" ]; then "
f" git init --bare \"{bare_repo}\"; "
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
"fi; "
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
f"git -C \"{bare_repo}\" worktree prune || true; "
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
"fi; "
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f" git -C \"{bare_repo}\" fetch --prune origin; "
"fi; "
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
"'"
)
clone_res = None
for label, cmd in clone_attempts:
t_attempt = time.perf_counter()
print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": cmd, "timeout": self.config.install_timeout_s},
)
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
)
clone_res = res
if res.success:
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s",
flush=True,
)
break
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: "
f"{res.error}",
flush=True,
)
if clone_res is None or not clone_res.success:
err = clone_res.error if clone_res is not None else "unknown"
out = clone_res.output if clone_res is not None else ""
)
if not res.success:
raise RuntimeError(
"git clone/checkout failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}"
"git worktree setup failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s",
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return {"repo_dir": repo_dir, "base_commit": base_commit}
@ -380,43 +345,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
_ = trajectory_id
repo_dir = self._repo_name(item)
if self.config.verification_mode == "install":
# Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, {
"verification_mode": "install",
"error": "No tool calls were made by the agent",
}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True)
t0 = time.perf_counter()
install_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e ."
)
res = await exec_tool(
ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s})
)
ok = bool(res.success)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} "
f"in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return (1.0 if ok else 0.0), {
"verification_mode": "install",
"install_success": ok,
"error": res.error,
"verification_messages": [{"role": "user", "content": res.to_xml()}],
}
# Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"error": "No tool calls were made by the agent",
}
@ -424,7 +356,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
if not nodeids:
return 0.0, {"error": "No tests provided"}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True)
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
setup_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
@ -439,7 +371,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
if not setup_res.success:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"phase": "install",
"error": setup_res.error,
"output": setup_res.output,
@ -459,7 +391,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages.append({"role": "user", "content": res.to_xml()})
if not res.success:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"phase": "pytest",
"failed_chunk": chunk_idx,
"error": res.error,
@ -467,7 +399,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
"verification_messages": verification_messages,
}
return 1.0, {"verification_mode": "pytest", "passed": True, "verification_messages": verification_messages}
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
async def score_trajectory(self, item: Item, final_response: str) -> float:
# Not used; scoring happens in verify_and_score_trajectory.