fix(file_tools): resolve bookkeeping paths against live terminal cwd

This commit is contained in:
Yukipukii1 2026-04-23 22:36:07 +03:00 committed by Teknium
parent 83859b4da0
commit 4a0c02b7dc
3 changed files with 139 additions and 17 deletions

View file

@ -13,8 +13,10 @@ import os
import tempfile
import time
import unittest
from types import SimpleNamespace
from unittest.mock import patch, MagicMock
from tools import file_state
from tools.file_tools import (
read_file_tool,
write_file_tool,
@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase):
def setUp(self):
_read_tracker.clear()
file_state.get_registry().clear()
self._tmpdir = tempfile.mkdtemp()
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
with open(self._tmpfile, "w") as f:
@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase):
def tearDown(self):
_read_tracker.clear()
file_state.get_registry().clear()
try:
os.unlink(self._tmpfile)
os.rmdir(self._tmpdir)
@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase):
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
self.assertNotIn("_warning", result)
@patch("tools.file_tools._get_file_ops")
def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops):
"""Relative-path stale tracking must follow the live terminal cwd."""
start_dir = os.path.join(self._tmpdir, "start")
live_dir = os.path.join(self._tmpdir, "worktree")
os.makedirs(start_dir, exist_ok=True)
os.makedirs(live_dir, exist_ok=True)
start_file = os.path.join(start_dir, "shared.txt")
live_file = os.path.join(live_dir, "shared.txt")
with open(start_file, "w") as f:
f.write("start copy\n")
with open(live_file, "w") as f:
f.write("live copy\n")
fake_ops = _make_fake_ops("live copy\n", 10)
fake_ops.env = SimpleNamespace(cwd=live_dir)
fake_ops.cwd = start_dir
mock_ops.return_value = fake_ops
from tools import file_tools
with file_tools._file_ops_lock:
previous = file_tools._file_ops_cache.get("live_task")
file_tools._file_ops_cache["live_task"] = fake_ops
try:
with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False):
read_file_tool("shared.txt", task_id="live_task")
time.sleep(0.05)
with open(live_file, "w") as f:
f.write("live copy modified elsewhere\n")
result = json.loads(
write_file_tool("shared.txt", "replacement", task_id="live_task")
)
finally:
with file_tools._file_ops_lock:
if previous is None:
file_tools._file_ops_cache.pop("live_task", None)
else:
file_tools._file_ops_cache["live_task"] = previous
self.assertIn("_warning", result)
self.assertIn("modified since you last read", result["_warning"])
# ---------------------------------------------------------------------------
# Staleness in patch
@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase):
def setUp(self):
_read_tracker.clear()
file_state.get_registry().clear()
self._tmpdir = tempfile.mkdtemp()
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
with open(self._tmpfile, "w") as f:
@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase):
def tearDown(self):
_read_tracker.clear()
file_state.get_registry().clear()
try:
os.unlink(self._tmpfile)
os.rmdir(self._tmpdir)
@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase):
def setUp(self):
_read_tracker.clear()
file_state.get_registry().clear()
def tearDown(self):
_read_tracker.clear()
file_state.get_registry().clear()
def test_returns_none_for_unknown_task(self):
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))

View file

@ -2,6 +2,7 @@
import os
from pathlib import Path
from types import SimpleNamespace
import pytest
@ -22,8 +23,9 @@ class TestResolvePath:
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
from tools.file_tools import _resolve_path
result = _resolve_path("/etc/hosts")
assert result == Path("/etc/hosts")
absolute = (tmp_path / "already-absolute.txt").resolve()
result = _resolve_path(str(absolute))
assert result == absolute
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
"""Without TERMINAL_CWD, falls back to os.getcwd()."""
@ -50,3 +52,34 @@ class TestResolvePath:
result = _resolve_path("a/../b/file.txt")
assert ".." not in str(result)
assert result == (tmp_path / "b" / "file.txt")
def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path):
"""Live env.cwd must win after the terminal session changes directory."""
start_dir = tmp_path / "start"
live_dir = tmp_path / "worktree"
start_dir.mkdir()
live_dir.mkdir()
monkeypatch.setenv("TERMINAL_CWD", str(start_dir))
from tools import file_tools
task_id = "live-cwd"
fake_ops = SimpleNamespace(
env=SimpleNamespace(cwd=str(live_dir)),
cwd=str(start_dir),
)
with file_tools._file_ops_lock:
previous = file_tools._file_ops_cache.get(task_id)
file_tools._file_ops_cache[task_id] = fake_ops
try:
result = file_tools._resolve_path("nested/file.txt", task_id=task_id)
finally:
with file_tools._file_ops_lock:
if previous is None:
file_tools._file_ops_cache.pop(task_id, None)
else:
file_tools._file_ops_cache[task_id] = previous
assert result == live_dir / "nested" / "file.txt"

View file

@ -79,13 +79,45 @@ _BLOCKED_DEVICE_PATHS = frozenset({
})
def _resolve_path(filepath: str) -> Path:
def _resolve_path(filepath: str, task_id: str = "default") -> Path:
"""Resolve a path relative to TERMINAL_CWD (the worktree base directory)
instead of the main repository root.
"""
return _resolve_path_for_task(filepath, task_id)
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
"""Return the task's live terminal cwd for bookkeeping when available."""
with _file_ops_lock:
cached = _file_ops_cache.get(task_id)
if cached is not None:
live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr(
cached, "cwd", None
)
if live_cwd:
return live_cwd
try:
from tools.terminal_tool import _active_environments, _env_lock
with _env_lock:
env = _active_environments.get(task_id)
live_cwd = getattr(env, "cwd", None) if env is not None else None
if live_cwd:
return live_cwd
except Exception:
pass
return None
def _resolve_path_for_task(filepath: str, task_id: str = "default") -> Path:
"""Resolve *filepath* against the task's live terminal cwd when possible."""
p = Path(filepath).expanduser()
if not p.is_absolute():
base = os.environ.get("TERMINAL_CWD", os.getcwd())
base = _get_live_tracking_cwd(task_id) or os.environ.get(
"TERMINAL_CWD", os.getcwd()
)
p = Path(base) / p
return p.resolve()
@ -118,10 +150,10 @@ _SENSITIVE_PATH_PREFIXES = (
_SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"}
def _check_sensitive_path(filepath: str) -> str | None:
def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None:
"""Return an error message if the path targets a sensitive system location."""
try:
resolved = str(_resolve_path(filepath))
resolved = str(_resolve_path_for_task(filepath, task_id))
except (OSError, ValueError):
resolved = filepath
normalized = os.path.normpath(os.path.expanduser(filepath))
@ -368,7 +400,7 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
),
})
_resolved = _resolve_path(path)
_resolved = _resolve_path_for_task(path, task_id)
# ── Binary file guard ─────────────────────────────────────────
# Block binary files by extension (no I/O).
@ -574,7 +606,7 @@ def _update_read_timestamp(filepath: str, task_id: str) -> None:
refreshes the stored timestamp to match the file's new state.
"""
try:
resolved = str(_resolve_path(filepath))
resolved = str(_resolve_path_for_task(filepath, task_id))
current_mtime = os.path.getmtime(resolved)
except (OSError, ValueError):
return
@ -593,7 +625,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
or was never read. Does not block the write still proceeds.
"""
try:
resolved = str(_resolve_path(filepath))
resolved = str(_resolve_path_for_task(filepath, task_id))
except (OSError, ValueError):
return None
with _read_tracker_lock:
@ -618,7 +650,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
"""Write content to a file."""
sensitive_err = _check_sensitive_path(path)
sensitive_err = _check_sensitive_path(path, task_id)
if sensitive_err:
return tool_error(sensitive_err)
try:
@ -626,7 +658,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
# fall back to the legacy path — write proceeds, per-task staleness
# check below still runs.
try:
_resolved = str(_resolve_path(path))
_resolved = str(_resolve_path_for_task(path, task_id))
except Exception:
_resolved = None
@ -681,7 +713,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE):
_paths_to_check.append(_m.group(1).strip())
for _p in _paths_to_check:
sensitive_err = _check_sensitive_path(_p)
sensitive_err = _check_sensitive_path(_p, task_id)
if sensitive_err:
return tool_error(sensitive_err)
try:
@ -692,7 +724,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_seen: set[str] = set()
for _p in _paths_to_check:
try:
_r = str(_resolve_path(_p))
_r = str(_resolve_path_for_task(_p, task_id))
except Exception:
_r = None
if _r and _r not in _seen:
@ -714,7 +746,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_path_to_resolved: dict[str, str] = {}
for _p in _paths_to_check:
try:
_r = str(_resolve_path(_p))
_r = str(_resolve_path_for_task(_p, task_id))
except Exception:
_r = None
_path_to_resolved[_p] = _r
@ -749,15 +781,17 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
_r = _path_to_resolved.get(_p)
if _r:
file_state.note_write(task_id, _r)
result_json = json.dumps(result_dict, ensure_ascii=False)
# Hint when old_string not found — saves iterations where the agent
# retries with stale content instead of re-reading the file.
# Suppressed when patch_replace already attached a rich "Did you mean?"
# snippet (which is strictly more useful than the generic hint).
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
if "Did you mean one of these sections?" not in str(result_dict["error"]):
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]"
return result_json
result_dict["_hint"] = (
"old_string not found. Use read_file to verify the current "
"content, or search_files to locate the text."
)
return json.dumps(result_dict, ensure_ascii=False)
except Exception as e:
return tool_error(str(e))