mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(file_tools): resolve bookkeeping paths against live terminal cwd
This commit is contained in:
parent
83859b4da0
commit
4a0c02b7dc
3 changed files with 139 additions and 17 deletions
|
|
@ -13,8 +13,10 @@ import os
|
|||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from tools import file_state
|
||||
from tools.file_tools import (
|
||||
read_file_tool,
|
||||
write_file_tool,
|
||||
|
|
@ -76,6 +78,7 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "stale_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
|
|
@ -83,6 +86,7 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
|
|
@ -145,6 +149,53 @@ class TestStalenessCheck(unittest.TestCase):
|
|||
result = json.loads(write_file_tool(self._tmpfile, "new", task_id="task_b"))
|
||||
self.assertNotIn("_warning", result)
|
||||
|
||||
@patch("tools.file_tools._get_file_ops")
|
||||
def test_relative_path_uses_live_cwd_for_staleness_tracking(self, mock_ops):
|
||||
"""Relative-path stale tracking must follow the live terminal cwd."""
|
||||
start_dir = os.path.join(self._tmpdir, "start")
|
||||
live_dir = os.path.join(self._tmpdir, "worktree")
|
||||
os.makedirs(start_dir, exist_ok=True)
|
||||
os.makedirs(live_dir, exist_ok=True)
|
||||
|
||||
start_file = os.path.join(start_dir, "shared.txt")
|
||||
live_file = os.path.join(live_dir, "shared.txt")
|
||||
with open(start_file, "w") as f:
|
||||
f.write("start copy\n")
|
||||
with open(live_file, "w") as f:
|
||||
f.write("live copy\n")
|
||||
|
||||
fake_ops = _make_fake_ops("live copy\n", 10)
|
||||
fake_ops.env = SimpleNamespace(cwd=live_dir)
|
||||
fake_ops.cwd = start_dir
|
||||
mock_ops.return_value = fake_ops
|
||||
|
||||
from tools import file_tools
|
||||
|
||||
with file_tools._file_ops_lock:
|
||||
previous = file_tools._file_ops_cache.get("live_task")
|
||||
file_tools._file_ops_cache["live_task"] = fake_ops
|
||||
|
||||
try:
|
||||
with patch.dict(os.environ, {"TERMINAL_CWD": start_dir}, clear=False):
|
||||
read_file_tool("shared.txt", task_id="live_task")
|
||||
|
||||
time.sleep(0.05)
|
||||
with open(live_file, "w") as f:
|
||||
f.write("live copy modified elsewhere\n")
|
||||
|
||||
result = json.loads(
|
||||
write_file_tool("shared.txt", "replacement", task_id="live_task")
|
||||
)
|
||||
finally:
|
||||
with file_tools._file_ops_lock:
|
||||
if previous is None:
|
||||
file_tools._file_ops_cache.pop("live_task", None)
|
||||
else:
|
||||
file_tools._file_ops_cache["live_task"] = previous
|
||||
|
||||
self.assertIn("_warning", result)
|
||||
self.assertIn("modified since you last read", result["_warning"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Staleness in patch
|
||||
|
|
@ -154,6 +205,7 @@ class TestPatchStaleness(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
self._tmpdir = tempfile.mkdtemp()
|
||||
self._tmpfile = os.path.join(self._tmpdir, "patch_test.txt")
|
||||
with open(self._tmpfile, "w") as f:
|
||||
|
|
@ -161,6 +213,7 @@ class TestPatchStaleness(unittest.TestCase):
|
|||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
try:
|
||||
os.unlink(self._tmpfile)
|
||||
os.rmdir(self._tmpdir)
|
||||
|
|
@ -207,9 +260,11 @@ class TestCheckFileStalenessHelper(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def tearDown(self):
|
||||
_read_tracker.clear()
|
||||
file_state.get_registry().clear()
|
||||
|
||||
def test_returns_none_for_unknown_task(self):
|
||||
self.assertIsNone(_check_file_staleness("/tmp/x.py", "nonexistent"))
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -22,8 +23,9 @@ class TestResolvePath:
|
|||
monkeypatch.setenv("TERMINAL_CWD", str(tmp_path))
|
||||
from tools.file_tools import _resolve_path
|
||||
|
||||
result = _resolve_path("/etc/hosts")
|
||||
assert result == Path("/etc/hosts")
|
||||
absolute = (tmp_path / "already-absolute.txt").resolve()
|
||||
result = _resolve_path(str(absolute))
|
||||
assert result == absolute
|
||||
|
||||
def test_falls_back_to_cwd_without_terminal_cwd(self, monkeypatch):
|
||||
"""Without TERMINAL_CWD, falls back to os.getcwd()."""
|
||||
|
|
@ -50,3 +52,34 @@ class TestResolvePath:
|
|||
result = _resolve_path("a/../b/file.txt")
|
||||
assert ".." not in str(result)
|
||||
assert result == (tmp_path / "b" / "file.txt")
|
||||
|
||||
def test_relative_path_prefers_live_file_ops_cwd(self, monkeypatch, tmp_path):
|
||||
"""Live env.cwd must win after the terminal session changes directory."""
|
||||
start_dir = tmp_path / "start"
|
||||
live_dir = tmp_path / "worktree"
|
||||
start_dir.mkdir()
|
||||
live_dir.mkdir()
|
||||
monkeypatch.setenv("TERMINAL_CWD", str(start_dir))
|
||||
|
||||
from tools import file_tools
|
||||
|
||||
task_id = "live-cwd"
|
||||
fake_ops = SimpleNamespace(
|
||||
env=SimpleNamespace(cwd=str(live_dir)),
|
||||
cwd=str(start_dir),
|
||||
)
|
||||
|
||||
with file_tools._file_ops_lock:
|
||||
previous = file_tools._file_ops_cache.get(task_id)
|
||||
file_tools._file_ops_cache[task_id] = fake_ops
|
||||
|
||||
try:
|
||||
result = file_tools._resolve_path("nested/file.txt", task_id=task_id)
|
||||
finally:
|
||||
with file_tools._file_ops_lock:
|
||||
if previous is None:
|
||||
file_tools._file_ops_cache.pop(task_id, None)
|
||||
else:
|
||||
file_tools._file_ops_cache[task_id] = previous
|
||||
|
||||
assert result == live_dir / "nested" / "file.txt"
|
||||
|
|
|
|||
|
|
@ -79,13 +79,45 @@ _BLOCKED_DEVICE_PATHS = frozenset({
|
|||
})
|
||||
|
||||
|
||||
def _resolve_path(filepath: str) -> Path:
|
||||
def _resolve_path(filepath: str, task_id: str = "default") -> Path:
|
||||
"""Resolve a path relative to TERMINAL_CWD (the worktree base directory)
|
||||
instead of the main repository root.
|
||||
"""
|
||||
return _resolve_path_for_task(filepath, task_id)
|
||||
|
||||
|
||||
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
|
||||
"""Return the task's live terminal cwd for bookkeeping when available."""
|
||||
with _file_ops_lock:
|
||||
cached = _file_ops_cache.get(task_id)
|
||||
if cached is not None:
|
||||
live_cwd = getattr(getattr(cached, "env", None), "cwd", None) or getattr(
|
||||
cached, "cwd", None
|
||||
)
|
||||
if live_cwd:
|
||||
return live_cwd
|
||||
|
||||
try:
|
||||
from tools.terminal_tool import _active_environments, _env_lock
|
||||
|
||||
with _env_lock:
|
||||
env = _active_environments.get(task_id)
|
||||
live_cwd = getattr(env, "cwd", None) if env is not None else None
|
||||
if live_cwd:
|
||||
return live_cwd
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_path_for_task(filepath: str, task_id: str = "default") -> Path:
|
||||
"""Resolve *filepath* against the task's live terminal cwd when possible."""
|
||||
p = Path(filepath).expanduser()
|
||||
if not p.is_absolute():
|
||||
base = os.environ.get("TERMINAL_CWD", os.getcwd())
|
||||
base = _get_live_tracking_cwd(task_id) or os.environ.get(
|
||||
"TERMINAL_CWD", os.getcwd()
|
||||
)
|
||||
p = Path(base) / p
|
||||
return p.resolve()
|
||||
|
||||
|
|
@ -118,10 +150,10 @@ _SENSITIVE_PATH_PREFIXES = (
|
|||
_SENSITIVE_EXACT_PATHS = {"/var/run/docker.sock", "/run/docker.sock"}
|
||||
|
||||
|
||||
def _check_sensitive_path(filepath: str) -> str | None:
|
||||
def _check_sensitive_path(filepath: str, task_id: str = "default") -> str | None:
|
||||
"""Return an error message if the path targets a sensitive system location."""
|
||||
try:
|
||||
resolved = str(_resolve_path(filepath))
|
||||
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||
except (OSError, ValueError):
|
||||
resolved = filepath
|
||||
normalized = os.path.normpath(os.path.expanduser(filepath))
|
||||
|
|
@ -368,7 +400,7 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
|||
),
|
||||
})
|
||||
|
||||
_resolved = _resolve_path(path)
|
||||
_resolved = _resolve_path_for_task(path, task_id)
|
||||
|
||||
# ── Binary file guard ─────────────────────────────────────────
|
||||
# Block binary files by extension (no I/O).
|
||||
|
|
@ -574,7 +606,7 @@ def _update_read_timestamp(filepath: str, task_id: str) -> None:
|
|||
refreshes the stored timestamp to match the file's new state.
|
||||
"""
|
||||
try:
|
||||
resolved = str(_resolve_path(filepath))
|
||||
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||
current_mtime = os.path.getmtime(resolved)
|
||||
except (OSError, ValueError):
|
||||
return
|
||||
|
|
@ -593,7 +625,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
|
|||
or was never read. Does not block — the write still proceeds.
|
||||
"""
|
||||
try:
|
||||
resolved = str(_resolve_path(filepath))
|
||||
resolved = str(_resolve_path_for_task(filepath, task_id))
|
||||
except (OSError, ValueError):
|
||||
return None
|
||||
with _read_tracker_lock:
|
||||
|
|
@ -618,7 +650,7 @@ def _check_file_staleness(filepath: str, task_id: str) -> str | None:
|
|||
|
||||
def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
||||
"""Write content to a file."""
|
||||
sensitive_err = _check_sensitive_path(path)
|
||||
sensitive_err = _check_sensitive_path(path, task_id)
|
||||
if sensitive_err:
|
||||
return tool_error(sensitive_err)
|
||||
try:
|
||||
|
|
@ -626,7 +658,7 @@ def write_file_tool(path: str, content: str, task_id: str = "default") -> str:
|
|||
# fall back to the legacy path — write proceeds, per-task staleness
|
||||
# check below still runs.
|
||||
try:
|
||||
_resolved = str(_resolve_path(path))
|
||||
_resolved = str(_resolve_path_for_task(path, task_id))
|
||||
except Exception:
|
||||
_resolved = None
|
||||
|
||||
|
|
@ -681,7 +713,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
for _m in _re.finditer(r'^\*\*\*\s+(?:Update|Add|Delete)\s+File:\s*(.+)$', patch, _re.MULTILINE):
|
||||
_paths_to_check.append(_m.group(1).strip())
|
||||
for _p in _paths_to_check:
|
||||
sensitive_err = _check_sensitive_path(_p)
|
||||
sensitive_err = _check_sensitive_path(_p, task_id)
|
||||
if sensitive_err:
|
||||
return tool_error(sensitive_err)
|
||||
try:
|
||||
|
|
@ -692,7 +724,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
_seen: set[str] = set()
|
||||
for _p in _paths_to_check:
|
||||
try:
|
||||
_r = str(_resolve_path(_p))
|
||||
_r = str(_resolve_path_for_task(_p, task_id))
|
||||
except Exception:
|
||||
_r = None
|
||||
if _r and _r not in _seen:
|
||||
|
|
@ -714,7 +746,7 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
_path_to_resolved: dict[str, str] = {}
|
||||
for _p in _paths_to_check:
|
||||
try:
|
||||
_r = str(_resolve_path(_p))
|
||||
_r = str(_resolve_path_for_task(_p, task_id))
|
||||
except Exception:
|
||||
_r = None
|
||||
_path_to_resolved[_p] = _r
|
||||
|
|
@ -749,15 +781,17 @@ def patch_tool(mode: str = "replace", path: str = None, old_string: str = None,
|
|||
_r = _path_to_resolved.get(_p)
|
||||
if _r:
|
||||
file_state.note_write(task_id, _r)
|
||||
result_json = json.dumps(result_dict, ensure_ascii=False)
|
||||
# Hint when old_string not found — saves iterations where the agent
|
||||
# retries with stale content instead of re-reading the file.
|
||||
# Suppressed when patch_replace already attached a rich "Did you mean?"
|
||||
# snippet (which is strictly more useful than the generic hint).
|
||||
if result_dict.get("error") and "Could not find" in str(result_dict["error"]):
|
||||
if "Did you mean one of these sections?" not in str(result_dict["error"]):
|
||||
result_json += "\n\n[Hint: old_string not found. Use read_file to verify the current content, or search_files to locate the text.]"
|
||||
return result_json
|
||||
result_dict["_hint"] = (
|
||||
"old_string not found. Use read_file to verify the current "
|
||||
"content, or search_files to locate the text."
|
||||
)
|
||||
return json.dumps(result_dict, ensure_ascii=False)
|
||||
except Exception as e:
|
||||
return tool_error(str(e))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue