diff --git a/tests/tools/test_file_staleness.py b/tests/tools/test_file_staleness.py index 4d9136125..dc5a1e7bd 100644 --- a/tests/tools/test_file_staleness.py +++ b/tests/tools/test_file_staleness.py @@ -13,8 +13,10 @@ import os import tempfile import time import unittest +from types import SimpleNamespace from unittest.mock import patch, MagicMock +from tools import file_state from tools.file_tools import ( read_file_tool, write_file_tool, @@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase): def setUp(self): _read_tracker.clear() + file_state.get_registry().clear() self._tmpdir = tempfile.mkdtemp() self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt") with open(self._tmpfile, "w") as f: @@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase): def tearDown(self): _read_tracker.clear() + file_state.get_registry().clear() try: os.unlink(self._tmpfile) os.rmdir(self._tmpdir) @@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase): result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b")) self.assertNotIn("_warning", result) + @patch("tools.file_tools._get_file_ops") + def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops): + """Relative-path stale tracking must follow the live terminal cwd.""" + start_dir = os.path.join(self._tmpdir, "start") + live_dir = os.path.join(self._tmpdir, "worktree") + os.makedirs(start_dir, exist_ok=True) + os.makedirs(live_dir, exist_ok=True) + + start_file = os.path.join(start_dir, "shared.txt") + live_file = os.path.join(live_dir, "shared.txt") + with open(start_file, "w") as f: + f.write("start copy\n") + with open(live_file, "w") as f: + f.write("live copy\n") + + fake_ops = _make_fake_ops("live copy\n", 10) + fake_ops.env = SimpleNamespace(cwd=live_dir) + fake_ops.cwd = start_dir + mock_ops.return_value = fake_ops + + from tools import file_tools + + with file_tools._file_ops_lock: + previous = file_tools._file_ops_cache.get("live_task") + file_tools._file_ops_cache["live_task"] = fake_ops + + try: + with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False): + read_file_tool("shared.txt", task_id="live_task") + + time.sleep(0.05) + with open(live_file, "w") as f: + f.write("live copy modified elsewhere\n") + + result = json.loads( + write_file_tool("shared.txt", "replacement", task_id="live_task") + ) + finally: + with file_tools._file_ops_lock: + if previous is None: + file_tools._file_ops_cache.pop("live_task", None) + else: + file_tools._file_ops_cache["live_task"] = previous + + self.assertIn("_warning", result) + self.assertIn("modified since you last read", result["_warning"]) + # --------------------------------------------------------------------------- # Staleness in patch @@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase): def setUp(self): _read_tracker.clear() + file_state.get_registry().clear() self._tmpdir = tempfile.mkdtemp() self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt") with open(self._tmpfile, "w") as f: @@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase): def tearDown(self): _read_tracker.clear() + file_state.get_registry().clear() try: os.unlink(self._tmpfile) os.rmdir(self._tmpdir) @@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase): def setUp(self): _read_tracker.clear() + file_state.get_registry().clear() def tearDown(self): _read_tracker.clear() + file_state.get_registry().clear() def test_returns_none_for_unknown_task(self): self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent")) diff --git a/tests/tools/test_resolve_path.py b/tests/tools/test_resolve_path.py index beea3cc40..cd4d86896 100644 --- a/tests/tools/test_resolve_path.py +++ b/tests/tools/test_resolve_path.py @@ -2,6 +2,7 @@ import os from pathlib import Path +from types import SimpleNamespace import pytest @@ -22,8 +23,9 @@ class TestResolvePath: monkeypatch.setenv("TERMINAL_CWD", str(tmp_path)) from tools.file_tools import _resolve_path - result = _resolve_path("/etc/hosts") - assert result == Path("/etc/hosts") + absolute = (tmp_path / "already-absolute.txt").resolve() + result = _resolve_path(str(absolute)) + assert result == absolute def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch): """Without TERMINAL_CWD, falls back to os.getcwd().""" @@ -50,3 +52,34 @@ class TestResolvePath: result = _resolve_path("a/../b/file.txt") assert ".." not in str(result) assert result == (tmp_path / "b" / "file.txt") + + def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path): + """Live env.cwd must win after the terminal session changes directory.""" + start_dir = tmp_path / "start" + live_dir = tmp_path / "worktree" + start_dir.mkdir() + live_dir.mkdir() + monkeypatch.setenv("TERMINAL_CWD", str(start_dir)) + + from tools import file_tools + + task_id = "live-cwd" + fake_ops = SimpleNamespace( + env=SimpleNamespace(cwd=str(live_dir)), + cwd=str(start_dir), + ) + + with file_tools._file_ops_lock: + previous = file_tools._file_ops_cache.get(task_id) + file_tools._file_ops_cache[task_id] = fake_ops + + try: + result = file_tools._resolve_path("nested/file.txt", task_id=task_id) + finally: + with file_tools._file_ops_lock: + if previous is None: + file_tools._file_ops_cache.pop(task_id, None) + else: + file_tools._file_ops_cache[task_id] = previous + + assert result == live_dir / "nested" / "file.txt" diff --git a/tools/file_tools.py b/tools/file_tools.py index 3b6f45942..609506c05 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -79,13 +79,45 @@ _BLOCKED_DEVICE_PATHS = frozenset({ }) -def _resolve_path(filepath: str) -> Path: +def _resolve_path(filepath: str, task_id: str = "default") -> Path: """Resolve a path relative to TERMINAL_CWD (the worktree base directory) instead of the main repository root. """ + return _resolve_path_for_task(filepath, task_id) + + +def _get_live_tracking_cwd(task_id: str = "default") -> str | None: + """Return the task's live terminal cwd for bookkeeping when available.""" + with _file_ops_lock: + cached = _file_ops_cache.get(task_id) + if cached is not None: + live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr( + cached, "cwd", None + ) + if live_cwd: + return live_cwd + + try: + from tools.terminal_tool import _active_environments, _env_lock + + with _env_lock: + env = _active_environments.get(task_id) + live_cwd = getattr(env, "cwd", None) if env is not None else None + if live_cwd: + return live_cwd + except Exception: + pass + + return None + + +def _resolve_path_for_task(filepath: str, task_id: str = "default") -> Path: + """Resolve *filepath* against the task's live terminal cwd when possible.""" p = Path(filepath).expanduser() if not p.is_absolute(): - base = os.environ.get("TERMINAL_CWD", os.getcwd()) + base = _get_live_tracking_cwd(task_id) or os.environ.get( + "TERMINAL_CWD", os.getcwd() + ) p = Path(base) / p return p.resolve() @@ -118,10 +150,10 @@ _SENSITIVE_PATH_PREFIXES = ( _SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"} -def _check_sensitive_path(filepath: str) -> str | None: +def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None: """Return an error message if the path targets a sensitive system location.""" try: - resolved = str(_resolve_path(filepath)) + resolved = str(_resolve_path_for_task(filepath, task_id)) except (OSError, ValueError): resolved = filepath normalized = os.path.normpath(os.path.expanduser(filepath)) @@ -368,7 +400,7 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = ), }) - _resolved = _resolve_path(path) + _resolved = _resolve_path_for_task(path, task_id) # ── Binary file guard ───────────────────────────────────────── # Block binary files by extension (no I/O). @@ -574,7 +606,7 @@ def _update_read_timestamp(filepath: str, task_id: str) -> None: refreshes the stored timestamp to match the file's new state. """ try: - resolved = str(_resolve_path(filepath)) + resolved = str(_resolve_path_for_task(filepath, task_id)) current_mtime = os.path.getmtime(resolved) except (OSError, ValueError): return @@ -593,7 +625,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None: or was never read. Does not block — the write still proceeds. """ try: - resolved = str(_resolve_path(filepath)) + resolved = str(_resolve_path_for_task(filepath, task_id)) except (OSError, ValueError): return None with _read_tracker_lock: @@ -618,7 +650,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None: def write_file_tool(path: str, content: str, task_id: str = "default") -> str: """Write content to a file.""" - sensitive_err = _check_sensitive_path(path) + sensitive_err = _check_sensitive_path(path, task_id) if sensitive_err: return tool_error(sensitive_err) try: @@ -626,7 +658,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str: # fall back to the legacy path — write proceeds, per-task staleness # check below still runs. try: - _resolved = str(_resolve_path(path)) + _resolved = str(_resolve_path_for_task(path, task_id)) except Exception: _resolved = None @@ -681,7 +713,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE): _paths_to_check.append(_m.group(1).strip()) for _p in _paths_to_check: - sensitive_err = _check_sensitive_path(_p) + sensitive_err = _check_sensitive_path(_p, task_id) if sensitive_err: return tool_error(sensitive_err) try: @@ -692,7 +724,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, _seen: set[str] = set() for _p in _paths_to_check: try: - _r = str(_resolve_path(_p)) + _r = str(_resolve_path_for_task(_p, task_id)) except Exception: _r = None if _r and _r not in _seen: @@ -714,7 +746,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, _path_to_resolved: dict[str, str] = {} for _p in _paths_to_check: try: - _r = str(_resolve_path(_p)) + _r = str(_resolve_path_for_task(_p, task_id)) except Exception: _r = None _path_to_resolved[_p] = _r @@ -749,15 +781,17 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, _r = _path_to_resolved.get(_p) if _r: file_state.note_write(task_id, _r) - result_json = json.dumps(result_dict, ensure_ascii=False) # Hint when old_string not found — saves iterations where the agent # retries with stale content instead of re-reading the file. # Suppressed when patch_replace already attached a rich "Did you mean?" # snippet (which is strictly more useful than the generic hint). if result_dict.get("error") and "Could not find" in str(result_dict["error"]): if "Did you mean one of these sections?" not in str(result_dict["error"]): - result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]" - return result_json + result_dict["_hint"] = ( + "old_string not found. Use read_file to verify the current " + "content, or search_files to locate the text." + ) + return json.dumps(result_dict, ensure_ascii=False) except Exception as e: return tool_error(str(e))