fix(file_tools): resolve bookkeeping paths against live terminal cwd

This commit is contained in:
Yukipukii1 2026-04-23 22:36:07 +03:00 committed by Teknium
parent 83859b4da0
commit 4a0c02b7dc
3 changed files with 139 additions and 17 deletions

View file

@ -13,8 +13,10 @@ import os
import tempfile import tempfile
import time import time
import unittest import unittest
from types import SimpleNamespace
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
from tools import file_state
from tools.file_tools import ( from tools.file_tools import (
read_file_tool, read_file_tool,
write_file_tool, write_file_tool,
@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase):
def setUp(self): def setUp(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
self._tmpdir = tempfile.mkdtemp() self._tmpdir = tempfile.mkdtemp()
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt") self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
with open(self._tmpfile, "w") as f: with open(self._tmpfile, "w") as f:
@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase):
def tearDown(self): def tearDown(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
try: try:
os.unlink(self._tmpfile) os.unlink(self._tmpfile)
os.rmdir(self._tmpdir) os.rmdir(self._tmpdir)
@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase):
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b")) result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
self.assertNotIn("_warning", result) self.assertNotIn("_warning", result)
@patch("tools.file_tools._get_file_ops")
def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops):
"""Relative-path stale tracking must follow the live terminal cwd."""
start_dir = os.path.join(self._tmpdir, "start")
live_dir = os.path.join(self._tmpdir, "worktree")
os.makedirs(start_dir, exist_ok=True)
os.makedirs(live_dir, exist_ok=True)
start_file = os.path.join(start_dir, "shared.txt")
live_file = os.path.join(live_dir, "shared.txt")
with open(start_file, "w") as f:
f.write("start copy\n")
with open(live_file, "w") as f:
f.write("live copy\n")
fake_ops = _make_fake_ops("live copy\n", 10)
fake_ops.env = SimpleNamespace(cwd=live_dir)
fake_ops.cwd = start_dir
mock_ops.return_value = fake_ops
from tools import file_tools
with file_tools._file_ops_lock:
previous = file_tools._file_ops_cache.get("live_task")
file_tools._file_ops_cache["live_task"] = fake_ops
try:
with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False):
read_file_tool("shared.txt", task_id="live_task")
time.sleep(0.05)
with open(live_file, "w") as f:
f.write("live copy modified elsewhere\n")
result = json.loads(
write_file_tool("shared.txt", "replacement", task_id="live_task")
)
finally:
with file_tools._file_ops_lock:
if previous is None:
file_tools._file_ops_cache.pop("live_task", None)
else:
file_tools._file_ops_cache["live_task"] = previous
self.assertIn("_warning", result)
self.assertIn("modified since you last read", result["_warning"])
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Staleness in patch # Staleness in patch
@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase):
def setUp(self): def setUp(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
self._tmpdir = tempfile.mkdtemp() self._tmpdir = tempfile.mkdtemp()
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt") self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
with open(self._tmpfile, "w") as f: with open(self._tmpfile, "w") as f:
@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase):
def tearDown(self): def tearDown(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
try: try:
os.unlink(self._tmpfile) os.unlink(self._tmpfile)
os.rmdir(self._tmpdir) os.rmdir(self._tmpdir)
@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase):
def setUp(self): def setUp(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
def tearDown(self): def tearDown(self):
_read_tracker.clear() _read_tracker.clear()
file_state.get_registry().clear()
def test_returns_none_for_unknown_task(self): def test_returns_none_for_unknown_task(self):
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent")) self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))

View file

@ -2,6 +2,7 @@
import os import os
from pathlib import Path from pathlib import Path
from types import SimpleNamespace
import pytest import pytest
@ -22,8 +23,9 @@ class TestResolvePath:
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path)) monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
from tools.file_tools import _resolve_path from tools.file_tools import _resolve_path
result = _resolve_path("/etc/hosts") absolute = (tmp_path / "already-absolute.txt").resolve()
assert result == Path("/etc/hosts") result = _resolve_path(str(absolute))
assert result == absolute
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch): def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
"""Without TERMINAL_CWD, falls back to os.getcwd().""" """Without TERMINAL_CWD, falls back to os.getcwd()."""
@ -50,3 +52,34 @@ class TestResolvePath:
result = _resolve_path("a/../b/file.txt") result = _resolve_path("a/../b/file.txt")
assert ".." not in str(result) assert ".." not in str(result)
assert result == (tmp_path / "b" / "file.txt") assert result == (tmp_path / "b" / "file.txt")
def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path):
"""Live env.cwd must win after the terminal session changes directory."""
start_dir = tmp_path / "start"
live_dir = tmp_path / "worktree"
start_dir.mkdir()
live_dir.mkdir()
monkeypatch.setenv("TERMINAL_CWD", str(start_dir))
from tools import file_tools
task_id = "live-cwd"
fake_ops = SimpleNamespace(
env=SimpleNamespace(cwd=str(live_dir)),
cwd=str(start_dir),
)
with file_tools._file_ops_lock:
previous = file_tools._file_ops_cache.get(task_id)
file_tools._file_ops_cache[task_id] = fake_ops
try:
result = file_tools._resolve_path("nested/file.txt", task_id=task_id)
finally:
with file_tools._file_ops_lock:
if previous is None:
file_tools._file_ops_cache.pop(task_id, None)
else:
file_tools._file_ops_cache[task_id] = previous
assert result == live_dir / "nested" / "file.txt"

View file

@ -79,13 +79,45 @@ _BLOCKED_DEVICE_PATHS = frozenset({
}) })
def _resolve_path(filepath: str) -> Path: def _resolve_path(filepath: str, task_id: str = "default") -> Path:
"""Resolve a path relative to TERMINAL_CWD (the worktree base directory) """Resolve a path relative to TERMINAL_CWD (the worktree base directory)
instead of the main repository root. instead of the main repository root.
""" """
return _resolve_path_for_task(filepath, task_id)
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
"""Return the task's live terminal cwd for bookkeeping when available."""
with _file_ops_lock:
cached = _file_ops_cache.get(task_id)
if cached is not None:
live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr(
cached, "cwd", None
)
if live_cwd:
return live_cwd
try:
from tools.terminal_tool import _active_environments, _env_lock
with _env_lock:
env = _active_environments.get(task_id)
live_cwd = getattr(env, "cwd", None) if env is not None else None
if live_cwd:
return live_cwd
except Exception:
pass
return None
def _resolve_path_for_task(filepath: str, task_id: str = "default") -> Path:
"""Resolve *filepath* against the task's live terminal cwd when possible."""
p = Path(filepath).expanduser() p = Path(filepath).expanduser()
if not p.is_absolute(): if not p.is_absolute():
base = os.environ.get("TERMINAL_CWD", os.getcwd()) base = _get_live_tracking_cwd(task_id) or os.environ.get(
"TERMINAL_CWD", os.getcwd()
)
p = Path(base) / p p = Path(base) / p
return p.resolve() return p.resolve()
@ -118,10 +150,10 @@ _SENSITIVE_PATH_PREFIXES = (
_SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"} _SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"}
def _check_sensitive_path(filepath: str) -> str | None: def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None:
"""Return an error message if the path targets a sensitive system location.""" """Return an error message if the path targets a sensitive system location."""
try: try:
resolved = str(_resolve_path(filepath)) resolved = str(_resolve_path_for_task(filepath, task_id))
except (OSError, ValueError): except (OSError, ValueError):
resolved = filepath resolved = filepath
normalized = os.path.normpath(os.path.expanduser(filepath)) normalized = os.path.normpath(os.path.expanduser(filepath))
@ -368,7 +400,7 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
), ),
}) })
_resolved = _resolve_path(path) _resolved = _resolve_path_for_task(path, task_id)
# ── Binary file guard ───────────────────────────────────────── # ── Binary file guard ─────────────────────────────────────────
# Block binary files by extension (no I/O). # Block binary files by extension (no I/O).
@ -574,7 +606,7 @@ def _update_read_timestamp(filepath: str, task_id: str) -> None:
refreshes the stored timestamp to match the file's new state. refreshes the stored timestamp to match the file's new state.
""" """
try: try:
resolved = str(_resolve_path(filepath)) resolved = str(_resolve_path_for_task(filepath, task_id))
current_mtime = os.path.getmtime(resolved) current_mtime = os.path.getmtime(resolved)
except (OSError, ValueError): except (OSError, ValueError):
return return
@ -593,7 +625,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
or was never read. Does not block the write still proceeds. or was never read. Does not block the write still proceeds.
""" """
try: try:
resolved = str(_resolve_path(filepath)) resolved = str(_resolve_path_for_task(filepath, task_id))
except (OSError, ValueError): except (OSError, ValueError):
return None return None
with _read_tracker_lock: with _read_tracker_lock:
@ -618,7 +650,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
def write_file_tool(path: str, content: str, task_id: str = "default") -> str: def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
"""Write content to a file.""" """Write content to a file."""
sensitive_err = _check_sensitive_path(path) sensitive_err = _check_sensitive_path(path, task_id)
if sensitive_err: if sensitive_err:
return tool_error(sensitive_err) return tool_error(sensitive_err)
try: try:
@ -626,7 +658,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
# fall back to the legacy path — write proceeds, per-task staleness # fall back to the legacy path — write proceeds, per-task staleness
# check below still runs. # check below still runs.
try: try:
_resolved = str(_resolve_path(path)) _resolved = str(_resolve_path_for_task(path, task_id))
except Exception: except Exception:
_resolved = None _resolved = None
@ -681,7 +713,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE): for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE):
_paths_to_check.append(_m.group(1).strip()) _paths_to_check.append(_m.group(1).strip())
for _p in _paths_to_check: for _p in _paths_to_check:
sensitive_err = _check_sensitive_path(_p) sensitive_err = _check_sensitive_path(_p, task_id)
if sensitive_err: if sensitive_err:
return tool_error(sensitive_err) return tool_error(sensitive_err)
try: try:
@ -692,7 +724,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_seen: set[str] = set() _seen: set[str] = set()
for _p in _paths_to_check: for _p in _paths_to_check:
try: try:
_r = str(_resolve_path(_p)) _r = str(_resolve_path_for_task(_p, task_id))
except Exception: except Exception:
_r = None _r = None
if _r and _r not in _seen: if _r and _r not in _seen:
@ -714,7 +746,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_path_to_resolved: dict[str, str] = {} _path_to_resolved: dict[str, str] = {}
for _p in _paths_to_check: for _p in _paths_to_check:
try: try:
_r = str(_resolve_path(_p)) _r = str(_resolve_path_for_task(_p, task_id))
except Exception: except Exception:
_r = None _r = None
_path_to_resolved[_p] = _r _path_to_resolved[_p] = _r
@ -749,15 +781,17 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_r = _path_to_resolved.get(_p) _r = _path_to_resolved.get(_p)
if _r: if _r:
file_state.note_write(task_id, _r) file_state.note_write(task_id, _r)
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when old_string not found — saves iterations where the agent # Hint when old_string not found — saves iterations where the agent
# retries with stale content instead of re-reading the file. # retries with stale content instead of re-reading the file.
# Suppressed when patch_replace already attached a rich "Did you mean?" # Suppressed when patch_replace already attached a rich "Did you mean?"
# snippet (which is strictly more useful than the generic hint). # snippet (which is strictly more useful than the generic hint).
if result_dict.get("error") and "Could not find" in str(result_dict["error"]): if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
if "Did you mean one of these sections?" not in str(result_dict["error"]): if "Did you mean one of these sections?" not in str(result_dict["error"]):
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]" result_dict["_hint"] = (
return result_json "old_string not found. Use read_file to verify the current "
"content, or search_files to locate the text."
)
return json.dumps(result_dict, ensure_ascii=False)
except Exception as e: except Exception as e:
return tool_error(str(e)) return tool_error(str(e))