diff --git a/atropos/agent/atropos_agent.py b/atropos/agent/atropos_agent.py index 0bc2637c4eb..88edc53d358 100644 --- a/atropos/agent/atropos_agent.py +++ b/atropos/agent/atropos_agent.py @@ -32,16 +32,13 @@ load_dotenv() # Default system prompt with tool calling instructions. -# -# IMPORTANT: In training-mode environments we want "raw text in -> raw text out" and we -# parse tool calls from completion text. Do not rely on server-specific `tool_calls` fields. AGENT_SYSTEM_PROMPT = """You are a deep thinking AI. You MUST enclose your internal reasoning inside ... tags. You are a function calling AI model. You are provided with function signatures within XML tags. -You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. - +You must call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. +You can ONLY respond without a tool call if you are totally certain you have the final answer to the user's question or task After calling & executing a function, you will be provided with function results within XML tags. Here are the available tools: @@ -314,6 +311,24 @@ class AtroposAgent: return model return None + def _infer_server_base_url_for_debug(self) -> Optional[str]: + """ + Best-effort inference of the configured base_url for debug logging. + + This is helpful when diagnosing hangs / retries at the transport layer. + """ + servers = getattr(self.server, "servers", None) + if isinstance(servers, list) and servers: + s0 = servers[0] + cfg = getattr(s0, "config", None) + base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None) + if isinstance(base_url, str) and base_url: + return base_url + base_url = getattr(self.server, "base_url", None) + if isinstance(base_url, str) and base_url: + return base_url + return None + def _extract_response_metadata(self, response: Any) -> Dict[str, Any]: """ Extract lightweight, JSON-serializable metadata from an OpenAI-style response. @@ -359,6 +374,8 @@ class AtroposAgent: # Avoid dumping megabytes by default; messages can be huge. meta = { "step": step_num, + "base_url": self._infer_server_base_url_for_debug(), + "model": chat_kwargs.get("model") or self._infer_server_model_for_debug(), "chat_kwargs_keys": sorted(list(chat_kwargs.keys())), "n": chat_kwargs.get("n"), "max_tokens": chat_kwargs.get("max_tokens"), @@ -421,8 +438,12 @@ class AtroposAgent: - `ATROPOS_AGENT_CHAT_TIMEOUT_S`: if set, wraps the await in `asyncio.wait_for`. - `ATROPOS_DEBUG_AGENT_WAIT_EVERY_S`: if set, prints a heartbeat while waiting. """ + # Hard guardrail: never allow a single chat completion to block for more than 2 minutes. + # This is essential for RL data-gen stability; long hangs should be treated as failures (score=0). timeout_s_raw = os.getenv("ATROPOS_AGENT_CHAT_TIMEOUT_S") - timeout_s = float(timeout_s_raw) if timeout_s_raw else None + timeout_s_default = 120.0 + timeout_s = float(timeout_s_raw) if timeout_s_raw else timeout_s_default + timeout_s = min(timeout_s, 120.0) wait_every_raw = os.getenv("ATROPOS_DEBUG_AGENT_WAIT_EVERY_S") wait_every_s = float(wait_every_raw) if wait_every_raw else None @@ -452,9 +473,15 @@ class AtroposAgent: raise try: - if timeout_s and timeout_s > 0: - return await asyncio.wait_for(_await_call(), timeout=timeout_s) - return await _await_call() + return await asyncio.wait_for(_await_call(), timeout=timeout_s) + except asyncio.TimeoutError as e: + print("\n=== ATROPOS_DEBUG_AGENT_CHAT_TIMEOUT ===", flush=True) + print({"step": step_num, "timeout_s": timeout_s}, flush=True) + raise RuntimeError(f"chat_completion timed out after {timeout_s:.1f}s") from e + except asyncio.CancelledError: + # Treat cancellation as a hard failure rather than crashing the whole env run. + # (Atropos/BaseEnv may cancel tasks during shutdown or retries.) + raise RuntimeError("chat_completion cancelled") from None except Exception as e: detail: Dict[str, Any] = { "step": step_num, diff --git a/atropos/envs/swe_smith_oracle_env.py b/atropos/envs/swe_smith_oracle_env.py index 587bd440d10..fee0c9ab7bb 100644 --- a/atropos/envs/swe_smith_oracle_env.py +++ b/atropos/envs/swe_smith_oracle_env.py @@ -17,7 +17,7 @@ from __future__ import annotations import os import random import time -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from pydantic import Field @@ -41,13 +41,14 @@ class SweSmithOracleEnvConfig(AgentEnvConfig): description="If true, score tests on PASS_TO_PASS ∪ FAIL_TO_PASS; else PASS_TO_PASS only.", ) + prompt_mode: str = Field( + default="problem_statement", + description="Task prompt content: 'problem_statement' (fast) or 'problem_statement+text' (slower, includes dataset 'text').", + ) + repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning") install_timeout_s: float = Field(default=600.0) test_timeout_s: float = Field(default=600.0) - verification_mode: Literal["pytest", "install"] = Field( - default="install", - description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.", - ) tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization") @@ -78,12 +79,15 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]: # Defaults for running the env via CLI in offline `process` mode. # Override via env vars or `--env.*` flags as needed. - base_url = ( + base_url_raw = ( os.getenv("ATROPOS_SERVER_BASE_URL") or os.getenv("OPENAI_BASE_URL") or os.getenv("LLM_BASE_URL") or "http://127.0.0.1:8080" ) + base_url = base_url_raw.rstrip("/") + if not base_url.endswith("/v1"): + base_url = f"{base_url}/v1" model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" @@ -108,7 +112,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): server_configs = [ APIServerConfig( model_name=model, - base_url=f"{base_url.rstrip('/')}/v1", + base_url=base_url, api_key=api_key, num_max_requests_at_once=1, num_requests_for_eval=1, @@ -184,69 +188,43 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): problem = str(item.get("problem_statement") or "") context = str(item.get("text") or "") - # The dataset "text" field can be extremely large (e.g. includes large code blobs - # and long test lists). In local dev and bring-up runs this can make the first LLM - # call appear "hung" while the model chews through a massive prompt. Keep a cap. - - # TODO: Remove, smoke test only - def _cap(s: str, n: int) -> tuple[str, bool]: - if len(s) <= n: - return s, False - return s[:n], True - - problem, problem_trunc = _cap(problem, 8_000) - context, context_trunc = _cap(context, 12_000) - nodeids = self._tests_for_item(item) - tests_preview = "\n".join(f"- {t}" for t in nodeids[:50]) - if len(nodeids) > 50: - tests_preview += f"\n- ... ({len(nodeids) - 50} more)" + tests_list = "\n".join(f"- {t}" for t in nodeids) repo_dir = self._repo_name(item) - verify_note = "" - # TODO: Remove, smoke testing only - if self.config.verification_mode == "install": - verify_note = ( - "\nVerification for this run is INSTALL-ONLY:\n" - "- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n" - "- You may skip running pytest to save time.\n" - ) - - trunc_note = "" - if problem_trunc or context_trunc: - trunc_note = ( - "\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n" - f"- problem_statement_truncated={problem_trunc}\n" - f"- text_truncated={context_trunc}\n" - ) tests_block = ( "Run these tests to verify:\n" - f"{tests_preview}\n\n" + f"{tests_list}\n\n" "When done, briefly describe what you changed and confirm tests pass." ) - if self.config.verification_mode == "install": - # Keep install-only prompts short and avoid huge test lists. - tests_block = ( - "When done, briefly describe what you changed and confirm that " - "`python -m pip install -e .` succeeds." + + prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower() + if prompt_mode not in {"problem_statement", "problem_statement+text"}: + raise ValueError( + f"Invalid prompt_mode={self.config.prompt_mode!r}. " + "Expected 'problem_statement' or 'problem_statement+text'." ) + context_block = "" + if prompt_mode == "problem_statement+text" and context: + # Note: We intentionally do NOT truncate/cap here. This mode is for debugging / richer prompts and can be slow. + context_block = f"\nAdditional context:\n{context}\n" + return ( "You are a senior software engineer. Fix the repository so the specified tests pass.\n\n" f"Repository: {repo} (checked out at base_commit={base_commit})\n" f"Workspace path: ./{repo_dir}\n\n" "Constraints:\n" + "- You MUST use the terminal tool to inspect, edit, and verify the repository. Do not respond with a patch file.\n" + f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n" "- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n" "- Use non-interactive commands only.\n\n" "- Terminal commands run under POSIX /bin/sh and each tool call runs in a fresh shell (no persisted env vars).\n" " Avoid bash-only `source`; prefer `. .venv/bin/activate` or `.venv/bin/python ...`.\n\n" - f"{verify_note}\n" - f"{trunc_note}\n" "Problem statement:\n" f"{problem}\n\n" - "Additional context:\n" - f"{context}\n\n" + f"{context_block}\n" f"{tests_block}" ) @@ -275,74 +253,61 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): flush=True, ) - # Prefer a lightweight "fetch by sha" to avoid pulling full history. - # If it fails (some servers disallow fetching unadvertised objects, or we hit - # shallow-object edge cases), fall back to a full clone + # Repo setup strategy: + # - Maintain a shared, per-container bare repo cache under /data/repo_cache + # - For each trajectory, create an isolated git worktree under the slot workspace + # This avoids cloning/fetching full repos per trajectory and is crucial for multiplexing. - # TODO: tbh, should just do this before setting up worktree & after sandbox build - clone_attempts: list[tuple[str, str]] = [] - clone_attempts.append( - ( - "shallow_fetch_sha", - ( - f"rm -rf {repo_dir} && " - f"git init {repo_dir} && " - f"cd {repo_dir} && " - "export GIT_TERMINAL_PROMPT=0 && " - "export GIT_LFS_SKIP_SMUDGE=1 && " - "git config advice.detachedHead false && " - f"git remote add origin {clone_url} && " - f"git fetch --depth 1 origin {base_commit} && " - "git checkout -q FETCH_HEAD" - ), - ) - ) - clone_attempts.append( - ( - "full_clone_checkout", - ( - f"rm -rf {repo_dir} && " - f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && " - f"cd {repo_dir} && " - "git config advice.detachedHead false && " - f"git checkout -q {base_commit}" - ), - ) + def _repo_cache_slug(repo_name: str) -> str: + return repo_name.replace("/", "__") + + repo_slug = _repo_cache_slug(repo) + cache_root = "/data/repo_cache" + bare_repo = f"{cache_root}/{repo_slug}.git" + lock_file = f"{cache_root}/.locks/{repo_slug}.lock" + + # Use flock to serialize operations that mutate the shared bare repo (fetch/worktree). + # util-linux (flock) is included in the sandbox image. + worktree_cmd = ( + "set -e; " + f"rm -rf {repo_dir}; " + f"mkdir -p {cache_root}/.locks; " + f": > {lock_file}; " + f"flock -x {lock_file} sh -lc '" + f"set -e; " + "export GIT_TERMINAL_PROMPT=0; " + "export GIT_LFS_SKIP_SMUDGE=1; " + f"if [ ! -d \"{bare_repo}\" ]; then " + f" git init --bare \"{bare_repo}\"; " + f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; " + "fi; " + f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; " + f"git -C \"{bare_repo}\" worktree prune || true; " + f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then " + f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; " + "fi; " + f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then " + f" git -C \"{bare_repo}\" fetch --prune origin; " + "fi; " + f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; " + "'" ) - clone_res = None - for label, cmd in clone_attempts: - t_attempt = time.perf_counter() - print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True) - res = await exec_tool( - ToolCall( - name="terminal", - arguments={"command": cmd, "timeout": self.config.install_timeout_s}, - ) + print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True) + res = await exec_tool( + ToolCall( + name="terminal", + arguments={"command": worktree_cmd, "timeout": self.config.install_timeout_s}, ) - clone_res = res - if res.success: - print( - f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s", - flush=True, - ) - break - print( - f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: " - f"{res.error}", - flush=True, - ) - - if clone_res is None or not clone_res.success: - err = clone_res.error if clone_res is not None else "unknown" - out = clone_res.output if clone_res is not None else "" + ) + if not res.success: raise RuntimeError( - "git clone/checkout failed " - f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}" + "git worktree setup failed " + f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}" ) print( - f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s", + f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): worktree ready in {time.perf_counter() - t0:.2f}s", flush=True, ) return {"repo_dir": repo_dir, "base_commit": base_commit} @@ -380,43 +345,10 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): _ = trajectory_id repo_dir = self._repo_name(item) - if self.config.verification_mode == "install": - # Training correctness: do not reward trajectories that never actually used tools. - if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0: - return 0.0, { - "verification_mode": "install", - "error": "No tool calls were made by the agent", - } - - print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True) - t0 = time.perf_counter() - install_cmd = ( - f"cd {repo_dir} && " - "python -m venv .venv && " - ". .venv/bin/activate && " - "python -m pip install -U pip setuptools wheel && " - "python -m pip install -e ." - ) - res = await exec_tool( - ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s}) - ) - ok = bool(res.success) - print( - f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} " - f"in {time.perf_counter() - t0:.2f}s", - flush=True, - ) - return (1.0 if ok else 0.0), { - "verification_mode": "install", - "install_success": ok, - "error": res.error, - "verification_messages": [{"role": "user", "content": res.to_xml()}], - } - # Training correctness: do not reward trajectories that never actually used tools. if agent_result is not None and getattr(agent_result, "total_tool_calls", 0) <= 0: return 0.0, { - "verification_mode": "pytest", + "verification_mode": "dataset_tests", "error": "No tool calls were made by the agent", } @@ -424,7 +356,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): if not nodeids: return 0.0, {"error": "No tests provided"} - print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True) + print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (dataset_tests): ensuring venv + deps", flush=True) setup_cmd = ( f"cd {repo_dir} && " "python -m venv .venv && " @@ -439,7 +371,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): verification_messages = [{"role": "user", "content": setup_res.to_xml()}] if not setup_res.success: return 0.0, { - "verification_mode": "pytest", + "verification_mode": "dataset_tests", "phase": "install", "error": setup_res.error, "output": setup_res.output, @@ -459,7 +391,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): verification_messages.append({"role": "user", "content": res.to_xml()}) if not res.success: return 0.0, { - "verification_mode": "pytest", + "verification_mode": "dataset_tests", "phase": "pytest", "failed_chunk": chunk_idx, "error": res.error, @@ -467,7 +399,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): "verification_messages": verification_messages, } - return 1.0, {"verification_mode": "pytest", "passed": True, "verification_messages": verification_messages} + return 1.0, {"verification_mode": "dataset_tests", "passed": True, "verification_messages": verification_messages} async def score_trajectory(self, item: Item, final_response: str) -> float: # Not used; scoring happens in verify_and_score_trajectory.