diff --git a/tests/tools/test_file_staleness.py b/tests/tools/test_file_staleness.py new file mode 100644 index 000000000..46e7aac9f --- /dev/null +++ b/tests/tools/test_file_staleness.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +""" +Tests for file staleness detection in write_file and patch. + +When a file is modified externally between the agent's read and write, +the write should include a warning so the agent can re-read and verify. + +Run with: python -m pytest tests/tools/test_file_staleness.py -v +""" + +import json +import os +import tempfile +import time +import unittest +from unittest.mock import patch, MagicMock + +from tools.file_tools import ( + read_file_tool, + write_file_tool, + patch_tool, + clear_read_tracker, + _check_file_staleness, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _FakeReadResult: + def __init__(self, content="line1\nline2\n", total_lines=2, file_size=100): + self.content = content + self._total_lines = total_lines + self._file_size = file_size + + def to_dict(self): + return { + "content": self.content, + "total_lines": self._total_lines, + "file_size": self._file_size, + } + + +class _FakeWriteResult: + def __init__(self): + self.bytes_written = 10 + + def to_dict(self): + return {"bytes_written": self.bytes_written} + + +class _FakePatchResult: + def __init__(self): + self.success = True + + def to_dict(self): + return {"success": True, "diff": "--- a\n+++ b\n@@ ...\n"} + + +def _make_fake_ops(read_content="hello\n", file_size=6): + fake = MagicMock() + fake.read_file = lambda path, offset=1, limit=500: _FakeReadResult( + content=read_content, total_lines=1, file_size=file_size, + ) + fake.write_file = lambda path, content: _FakeWriteResult() + fake.patch_replace = lambda path, old, new, replace_all=False: _FakePatchResult() + return fake + + +# --------------------------------------------------------------------------- +# Core staleness check +# --------------------------------------------------------------------------- + +class TestStalenessCheck(unittest.TestCase): + + def setUp(self): + clear_read_tracker() + self._tmpdir = tempfile.mkdtemp() + self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt") + with open(self._tmpfile, "w") as f: + f.write("original content\n") + + def tearDown(self): + clear_read_tracker() + try: + os.unlink(self._tmpfile) + os.rmdir(self._tmpdir) + except OSError: + pass + + @patch("tools.file_tools._get_file_ops") + def test_no_warning_when_file_unchanged(self, mock_ops): + """Read then write with no external modification — no warning.""" + mock_ops.return_value = _make_fake_ops("original content\n", 18) + read_file_tool(self._tmpfile, task_id="t1") + + result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t1")) + self.assertNotIn("_warning", result) + + @patch("tools.file_tools._get_file_ops") + def test_warning_when_file_modified_externally(self, mock_ops): + """Read, then external modify, then write — should warn.""" + mock_ops.return_value = _make_fake_ops("original content\n", 18) + read_file_tool(self._tmpfile, task_id="t1") + + # Simulate external modification + time.sleep(0.05) + with open(self._tmpfile, "w") as f: + f.write("someone else changed this\n") + + result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t1")) + self.assertIn("_warning", result) + self.assertIn("modified since you last read", result["_warning"]) + + @patch("tools.file_tools._get_file_ops") + def test_no_warning_when_file_never_read(self, mock_ops): + """Writing a file that was never read — no warning.""" + mock_ops.return_value = _make_fake_ops() + result = json.loads(write_file_tool(self._tmpfile, "new content", task_id="t2")) + self.assertNotIn("_warning", result) + + @patch("tools.file_tools._get_file_ops") + def test_no_warning_for_new_file(self, mock_ops): + """Creating a new file — no warning.""" + mock_ops.return_value = _make_fake_ops() + new_path = os.path.join(self._tmpdir, "brand_new.txt") + result = json.loads(write_file_tool(new_path, "content", task_id="t3")) + self.assertNotIn("_warning", result) + try: + os.unlink(new_path) + except OSError: + pass + + @patch("tools.file_tools._get_file_ops") + def test_different_task_isolated(self, mock_ops): + """Task A reads, file changes, Task B writes — no warning for B.""" + mock_ops.return_value = _make_fake_ops("original content\n", 18) + read_file_tool(self._tmpfile, task_id="task_a") + + time.sleep(0.05) + with open(self._tmpfile, "w") as f: + f.write("changed\n") + + result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b")) + self.assertNotIn("_warning", result) + + +# --------------------------------------------------------------------------- +# Staleness in patch +# --------------------------------------------------------------------------- + +class TestPatchStaleness(unittest.TestCase): + + def setUp(self): + clear_read_tracker() + self._tmpdir = tempfile.mkdtemp() + self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt") + with open(self._tmpfile, "w") as f: + f.write("original line\n") + + def tearDown(self): + clear_read_tracker() + try: + os.unlink(self._tmpfile) + os.rmdir(self._tmpdir) + except OSError: + pass + + @patch("tools.file_tools._get_file_ops") + def test_patch_warns_on_stale_file(self, mock_ops): + """Patch should warn if the target file changed since last read.""" + mock_ops.return_value = _make_fake_ops("original line\n", 15) + read_file_tool(self._tmpfile, task_id="p1") + + time.sleep(0.05) + with open(self._tmpfile, "w") as f: + f.write("externally modified\n") + + result = json.loads(patch_tool( + mode="replace", path=self._tmpfile, + old_string="original", new_string="patched", + task_id="p1", + )) + self.assertIn("_warning", result) + self.assertIn("modified since you last read", result["_warning"]) + + @patch("tools.file_tools._get_file_ops") + def test_patch_no_warning_when_fresh(self, mock_ops): + """Patch with no external changes — no warning.""" + mock_ops.return_value = _make_fake_ops("original line\n", 15) + read_file_tool(self._tmpfile, task_id="p2") + + result = json.loads(patch_tool( + mode="replace", path=self._tmpfile, + old_string="original", new_string="patched", + task_id="p2", + )) + self.assertNotIn("_warning", result) + + +# --------------------------------------------------------------------------- +# Unit test for the helper +# --------------------------------------------------------------------------- + +class TestCheckFileStalenessHelper(unittest.TestCase): + + def setUp(self): + clear_read_tracker() + + def tearDown(self): + clear_read_tracker() + + def test_returns_none_for_unknown_task(self): + self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent")) + + def test_returns_none_for_unread_file(self): + # Populate tracker with a different file + from tools.file_tools import _read_tracker, _read_tracker_lock + with _read_tracker_lock: + _read_tracker["t1"] = { + "last_key": None, "consecutive": 0, + "read_history": set(), "dedup": {}, + "file_mtimes": {"/tmp/other.py": 12345.0}, + } + self.assertIsNone(_check_file_staleness("/tmp/x.py", "t1")) + + def test_returns_none_when_stat_fails(self): + from tools.file_tools import _read_tracker, _read_tracker_lock + with _read_tracker_lock: + _read_tracker["t1"] = { + "last_key": None, "consecutive": 0, + "read_history": set(), "dedup": {}, + "file_mtimes": {"/nonexistent/path": 99999.0}, + } + # File doesn't exist → stat fails → returns None (let write handle it) + self.assertIsNone(_check_file_staleness("/nonexistent/path", "t1")) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/file_tools.py b/tools/file_tools.py index 1245e68de..07fb86d1a 100644 --- a/tools/file_tools.py +++ b/tools/file_tools.py @@ -136,6 +136,9 @@ _file_ops_cache: dict = {} # Used to skip re-reads of unchanged files. Reset on # context compression (the original content is summarised # away so the model needs the full content again). +# "file_mtimes": dict mapping resolved_path → mtime float at last read. +# Used by write_file and patch to detect when a file was +# modified externally between the agent's read and write. _read_tracker_lock = threading.Lock() _read_tracker: dict = {} @@ -391,14 +394,16 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str = task_data["consecutive"] = 1 count = task_data["consecutive"] - # Store dedup entry (mtime at read time). - # Writes/patches will naturally change mtime, so subsequent - # dedup checks after edits will see a different mtime and - # return the full content — no special handling needed. + # Store mtime at read time for two purposes: + # 1. Dedup: skip identical re-reads of unchanged files. + # 2. Staleness: warn on write/patch if the file changed since + # the agent last read it (external edit, concurrent agent, etc.). try: - task_data["dedup"][dedup_key] = os.path.getmtime(resolved_str) + _mtime_now = os.path.getmtime(resolved_str) + task_data["dedup"][dedup_key] = _mtime_now + task_data.setdefault("file_mtimes", {})[resolved_str] = _mtime_now except OSError: - pass # Can't stat — skip dedup for this entry + pass # Can't stat — skip tracking for this entry if count >= 4: # Hard block: stop returning content to break the loop @@ -495,15 +500,50 @@ def notify_other_tool_call(task_id: str = "default"): task_data["consecutive"] = 0 +def _check_file_staleness(filepath: str, task_id: str) -> str | None: + """Check whether a file was modified since the agent last read it. + + Returns a warning string if the file is stale (mtime changed since + the last read_file call for this task), or None if the file is fresh + or was never read. Does not block — the write still proceeds. + """ + try: + resolved = str(Path(filepath).expanduser().resolve()) + except (OSError, ValueError): + return None + with _read_tracker_lock: + task_data = _read_tracker.get(task_id) + if not task_data: + return None + read_mtime = task_data.get("file_mtimes", {}).get(resolved) + if read_mtime is None: + return None # File was never read — nothing to compare against + try: + current_mtime = os.path.getmtime(resolved) + except OSError: + return None # Can't stat — file may have been deleted, let write handle it + if current_mtime != read_mtime: + return ( + f"Warning: {filepath} was modified since you last read it " + "(external edit or concurrent agent). The content you read may be " + "stale. Consider re-reading the file to verify before writing." + ) + return None + + def write_file_tool(path: str, content: str, task_id: str = "default") -> str: """Write content to a file.""" sensitive_err = _check_sensitive_path(path) if sensitive_err: return json.dumps({"error": sensitive_err}, ensure_ascii=False) try: + stale_warning = _check_file_staleness(path, task_id) file_ops = _get_file_ops(task_id) result = file_ops.write_file(path, content) - return json.dumps(result.to_dict(), ensure_ascii=False) + result_dict = result.to_dict() + if stale_warning: + result_dict["_warning"] = stale_warning + return json.dumps(result_dict, ensure_ascii=False) except Exception as e: if _is_expected_write_exception(e): logger.debug("write_file expected denial: %s: %s", type(e).__name__, e) @@ -529,6 +569,13 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, if sensitive_err: return json.dumps({"error": sensitive_err}, ensure_ascii=False) try: + # Check staleness for all files this patch will touch. + stale_warnings = [] + for _p in _paths_to_check: + _sw = _check_file_staleness(_p, task_id) + if _sw: + stale_warnings.append(_sw) + file_ops = _get_file_ops(task_id) if mode == "replace": @@ -545,6 +592,8 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None, return json.dumps({"error": f"Unknown mode: {mode}"}) result_dict = result.to_dict() + if stale_warnings: + result_dict["_warning"] = stale_warnings[0] if len(stale_warnings) == 1 else " | ".join(stale_warnings) result_json = json.dumps(result_dict, ensure_ascii=False) # Hint when old_string not found — saves iterations where the agent # retries with stale content instead of re-reading the file.