mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-03 02:11:48 +00:00
- Revert compute_reward() tool-call shaping to simple count-based reward (0.05 per tool call, capped at 0.3) - Keep new agent-loop metrics available but only print them for debugging, so environments/users can decide their own tool-call validity policy
620 lines
24 KiB
Python
620 lines
24 KiB
Python
"""
|
||
SWE-smith-oracle environment (ported to HermesAgentBaseEnv).
|
||
|
||
Trains models to fix real GitHub repositories:
|
||
- Clones a public GitHub repo at a specific commit
|
||
- Runs an agent loop with terminal tool to apply a fix
|
||
- Verifies by running pytest with nodeids from the dataset
|
||
- Reward: 1.0 if all tests pass, 0.0 otherwise
|
||
|
||
Dataset: NousResearch/SWE-smith-oracle (train split; does NOT use SWE-bench eval set).
|
||
|
||
Usage:
|
||
# Process mode (OpenAI server, no training):
|
||
python environments/swe_smith_oracle_env.py process \\
|
||
--env.data_path_to_save_groups data/swe_oracle_output.jsonl
|
||
|
||
# With Modal sandbox backend:
|
||
python environments/swe_smith_oracle_env.py process \\
|
||
--env.tool_pool_mode modal \\
|
||
--env.modal_image python:3.11
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
import random
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
||
_repo_root = Path(__file__).resolve().parent.parent
|
||
if str(_repo_root) not in sys.path:
|
||
sys.path.insert(0, str(_repo_root))
|
||
|
||
from pydantic import Field
|
||
|
||
from atroposlib.envs.base import ScoredDataGroup
|
||
from atroposlib.envs.server_handling.server_manager import APIServerConfig
|
||
from atroposlib.type_definitions import Item
|
||
|
||
from environments.agent_loop import AgentResult
|
||
from environments.hermes_base_env import HermesAgentBaseEnv, HermesAgentEnvConfig
|
||
from environments.tool_context import ToolContext
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# =============================================================================
|
||
# Config
|
||
# =============================================================================
|
||
|
||
class SweSmithOracleEnvConfig(HermesAgentEnvConfig):
|
||
"""Config for SWE-smith-oracle environment."""
|
||
|
||
dataset_name: str = Field(default="NousResearch/SWE-smith-oracle")
|
||
dataset_split: str = Field(default="train")
|
||
max_items: int = Field(default=0, description="0 = no limit")
|
||
shuffle: bool = Field(default=True)
|
||
seed: int = Field(default=0)
|
||
|
||
python_only: bool = Field(default=True, description="Filter to Python-evaluable rows")
|
||
score_include_fail_to_pass: bool = Field(
|
||
default=True,
|
||
description="Score tests on PASS_TO_PASS ∪ FAIL_TO_PASS. "
|
||
"Disable to only run PASS_TO_PASS (faster but weaker signal).",
|
||
)
|
||
|
||
prompt_mode: str = Field(
|
||
default="problem_statement",
|
||
description="'problem_statement' (fast) or 'problem_statement+text' (includes dataset 'text').",
|
||
)
|
||
|
||
repo_base_url: str = Field(default="https://github.com", description="Base URL for repo cloning")
|
||
install_timeout_s: float = Field(default=600.0)
|
||
test_timeout_s: float = Field(default=600.0)
|
||
|
||
|
||
# =============================================================================
|
||
# Environment
|
||
# =============================================================================
|
||
|
||
class SweSmithOracleEnv(HermesAgentBaseEnv):
|
||
"""
|
||
SWE-smith-oracle environment for training models to fix real GitHub repos.
|
||
|
||
Uses proper OpenAI-spec tool calling via HermesAgentBaseEnv.
|
||
The model gets terminal access to inspect, edit, and test the repository.
|
||
"""
|
||
|
||
name = "swe-smith-oracle"
|
||
env_config_cls = SweSmithOracleEnvConfig
|
||
|
||
def __init__(
|
||
self,
|
||
config: SweSmithOracleEnvConfig,
|
||
server_configs,
|
||
slurm=False,
|
||
testing=False,
|
||
):
|
||
super().__init__(config, server_configs, slurm, testing)
|
||
self._dataset = None
|
||
self._indices: List[int] = []
|
||
self._cursor = 0
|
||
|
||
@classmethod
|
||
def config_init(cls) -> Tuple[SweSmithOracleEnvConfig, List[APIServerConfig]]:
|
||
"""Default config — reads from ATROPOS_SERVER_* env vars."""
|
||
base_url = (
|
||
os.getenv("ATROPOS_SERVER_BASE_URL")
|
||
or os.getenv("OPENAI_BASE_URL")
|
||
or os.getenv("LLM_BASE_URL")
|
||
or "http://127.0.0.1:8080"
|
||
)
|
||
if not base_url.rstrip("/").endswith("/v1"):
|
||
base_url = base_url.rstrip("/") + "/v1"
|
||
|
||
model = os.getenv("ATROPOS_SERVER_MODEL") or os.getenv("LLM_MODEL") or "Hermes-4.3-36B"
|
||
api_key = (
|
||
os.getenv("ATROPOS_SERVER_API_KEY")
|
||
or os.getenv("NOUS_API_KEY")
|
||
or os.getenv("OPENAI_API_KEY")
|
||
or "local"
|
||
)
|
||
|
||
env_config = SweSmithOracleEnvConfig(
|
||
tokenizer_name=os.getenv("ATROPOS_TOKENIZER_NAME") or "NousResearch/Hermes-4.3-36B",
|
||
group_size=1,
|
||
use_wandb=False,
|
||
rollout_server_url="http://localhost:8000",
|
||
total_steps=1,
|
||
batch_size=1,
|
||
steps_per_eval=1,
|
||
max_token_length=8192,
|
||
wandb_name="swe_smith_oracle",
|
||
enabled_toolsets=["terminal", "file"],
|
||
terminal_backend=os.getenv("TERMINAL_ENV", "local"),
|
||
# Longer agent turns for SWE tasks
|
||
max_agent_turns=50,
|
||
agent_temperature=0.7,
|
||
system_prompt=(
|
||
"You are a senior software engineer. You have access to a terminal "
|
||
"to inspect and fix repositories. Use non-interactive commands only. "
|
||
"Each terminal command runs in a fresh shell."
|
||
),
|
||
tool_call_parser="hermes",
|
||
# Sandbox settings (used when tool_pool_mode != "default")
|
||
sandbox_image=os.getenv("ATROPOS_SANDBOX_IMAGE") or "atropos-sandbox:local",
|
||
purge_job_on_start=True,
|
||
purge_job_on_shutdown=True,
|
||
)
|
||
|
||
server_configs = [
|
||
APIServerConfig(
|
||
model_name=model,
|
||
base_url=base_url,
|
||
api_key=api_key,
|
||
server_type="vllm",
|
||
health_check=False,
|
||
timeout=int(os.getenv("ATROPOS_SERVER_TIMEOUT_S") or "300"),
|
||
),
|
||
]
|
||
|
||
return env_config, server_configs
|
||
|
||
# =========================================================================
|
||
# Dataset loading
|
||
# =========================================================================
|
||
|
||
async def setup(self):
|
||
"""Load SWE-smith-oracle dataset."""
|
||
from datasets import load_dataset
|
||
|
||
t0 = time.perf_counter()
|
||
print(
|
||
f"[SweSmithOracleEnv] loading dataset {self.config.dataset_name}:{self.config.dataset_split} "
|
||
f"(python_only={self.config.python_only}, max_items={self.config.max_items or 'all'})",
|
||
flush=True,
|
||
)
|
||
ds = load_dataset(self.config.dataset_name, split=self.config.dataset_split)
|
||
self._dataset = ds
|
||
|
||
indices: List[int] = []
|
||
for idx in range(len(ds)):
|
||
row = ds[idx]
|
||
if self.config.python_only and not self._is_python_row(row):
|
||
continue
|
||
indices.append(idx)
|
||
|
||
if self.config.shuffle:
|
||
rnd = random.Random(self.config.seed)
|
||
rnd.shuffle(indices)
|
||
|
||
if self.config.max_items and self.config.max_items > 0:
|
||
indices = indices[: self.config.max_items]
|
||
|
||
self._indices = indices
|
||
self._cursor = 0
|
||
|
||
print(
|
||
f"[SweSmithOracleEnv] loaded {len(self._indices)} items "
|
||
f"in {time.perf_counter() - t0:.2f}s",
|
||
flush=True,
|
||
)
|
||
|
||
def _is_python_row(self, row: Dict[str, Any]) -> bool:
|
||
nodeids = row.get("PASS_TO_PASS")
|
||
if not isinstance(nodeids, list) or not nodeids:
|
||
return False
|
||
return all(isinstance(nid, str) and ".py::" in nid for nid in nodeids)
|
||
|
||
async def get_next_item(self) -> Item:
|
||
if not self._dataset or not self._indices:
|
||
raise RuntimeError("Dataset not initialized")
|
||
if self._cursor >= len(self._indices):
|
||
self._cursor = 0
|
||
idx = self._indices[self._cursor]
|
||
self._cursor += 1
|
||
return dict(self._dataset[idx])
|
||
|
||
# =========================================================================
|
||
# Prompt formatting
|
||
# =========================================================================
|
||
|
||
def _repo_name(self, item: Item) -> str:
|
||
repo = item.get("repo") or ""
|
||
if isinstance(repo, str) and "/" in repo:
|
||
return repo.split("/")[-1]
|
||
return "repo"
|
||
|
||
def format_prompt(self, item: Item) -> str:
|
||
"""Build the SWE task prompt."""
|
||
repo = item.get("repo") or ""
|
||
base_commit = item.get("base_commit") or ""
|
||
problem = str(item.get("problem_statement") or "")
|
||
context = str(item.get("text") or "")
|
||
repo_dir = self._repo_name(item)
|
||
|
||
nodeids = self._tests_for_item(item)
|
||
tests_list = "\n".join(f"- {t}" for t in nodeids)
|
||
|
||
context_block = ""
|
||
prompt_mode = (self.config.prompt_mode or "problem_statement").strip().lower()
|
||
if prompt_mode == "problem_statement+text" and context:
|
||
context_block = f"\nAdditional context:\n{context}\n"
|
||
|
||
return (
|
||
f"Fix the repository so the specified tests pass.\n\n"
|
||
f"Repository: {repo} (checked out at base_commit={base_commit})\n"
|
||
f"Workspace path: ./{repo_dir}\n\n"
|
||
"Constraints:\n"
|
||
"- Use the terminal tool to inspect, edit, and verify the repository.\n"
|
||
f"- Start by inspecting the repo (e.g. `ls`, `cd ./{repo_dir}`, `git status`).\n"
|
||
"- Use a workspace-local virtualenv (.venv) to avoid cross-run contamination.\n"
|
||
"- Use non-interactive commands only.\n"
|
||
"- Prefer `. .venv/bin/activate` or `.venv/bin/python ...` (POSIX compatible).\n\n"
|
||
f"Problem statement:\n{problem}\n\n"
|
||
f"{context_block}"
|
||
f"Run these tests to verify:\n{tests_list}\n\n"
|
||
"When done, briefly describe what you changed and confirm tests pass."
|
||
)
|
||
|
||
# =========================================================================
|
||
# Test helpers
|
||
# =========================================================================
|
||
|
||
def _tests_for_item(self, item: Item) -> List[str]:
|
||
tests: List[str] = []
|
||
if self.config.score_include_fail_to_pass:
|
||
for key in ("PASS_TO_PASS", "FAIL_TO_PASS"):
|
||
nodeids = item.get(key)
|
||
if isinstance(nodeids, list):
|
||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||
else:
|
||
nodeids = item.get("PASS_TO_PASS")
|
||
if isinstance(nodeids, list):
|
||
tests.extend([n for n in nodeids if isinstance(n, str)])
|
||
return sorted(dict.fromkeys(tests))
|
||
|
||
def _chunk_nodeids(self, nodeids: List[str], max_per_chunk: int = 50) -> List[List[str]]:
|
||
return [nodeids[i : i + max_per_chunk] for i in range(0, len(nodeids), max_per_chunk)]
|
||
|
||
# =========================================================================
|
||
# Sandbox hooks: setup_trajectory_workspace + verify_and_score_trajectory
|
||
# =========================================================================
|
||
|
||
async def setup_trajectory_workspace(
|
||
self, item: Item, *, trajectory_id: str, exec_tool
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Prepare a sandbox workspace: bare repo cache + git worktree.
|
||
|
||
Uses flock-serialized bare repo cache under /data/repo_cache so
|
||
multiple trajectories sharing a sandbox don't clone the same repo
|
||
in parallel. Each trajectory gets an isolated worktree at the
|
||
specified base_commit.
|
||
|
||
Args:
|
||
item: Dataset row with repo, base_commit, etc.
|
||
trajectory_id: Unique trajectory ID
|
||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||
|
||
Returns:
|
||
Dict with repo_dir, base_commit metadata
|
||
"""
|
||
import time as _time
|
||
|
||
t0 = _time.perf_counter()
|
||
repo = item.get("repo")
|
||
base_commit = item.get("base_commit")
|
||
instance_id = item.get("instance_id") or item.get("id") or item.get("problem_id")
|
||
if not isinstance(repo, str) or not isinstance(base_commit, str):
|
||
raise RuntimeError("Invalid dataset row: missing repo/base_commit")
|
||
|
||
repo_dir = self._repo_name(item)
|
||
clone_url = f"{self.config.repo_base_url.rstrip('/')}/{repo}.git"
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||
f"repo={repo} base_commit={base_commit} instance_id={instance_id} dir=./{repo_dir}",
|
||
flush=True,
|
||
)
|
||
|
||
# Bare repo cache + worktree strategy (same as atropos/envs/swe_smith_oracle_env.py)
|
||
repo_slug = repo.replace("/", "__")
|
||
cache_root = "/data/repo_cache"
|
||
bare_repo = f"{cache_root}/{repo_slug}.git"
|
||
lock_file = f"{cache_root}/.locks/{repo_slug}.lock"
|
||
|
||
worktree_cmd = (
|
||
"set -e; "
|
||
f"rm -rf {repo_dir}; "
|
||
f"mkdir -p {cache_root}/.locks; "
|
||
f": > {lock_file}; "
|
||
f"flock -x {lock_file} sh -lc '"
|
||
f"set -e; "
|
||
"export GIT_TERMINAL_PROMPT=0; "
|
||
"export GIT_LFS_SKIP_SMUDGE=1; "
|
||
f"if [ ! -d \"{bare_repo}\" ]; then "
|
||
f" git init --bare \"{bare_repo}\"; "
|
||
f" git -C \"{bare_repo}\" remote add origin \"{clone_url}\"; "
|
||
"fi; "
|
||
f"git -C \"{bare_repo}\" remote set-url origin \"{clone_url}\"; "
|
||
f"git -C \"{bare_repo}\" worktree prune || true; "
|
||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||
f" git -C \"{bare_repo}\" fetch --depth 1 origin \"{base_commit}\" || true; "
|
||
"fi; "
|
||
f"if ! git -C \"{bare_repo}\" cat-file -e \"{base_commit}^{{commit}}\" 2>/dev/null; then "
|
||
f" git -C \"{bare_repo}\" fetch --prune origin; "
|
||
"fi; "
|
||
f"git --git-dir=\"{bare_repo}\" worktree add --detach \"{repo_dir}\" \"{base_commit}\"; "
|
||
"'"
|
||
)
|
||
|
||
print(f"[SweSmithOracleEnv] tid={trajectory_id} preparing worktree from repo cache", flush=True)
|
||
res = await exec_tool(
|
||
"bash",
|
||
{"command": worktree_cmd},
|
||
timeout=self.config.install_timeout_s,
|
||
)
|
||
if not res.success:
|
||
raise RuntimeError(
|
||
f"git worktree setup failed "
|
||
f"(repo={repo}, base_commit={base_commit}, instance_id={instance_id}): "
|
||
f"{res.error}\n{res.output}"
|
||
)
|
||
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} setup_trajectory_workspace(): "
|
||
f"worktree ready in {_time.perf_counter() - t0:.2f}s",
|
||
flush=True,
|
||
)
|
||
return {"repo_dir": repo_dir, "base_commit": base_commit}
|
||
|
||
async def verify_and_score_trajectory(
|
||
self,
|
||
item: Item,
|
||
result: AgentResult,
|
||
*,
|
||
trajectory_id: str,
|
||
exec_tool,
|
||
workspace_meta: Optional[Dict[str, Any]] = None,
|
||
) -> Tuple[float, Dict[str, Any]]:
|
||
"""
|
||
In-sandbox verification: install deps + run pytest with dataset nodeids.
|
||
|
||
Args:
|
||
item: Dataset row
|
||
result: Agent's rollout result
|
||
trajectory_id: Unique trajectory ID
|
||
exec_tool: async callable(tool_name, args, timeout) -> ExecutionResult
|
||
workspace_meta: From setup_trajectory_workspace (has repo_dir)
|
||
|
||
Returns:
|
||
(reward, metadata) tuple
|
||
"""
|
||
repo_dir = (workspace_meta or {}).get("repo_dir") or self._repo_name(item)
|
||
|
||
# Don't reward trajectories that never used tools
|
||
tool_call_count = sum(
|
||
len(msg.get("tool_calls", []))
|
||
for msg in result.messages
|
||
if msg.get("role") == "assistant"
|
||
)
|
||
if tool_call_count == 0:
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: no tool calls; score=0.0",
|
||
flush=True,
|
||
)
|
||
return 0.0, {"error": "No tool calls were made by the agent"}
|
||
|
||
nodeids = self._tests_for_item(item)
|
||
if not nodeids:
|
||
return 0.0, {"error": "No tests provided"}
|
||
|
||
# Install dependencies
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} verify: installing deps + running tests",
|
||
flush=True,
|
||
)
|
||
setup_cmd = (
|
||
f"cd {repo_dir} && "
|
||
"python -m venv .venv && "
|
||
". .venv/bin/activate && "
|
||
"python -m pip install -U pip setuptools wheel && "
|
||
"python -m pip install -e . && "
|
||
"python -m pip install pytest"
|
||
)
|
||
setup_res = await exec_tool(
|
||
"bash", {"command": setup_cmd}, timeout=self.config.install_timeout_s
|
||
)
|
||
if not setup_res.success:
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} install failed; score=0.0",
|
||
flush=True,
|
||
)
|
||
return 0.0, {
|
||
"phase": "install",
|
||
"error": setup_res.error,
|
||
"output": setup_res.output,
|
||
}
|
||
|
||
# Run test chunks
|
||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||
for chunk_idx, chunk in enumerate(chunks):
|
||
joined = " ".join(chunk)
|
||
cmd = f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}"
|
||
res = await exec_tool(
|
||
"bash", {"command": cmd}, timeout=self.config.test_timeout_s
|
||
)
|
||
if not res.success:
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} tests failed (chunk {chunk_idx}); score=0.0",
|
||
flush=True,
|
||
)
|
||
return 0.0, {
|
||
"phase": "pytest",
|
||
"failed_chunk": chunk_idx,
|
||
"error": res.error,
|
||
"output": res.output,
|
||
}
|
||
|
||
print(
|
||
f"[SweSmithOracleEnv] tid={trajectory_id} all tests passed; score=1.0",
|
||
flush=True,
|
||
)
|
||
return 1.0, {"passed": True}
|
||
|
||
# =========================================================================
|
||
# Reward: run pytest in the terminal (local / non-sandbox path)
|
||
# =========================================================================
|
||
|
||
async def compute_reward(
|
||
self, item: Item, result: AgentResult, ctx: ToolContext
|
||
) -> float:
|
||
"""
|
||
Verify by running pytest with the dataset's nodeids.
|
||
|
||
Reward structure (shaped to give training signal even when model can't solve tasks):
|
||
- 0.0: No tool calls at all
|
||
- 0.05: Per valid tool call (up to 0.3 max for tool-call shaping)
|
||
- 0.4: Successfully installed deps
|
||
- 1.0: All tests pass
|
||
|
||
The partial rewards for tool calls help the model learn to USE tools
|
||
before it can learn to use them CORRECTLY. This is critical for cold-start
|
||
training where the base model barely makes any tool calls.
|
||
"""
|
||
repo_dir = self._repo_name(item)
|
||
|
||
# Count tool calls (assistant messages that have tool_calls).
|
||
# NOTE: we keep scoring policy here intentionally simple and env-specific.
|
||
# The agent loop exposes additional tool-call metrics (attempted/schema_valid/
|
||
# executed_ok/exec_error) that other environments may choose to use for
|
||
# reward shaping, but we don't hard-require any particular calling format here.
|
||
tool_call_count = sum(
|
||
len(msg.get("tool_calls", []))
|
||
for msg in result.messages
|
||
if msg.get("role") == "assistant"
|
||
)
|
||
|
||
if tool_call_count == 0:
|
||
print(f"[SweSmithOracleEnv] No tool calls made; score=0.0", flush=True)
|
||
return 0.0
|
||
|
||
# Partial reward: 0.05 per tool call, capped at 0.3
|
||
tool_call_reward = min(tool_call_count * 0.05, 0.3)
|
||
|
||
# Debug: log tool-call quality metrics if present
|
||
attempted = getattr(result, "tool_calls_attempted", None)
|
||
schema_valid = getattr(result, "tool_calls_schema_valid", None)
|
||
executed_ok = getattr(result, "tool_calls_executed_ok", None)
|
||
exec_error = getattr(result, "tool_calls_exec_error", None)
|
||
if attempted is not None:
|
||
print(
|
||
f"[SweSmithOracleEnv] Tool calls: total={tool_call_count}, attempted={attempted}, schema_valid={schema_valid}, ok={executed_ok}, err={exec_error}",
|
||
flush=True,
|
||
)
|
||
|
||
nodeids = self._tests_for_item(item)
|
||
if not nodeids:
|
||
# No tests defined — just reward tool usage
|
||
print(f"[SweSmithOracleEnv] No tests defined; score={tool_call_reward:.2f} (tool calls)", flush=True)
|
||
return tool_call_reward
|
||
|
||
# Install deps + run tests
|
||
print(f"[SweSmithOracleEnv] Verifying: installing deps + running tests", flush=True)
|
||
setup_result = ctx.terminal(
|
||
f"cd {repo_dir} && "
|
||
"python -m venv .venv && "
|
||
". .venv/bin/activate && "
|
||
"python -m pip install -U pip setuptools wheel && "
|
||
"python -m pip install -e . && "
|
||
"python -m pip install pytest",
|
||
timeout=int(self.config.install_timeout_s),
|
||
)
|
||
if setup_result.get("exit_code", 1) != 0:
|
||
print(f"[SweSmithOracleEnv] Install failed; score={tool_call_reward:.2f} (tool calls only)", flush=True)
|
||
return tool_call_reward
|
||
|
||
# Partial reward for successful install
|
||
install_reward = 0.4
|
||
|
||
# Run test chunks
|
||
chunks = self._chunk_nodeids(nodeids, max_per_chunk=50)
|
||
for chunk_idx, chunk in enumerate(chunks):
|
||
joined = " ".join(chunk)
|
||
test_result = ctx.terminal(
|
||
f"cd {repo_dir} && . .venv/bin/activate && python -m pytest -q {joined}",
|
||
timeout=int(self.config.test_timeout_s),
|
||
)
|
||
if test_result.get("exit_code", 1) != 0:
|
||
print(f"[SweSmithOracleEnv] Tests failed (chunk {chunk_idx}); score={install_reward:.2f} (install ok)", flush=True)
|
||
return install_reward
|
||
|
||
print(f"[SweSmithOracleEnv] All tests passed; score=1.0", flush=True)
|
||
return 1.0
|
||
|
||
# =========================================================================
|
||
# Token truncation — keep start of trajectory, truncate from end
|
||
# =========================================================================
|
||
|
||
def _build_scored_item(self, item, result, reward):
|
||
"""
|
||
Override to truncate tokens/masks from the END to fit within max_token_len.
|
||
|
||
Intuition (from NeurIPS finding): the start of the trajectory is most important
|
||
for shifting the model distribution. Truncating from the end only costs ~2-3%
|
||
vs handling the full sequence, but avoids the "Token length is too long" discard
|
||
that throws away entire groups including valid training signal.
|
||
"""
|
||
scored_item, remaining = super()._build_scored_item(item, result, reward)
|
||
if scored_item is None:
|
||
return scored_item, remaining
|
||
|
||
# Use config.max_token_length as the truncation limit.
|
||
# self.max_token_len comes from the trainer via /info, but may be -1
|
||
# if the trainer hasn't registered yet (race condition).
|
||
max_len = self.max_token_len
|
||
if max_len <= 0:
|
||
# Fallback to config value
|
||
max_len = getattr(self.config, 'max_token_length', 0)
|
||
if max_len <= 0:
|
||
return scored_item, remaining
|
||
|
||
# Leave some margin (64 tokens) to avoid edge cases with padding alignment
|
||
truncate_to = max_len - 64
|
||
|
||
tokens = scored_item.get("tokens")
|
||
masks = scored_item.get("masks")
|
||
|
||
if tokens is not None and len(tokens) >= max_len:
|
||
orig_len = len(tokens)
|
||
scored_item["tokens"] = tokens[:truncate_to]
|
||
if masks is not None and len(masks) >= max_len:
|
||
scored_item["masks"] = masks[:truncate_to]
|
||
logger.info(
|
||
"Truncated trajectory from %d to %d tokens (max_token_len=%d)",
|
||
orig_len, truncate_to, max_len,
|
||
)
|
||
|
||
return scored_item, remaining
|
||
|
||
# =========================================================================
|
||
# Evaluation (minimal for now)
|
||
# =========================================================================
|
||
|
||
async def evaluate(self, *args, **kwargs):
|
||
"""Placeholder evaluation — SWE tasks are too expensive for frequent eval."""
|
||
start_time = time.time()
|
||
await self.evaluate_log(
|
||
metrics={"eval/placeholder": 0.0},
|
||
samples=[],
|
||
start_time=start_time,
|
||
end_time=time.time(),
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
SweSmithOracleEnv.cli()
|