adjusted prompt again to make things more reliable, having api issues

This commit is contained in:
Shannon Sands 2026-02-05 14:42:10 +10:00
parent 87464821d8
commit 487487406d
2 changed files with 115 additions and 156 deletions

View file

@ -32,16 +32,13 @@ load_dotenv()
# Default system prompt with tool calling instructions. # Default system prompt with tool calling instructions.
#
# IMPORTANT: In training-mode environments we want "raw text in -> raw text out" and we
# parse tool calls from completion text. Do not rely on server-specific `tool_calls` fields.
AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags. AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside <think>...</think> tags.
You are a function calling AI model. You are a function calling AI model.
You are provided with function signatures within <tools></tools> XML tags. You are provided with function signatures within <tools></tools> XML tags.
You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions.
You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task
After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags. After calling & executing a function, you will be provided with function results within <tool_response></tool_response> XML tags.
Here are the available tools: Here are the available tools:
@ -314,6 +311,24 @@ class AtroposAgent:
return model return model
return None return None
def _infer_server_base_url_for_debug(self) -> Optional[str]:
"""
Best-effort inference of the configured base_url for debug logging.
This is helpful when diagnosing hangs / retries at the transport layer.
"""
servers = getattr(self.server, "servers", None)
if isinstance(servers, list) and servers:
s0 = servers[0]
cfg = getattr(s0, "config", None)
base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None)
if isinstance(base_url, str) and base_url:
return base_url
base_url = getattr(self.server, "base_url", None)
if isinstance(base_url, str) and base_url:
return base_url
return None
def _extract_response_metadata(self, response: Any) -> Dict[str, Any]: def _extract_response_metadata(self, response: Any) -> Dict[str, Any]:
""" """
Extract lightweight, JSON-serializable metadata from an OpenAI-style response. Extract lightweight, JSON-serializable metadata from an OpenAI-style response.
@ -359,6 +374,8 @@ class AtroposAgent:
# Avoid dumping megabytes by default; messages can be huge. # Avoid dumping megabytes by default; messages can be huge.
meta = { meta = {
"step": step_num, "step": step_num,
"base_url": self._infer_server_base_url_for_debug(),
"model": chat_kwargs.get("model") or self._infer_server_model_for_debug(),
"chat_kwargs_keys": sorted(list(chat_kwargs.keys())), "chat_kwargs_keys": sorted(list(chat_kwargs.keys())),
"n": chat_kwargs.get("n"), "n": chat_kwargs.get("n"),
"max_tokens": chat_kwargs.get("max_tokens"), "max_tokens": chat_kwargs.get("max_tokens"),
@ -421,8 +438,12 @@ class AtroposAgent:
- `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`. - `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`.
- `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting. - `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting.
""" """
# Hard guardrail: never allow a single chat completion to block for more than 2 minutes.
# This is essential for RL data-gen stability; long hangs should be treated as failures (score=0).
timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S") timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S")
timeout_s = float(timeout_s_raw) if timeout_s_raw else None timeout_s_default = 120.0
timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default
timeout_s = min(timeout_s, 120.0)
wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S") wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S")
wait_every_s = float(wait_every_raw) if wait_every_raw else None wait_every_s = float(wait_every_raw) if wait_every_raw else None
@ -452,9 +473,15 @@ class AtroposAgent:
raise raise
try: try:
if timeout_s and timeout_s > 0:
return await asyncio.wait_for(_await_call(), timeout=timeout_s) return await asyncio.wait_for(_await_call(), timeout=timeout_s)
return await _await_call() except asyncio.TimeoutError as e:
print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True)
print({"step": step_num, "timeout_s": timeout_s}, flush=True)
raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e
except asyncio.CancelledError:
# Treat cancellation as a hard failure rather than crashing the whole env run.
# (Atropos/BaseEnv may cancel tasks during shutdown or retries.)
raise RuntimeError("chat_completion cancelled") from None
except Exception as e: except Exception as e:
detail: Dict[str, Any] = { detail: Dict[str, Any] = {
"step": step_num, "step": step_num,

View file

@ -17,7 +17,7 @@ from __future__ import annotations
import os import os
import random import random
import time import time
from typing import Any, Dict, List, Literal, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from pydantic import Field from pydantic import Field
@ -41,13 +41,14 @@ class SweSmithOracleEnvConfig(AgentEnvConfig):
description="If true, score tests on PASS_TO_PASS FAIL_TO_PASS; else PASS_TO_PASS only.", description="If true, score tests on PASS_TO_PASS FAIL_TO_PASS; else PASS_TO_PASS only.",
) )
prompt_mode: str = Field(
default="problem_statement",
description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').",
)
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning") repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
install_timeout_s: float = Field(default=600.0) install_timeout_s: float = Field(default=600.0)
test_timeout_s: float = Field(default=600.0) test_timeout_s: float = Field(default=600.0)
verification_mode: Literal["pytest", "install"] = Field(
default="install",
description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.",
)
tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization") tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization")
@ -78,12 +79,15 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]: def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
# Defaults for running the env via CLI in offline `process` mode. # Defaults for running the env via CLI in offline `process` mode.
# Override via env vars or `--env.*` flags as needed. # Override via env vars or `--env.*` flags as needed.
base_url = ( base_url_raw = (
os.getenv("ATROPOS_SERVER_BASE_URL") os.getenv("ATROPOS_SERVER_BASE_URL")
or os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_BASE_URL")
or os.getenv("LLM_BASE_URL") or os.getenv("LLM_BASE_URL")
or "http://127.0.0.1:8080" or "http://127.0.0.1:8080"
) )
base_url = base_url_raw.rstrip("/")
if not base_url.endswith("/v1"):
base_url = f"{base_url}/v1"
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b"
api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local"
@ -108,7 +112,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
server_configs = [ server_configs = [
APIServerConfig( APIServerConfig(
model_name=model, model_name=model,
base_url=f"{base_url.rstrip('/')}/v1", base_url=base_url,
api_key=api_key, api_key=api_key,
num_max_requests_at_once=1, num_max_requests_at_once=1,
num_requests_for_eval=1, num_requests_for_eval=1,
@ -184,69 +188,43 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
problem = str(item.get("problem_statement") or "") problem = str(item.get("problem_statement") or "")
context = str(item.get("text") or "") context = str(item.get("text") or "")
# The dataset "text" field can be extremely large (e.g. includes large code blobs
# and long test lists). In local dev and bring-up runs this can make the first LLM
# call appear "hung" while the model chews through a massive prompt. Keep a cap.
# TODO: Remove, smoke test only
def _cap(s: str, n: int) -> tuple[str, bool]:
if len(s) <= n:
return s, False
return s[:n], True
problem, problem_trunc = _cap(problem, 8_000)
context, context_trunc = _cap(context, 12_000)
nodeids = self._tests_for_item(item) nodeids = self._tests_for_item(item)
tests_preview = "\n".join(f"- {t}" for t in nodeids[:50]) tests_list = "\n".join(f"- {t}" for t in nodeids)
if len(nodeids) > 50:
tests_preview += f"\n- ... ({len(nodeids) - 50} more)"
repo_dir = self._repo_name(item) repo_dir = self._repo_name(item)
verify_note = ""
# TODO: Remove, smoke testing only
if self.config.verification_mode == "install":
verify_note = (
"\nVerification for this run is INSTALL-ONLY:\n"
"- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n"
"- You may skip running pytest to save time.\n"
)
trunc_note = ""
if problem_trunc or context_trunc:
trunc_note = (
"\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n"
f"- problem_statement_truncated={problem_trunc}\n"
f"- text_truncated={context_trunc}\n"
)
tests_block = ( tests_block = (
"Run these tests to verify:\n" "Run these tests to verify:\n"
f"{tests_preview}\n\n" f"{tests_list}\n\n"
"When done, briefly describe what you changed and confirm tests pass." "When done, briefly describe what you changed and confirm tests pass."
) )
if self.config.verification_mode == "install":
# Keep install-only prompts short and avoid huge test lists. prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
tests_block = ( if prompt_mode not in {"problem_statement", "problem_statement+text"}:
"When done, briefly describe what you changed and confirm that " raise ValueError(
"`python -m pip install -e .` succeeds." f"Invalid prompt_mode={self.config.prompt_mode!r}. "
"Expected 'problem_statement' or 'problem_statement+text'."
) )
context_block = ""
if prompt_mode == "problem_statement+text" and context:
# Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow.
context_block = f"\nAdditional context:\n{context}\n"
return ( return (
"You are a senior software engineer. Fix the repository so the specified tests pass.\n\n" "You are a senior software engineer. Fix the repository so the specified tests pass.\n\n"
f"Repository: {repo} (checked out at base_commit={base_commit})\n" f"Repository: {repo} (checked out at base_commit={base_commit})\n"
f"Workspace path: ./{repo_dir}\n\n" f"Workspace path: ./{repo_dir}\n\n"
"Constraints:\n" "Constraints:\n"
"- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n"
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
"- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n" "- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n"
"- Use non-interactive commands only.\n\n" "- Use non-interactive commands only.\n\n"
"- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n" "- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n"
" Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n" " Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n"
f"{verify_note}\n"
f"{trunc_note}\n"
"Problem statement:\n" "Problem statement:\n"
f"{problem}\n\n" f"{problem}\n\n"
"Additional context:\n" f"{context_block}\n"
f"{context}\n\n"
f"{tests_block}" f"{tests_block}"
) )
@ -275,74 +253,61 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
flush=True, flush=True,
) )
# Prefer a lightweight "fetch by sha" to avoid pulling full history. # Repo setup strategy:
# If it fails (some servers disallow fetching unadvertised objects, or we hit # - Maintain a shared, per-container bare repo cache under /data/repo_cache
# shallow-object edge cases), fall back to a full clone # - For each trajectory, create an isolated git worktree under the slot workspace
# This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing.
# TODO: tbh, should just do this before setting up worktree & after sandbox build def _repo_cache_slug(repo_name: str) -> str:
clone_attempts: list[tuple[str, str]] = [] return repo_name.replace("/", "__")
clone_attempts.append(
( repo_slug = _repo_cache_slug(repo)
"shallow_fetch_sha", cache_root = "/data/repo_cache"
( bare_repo = f"{cache_root}/{repo_slug}.git"
f"rm -rf {repo_dir} && " lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
f"git init {repo_dir} && "
f"cd {repo_dir} && " # Use flock to serialize operations that mutate the shared bare repo (fetch/worktree).
"export GIT_TERMINAL_PROMPT=0 && " # util-linux (flock) is included in the sandbox image.
"export GIT_LFS_SKIP_SMUDGE=1 && " worktree_cmd = (
"git config advice.detachedHead false && " "set -e; "
f"git remote add origin {clone_url} && " f"rm -rf {repo_dir}; "
f"git fetch --depth 1 origin {base_commit} && " f"mkdir -p {cache_root}/.locks; "
"git checkout -q FETCH_HEAD" f": > {lock_file}; "
), f"flock -x {lock_file} sh -lc '"
) f"set -e; "
) "export GIT_TERMINAL_PROMPT=0; "
clone_attempts.append( "export GIT_LFS_SKIP_SMUDGE=1; "
( f"if [ ! -d \"{bare_repo}\" ]; then "
"full_clone_checkout", f" git init --bare \"{bare_repo}\"; "
( f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
f"rm -rf {repo_dir} && " "fi; "
f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && " f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
f"cd {repo_dir} && " f"git -C \"{bare_repo}\" worktree prune || true; "
"git config advice.detachedHead false && " f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f"git checkout -q {base_commit}" f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
), "fi; "
) f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
f" git -C \"{bare_repo}\" fetch --prune origin; "
"fi; "
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
"'"
) )
clone_res = None print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
for label, cmd in clone_attempts:
t_attempt = time.perf_counter()
print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True)
res = await exec_tool( res = await exec_tool(
ToolCall( ToolCall(
name="terminal", name="terminal",
arguments={"command": cmd, "timeout": self.config.install_timeout_s}, arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s},
) )
) )
clone_res = res if not res.success:
if res.success:
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s",
flush=True,
)
break
print(
f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: "
f"{res.error}",
flush=True,
)
if clone_res is None or not clone_res.success:
err = clone_res.error if clone_res is not None else "unknown"
out = clone_res.output if clone_res is not None else ""
raise RuntimeError( raise RuntimeError(
"git clone/checkout failed " "git worktree setup failed "
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}" f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}"
) )
print( print(
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s", f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s",
flush=True, flush=True,
) )
return {"repo_dir": repo_dir, "base_commit": base_commit} return {"repo_dir": repo_dir, "base_commit": base_commit}
@ -380,43 +345,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
_ = trajectory_id _ = trajectory_id
repo_dir = self._repo_name(item) repo_dir = self._repo_name(item)
if self.config.verification_mode == "install":
# Training correctness: do not reward trajectories that never actually used tools. # Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0: if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, { return 0.0, {
"verification_mode": "install", "verification_mode": "dataset_tests",
"error": "No tool calls were made by the agent",
}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True)
t0 = time.perf_counter()
install_cmd = (
f"cd {repo_dir} && "
"python -m venv .venv && "
". .venv/bin/activate && "
"python -m pip install -U pip setuptools wheel && "
"python -m pip install -e ."
)
res = await exec_tool(
ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s})
)
ok = bool(res.success)
print(
f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} "
f"in {time.perf_counter() - t0:.2f}s",
flush=True,
)
return (1.0 if ok else 0.0), {
"verification_mode": "install",
"install_success": ok,
"error": res.error,
"verification_messages": [{"role": "user", "content": res.to_xml()}],
}
# Training correctness: do not reward trajectories that never actually used tools.
if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0:
return 0.0, {
"verification_mode": "pytest",
"error": "No tool calls were made by the agent", "error": "No tool calls were made by the agent",
} }
@ -424,7 +356,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
if not nodeids: if not nodeids:
return 0.0, {"error": "No tests provided"} return 0.0, {"error": "No tests provided"}
print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True) print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True)
setup_cmd = ( setup_cmd = (
f"cd {repo_dir} && " f"cd {repo_dir} && "
"python -m venv .venv && " "python -m venv .venv && "
@ -439,7 +371,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages = [{"role": "user", "content": setup_res.to_xml()}] verification_messages = [{"role": "user", "content": setup_res.to_xml()}]
if not setup_res.success: if not setup_res.success:
return 0.0, { return 0.0, {
"verification_mode": "pytest", "verification_mode": "dataset_tests",
"phase": "install", "phase": "install",
"error": setup_res.error, "error": setup_res.error,
"output": setup_res.output, "output": setup_res.output,
@ -459,7 +391,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
verification_messages.append({"role": "user", "content": res.to_xml()}) verification_messages.append({"role": "user", "content": res.to_xml()})
if not res.success: if not res.success:
return 0.0, { return 0.0, {
"verification_mode": "pytest", "verification_mode": "dataset_tests",
"phase": "pytest", "phase": "pytest",
"failed_chunk": chunk_idx, "failed_chunk": chunk_idx,
"error": res.error, "error": res.error,
@ -467,7 +399,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]):
"verification_messages": verification_messages, "verification_messages": verification_messages,
} }
return 1.0, {"verification_mode": "pytest", "passed": True, "verification_messages": verification_messages} return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages}
async def score_trajectory(self, item: Item, final_response: str) -> float: async def score_trajectory(self, item: Item, final_response: str) -> float:
# Not used; scoring happens in verify_and_score_trajectory. # Not used; scoring happens in verify_and_score_trajectory.