adjusted prompt again to make things more reliable, having api issues

This commit is contained in:
Shannon Sands 2026-02-05 14:42:10 +10:00
parent 87464821d8
commit 487487406d
2 changed files with 115 additions and 156 deletions

View file

@ -32,16 +32,13 @@ load_dotenv()
# Default system prompt with tool calling instructions.
#
# IMPORTANT: In training-mode environments we want "raw text in -> raw text out" and we
# parse tool calls from completion text. Do not rely on server-specific `tool_calls` fields.
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
You are a function calling AI model.
You are provided with function signatures within <tools></tools> XML tags.
You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
Here are the available tools:
@ -314,6 +311,24 @@ class AtroposAgent:
return model
return None
def _infer_server_base_url_for_debug(self) -> Optional[str]:
"""
Best-effort inference of the configured base_url for debug logging.
This is helpful when diagnosing hangs / retries at the transport layer.
"""
servers = getattr(self.server, "servers", None)
if isinstance(servers, list) and servers:
s0 = servers[0]
cfg = getattr(s0, "config", None)
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
if isinstance(base_url, str) and base_url:
return base_url
base_url = getattr(self.server, "base_url", None)
if isinstance(base_url, str) and base_url:
return base_url
return None
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
"""
Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
@ -359,6 +374,8 @@ class AtroposAgent:
# Avoid dumping megabytes by default; messages can be huge.
meta = {
"step": step_num,
"base_url": self._infer_server_base_url_for_debug(),
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
"n": chat_kwargs.get("n"),
"max_tokens": chat_kwargs.get("max_tokens"),
@ -421,8 +438,12 @@ class AtroposAgent:
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
"""
# Hard guardrail: never allow a single chat completion to block for more than 2 minutes.
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
timeout_s = float(timeout_s_raw) if timeout_s_raw else None
timeout_s_default = 120.0
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
timeout_s = min(timeout_s, 120.0)
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
wait_every_s = float(wait_every_raw) if wait_every_raw else None
@ -452,9 +473,15 @@ class AtroposAgent:
raise
try:
if timeout_s and timeout_s > 0:
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
return await _await_call()
return await asyncio.wait_for(_await_call(), timeout=timeout_s)
except asyncio.TimeoutError as e:
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
except asyncio.CancelledError:
# Treat cancellation as a hard failure rather than crashing the whole env run.
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
raise RuntimeError("chat_completion cancelled") from None
except Exception as e:
detail: Dict[str, Any] = {
"step": step_num,

View file

@ -17,7 +17,7 @@ from __future__ import annotations
import os
import random
import time
from typing import Any, Dict, List, Literal, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field
@ -41,13 +41,14 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
description="If true, score tests on PASS_TO_PASS FAIL_TO_PASS; else PASS_TO_PASS only.",
)
prompt_mode: str = Field(
default="problem_statement",
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
)
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
install_timeout_s: float = Field(default=600.0)
test_timeout_s: float = Field(default=600.0)
verification_mode: Literal["pytest", "install"] = Field(
default="install",
description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.",
)
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
@ -78,12 +79,15 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
# Defaults for running the env via CLI in offline `process` mode.
# Override via env vars or `--env.*` flags as needed.
base_url = (
base_url_raw = (
os.getenv("ATROPOS_SERVER_BASE_URL")
or os.getenv("OPENAI_BASE_URL")
or os.getenv("LLM_BASE_URL")
or "http://127.0.0.1:8080"
)
base_url = base_url_raw.rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
@ -108,7 +112,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
server_configs = [
APIServerConfig(
model_name=model,
base_url=f"{base_url.rstrip('/')}/v1",
base_url=base_url,
api_key=api_key,
num_max_requests_at_once=1,
num_requests_for_eval=1,
@ -184,69 +188,43 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
problem = str(item.get("problem_statement") or "")
context = str(item.get("text") or "")
# The dataset "text" field can be extremely large (e.g. includes large code blobs
# and long test lists). In local dev and bring-up runs this can make the first LLM
# call appear "hung" while the model chews through a massive prompt. Keep a cap.
# TODO: Remove, smoke test only
def _cap(s: str, n: int) -> tuple[str, bool]:
if len(s) <= n:
return s, False
return s[:n], True
problem, problem_trunc = _cap(problem, 8_000)
context, context_trunc = _cap(context, 12_000)
nodeids = self._tests_for_item(item)
tests_preview = "\n".join(f"- {t}" for t in nodeids[:50])
if len(nodeids) > 50:
tests_preview += f"\n- ... ({len(nodeids) - 50} more)"
tests_list = "\n".join(f"- {t}" for t in nodeids)
repo_dir = self._repo_name(item)
verify_note = ""
# TODO: Remove, smoke testing only
if self.config.verification_mode == "install":
verify_note = (
"\nVerification for this run is INSTALL-ONLY:\n"
"- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n"
"- You may skip running pytest to save time.\n"
)
trunc_note = ""
if problem_trunc or context_trunc:
trunc_note = (
"\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n"
f"- problem_statement_truncated={problem_trunc}\n"
f"- text_truncated={context_trunc}\n"
)
tests_block = (
"Run these tests to verify:\n"
f"{tests_preview}\n\n"
f"{tests_list}\n\n"
"When done, briefly describe what you changed and confirm tests pass."
)
if self.config.verification_mode == "install":
# Keep install-only prompts short and avoid huge test lists.
tests_block = (
"When done, briefly describe what you changed and confirm that "
"`python -m pip install -e .` succeeds."
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
if prompt_mode not in {"problem_statement", "problem_statement+text"}:
raise ValueError(
f"Invalid prompt_mode={self.config.prompt_mode!r}. "
"Expected 'problem_statement' or 'problem_statement+text'."
)
context_block = ""
if prompt_mode == "problem_statement+text" and context:
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
context_block = f"\nAdditional context:\n{context}\n"
return (
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
f"Workspace path: ./{repo_dir}\n\n"
"Constraints:\n"
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
"- Use non-interactive commands only.\n\n"
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
f"{verify_note}\n"
f"{trunc_note}\n"
"Problem statement:\n"
f"{problem}\n\n"
"Additional context:\n"
f"{context}\n\n"
f"{context_block}\n"
f"{tests_block}"
)
@ -275,74 +253,61 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
flush=True,
)
# Prefer a lightweight "fetch by sha" to avoid pulling full history.
# If it fails (some servers disallow fetching unadvertised objects, or we hit
# shallow-object edge cases), fall back to a full clone
# Repo setup strategy:
# - Maintain a shared, per-container bare repo cache under /data/repo_cache
# - For each trajectory, create an isolated git worktree under the slot workspace
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
# TODO: tbh, should just do this before setting up worktree & after sandbox build
clone_attempts: list[tuple[str, str]] = []
clone_attempts.append(
(
"shallow_fetch_sha",
(
f"rm -rf {repo_dir} && "
f"git init {repo_dir} && "
f"cd {repo_dir} && "
"export GIT_TERMINAL_PROMPT=0 && "
"export GIT_LFS_SKIP_SMUDGE=1 && "
"git config advice.detachedHead false && "
f"git remote add origin {clone_url} && "
f"git fetch --depth 1 origin {base_commit} && "
"git checkout -q FETCH_HEAD"
),
)
)
clone_attempts.append(
(
"full_clone_checkout",
(
f"rm -rf {repo_dir} && "
f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && "
f"cd {repo_dir} && "
"git config advice.detachedHead false && "
f"git checkout -q {base_commit}"
),
)
def _repo_cache_slug(repo_name: str) -> str:
return repo_name.replace("/", "__")
repo_slug = _repo_cache_slug(repo)
cache_root = "/data/repo_cache"
bare_repo = f"{cache_root}/{repo_slug}.git"
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
# Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
# util-linux (flock) is included in the sandbox image.
worktree_cmd = (
"set -e; "
f"rm -rf {repo_dir}; "
f"mkdir -p {cache_root}/.locks; "
f": > {lock_file}; "
f"flock -x {lock_file} sh -lc '"
f"set -e; "
"export GIT_TERMINAL_PROMPT=0; "
"export GIT_LFS_SKIP_SMUDGE=1; "
f"if [ ! -d \"{bare_repo}\" ]; then "
f" git init --bare \"{bare_repo}\"; "
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
"fi; "
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
f"git -C \"{bare_repo}\" worktree prune || true; "
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
"fi; "
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f" git -C \"{bare_repo}\" fetch --prune origin; "
"fi; "
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
"'"
)
clone_res = None
for label, cmd in clone_attempts:
t_attempt = time.perf_counter()
print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": cmd, "timeout": self.config.install_timeout_s},
)
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
res = await exec_tool(
ToolCall(
name="terminal",
arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
)
clone_res = res
if res.success:
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s",
flush=True,
)
break
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: "
f"{res.error}",
flush=True,
)
if clone_res is None or not clone_res.success:
err = clone_res.error if clone_res is not None else "unknown"
out = clone_res.output if clone_res is not None else ""
)
if not res.success:
raise RuntimeError(
"git clone/checkout failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}"
"git worktree setup failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s",
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return {"repo_dir": repo_dir, "base_commit": base_commit}
@ -380,43 +345,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
_ = trajectory_id
repo_dir = self._repo_name(item)
if self.config.verification_mode == "install":
# Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, {
"verification_mode": "install",
"error": "No tool calls were made by the agent",
}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True)
t0 = time.perf_counter()
install_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e ."
)
res = await exec_tool(
ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s})
)
ok = bool(res.success)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} "
f"in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return (1.0 if ok else 0.0), {
"verification_mode": "install",
"install_success": ok,
"error": res.error,
"verification_messages": [{"role": "user", "content": res.to_xml()}],
}
# Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"error": "No tool calls were made by the agent",
}
@ -424,7 +356,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
if not nodeids:
return 0.0, {"error": "No tests provided"}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True)
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
setup_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
@ -439,7 +371,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
if not setup_res.success:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"phase": "install",
"error": setup_res.error,
"output": setup_res.output,
@ -459,7 +391,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages.append({"role": "user", "content": res.to_xml()})
if not res.success:
return 0.0, {
"verification_mode": "pytest",
"verification_mode": "dataset_tests",
"phase": "pytest",
"failed_chunk": chunk_idx,
"error": res.error,
@ -467,7 +399,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
"verification_messages": verification_messages,
}
return 1.0, {"verification_mode": "pytest", "passed": True, "verification_messages": verification_messages}
return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
async def score_trajectory(self, item: Item, final_response: str) -> float:
# Not used; scoring happens in verify_and_score_trajectory.