mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-16 09:31:37 +00:00
fix(tools): respect session cwd in file tools
This commit is contained in:
parent
95715dcb03
commit
d6a8d9dcab
3 changed files with 91 additions and 10 deletions
|
|
@ -27,7 +27,7 @@ def _make_env_config(**overrides):
|
|||
|
||||
|
||||
class TestFileToolsContainerConfig:
|
||||
def _run(self, env_config, task_id):
|
||||
def _run(self, env_config, task_id, task_env_overrides=None):
|
||||
captured = {}
|
||||
mock_env = MagicMock()
|
||||
|
||||
|
|
@ -35,31 +35,51 @@ class TestFileToolsContainerConfig:
|
|||
captured.update(kwargs)
|
||||
return mock_env
|
||||
|
||||
with patch("tools.terminal_tool._get_env_config", return_value=env_config), patch("tools.terminal_tool._task_env_overrides", {}), patch("tools.terminal_tool._active_environments", {}), patch("tools.terminal_tool._creation_locks", {}), patch("tools.terminal_tool._creation_locks_lock", __import__("threading").Lock()), patch("tools.terminal_tool._create_environment", side_effect=fake_create_env), patch("tools.terminal_tool._start_cleanup_thread"), patch("tools.terminal_tool._check_disk_usage_warning"), patch("tools.file_tools._file_ops_cache", {}), patch("tools.file_tools._file_ops_lock", __import__("threading").Lock()):
|
||||
with patch("tools.terminal_tool._get_env_config", return_value=env_config), \
|
||||
patch("tools.terminal_tool._task_env_overrides", task_env_overrides or {}), \
|
||||
patch("tools.terminal_tool._active_environments", {}), \
|
||||
patch("tools.terminal_tool._creation_locks", {}), \
|
||||
patch("tools.terminal_tool._creation_locks_lock", __import__("threading").Lock()), \
|
||||
patch("tools.terminal_tool._create_environment", side_effect=fake_create_env), \
|
||||
patch("tools.terminal_tool._start_cleanup_thread"), \
|
||||
patch("tools.terminal_tool._check_disk_usage_warning"), \
|
||||
patch("tools.file_tools._file_ops_cache", {}), \
|
||||
patch("tools.file_tools._file_ops_lock", __import__("threading").Lock()):
|
||||
file_tools._get_file_ops(task_id)
|
||||
|
||||
return captured.get("container_config", {})
|
||||
return captured
|
||||
|
||||
def test_docker_mount_cwd_to_workspace_passed(self):
|
||||
"""docker_mount_cwd_to_workspace is forwarded to container_config."""
|
||||
cc = self._run(_make_env_config(docker_mount_cwd_to_workspace=True), "t1")
|
||||
cc = self._run(_make_env_config(docker_mount_cwd_to_workspace=True), "t1").get("container_config", {})
|
||||
assert cc.get("docker_mount_cwd_to_workspace") is True
|
||||
|
||||
def test_docker_forward_env_passed(self):
|
||||
"""docker_forward_env is forwarded to container_config."""
|
||||
cc = self._run(_make_env_config(docker_forward_env=["MY_SECRET"]), "t2")
|
||||
cc = self._run(_make_env_config(docker_forward_env=["MY_SECRET"]), "t2").get("container_config", {})
|
||||
assert cc.get("docker_forward_env") == ["MY_SECRET"]
|
||||
|
||||
def test_docker_mount_cwd_defaults_to_false(self):
|
||||
"""docker_mount_cwd_to_workspace defaults to False when absent from config."""
|
||||
cfg = _make_env_config()
|
||||
del cfg["docker_mount_cwd_to_workspace"]
|
||||
cc = self._run(cfg, "t3")
|
||||
cc = self._run(cfg, "t3").get("container_config", {})
|
||||
assert cc.get("docker_mount_cwd_to_workspace") is False
|
||||
|
||||
def test_docker_forward_env_defaults_to_empty_list(self):
|
||||
"""docker_forward_env defaults to [] when absent from config."""
|
||||
cfg = _make_env_config()
|
||||
del cfg["docker_forward_env"]
|
||||
cc = self._run(cfg, "t4")
|
||||
cc = self._run(cfg, "t4").get("container_config", {})
|
||||
assert cc.get("docker_forward_env") == []
|
||||
|
||||
def test_cwd_only_raw_task_override_reaches_file_environment(self):
|
||||
"""CWD-only task overrides collapse to default but must keep their cwd."""
|
||||
captured = self._run(
|
||||
_make_env_config(env_type="local", cwd="/config-cwd"),
|
||||
"desktop-session-cwd",
|
||||
task_env_overrides={"desktop-session-cwd": {"cwd": "/workspace/session"}},
|
||||
)
|
||||
|
||||
assert captured["task_id"] == "default"
|
||||
assert captured["cwd"] == "/workspace/session"
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
import tools.file_tools as ft
|
||||
import tools.terminal_tool as terminal_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -218,6 +219,28 @@ def test_absolute_terminal_cwd_anchors_with_empty_registry(_isolated_cwd, monkey
|
|||
assert not str(resolved).startswith(str(decoy))
|
||||
|
||||
|
||||
def test_registered_task_cwd_override_anchors_before_terminal_env_exists(_isolated_cwd, monkeypatch):
|
||||
"""TUI/Desktop sessions register cwd by raw session key before tools run.
|
||||
|
||||
CWD-only overrides collapse to the shared terminal environment key, but the
|
||||
file resolver must still read the raw task/session override before falling
|
||||
back to TERMINAL_CWD or the process cwd.
|
||||
"""
|
||||
workspace, decoy = _isolated_cwd
|
||||
task_id = "desktop-session-cwd"
|
||||
monkeypatch.setattr(ft, "_get_live_tracking_cwd", lambda task_id="default": None)
|
||||
monkeypatch.delenv("TERMINAL_CWD", raising=False)
|
||||
monkeypatch.setattr(terminal_tool, "_task_env_overrides", {})
|
||||
|
||||
terminal_tool.register_task_env_overrides(task_id, {"cwd": str(workspace)})
|
||||
|
||||
resolved = ft._resolve_path_for_task("target.py", task_id=task_id)
|
||||
|
||||
assert terminal_tool._resolve_container_task_id(task_id) == "default"
|
||||
assert resolved == (workspace / "target.py")
|
||||
assert not str(resolved).startswith(str(decoy))
|
||||
|
||||
|
||||
def test_warning_fires_from_terminal_cwd_when_registry_empty(_isolated_cwd, monkeypatch):
|
||||
"""Divergence warning must fire even before any terminal command runs.
|
||||
|
||||
|
|
@ -291,4 +314,3 @@ def test_patch_reports_resolved_absolute_path(_isolated_cwd, monkeypatch):
|
|||
assert "WORKSPACE_PATCHED" in (workspace / "target.py").read_text()
|
||||
# And the decoy copy is untouched.
|
||||
assert (decoy / "target.py").read_text() == "DECOY_ORIGINAL\n"
|
||||
|
||||
|
|
|
|||
|
|
@ -113,6 +113,37 @@ def _configured_terminal_cwd() -> str | None:
|
|||
return expanded
|
||||
|
||||
|
||||
def _registered_task_cwd_override(task_id: str = "default") -> str | None:
|
||||
"""Return a registered cwd override for the raw task id, when available.
|
||||
|
||||
``terminal_tool`` intentionally collapses CWD-only task overrides to the
|
||||
shared ``"default"`` environment so TUI/dashboard/ACP sessions do not spin
|
||||
up isolated sandboxes just because they have different workspaces. The cwd
|
||||
value itself is still keyed by the raw session/task id, so file tools must
|
||||
read that raw override before falling back to the collapsed container key.
|
||||
"""
|
||||
try:
|
||||
from tools.terminal_tool import _resolve_container_task_id, _task_env_overrides
|
||||
|
||||
raw_task_id = task_id or "default"
|
||||
container_key = _resolve_container_task_id(raw_task_id)
|
||||
overrides = (
|
||||
_task_env_overrides.get(raw_task_id)
|
||||
or _task_env_overrides.get(container_key)
|
||||
or {}
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
raw_cwd = str(overrides.get("cwd") or "").strip()
|
||||
if raw_cwd.lower() in _TERMINAL_CWD_SENTINELS:
|
||||
return None
|
||||
expanded = os.path.expanduser(raw_cwd)
|
||||
if not os.path.isabs(expanded):
|
||||
return None
|
||||
return expanded
|
||||
|
||||
|
||||
def _get_live_tracking_cwd(task_id: str = "default") -> str | None:
|
||||
"""Return the task's live terminal cwd for bookkeeping when available."""
|
||||
try:
|
||||
|
|
@ -159,6 +190,9 @@ def _authoritative_workspace_root(task_id: str = "default") -> str | None:
|
|||
live = _get_live_tracking_cwd(task_id)
|
||||
if live:
|
||||
return live
|
||||
registered = _registered_task_cwd_override(task_id)
|
||||
if registered:
|
||||
return registered
|
||||
return _configured_terminal_cwd()
|
||||
|
||||
|
||||
|
|
@ -625,7 +659,8 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
|||
)
|
||||
import time
|
||||
|
||||
task_id = _resolve_container_task_id(task_id)
|
||||
raw_task_id = task_id or "default"
|
||||
task_id = _resolve_container_task_id(raw_task_id)
|
||||
|
||||
# Fast path: check cache -- but also verify the underlying environment
|
||||
# is still alive (it may have been killed by the cleanup thread).
|
||||
|
|
@ -662,7 +697,11 @@ def _get_file_ops(task_id: str = "default") -> ShellFileOperations:
|
|||
|
||||
config = _get_env_config()
|
||||
env_type = config["env_type"]
|
||||
overrides = _task_env_overrides.get(task_id, {})
|
||||
overrides = (
|
||||
_task_env_overrides.get(raw_task_id)
|
||||
or _task_env_overrides.get(task_id)
|
||||
or {}
|
||||
)
|
||||
|
||||
if env_type == "docker":
|
||||
image = overrides.get("docker_image") or config["docker_image"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue