diff --git a/acp_adapter/edit_approval.py b/acp_adapter/edit_approval.py new file mode 100644 index 00000000000..ebeab0bc7ec --- /dev/null +++ b/acp_adapter/edit_approval.py @@ -0,0 +1,228 @@ +"""Pre-execution ACP edit approval helpers. + +This module is intentionally isolated from the generic tool registry. ACP binds +an edit approval requester in a ContextVar for the duration of one ACP agent run; +CLI, gateway, and other sessions leave it unset and therefore bypass this guard. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from concurrent.futures import TimeoutError as FutureTimeout +from contextvars import ContextVar, Token +from dataclasses import dataclass +from itertools import count +from pathlib import Path +from typing import Any, Callable + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class EditProposal: + """A proposed single-file edit that can be shown to an ACP client.""" + + tool_name: str + path: str + old_text: str | None + new_text: str + arguments: dict[str, Any] + + +EditApprovalRequester = Callable[[EditProposal], bool] + +_EDIT_APPROVAL_REQUESTER: ContextVar[EditApprovalRequester | None] = ContextVar( + "ACP_EDIT_APPROVAL_REQUESTER", + default=None, +) +_PERMISSION_REQUEST_IDS = count(1) + + +def set_edit_approval_requester(requester: EditApprovalRequester | None) -> Token: + """Bind an ACP edit approval requester for the current context.""" + + return _EDIT_APPROVAL_REQUESTER.set(requester) + + +def reset_edit_approval_requester(token: Token) -> None: + """Restore a previous edit approval requester binding.""" + + _EDIT_APPROVAL_REQUESTER.reset(token) + + +def clear_edit_approval_requester() -> None: + """Clear the current requester; primarily used by tests.""" + + _EDIT_APPROVAL_REQUESTER.set(None) + + +def get_edit_approval_requester() -> EditApprovalRequester | None: + return _EDIT_APPROVAL_REQUESTER.get() + + +def _read_text_if_exists(path: str) -> str | None: + p = Path(path).expanduser() + if not p.exists(): + return None + if not p.is_file(): + raise OSError(f"Cannot edit non-file path: {path}") + return p.read_text(encoding="utf-8", errors="replace") + + +def _proposal_for_write_file(arguments: dict[str, Any]) -> EditProposal: + path = str(arguments.get("path") or "") + if not path: + raise ValueError("path required") + content = arguments.get("content") + if content is None: + raise ValueError("content required") + return EditProposal( + tool_name="write_file", + path=path, + old_text=_read_text_if_exists(path), + new_text=str(content), + arguments=dict(arguments), + ) + + +def _proposal_for_patch_replace(arguments: dict[str, Any]) -> EditProposal: + path = str(arguments.get("path") or "") + if not path: + raise ValueError("path required") + old_string = arguments.get("old_string") + new_string = arguments.get("new_string") + if old_string is None or new_string is None: + raise ValueError("old_string and new_string required") + + old_text = _read_text_if_exists(path) + if old_text is None: + raise ValueError(f"Failed to read file: {path}") + + from tools.fuzzy_match import fuzzy_find_and_replace + + new_text, match_count, _strategy, error = fuzzy_find_and_replace( + old_text, + str(old_string), + str(new_string), + bool(arguments.get("replace_all", False)), + ) + if error or match_count == 0: + raise ValueError(error or f"Could not find match for old_string in {path}") + + return EditProposal( + tool_name="patch", + path=path, + old_text=old_text, + new_text=new_text, + arguments=dict(arguments), + ) + + +def build_edit_proposal(tool_name: str, arguments: dict[str, Any]) -> EditProposal | None: + """Return an edit proposal for supported file mutation calls.""" + + if tool_name == "write_file": + return _proposal_for_write_file(arguments) + if tool_name == "patch" and arguments.get("mode", "replace") == "replace": + return _proposal_for_patch_replace(arguments) + return None + + +def maybe_require_edit_approval(tool_name: str, arguments: dict[str, Any]) -> str | None: + """Run ACP edit approval if bound. + + Returns a JSON tool-error string when the edit must be blocked, otherwise + ``None`` so dispatch can continue. Requester exceptions deny by default. + """ + + requester = get_edit_approval_requester() + if requester is None: + return None + + try: + proposal = build_edit_proposal(tool_name, arguments) + except Exception as exc: + logger.warning("Could not build ACP edit approval proposal for %s: %s", tool_name, exc) + return json.dumps({"error": f"Edit approval denied: could not prepare diff ({exc})"}, ensure_ascii=False) + + if proposal is None: + return None + + try: + approved = bool(requester(proposal)) + except Exception as exc: + logger.warning("ACP edit approval requester failed: %s", exc) + approved = False + + if approved: + return None + return json.dumps({"error": "Edit approval denied by ACP client; file was not modified."}, ensure_ascii=False) + + +def build_acp_edit_tool_call(proposal: EditProposal): + """Build the ToolCallUpdate payload for ACP request_permission.""" + + import acp + + tool_call_id = f"edit-approval-{next(_PERMISSION_REQUEST_IDS)}" + return acp.update_tool_call( + tool_call_id, + title=f"Approve edit: {proposal.path}", + kind="edit", + status="pending", + content=[ + acp.tool_diff_content( + path=proposal.path, + old_text=proposal.old_text, + new_text=proposal.new_text, + ) + ], + raw_input={"tool": proposal.tool_name, "arguments": proposal.arguments}, + ) + + +def make_acp_edit_approval_requester( + request_permission_fn: Callable, + loop: asyncio.AbstractEventLoop, + session_id: str, + timeout: float = 60.0, +) -> EditApprovalRequester: + """Return a sync requester that bridges edit proposals to ACP permissions.""" + + def _requester(proposal: EditProposal) -> bool: + from acp.schema import PermissionOption + from agent.async_utils import safe_schedule_threadsafe + + options = [ + PermissionOption(option_id="allow_once", kind="allow_once", name="Allow edit"), + PermissionOption(option_id="deny", kind="reject_once", name="Deny"), + ] + tool_call = build_acp_edit_tool_call(proposal) + coro = request_permission_fn( + session_id=session_id, + tool_call=tool_call, + options=options, + ) + future = safe_schedule_threadsafe( + coro, + loop, + logger=logger, + log_message="Edit approval request: failed to schedule on loop", + ) + if future is None: + return False + try: + response = future.result(timeout=timeout) + except (FutureTimeout, Exception) as exc: + future.cancel() + logger.warning("Edit approval request timed out or failed: %s", exc) + return False + outcome = getattr(response, "outcome", None) + return ( + getattr(outcome, "outcome", None) == "selected" + and getattr(outcome, "option_id", None) == "allow_once" + ) + + return _requester diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 3031de161fd..ebec969205c 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -1243,6 +1243,7 @@ class HermesACPAgent(acp.Agent): tool_call_ids: dict[str, Deque[str]] = defaultdict(deque) tool_call_meta: dict[str, dict[str, Any]] = {} previous_approval_cb = None + edit_approval_requester = None streamed_message = False @@ -1259,6 +1260,16 @@ class HermesACPAgent(acp.Agent): message_cb(text) approval_cb = make_approval_callback(conn.request_permission, loop, session_id) + try: + from acp_adapter.edit_approval import make_acp_edit_approval_requester + + edit_approval_requester = make_acp_edit_approval_requester( + conn.request_permission, + loop, + session_id, + ) + except Exception: + logger.debug("Could not create ACP edit approval requester", exc_info=True) else: tool_progress_cb = None reasoning_cb = None @@ -1288,9 +1299,10 @@ class HermesACPAgent(acp.Agent): # which requires a notify_cb registered in _gateway_notify_cbs. previous_approval_cb = None previous_interactive = None + edit_approval_token = None def _run_agent() -> dict: - nonlocal previous_approval_cb, previous_interactive + nonlocal previous_approval_cb, previous_interactive, edit_approval_token # Bind HERMES_SESSION_KEY for this session so per-session caches # (e.g. the interactive sudo password cache in tools.terminal_tool) # scope to the ACP session rather than leaking across sessions @@ -1314,6 +1326,13 @@ class HermesACPAgent(acp.Agent): _terminal_tool.set_approval_callback(approval_cb) except Exception: logger.debug("Could not set ACP approval callback", exc_info=True) + if edit_approval_requester: + try: + from acp_adapter.edit_approval import set_edit_approval_requester + + edit_approval_token = set_edit_approval_requester(edit_approval_requester) + except Exception: + logger.debug("Could not set ACP edit approval requester", exc_info=True) # Signal to tools.approval that we have an interactive callback # and the non-interactive auto-approve path must not fire. previous_interactive = os.environ.get("HERMES_INTERACTIVE") @@ -1341,6 +1360,13 @@ class HermesACPAgent(acp.Agent): _terminal_tool.set_approval_callback(previous_approval_cb) except Exception: logger.debug("Could not restore approval callback", exc_info=True) + if edit_approval_token is not None: + try: + from acp_adapter.edit_approval import reset_edit_approval_requester + + reset_edit_approval_requester(edit_approval_token) + except Exception: + logger.debug("Could not restore ACP edit approval requester", exc_info=True) if session_tokens is not None and clear_session_vars is not None: try: clear_session_vars(session_tokens) diff --git a/model_tools.py b/model_tools.py index 1cbc83096ac..ad938b5f18b 100644 --- a/model_tools.py +++ b/model_tools.py @@ -788,6 +788,20 @@ def handle_function_call( if block_message is not None: return json.dumps({"error": block_message}, ensure_ascii=False) + # ACP/Zed edit approval runs before any file mutation. The requester + # is bound via ContextVar only for ACP sessions, so CLI/gateway paths + # are unaffected when it is unset. + try: + from acp_adapter.edit_approval import maybe_require_edit_approval + + edit_block_message = maybe_require_edit_approval(function_name, function_args) + if edit_block_message is not None: + return edit_block_message + except Exception as _edit_approval_err: + logger.debug("ACP edit approval guard error: %s", _edit_approval_err) + if function_name in {"write_file", "patch"}: + return json.dumps({"error": "Edit approval denied: approval guard failed"}, ensure_ascii=False) + # Notify the read-loop tracker when a non-read/search tool runs, # so the *consecutive* counter resets (reads after other work are fine). if function_name not in _READ_SEARCH_TOOLS: diff --git a/tests/acp/test_edit_approval.py b/tests/acp/test_edit_approval.py new file mode 100644 index 00000000000..2d68e22045b --- /dev/null +++ b/tests/acp/test_edit_approval.py @@ -0,0 +1,179 @@ +"""Tests for ACP pre-edit approval gating.""" + +from __future__ import annotations + +import json + +from acp_adapter.edit_approval import ( + EditProposal, + build_acp_edit_tool_call, + clear_edit_approval_requester, + set_edit_approval_requester, +) +from model_tools import handle_function_call + + +def teardown_function() -> None: + clear_edit_approval_requester() + + +def test_acp_permission_tool_call_uses_edit_kind_and_diff_content(): + proposal = EditProposal( + tool_name="write_file", + path="demo.txt", + old_text="old\n", + new_text="new\n", + arguments={"path": "demo.txt", "content": "new\n"}, + ) + + tool_call = build_acp_edit_tool_call(proposal) + + assert tool_call.kind == "edit" + assert tool_call.status == "pending" + assert tool_call.rawInput == {"tool": "write_file", "arguments": proposal.arguments} + assert len(tool_call.content) == 1 + diff = tool_call.content[0] + assert diff.path == "demo.txt" + assert diff.oldText == "old\n" + assert diff.newText == "new\n" + + +def test_write_file_rejection_does_not_mutate_existing_file(tmp_path): + target = tmp_path / "sample.txt" + target.write_text("before\n", encoding="utf-8") + + set_edit_approval_requester(lambda _proposal: False) + + result = json.loads( + handle_function_call( + "write_file", + {"path": str(target), "content": "after\n"}, + task_id="acp-edit-reject", + ) + ) + + assert "error" in result + assert "Edit approval denied" in result["error"] + assert target.read_text(encoding="utf-8") == "before\n" + + +def test_write_file_approval_mutates_and_request_includes_diff(tmp_path): + target = tmp_path / "sample.txt" + target.write_text("before\n", encoding="utf-8") + proposals = [] + + def approve(proposal): + proposals.append(proposal) + return True + + set_edit_approval_requester(approve) + + result = json.loads( + handle_function_call( + "write_file", + {"path": str(target), "content": "after\n"}, + task_id="acp-edit-approve", + ) + ) + + assert result.get("bytes_written") == len("after\n") + assert target.read_text(encoding="utf-8") == "after\n" + assert len(proposals) == 1 + proposal = proposals[0] + assert proposal.tool_name == "write_file" + assert proposal.path == str(target) + assert proposal.old_text == "before\n" + assert proposal.new_text == "after\n" + + +def test_write_file_new_file_request_has_empty_old_text(tmp_path): + target = tmp_path / "new.txt" + proposals = [] + + set_edit_approval_requester(lambda proposal: proposals.append(proposal) or True) + + result = json.loads( + handle_function_call( + "write_file", + {"path": str(target), "content": "created\n"}, + task_id="acp-edit-new-file", + ) + ) + + assert result.get("bytes_written") == len("created\n") + assert target.read_text(encoding="utf-8") == "created\n" + assert proposals[0].old_text is None + assert proposals[0].new_text == "created\n" + + +def test_requester_exception_denies_and_does_not_mutate(tmp_path): + target = tmp_path / "sample.txt" + target.write_text("before\n", encoding="utf-8") + + def boom(_proposal): + raise RuntimeError("zed disconnected") + + set_edit_approval_requester(boom) + + result = json.loads( + handle_function_call( + "write_file", + {"path": str(target), "content": "after\n"}, + task_id="acp-edit-exception", + ) + ) + + assert "error" in result + assert "Edit approval denied" in result["error"] + assert target.read_text(encoding="utf-8") == "before\n" + + +def test_patch_replace_rejection_does_not_mutate(tmp_path): + target = tmp_path / "sample.txt" + target.write_text("alpha\nbeta\n", encoding="utf-8") + + set_edit_approval_requester(lambda _proposal: False) + + result = json.loads( + handle_function_call( + "patch", + { + "mode": "replace", + "path": str(target), + "old_string": "beta\n", + "new_string": "gamma\n", + }, + task_id="acp-patch-reject", + ) + ) + + assert "error" in result + assert "Edit approval denied" in result["error"] + assert target.read_text(encoding="utf-8") == "alpha\nbeta\n" + + +def test_patch_replace_approval_request_includes_full_file_diff(tmp_path): + target = tmp_path / "sample.txt" + target.write_text("alpha\nbeta\n", encoding="utf-8") + proposals = [] + + set_edit_approval_requester(lambda proposal: proposals.append(proposal) or True) + + result = json.loads( + handle_function_call( + "patch", + { + "mode": "replace", + "path": str(target), + "old_string": "beta\n", + "new_string": "gamma\n", + }, + task_id="acp-patch-approve", + ) + ) + + assert result.get("success") is True + assert target.read_text(encoding="utf-8") == "alpha\ngamma\n" + assert proposals[0].tool_name == "patch" + assert proposals[0].old_text == "alpha\nbeta\n" + assert proposals[0].new_text == "alpha\ngamma\n"