""" SWE-smith-oracle environment (ported to HermesAgentBaseEnv). Trains models to fix real GitHub repositories: - Clones a public GitHub repo at a specific commit - Runs an agent loop with terminal tool to apply a fix - Verifies by running pytest with nodeids from the dataset - Reward: 1.0 if all tests pass, 0.0 otherwise Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set). Usage: # Process mode (OpenAI server, no training): python environments/swe_smith_oracle_env.py process \\ --env.data_path_to_save_groups data/swe_oracle_output.jsonl # With Modal sandbox backend: python environments/swe_smith_oracle_env.py process \\ --env.tool_pool_mode modal \\ --env.modal_image python:3.11 """ from __future__ import annotations import logging import os import random import sys import time from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union _repo_root = Path(__file__).resolve().parent.parent if str(_repo_root) not in sys.path: sys.path.insert(0, str(_repo_root)) from pydantic import Field from atroposlib.envs.base import ScoredDataGroup from atroposlib.envs.server_handling.server_manager import APIServerConfig from atroposlib.type_definitions import Item from environments.agent_loop import AgentResult from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig from environments.tool_context import ToolContext logger = logging.getLogger(__name__) # ============================================================================= # Config # ============================================================================= class SweSmithOracleEnvConfig(HermesAgentEnvConfig): """Config for SWE-smith-oracle environment.""" dataset_name: str = Field(default="NousResearch/SWE-smith-oracle") dataset_split: str = Field(default="train") max_items: int = Field(default=0, description="0 = no limit") shuffle: bool = Field(default=True) seed: int = Field(default=0) python_only: bool = Field(default=True, description="Filter to Python-evaluable rows") score_include_fail_to_pass: bool = Field( default=True, description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. " "Disable to only run PASS_TO_PASS (faster but weaker signal).", ) prompt_mode: str = Field( default="problem_statement", description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').", ) repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning") install_timeout_s: float = Field(default=600.0) test_timeout_s: float = Field(default=600.0) # ============================================================================= # Environment # ============================================================================= class SweSmithOracleEnv(HermesAgentBaseEnv): """ SWE-smith-oracle environment for training models to fix real GitHub repos. Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv. The model gets terminal access to inspect, edit, and test the repository. """ name = "swe-smith-oracle" env_config_cls = SweSmithOracleEnvConfig def __init__( self, config: SweSmithOracleEnvConfig, server_configs, slurm=False, testing=False, ): super().__init__(config, server_configs, slurm, testing) self._dataset = None self._indices: List[int] = [] self._cursor = 0 @classmethod def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]: """Default config — reads from ATROPOS_SERVER_* env vars.""" base_url = ( os.getenv("ATROPOS_SERVER_BASE_URL") or os.getenv("OPENAI_BASE_URL") or os.getenv("LLM_BASE_URL") or "http://127.0.0.1:8080" ) if not base_url.rstrip("/").endswith("/v1"): base_url = base_url.rstrip("/") + "/v1" model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B" api_key = ( os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" ) env_config = SweSmithOracleEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", group_size=1, use_wandb=False, rollout_server_url="http://localhost:8000", total_steps=1, batch_size=1, steps_per_eval=1, max_token_length=8192, wandb_name="swe_smith_oracle", # Terminal tool for the agent enabled_toolsets=["terminal"], terminal_backend=os.getenv("TERMINAL_ENV", "local"), # Longer agent turns for SWE tasks max_agent_turns=50, agent_temperature=0.7, system_prompt=( "You are a senior software engineer. You have access to a terminal " "to inspect and fix repositories. Use non-interactive commands only. " "Each terminal command runs in a fresh shell." ), tool_call_parser="hermes", # Sandbox settings (used when tool_pool_mode != "default") sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local", purge_job_on_start=True, purge_job_on_shutdown=True, ) server_configs = [ APIServerConfig( model_name=model, base_url=base_url, api_key=api_key, server_type="openai", health_check=False, timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"), ), ] return env_config, server_configs # ========================================================================= # Dataset loading # ========================================================================= async def setup(self): """Load SWE-smith-oracle dataset.""" from datasets import load_dataset t0 = time.perf_counter() print( f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} " f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})", flush=True, ) ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split) self._dataset = ds indices: List[int] = [] for idx in range(len(ds)): row = ds[idx] if self.config.python_only and not self._is_python_row(row): continue indices.append(idx) if self.config.shuffle: rnd = random.Random(self.config.seed) rnd.shuffle(indices) if self.config.max_items and self.config.max_items > 0: indices = indices[: self.config.max_items] self._indices = indices self._cursor = 0 print( f"[SweSmithOracleEnv] loaded {len(self._indices)} items " f"in {time.perf_counter() - t0:.2f}s", flush=True, ) def _is_python_row(self, row: Dict[str, Any]) -> bool: nodeids = row.get("PASS_TO_PASS") if not isinstance(nodeids, list) or not nodeids: return False return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids) async def get_next_item(self) -> Item: if not self._dataset or not self._indices: raise RuntimeError("Dataset not initialized") if self._cursor >= len(self._indices): self._cursor = 0 idx = self._indices[self._cursor] self._cursor += 1 return dict(self._dataset[idx]) # ========================================================================= # Prompt formatting # ========================================================================= def _repo_name(self, item: Item) -> str: repo = item.get("repo") or "" if isinstance(repo, str) and "/" in repo: return repo.split("/")[-1] return "repo" def format_prompt(self, item: Item) -> str: """Build the SWE task prompt.""" repo = item.get("repo") or "" base_commit = item.get("base_commit") or "" problem = str(item.get("problem_statement") or "") context = str(item.get("text") or "") repo_dir = self._repo_name(item) nodeids = self._tests_for_item(item) tests_list = "\n".join(f"- {t}" for t in nodeids) context_block = "" prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower() if prompt_mode == "problem_statement+text" and context: context_block = f"\nAdditional context:\n{context}\n" return ( f"Fix the repository so the specified tests pass.\n\n" f"Repository: {repo} (checked out at base_commit={base_commit})\n" f"Workspace path: ./{repo_dir}\n\n" "Constraints:\n" "- Use the terminal tool to inspect, edit, and verify the repository.\n" f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n" "- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n" "- Use non-interactive commands only.\n" "- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n" f"Problem statement:\n{problem}\n\n" f"{context_block}" f"Run these tests to verify:\n{tests_list}\n\n" "When done, briefly describe what you changed and confirm tests pass." ) # ========================================================================= # Test helpers # ========================================================================= def _tests_for_item(self, item: Item) -> List[str]: tests: List[str] = [] if self.config.score_include_fail_to_pass: for key in ("PASS_TO_PASS", "FAIL_TO_PASS"): nodeids = item.get(key) if isinstance(nodeids, list): tests.extend([n for n in nodeids if isinstance(n, str)]) else: nodeids = item.get("PASS_TO_PASS") if isinstance(nodeids, list): tests.extend([n for n in nodeids if isinstance(n, str)]) return sorted(dict.fromkeys(tests)) def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]: return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)] # ========================================================================= # Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory # ========================================================================= async def setup_trajectory_workspace( self, item: Item, *, trajectory_id: str, exec_tool ) -> Dict[str, Any]: """ Prepare a sandbox workspace: bare repo cache + git worktree. Uses flock-serialized bare repo cache under /data/repo_cache so multiple trajectories sharing a sandbox don't clone the same repo in parallel. Each trajectory gets an isolated worktree at the specified base_commit. Args: item: Dataset row with repo, base_commit, etc. trajectory_id: Unique trajectory ID exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult Returns: Dict with repo_dir, base_commit metadata """ import time as _time t0 = _time.perf_counter() repo = item.get("repo") base_commit = item.get("base_commit") instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id") if not isinstance(repo, str) or not isinstance(base_commit, str): raise RuntimeError("Invalid dataset row: missing repo/base_commit") repo_dir = self._repo_name(item) clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git" print( f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): " f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}", flush=True, ) # Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py) repo_slug = repo.replace("/", "__") cache_root = "/data/repo_cache" bare_repo = f"{cache_root}/{repo_slug}.git" lock_file = f"{cache_root}/.locks/{repo_slug}.lock" worktree_cmd = ( "set -e; " f"rm -rf {repo_dir}; " f"mkdir -p {cache_root}/.locks; " f": > {lock_file}; " f"flock -x {lock_file} sh -lc '" f"set -e; " "export GIT_TERMINAL_PROMPT=0; " "export GIT_LFS_SKIP_SMUDGE=1; " f"if [ ! -d \"{bare_repo}\" ]; then " f" git init --bare \"{bare_repo}\"; " f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; " "fi; " f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; " f"git -C \"{bare_repo}\" worktree prune || true; " f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then " f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; " "fi; " f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then " f" git -C \"{bare_repo}\" fetch --prune origin; " "fi; " f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; " "'" ) print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True) res = await exec_tool( "bash", {"command": worktree_cmd}, timeout=self.config.install_timeout_s, ) if not res.success: raise RuntimeError( f"git worktree setup failed " f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): " f"{res.error}\n{res.output}" ) print( f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): " f"worktree ready in {_time.perf_counter() - t0:.2f}s", flush=True, ) return {"repo_dir": repo_dir, "base_commit": base_commit} async def verify_and_score_trajectory( self, item: Item, result: AgentResult, *, trajectory_id: str, exec_tool, workspace_meta: Optional[Dict[str, Any]] = None, ) -> Tuple[float, Dict[str, Any]]: """ In-sandbox verification: install deps + run pytest with dataset nodeids. Args: item: Dataset row result: Agent's rollout result trajectory_id: Unique trajectory ID exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult workspace_meta: From setup_trajectory_workspace (has repo_dir) Returns: (reward, metadata) tuple """ repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item) # Don't reward trajectories that never used tools tool_call_count = sum( len(msg.get("tool_calls", [])) for msg in result.messages if msg.get("role") == "assistant" ) if tool_call_count == 0: print( f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0", flush=True, ) return 0.0, {"error": "No tool calls were made by the agent"} nodeids = self._tests_for_item(item) if not nodeids: return 0.0, {"error": "No tests provided"} # Install dependencies print( f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests", flush=True, ) setup_cmd = ( f"cd {repo_dir} && " "python -m venv .venv && " ". .venv/bin/activate && " "python -m pip install -U pip setuptools wheel && " "python -m pip install -e . && " "python -m pip install pytest" ) setup_res = await exec_tool( "bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s ) if not setup_res.success: print( f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0", flush=True, ) return 0.0, { "phase": "install", "error": setup_res.error, "output": setup_res.output, } # Run test chunks chunks = self._chunk_nodeids(nodeids, max_per_chunk=50) for chunk_idx, chunk in enumerate(chunks): joined = " ".join(chunk) cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}" res = await exec_tool( "bash", {"command": cmd}, timeout=self.config.test_timeout_s ) if not res.success: print( f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0", flush=True, ) return 0.0, { "phase": "pytest", "failed_chunk": chunk_idx, "error": res.error, "output": res.output, } print( f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0", flush=True, ) return 1.0, {"passed": True} # ========================================================================= # Reward: run pytest in the terminal (local / non-sandbox path) # ========================================================================= async def compute_reward( self, item: Item, result: AgentResult, ctx: ToolContext ) -> float: """ Verify by running pytest with the dataset's nodeids. Uses ToolContext.terminal() to run commands in the same terminal session as the agent (same task_id = same sandbox). This is used for the local (non-sandbox) path. When tool_pool_mode is set to 'modal' or 'nomad', verify_and_score_trajectory() is used instead (runs in the sandbox slot). """ repo_dir = self._repo_name(item) # Don't reward trajectories that never used tools tool_call_count = sum( len(msg.get("tool_calls", [])) for msg in result.messages if msg.get("role") == "assistant" ) if tool_call_count == 0: print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True) return 0.0 nodeids = self._tests_for_item(item) if not nodeids: return 0.0 # Install deps + run tests print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True) setup_result = ctx.terminal( f"cd {repo_dir} && " "python -m venv .venv && " ". .venv/bin/activate && " "python -m pip install -U pip setuptools wheel && " "python -m pip install -e . && " "python -m pip install pytest", timeout=int(self.config.install_timeout_s), ) if setup_result.get("exit_code", 1) != 0: print(f"[SweSmithOracleEnv] Install failed; score=0.0", flush=True) return 0.0 # Run test chunks chunks = self._chunk_nodeids(nodeids, max_per_chunk=50) for chunk_idx, chunk in enumerate(chunks): joined = " ".join(chunk) test_result = ctx.terminal( f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}", timeout=int(self.config.test_timeout_s), ) if test_result.get("exit_code", 1) != 0: print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score=0.0", flush=True) return 0.0 print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True) return 1.0 # ========================================================================= # Evaluation (minimal for now) # ========================================================================= async def evaluate(self, *args, **kwargs): """Placeholder evaluation — SWE tasks are too expensive for frequent eval.""" start_time = time.time() await self.evaluate_log( metrics={"eval/placeholder": 0.0}, samples=[], start_time=start_time, end_time=time.time(), ) if __name__ == "__main__": SweSmithOracleEnv.cli()