mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(file_tools): resolve bookkeeping paths against live terminal cwd
This commit is contained in:
parent
83859b4da0
commit
4a0c02b7dc
3 changed files with 139 additions and 17 deletions
|
|
@ -13,8 +13,10 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import patch, MagicMock
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from tools import file_state
|
||||||
from tools.file_tools import (
|
from tools.file_tools import (
|
||||||
read_file_tool,
|
read_file_tool,
|
||||||
write_file_tool,
|
write_file_tool,
|
||||||
|
|
@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
self._tmpdir = tempfile.mkdtemp()
|
self._tmpdir = tempfile.mkdtemp()
|
||||||
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
|
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
|
||||||
with open(self._tmpfile, "w") as f:
|
with open(self._tmpfile, "w") as f:
|
||||||
|
|
@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase):
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
try:
|
try:
|
||||||
os.unlink(self._tmpfile)
|
os.unlink(self._tmpfile)
|
||||||
os.rmdir(self._tmpdir)
|
os.rmdir(self._tmpdir)
|
||||||
|
|
@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase):
|
||||||
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
|
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
|
||||||
self.assertNotIn("_warning", result)
|
self.assertNotIn("_warning", result)
|
||||||
|
|
||||||
|
@patch("tools.file_tools._get_file_ops")
|
||||||
|
def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops):
|
||||||
|
"""Relative-path stale tracking must follow the live terminal cwd."""
|
||||||
|
start_dir = os.path.join(self._tmpdir, "start")
|
||||||
|
live_dir = os.path.join(self._tmpdir, "worktree")
|
||||||
|
os.makedirs(start_dir, exist_ok=True)
|
||||||
|
os.makedirs(live_dir, exist_ok=True)
|
||||||
|
|
||||||
|
start_file = os.path.join(start_dir, "shared.txt")
|
||||||
|
live_file = os.path.join(live_dir, "shared.txt")
|
||||||
|
with open(start_file, "w") as f:
|
||||||
|
f.write("start copy\n")
|
||||||
|
with open(live_file, "w") as f:
|
||||||
|
f.write("live copy\n")
|
||||||
|
|
||||||
|
fake_ops = _make_fake_ops("live copy\n", 10)
|
||||||
|
fake_ops.env = SimpleNamespace(cwd=live_dir)
|
||||||
|
fake_ops.cwd = start_dir
|
||||||
|
mock_ops.return_value = fake_ops
|
||||||
|
|
||||||
|
from tools import file_tools
|
||||||
|
|
||||||
|
with file_tools._file_ops_lock:
|
||||||
|
previous = file_tools._file_ops_cache.get("live_task")
|
||||||
|
file_tools._file_ops_cache["live_task"] = fake_ops
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False):
|
||||||
|
read_file_tool("shared.txt", task_id="live_task")
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
with open(live_file, "w") as f:
|
||||||
|
f.write("live copy modified elsewhere\n")
|
||||||
|
|
||||||
|
result = json.loads(
|
||||||
|
write_file_tool("shared.txt", "replacement", task_id="live_task")
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
with file_tools._file_ops_lock:
|
||||||
|
if previous is None:
|
||||||
|
file_tools._file_ops_cache.pop("live_task", None)
|
||||||
|
else:
|
||||||
|
file_tools._file_ops_cache["live_task"] = previous
|
||||||
|
|
||||||
|
self.assertIn("_warning", result)
|
||||||
|
self.assertIn("modified since you last read", result["_warning"])
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Staleness in patch
|
# Staleness in patch
|
||||||
|
|
@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
self._tmpdir = tempfile.mkdtemp()
|
self._tmpdir = tempfile.mkdtemp()
|
||||||
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
|
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
|
||||||
with open(self._tmpfile, "w") as f:
|
with open(self._tmpfile, "w") as f:
|
||||||
|
|
@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase):
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
try:
|
try:
|
||||||
os.unlink(self._tmpfile)
|
os.unlink(self._tmpfile)
|
||||||
os.rmdir(self._tmpdir)
|
os.rmdir(self._tmpdir)
|
||||||
|
|
@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase):
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
_read_tracker.clear()
|
_read_tracker.clear()
|
||||||
|
file_state.get_registry().clear()
|
||||||
|
|
||||||
def test_returns_none_for_unknown_task(self):
|
def test_returns_none_for_unknown_task(self):
|
||||||
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))
|
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
@ -22,8 +23,9 @@ class TestResolvePath:
|
||||||
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
|
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
|
||||||
from tools.file_tools import _resolve_path
|
from tools.file_tools import _resolve_path
|
||||||
|
|
||||||
result = _resolve_path("/etc/hosts")
|
absolute = (tmp_path / "already-absolute.txt").resolve()
|
||||||
assert result == Path("/etc/hosts")
|
result = _resolve_path(str(absolute))
|
||||||
|
assert result == absolute
|
||||||
|
|
||||||
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
|
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
|
||||||
"""Without TERMINAL_CWD, falls back to os.getcwd()."""
|
"""Without TERMINAL_CWD, falls back to os.getcwd()."""
|
||||||
|
|
@ -50,3 +52,34 @@ class TestResolvePath:
|
||||||
result = _resolve_path("a/../b/file.txt")
|
result = _resolve_path("a/../b/file.txt")
|
||||||
assert ".." not in str(result)
|
assert ".." not in str(result)
|
||||||
assert result == (tmp_path / "b" / "file.txt")
|
assert result == (tmp_path / "b" / "file.txt")
|
||||||
|
|
||||||
|
def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path):
|
||||||
|
"""Live env.cwd must win after the terminal session changes directory."""
|
||||||
|
start_dir = tmp_path / "start"
|
||||||
|
live_dir = tmp_path / "worktree"
|
||||||
|
start_dir.mkdir()
|
||||||
|
live_dir.mkdir()
|
||||||
|
monkeypatch.setenv("TERMINAL_CWD", str(start_dir))
|
||||||
|
|
||||||
|
from tools import file_tools
|
||||||
|
|
||||||
|
task_id = "live-cwd"
|
||||||
|
fake_ops = SimpleNamespace(
|
||||||
|
env=SimpleNamespace(cwd=str(live_dir)),
|
||||||
|
cwd=str(start_dir),
|
||||||
|
)
|
||||||
|
|
||||||
|
with file_tools._file_ops_lock:
|
||||||
|
previous = file_tools._file_ops_cache.get(task_id)
|
||||||
|
file_tools._file_ops_cache[task_id] = fake_ops
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = file_tools._resolve_path("nested/file.txt", task_id=task_id)
|
||||||
|
finally:
|
||||||
|
with file_tools._file_ops_lock:
|
||||||
|
if previous is None:
|
||||||
|
file_tools._file_ops_cache.pop(task_id, None)
|
||||||
|
else:
|
||||||
|
file_tools._file_ops_cache[task_id] = previous
|
||||||
|
|
||||||
|
assert result == live_dir / "nested" / "file.txt"
|
||||||
|
|
|
||||||
|
|
@ -79,13 +79,45 @@ _BLOCKED_DEVICE_PATHS = frozenset({
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(filepath: str) -> Path:
|
def _resolve_path(filepath: str, task_id: str = "default") -> Path:
|
||||||
"""Resolve a path relative to TERMINAL_CWD (the worktree base directory)
|
"""Resolve a path relative to TERMINAL_CWD (the worktree base directory)
|
||||||
instead of the main repository root.
|
instead of the main repository root.
|
||||||
"""
|
"""
|
||||||
|
return _resolve_path_for_task(filepath, task_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
|
||||||
|
"""Return the task's live terminal cwd for bookkeeping when available."""
|
||||||
|
with _file_ops_lock:
|
||||||
|
cached = _file_ops_cache.get(task_id)
|
||||||
|
if cached is not None:
|
||||||
|
live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr(
|
||||||
|
cached, "cwd", None
|
||||||
|
)
|
||||||
|
if live_cwd:
|
||||||
|
return live_cwd
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tools.terminal_tool import _active_environments, _env_lock
|
||||||
|
|
||||||
|
with _env_lock:
|
||||||
|
env = _active_environments.get(task_id)
|
||||||
|
live_cwd = getattr(env, "cwd", None) if env is not None else None
|
||||||
|
if live_cwd:
|
||||||
|
return live_cwd
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path_for_task(filepath: str, task_id: str = "default") -> Path:
|
||||||
|
"""Resolve *filepath* against the task's live terminal cwd when possible."""
|
||||||
p = Path(filepath).expanduser()
|
p = Path(filepath).expanduser()
|
||||||
if not p.is_absolute():
|
if not p.is_absolute():
|
||||||
base = os.environ.get("TERMINAL_CWD", os.getcwd())
|
base = _get_live_tracking_cwd(task_id) or os.environ.get(
|
||||||
|
"TERMINAL_CWD", os.getcwd()
|
||||||
|
)
|
||||||
p = Path(base) / p
|
p = Path(base) / p
|
||||||
return p.resolve()
|
return p.resolve()
|
||||||
|
|
||||||
|
|
@ -118,10 +150,10 @@ _SENSITIVE_PATH_PREFIXES = (
|
||||||
_SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"}
|
_SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"}
|
||||||
|
|
||||||
|
|
||||||
def _check_sensitive_path(filepath: str) -> str | None:
|
def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None:
|
||||||
"""Return an error message if the path targets a sensitive system location."""
|
"""Return an error message if the path targets a sensitive system location."""
|
||||||
try:
|
try:
|
||||||
resolved = str(_resolve_path(filepath))
|
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
resolved = filepath
|
resolved = filepath
|
||||||
normalized = os.path.normpath(os.path.expanduser(filepath))
|
normalized = os.path.normpath(os.path.expanduser(filepath))
|
||||||
|
|
@ -368,7 +400,7 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
|
|
||||||
_resolved = _resolve_path(path)
|
_resolved = _resolve_path_for_task(path, task_id)
|
||||||
|
|
||||||
# ── Binary file guard ─────────────────────────────────────────
|
# ── Binary file guard ─────────────────────────────────────────
|
||||||
# Block binary files by extension (no I/O).
|
# Block binary files by extension (no I/O).
|
||||||
|
|
@ -574,7 +606,7 @@ def _update_read_timestamp(filepath: str, task_id: str) -> None:
|
||||||
refreshes the stored timestamp to match the file's new state.
|
refreshes the stored timestamp to match the file's new state.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
resolved = str(_resolve_path(filepath))
|
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||||
current_mtime = os.path.getmtime(resolved)
|
current_mtime = os.path.getmtime(resolved)
|
||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
return
|
return
|
||||||
|
|
@ -593,7 +625,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
|
||||||
or was never read. Does not block — the write still proceeds.
|
or was never read. Does not block — the write still proceeds.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
resolved = str(_resolve_path(filepath))
|
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||||
except (OSError, ValueError):
|
except (OSError, ValueError):
|
||||||
return None
|
return None
|
||||||
with _read_tracker_lock:
|
with _read_tracker_lock:
|
||||||
|
|
@ -618,7 +650,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
|
||||||
|
|
||||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||||
"""Write content to a file."""
|
"""Write content to a file."""
|
||||||
sensitive_err = _check_sensitive_path(path)
|
sensitive_err = _check_sensitive_path(path, task_id)
|
||||||
if sensitive_err:
|
if sensitive_err:
|
||||||
return tool_error(sensitive_err)
|
return tool_error(sensitive_err)
|
||||||
try:
|
try:
|
||||||
|
|
@ -626,7 +658,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||||
# fall back to the legacy path — write proceeds, per-task staleness
|
# fall back to the legacy path — write proceeds, per-task staleness
|
||||||
# check below still runs.
|
# check below still runs.
|
||||||
try:
|
try:
|
||||||
_resolved = str(_resolve_path(path))
|
_resolved = str(_resolve_path_for_task(path, task_id))
|
||||||
except Exception:
|
except Exception:
|
||||||
_resolved = None
|
_resolved = None
|
||||||
|
|
||||||
|
|
@ -681,7 +713,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||||
for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE):
|
for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE):
|
||||||
_paths_to_check.append(_m.group(1).strip())
|
_paths_to_check.append(_m.group(1).strip())
|
||||||
for _p in _paths_to_check:
|
for _p in _paths_to_check:
|
||||||
sensitive_err = _check_sensitive_path(_p)
|
sensitive_err = _check_sensitive_path(_p, task_id)
|
||||||
if sensitive_err:
|
if sensitive_err:
|
||||||
return tool_error(sensitive_err)
|
return tool_error(sensitive_err)
|
||||||
try:
|
try:
|
||||||
|
|
@ -692,7 +724,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||||
_seen: set[str] = set()
|
_seen: set[str] = set()
|
||||||
for _p in _paths_to_check:
|
for _p in _paths_to_check:
|
||||||
try:
|
try:
|
||||||
_r = str(_resolve_path(_p))
|
_r = str(_resolve_path_for_task(_p, task_id))
|
||||||
except Exception:
|
except Exception:
|
||||||
_r = None
|
_r = None
|
||||||
if _r and _r not in _seen:
|
if _r and _r not in _seen:
|
||||||
|
|
@ -714,7 +746,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||||
_path_to_resolved: dict[str, str] = {}
|
_path_to_resolved: dict[str, str] = {}
|
||||||
for _p in _paths_to_check:
|
for _p in _paths_to_check:
|
||||||
try:
|
try:
|
||||||
_r = str(_resolve_path(_p))
|
_r = str(_resolve_path_for_task(_p, task_id))
|
||||||
except Exception:
|
except Exception:
|
||||||
_r = None
|
_r = None
|
||||||
_path_to_resolved[_p] = _r
|
_path_to_resolved[_p] = _r
|
||||||
|
|
@ -749,15 +781,17 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
||||||
_r = _path_to_resolved.get(_p)
|
_r = _path_to_resolved.get(_p)
|
||||||
if _r:
|
if _r:
|
||||||
file_state.note_write(task_id, _r)
|
file_state.note_write(task_id, _r)
|
||||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
|
||||||
# Hint when old_string not found — saves iterations where the agent
|
# Hint when old_string not found — saves iterations where the agent
|
||||||
# retries with stale content instead of re-reading the file.
|
# retries with stale content instead of re-reading the file.
|
||||||
# Suppressed when patch_replace already attached a rich "Did you mean?"
|
# Suppressed when patch_replace already attached a rich "Did you mean?"
|
||||||
# snippet (which is strictly more useful than the generic hint).
|
# snippet (which is strictly more useful than the generic hint).
|
||||||
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
|
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
|
||||||
if "Did you mean one of these sections?" not in str(result_dict["error"]):
|
if "Did you mean one of these sections?" not in str(result_dict["error"]):
|
||||||
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]"
|
result_dict["_hint"] = (
|
||||||
return result_json
|
"old_string not found. Use read_file to verify the current "
|
||||||
|
"content, or search_files to locate the text."
|
||||||
|
)
|
||||||
|
return json.dumps(result_dict, ensure_ascii=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return tool_error(str(e))
|
return tool_error(str(e))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue