From 5a9c98a771b82545b91d1341b4cb026100b54c71 Mon Sep 17 00:00:00 2001 From: Shannon Sands Date: Wed, 4 Feb 2026 11:22:45 +1000 Subject: [PATCH] swe-smith-oracle runs 1 step process. llama server was just breaking again locally idk, works through Hermes endpoint & ManagedServer fine --- .env.example | 6 + atropos/agent/atropos_agent.py | 95 +++++++++ atropos/envs/agent_env.py | 77 ++++++- atropos/envs/hermes_compat_test_env.py | 3 +- atropos/envs/sandbox_terminal_smoke_env.py | 3 +- atropos/envs/swe_smith_oracle_env.py | 221 +++++++++++++++++---- atropos/envs/test_env.py | 2 +- atropos/envs/toolserver_smoke_env.py | 3 +- 8 files changed, 359 insertions(+), 51 deletions(-) diff --git a/.env.example b/.env.example index 24ae87366f..5b70dd8d35 100644 --- a/.env.example +++ b/.env.example @@ -28,6 +28,12 @@ HERMES_BACKEND=openai # ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B # ATROPOS_SERVER_API_KEY=local # +# Hosted Nous inference API: +# ATROPOS_SERVER_BASE_URL=https://inference-api.nousresearch.com +# ATROPOS_SERVER_MODEL=Hermes-4.3-36B +# ATROPOS_TOKENIZER_NAME=NousResearch/Hermes-4.3-36B +# ATROPOS_SERVER_API_KEY=sk-... (Bearer token) +# # If you plan to run GRPO-style group sampling (e.g. `--env.group_size 4`) against # llama.cpp, start the server with at least that many slots, e.g.: # LLAMA_CPP_PARALLEL=4 Hermes-Agent/scripts/launch_llama_cpp_hermes_4_36b.sh diff --git a/atropos/agent/atropos_agent.py b/atropos/agent/atropos_agent.py index 2fb1de6d68..48b9039c8a 100644 --- a/atropos/agent/atropos_agent.py +++ b/atropos/agent/atropos_agent.py @@ -16,11 +16,13 @@ The agent uses Hermes-style XML tags for tool calls: import asyncio import os import json +import time from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Union from dotenv import load_dotenv +import httpx from ..tools import ToolCall, ToolRegistry, ToolResult from atroposlib.envs.server_handling.managed_server import ManagedServer @@ -243,6 +245,9 @@ class AtroposAgent: - If `self.server` is a ServerManager, use its `managed_server()` context manager. - If `self.server` is a single APIServer, wrap it in `ManagedServer` directly. """ + if os.getenv("ATROPOS_BYPASS_MANAGED_SERVER") == "1": + yield _DirectChatCompletionClient(server=self.server) + return if hasattr(self.server, "managed_server"): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: yield managed @@ -336,6 +341,7 @@ class AtroposAgent: # Use ManagedServer for automatic token tracking async with self._managed() as managed: for step_num in range(self.config.max_steps): + # ReACT loop iteration here, just call -> tools -> observe until done (no tools called) try: # Keep a copy of the prompt messages used for this completion. # Useful for reconstructing tokens/masks when state tracking is unavailable. @@ -346,9 +352,19 @@ class AtroposAgent: if self.config.temperature is not None: chat_kwargs["temperature"] = self.config.temperature + t_req = time.perf_counter() + print( + f"[AtroposAgent] step={step_num+1} chat_completion start " + f"(messages={len(messages)}, max_tokens={self.config.max_tokens}, temp={self.config.temperature})", + flush=True, + ) self._debug_dump_request(step_num=step_num + 1, chat_kwargs=chat_kwargs) response = await managed.chat_completion(**chat_kwargs) self._debug_dump_response(step_num=step_num + 1, response=response) + print( + f"[AtroposAgent] step={step_num+1} chat_completion done in {time.perf_counter() - t_req:.2f}s", + flush=True, + ) current_node = None if hasattr(managed, "get_state"): @@ -489,3 +505,82 @@ class AtroposAgent: sequence_data = SequenceData.from_sequence_node(current_node) if current_node else None return response_text, tool_results, sequence_data + + +class _DirectChatCompletionClient: + """ + Minimal stand-in for ManagedServer that calls the OpenAI-compatible endpoint directly. + + This is for isolating issues where `ManagedServer.chat_completion()` hangs or misbehaves. + It intentionally does NOT do token/logprob tracking. + """ + + def __init__(self, server: Any): + self._server = server + + def _server_config(self) -> tuple[str, str, str]: + # ServerManager case: first configured server. + servers = getattr(self._server, "servers", None) + if isinstance(servers, list) and servers: + s0 = servers[0] + cfg = getattr(s0, "config", None) + base_url = getattr(cfg, "base_url", None) or getattr(s0, "base_url", None) + api_key = getattr(cfg, "api_key", None) or getattr(s0, "api_key", None) + model = getattr(cfg, "model_name", None) or getattr(s0, "model_name", None) + if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str): + return base_url.rstrip("/"), api_key, model + + # APIServer-like fallback. + base_url = getattr(self._server, "base_url", None) + api_key = getattr(self._server, "api_key", None) + model = getattr(self._server, "model_name", None) or getattr(self._server, "model", None) + if isinstance(base_url, str) and isinstance(api_key, str) and isinstance(model, str): + return base_url.rstrip("/"), api_key, model + + raise RuntimeError("Unable to resolve server base_url/api_key/model for direct chat completion") + + async def chat_completion(self, *, messages: List[Dict[str, str]], n: int = 1, **kwargs: Any) -> Any: + base_url, api_key, model = self._server_config() + url = f"{base_url}/chat/completions" + + payload: Dict[str, Any] = { + "model": model, + "messages": messages, + "n": n, + } + # Pass through common generation kwargs. + for k in ("max_tokens", "temperature", "top_p", "presence_penalty", "frequency_penalty", "stop"): + if k in kwargs and kwargs[k] is not None: + payload[k] = kwargs[k] + + timeout_s = float(os.getenv("ATROPOS_DIRECT_REQUEST_TIMEOUT_S") or "120") + print(f"[AtroposAgent] DIRECT chat_completion POST {url} (timeout={timeout_s}s)", flush=True) + async with httpx.AsyncClient(timeout=timeout_s) as client: + resp = await client.post( + url, + headers={"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}, + json=payload, + ) + resp.raise_for_status() + data = resp.json() + + # Return a very small object compatible with the code paths that read + # `response.choices[0].message.content`. + class _Msg: + def __init__(self, d: Dict[str, Any]): + self.content = d.get("content") + self.reasoning = d.get("reasoning") + + class _Choice: + def __init__(self, d: Dict[str, Any]): + self.message = _Msg(d.get("message") or {}) + + class _Resp: + def __init__(self, d: Dict[str, Any]): + self._d = d + self.choices = [_Choice(c) for c in (d.get("choices") or [])] + + def model_dump(self) -> Dict[str, Any]: + return self._d + + return _Resp(data) diff --git a/atropos/envs/agent_env.py b/atropos/envs/agent_env.py index e9bdc9d4e5..da80f5f50d 100644 --- a/atropos/envs/agent_env.py +++ b/atropos/envs/agent_env.py @@ -3,13 +3,12 @@ AgentEnv - Atropos BaseEnv extension for agent/tool-call workloads. AgentEnv is responsible for starting the sandbox tool execution backend and providing helpers for running agent trajectories with queued/batched tool calls. - -For Phase 4 we support a Nomad-backed `SlotPool` for true container sandboxing. """ from __future__ import annotations import asyncio +import time import uuid from abc import ABC, abstractmethod from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, TypeVar @@ -17,6 +16,7 @@ from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tupl from pydantic import Field from atroposlib.envs.base import APIServerConfig, BaseEnv, BaseEnvConfig, Item, ScoredDataGroup, ScoredDataItem +from atroposlib.envs.server_handling.server_baseline import AsyncSemWithAdaptiveWeight from ..agent import AgentConfig, AgentResult, AtroposAgent from ..slots import SlotPool, SlotPoolConfig @@ -43,7 +43,7 @@ class AgentEnvConfig(BaseEnvConfig): tool_batch_window_ms: int = Field(default=20, description="ToolExecutor batching window (ms)") tool_max_batch_size: int = Field(default=200, description="ToolExecutor maximum batch size") - # nomad mode settings + # nomad mode settings. TODO: Add Modal support, split this into own config nomad_address: str = Field(default="http://localhost:4646", description="Nomad API address") sandbox_job_id: str = Field(default="atropos-sandbox-agent-env", description="Nomad job id for sandbox containers") sandbox_image: str = Field(default="atropos-sandbox:local", description="Docker image for sandbox containers") @@ -111,8 +111,12 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): self._pool: Optional[Any] = None self._tool_executor: Optional[ToolExecutor] = None self._tool_server_inprocess: bool = False + self._trajectory_workspace_meta: Dict[str, Dict[str, Any]] = {} def build_tools(self) -> ToolRegistry: + """Wraps original Hermes-Agent ToolRegistry for atropos AgentEnv use. + See Hermes-Agent docs for toolsets and available tools etc. + """ return build_tool_registry( enabled_toolsets=self.config.enabled_toolsets or ["default"], disabled_toolsets=self.config.disabled_toolsets or None, @@ -155,6 +159,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): trajectory_id: str, exec_tool: Callable[["ToolCall"], Awaitable["ToolResult"]], agent_result: Optional[AgentResult] = None, + workspace_meta: Optional[Dict[str, Any]] = None, ) -> tuple[float, Dict[str, Any]]: """ Optional hook: run in-sandbox verification before scoring. @@ -164,7 +169,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): Default: calls `score_trajectory()` and returns empty metadata. """ - _ = (trajectory_id, exec_tool, agent_result) # default ignores in-workspace verification + _ = (trajectory_id, exec_tool, agent_result, workspace_meta) # default ignores in-workspace verification score = await self.score_trajectory(item, final_response) return score, {} @@ -177,8 +182,42 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): ) async def setup(self) -> None: + print(f"[AgentEnv] setup(): starting tool backend ({self.config.tool_pool_mode})", flush=True) await self._start_tool_backend() + print("[AgentEnv] setup(): configuring server concurrency", flush=True) + self._configure_server_concurrency() + print("[AgentEnv] setup(): running env-specific setup_agent_env()", flush=True) await self.setup_agent_env() + print("[AgentEnv] setup(): done", flush=True) + + def _configure_server_concurrency(self) -> None: + """ + Ensure the LLM server concurrency isn't accidentally capped below `group_size`. + + In `BaseEnv process` mode, groups are collected concurrently and if the underlying + ServerManager/OpenAIServer semaphore is left at 1, we serialize inference even + when `--env.group_size` is > 1. + """ + desired = int(getattr(self.config, "group_size", 1) or 1) + if desired <= 1: + return + + servers = getattr(self.server, "servers", None) + if not isinstance(servers, list) or not servers: + return + + for s in servers: + sem = getattr(s, "sem", None) + eval_sem = getattr(s, "eval_sem", None) + # Only increase; never shrink. + if sem is not None and getattr(sem, "max_val", 0) < desired: + s.sem = AsyncSemWithAdaptiveWeight(desired) + if hasattr(s, "config") and hasattr(s.config, "num_max_requests_at_once"): + s.config.num_max_requests_at_once = desired + if eval_sem is not None and getattr(eval_sem, "max_val", 0) < desired: + s.eval_sem = AsyncSemWithAdaptiveWeight(desired) + if hasattr(s, "config") and hasattr(s.config, "num_requests_for_eval"): + s.config.num_requests_for_eval = desired @abstractmethod async def setup_agent_env(self) -> None: @@ -225,6 +264,7 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): self._tool_server_inprocess = True if self.config.tool_pool_mode != "nomad": + # TODO Add Modal here, maybe in-process, but not safe to have that tbh raise RuntimeError("tool_pool_mode must be 'nomad' (local/in-process pools are not supported)") pool = SlotPool( @@ -286,8 +326,11 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): raise RuntimeError("Tool backend not started") trajectory_id = str(uuid.uuid4()) + t0 = time.perf_counter() + print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} start", flush=True) task = self.build_task(item) agent_config = self.build_agent_config(item) + print(f"Starting trajectory {trajectory_id} with task: {task}") async def _exec(call): return await self._tool_executor.execute(trajectory_id, call) @@ -301,18 +344,39 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): ) try: - await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec) + print(f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() start", flush=True) + workspace_meta = await self.setup_trajectory_workspace(item, trajectory_id=trajectory_id, exec_tool=_exec) + if not isinstance(workspace_meta, dict): + workspace_meta = {} + self._trajectory_workspace_meta[trajectory_id] = workspace_meta + print( + f"[AgentEnv] tid={trajectory_id} setup_trajectory_workspace() done in {time.perf_counter() - t0:.2f}s", + flush=True, + ) + print(f"[AgentEnv] tid={trajectory_id} agent.run() start", flush=True) result = await agent.run(task) + print( + f"[AgentEnv] tid={trajectory_id} agent.run() done in {time.perf_counter() - t0:.2f}s " + f"success={result.success} tool_calls={result.total_tool_calls}", + flush=True, + ) if not result.success or result.trajectory_data is None: return None, [] + print(f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() start", flush=True) score, _score_metadata = await self.verify_and_score_trajectory( item, result.final_response, trajectory_id=trajectory_id, exec_tool=_exec, agent_result=result, + workspace_meta=workspace_meta, + ) + print( + f"[AgentEnv] tid={trajectory_id} verify_and_score_trajectory() done in {time.perf_counter() - t0:.2f}s " + f"score={score}", + flush=True, ) messages = [{"role": "system", "content": agent._build_system_prompt()}] # noqa: SLF001 @@ -333,7 +397,10 @@ class AgentEnv(BaseEnv, ABC, Generic[AgentEnvConfigT]): return scored, [] finally: + self._trajectory_workspace_meta.pop(trajectory_id, None) + print(f"[AgentEnv] tid={trajectory_id} release_trajectory(reset_workspace=True)", flush=True) await self._tool_executor.release_trajectory(trajectory_id, reset_workspace=True) + print(f"[AgentEnv] collect_trajectory(): tid={trajectory_id} done in {time.perf_counter() - t0:.2f}s", flush=True) async def collect_trajectories( self, item: Item diff --git a/atropos/envs/hermes_compat_test_env.py b/atropos/envs/hermes_compat_test_env.py index 2e2963ab9e..e38c3ddd83 100644 --- a/atropos/envs/hermes_compat_test_env.py +++ b/atropos/envs/hermes_compat_test_env.py @@ -81,7 +81,7 @@ class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]): or "http://127.0.0.1:8080" ) model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" - api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" env_config = HermesCompatTestEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", @@ -146,6 +146,7 @@ class HermesCompatTestEnv(AgentEnv[HermesCompatTestEnvConfig]): trajectory_id: str, # noqa: ARG002 exec_tool, # noqa: ARG002 agent_result: AgentResult | None = None, + workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002 ) -> tuple[float, Dict[str, Any]]: if agent_result is None: return 0.0, {"error": "Missing agent_result"} diff --git a/atropos/envs/sandbox_terminal_smoke_env.py b/atropos/envs/sandbox_terminal_smoke_env.py index 13bf53ec3c..8c267e8792 100644 --- a/atropos/envs/sandbox_terminal_smoke_env.py +++ b/atropos/envs/sandbox_terminal_smoke_env.py @@ -82,7 +82,7 @@ class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]): or "http://127.0.0.1:8080" ) model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" - api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" env_config = SandboxTerminalSmokeEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", @@ -147,6 +147,7 @@ class SandboxTerminalSmokeEnv(AgentEnv[SandboxTerminalSmokeEnvConfig]): trajectory_id: str, # noqa: ARG002 exec_tool, # noqa: ARG002 agent_result: AgentResult | None = None, + workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002 ) -> tuple[float, Dict[str, Any]]: if agent_result is None: return 0.0, {"error": "Missing agent_result"} diff --git a/atropos/envs/swe_smith_oracle_env.py b/atropos/envs/swe_smith_oracle_env.py index b1603e1eb5..d0483bb200 100644 --- a/atropos/envs/swe_smith_oracle_env.py +++ b/atropos/envs/swe_smith_oracle_env.py @@ -1,10 +1,11 @@ """ -SWE-smith-oracle benchmark environment (Phase 4.7). +SWE-smith-oracle environment. This environment is intentionally minimal: - prepares a sandbox workspace by cloning a public GitHub repo at `base_commit` - runs an AtroposAgent tool loop to apply a fix - verifies by running pytest nodeids from the dataset (reward = pass/fail) +- Python only (no multi-language support currently, need to properly bauild & add to dropbox) Dataset: NousResearch/SWE-smith-oracle (train; does NOT use SWE-bench eval set). """ @@ -13,7 +14,8 @@ from __future__ import annotations import os import random -from typing import Any, Dict, List, Optional, Tuple +import time +from typing import Any, Dict, List, Literal, Optional, Tuple from pydantic import Field @@ -40,8 +42,11 @@ class SweSmithOracleEnvConfig(AgentEnvConfig): repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning") install_timeout_s: float = Field(default=600.0) test_timeout_s: float = Field(default=600.0) + verification_mode: Literal["pytest", "install"] = Field( + default="install", + description="How to score trajectories: 'pytest' runs dataset tests, 'install' scores based on repo install success.", + ) - # Tokenization: should match the model used for training. tokenizer_name: str = Field(default="NousResearch/Hermes-4.3-36B", description="Tokenizer name for RL tokenization") @@ -78,7 +83,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): or "http://127.0.0.1:8080" ) model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" - api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" env_config = SweSmithOracleEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", @@ -114,6 +119,12 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): async def setup_agent_env(self) -> None: from datasets import load_dataset + t0 = time.perf_counter() + print( + f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} " + f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})", + flush=True, + ) ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split) self._dataset = ds @@ -135,7 +146,9 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): self._cursor = 0 print( - f"SweSmithOracleEnv loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split}" + f"[SweSmithOracleEnv] loaded {len(self._indices)} items from {self.config.dataset_name}:{self.config.dataset_split} " + f"in {time.perf_counter() - t0:.2f}s", + flush=True, ) def _is_python_row(self, row: Dict[str, Any]) -> bool: @@ -148,6 +161,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): return True async def get_next_item(self) -> Item: + print(f"[SweSmithOracleEnv] get_next_item() cursor={self._cursor}/{len(self._indices)}", flush=True) if not self._dataset or not self._indices: raise RuntimeError("Dataset not initialized (did setup() run?)") if self._cursor >= len(self._indices): @@ -165,8 +179,19 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): def build_task(self, item: Item) -> str: repo = item.get("repo") or "" base_commit = item.get("base_commit") or "" - problem = item.get("problem_statement") or "" - context = item.get("text") or "" + problem = str(item.get("problem_statement") or "") + context = str(item.get("text") or "") + + # The dataset "text" field can be extremely large (e.g. includes large code blobs + # and long test lists). In local dev and bring-up runs this can make the first LLM + # call appear "hung" while the model chews through a massive prompt. Keep a cap. + def _cap(s: str, n: int) -> tuple[str, bool]: + if len(s) <= n: + return s, False + return s[:n], True + + problem, problem_trunc = _cap(problem, 8_000) + context, context_trunc = _cap(context, 12_000) nodeids = self._tests_for_item(item) tests_preview = "\n".join(f"- {t}" for t in nodeids[:50]) @@ -174,6 +199,34 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): tests_preview += f"\n- ... ({len(nodeids) - 50} more)" repo_dir = self._repo_name(item) + verify_note = "" + if self.config.verification_mode == "install": + verify_note = ( + "\nVerification for this run is INSTALL-ONLY:\n" + "- Your goal is to make `python -m pip install -e .` succeed in a repo-local venv (./.venv).\n" + "- You may skip running pytest to save time.\n" + ) + + trunc_note = "" + if problem_trunc or context_trunc: + trunc_note = ( + "\nNOTE: Some context was truncated to keep prompts manageable in local dev.\n" + f"- problem_statement_truncated={problem_trunc}\n" + f"- text_truncated={context_trunc}\n" + ) + + tests_block = ( + "Run these tests to verify:\n" + f"{tests_preview}\n\n" + "When done, briefly describe what you changed and confirm tests pass." + ) + if self.config.verification_mode == "install": + # Keep install-only prompts short and avoid huge test lists. + tests_block = ( + "When done, briefly describe what you changed and confirm that " + "`python -m pip install -e .` succeeds." + ) + return ( "You are a senior software engineer. Fix the repository so the specified tests pass.\n\n" f"Repository: {repo} (checked out at base_commit={base_commit})\n" @@ -181,13 +234,13 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): "Constraints:\n" "- Use a workspace-local virtualenv (e.g. inside the repo at ./.venv) to avoid cross-run contamination.\n" "- Use non-interactive commands only.\n\n" + f"{verify_note}\n" + f"{trunc_note}\n" "Problem statement:\n" f"{problem}\n\n" "Additional context:\n" f"{context}\n\n" - "Run these tests to verify:\n" - f"{tests_preview}\n\n" - "When done, briefly describe what you changed and confirm tests pass." + f"{tests_block}" ) def build_agent_config(self, item: Item) -> AgentConfig: # noqa: ARG002 @@ -200,7 +253,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): ) async def setup_trajectory_workspace(self, item: Item, *, trajectory_id: str, exec_tool) -> Dict[str, Any]: - _ = trajectory_id + t0 = time.perf_counter() repo = item.get("repo") base_commit = item.get("base_commit") instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id") @@ -209,43 +262,80 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): repo_dir = self._repo_name(item) clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git" - - # Fetch only the requested base commit (much lighter than a full clone and robust - # even if the commit is not on the default branch ref we cloned). - # - # This also avoids some failure modes where `git clone` fetches only default-branch - # refs and then `git checkout ` fails because the commit isn't present locally. - clone_and_checkout_cmd = ( - f"rm -rf {repo_dir} && " - f"git init {repo_dir} && " - f"cd {repo_dir} && " - f"git remote add origin {clone_url} && " - f"git fetch --depth 1 origin {base_commit} && " - "git -c advice.detachedHead=false checkout -q FETCH_HEAD" + print( + f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): " + f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}", + flush=True, ) - res = await exec_tool( - ToolCall( - name="terminal", - arguments={"command": clone_and_checkout_cmd, "timeout": self.config.install_timeout_s}, + + # Prefer a lightweight "fetch by sha" to avoid pulling full history. + # If it fails (some servers disallow fetching unadvertised objects, or we hit + # shallow-object edge cases), fall back to a full clone. + clone_attempts: list[tuple[str, str]] = [] + clone_attempts.append( + ( + "shallow_fetch_sha", + ( + f"rm -rf {repo_dir} && " + f"git init {repo_dir} && " + f"cd {repo_dir} && " + "export GIT_TERMINAL_PROMPT=0 && " + "export GIT_LFS_SKIP_SMUDGE=1 && " + "git config advice.detachedHead false && " + f"git remote add origin {clone_url} && " + f"git fetch --depth 1 origin {base_commit} && " + "git checkout -q FETCH_HEAD" + ), ) ) - if not res.success: + clone_attempts.append( + ( + "full_clone_checkout", + ( + f"rm -rf {repo_dir} && " + f"GIT_TERMINAL_PROMPT=0 GIT_LFS_SKIP_SMUDGE=1 git clone {clone_url} {repo_dir} && " + f"cd {repo_dir} && " + "git config advice.detachedHead false && " + f"git checkout -q {base_commit}" + ), + ) + ) + + clone_res = None + for label, cmd in clone_attempts: + t_attempt = time.perf_counter() + print(f"[SweSmithOracleEnv] tid={trajectory_id} clone attempt: {label}", flush=True) + res = await exec_tool( + ToolCall( + name="terminal", + arguments={"command": cmd, "timeout": self.config.install_timeout_s}, + ) + ) + clone_res = res + if res.success: + print( + f"[SweSmithOracleEnv] tid={trajectory_id} clone ok ({label}) in {time.perf_counter() - t_attempt:.2f}s", + flush=True, + ) + break + print( + f"[SweSmithOracleEnv] tid={trajectory_id} clone failed ({label}) in {time.perf_counter() - t_attempt:.2f}s: " + f"{res.error}", + flush=True, + ) + + if clone_res is None or not clone_res.success: + err = clone_res.error if clone_res is not None else "unknown" + out = clone_res.output if clone_res is not None else "" raise RuntimeError( - "git fetch/checkout failed " - f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {res.error}\n{res.output}" + "git clone/checkout failed " + f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): {err}\n{out}" ) - # Best-effort baseline python env. - setup_cmd = ( - f"cd {repo_dir} && " - "python -m venv .venv && " - ". .venv/bin/activate && " - "python -m pip install -U pip setuptools wheel && " - "python -m pip install -e . && " - "python -m pip install pytest" + print( + f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): clone complete in {time.perf_counter() - t0:.2f}s", + flush=True, ) - await exec_tool(ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s})) - return {"repo_dir": repo_dir, "base_commit": base_commit} def _tests_for_item(self, item: Item) -> List[str]: @@ -276,13 +366,60 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): trajectory_id: str, exec_tool, agent_result=None, # noqa: ARG002 + workspace_meta: Optional[Dict[str, Any]] = None, ) -> tuple[float, Dict[str, Any]]: _ = trajectory_id repo_dir = self._repo_name(item) + + if self.config.verification_mode == "install": + print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): running pip install -e .", flush=True) + t0 = time.perf_counter() + install_cmd = ( + f"cd {repo_dir} && " + "python -m venv .venv && " + ". .venv/bin/activate && " + "python -m pip install -U pip setuptools wheel && " + "python -m pip install -e ." + ) + res = await exec_tool( + ToolCall(name="terminal", arguments={"command": install_cmd, "timeout": self.config.install_timeout_s}) + ) + ok = bool(res.success) + print( + f"[SweSmithOracleEnv] tid={trajectory_id} verify (install): {'ok' if ok else 'fail'} " + f"in {time.perf_counter() - t0:.2f}s", + flush=True, + ) + return (1.0 if ok else 0.0), { + "verification_mode": "install", + "install_success": ok, + "error": res.error, + } + nodeids = self._tests_for_item(item) if not nodeids: return 0.0, {"error": "No tests provided"} + print(f"[SweSmithOracleEnv] tid={trajectory_id} verify (pytest): ensuring venv + deps", flush=True) + setup_cmd = ( + f"cd {repo_dir} && " + "python -m venv .venv && " + ". .venv/bin/activate && " + "python -m pip install -U pip setuptools wheel && " + "python -m pip install -e . && " + "python -m pip install pytest" + ) + setup_res = await exec_tool( + ToolCall(name="terminal", arguments={"command": setup_cmd, "timeout": self.config.install_timeout_s}) + ) + if not setup_res.success: + return 0.0, { + "verification_mode": "pytest", + "phase": "install", + "error": setup_res.error, + "output": setup_res.output, + } + chunks = self._chunk_nodeids(nodeids, max_per_chunk=50) for chunk_idx, chunk in enumerate(chunks): joined = " ".join(chunk) @@ -296,7 +433,7 @@ class SweSmithOracleEnv(AgentEnv[SweSmithOracleEnvConfig]): if not res.success: return 0.0, {"failed_chunk": chunk_idx, "error": res.error, "output": res.output} - return 1.0, {"passed": True} + return 1.0, {"verification_mode": "pytest", "passed": True} async def score_trajectory(self, item: Item, final_response: str) -> float: # Not used; scoring happens in verify_and_score_trajectory. diff --git a/atropos/envs/test_env.py b/atropos/envs/test_env.py index 8b0a5df74d..9a107fc99c 100644 --- a/atropos/envs/test_env.py +++ b/atropos/envs/test_env.py @@ -106,7 +106,7 @@ class SimpleTestEnv(AgentEnv[SimpleTestEnvConfig]): or "http://127.0.0.1:8080" ) model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" - api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" env_config = SimpleTestEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", diff --git a/atropos/envs/toolserver_smoke_env.py b/atropos/envs/toolserver_smoke_env.py index ec17698953..d3fa8b12bb 100644 --- a/atropos/envs/toolserver_smoke_env.py +++ b/atropos/envs/toolserver_smoke_env.py @@ -60,7 +60,7 @@ class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]): or "http://127.0.0.1:8080" ) model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "hermes-4-36b" - api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" + api_key = os.getenv("ATROPOS_SERVER_API_KEY") or os.getenv("NOUS_API_KEY") or os.getenv("OPENAI_API_KEY") or "local" env_config = ToolServerSmokeEnvConfig( tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B", @@ -139,6 +139,7 @@ class ToolServerSmokeEnv(AgentEnv[ToolServerSmokeEnvConfig]): trajectory_id: str, # noqa: ARG002 exec_tool, # noqa: ARG002 agent_result: AgentResult | None = None, + workspace_meta: Dict[str, Any] | None = None, # noqa: ARG002 ) -> tuple[float, Dict[str, Any]]: if agent_result is None: return 0.0, {"error": "Missing agent_result"}