diff --git a/agent/tool_executor.py b/agent/tool_executor.py index b249de3de04..1176d95c259 100644 --- a/agent/tool_executor.py +++ b/agent/tool_executor.py @@ -180,28 +180,9 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe except Exception: pass - # Checkpoint for file-mutating tools - if function_name in {"write_file", "patch"} and agent._checkpoint_mgr.enabled: - try: - file_path = function_args.get("path", "") - if file_path: - work_dir = agent._checkpoint_mgr.get_working_dir_for_path(file_path) - agent._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}") - except Exception: - pass - - # Checkpoint before destructive terminal commands - if function_name == "terminal" and agent._checkpoint_mgr.enabled: - try: - cmd = function_args.get("command", "") - if _is_destructive_command(cmd): - cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd()) - agent._checkpoint_mgr.ensure_checkpoint( - cwd, f"before terminal: {cmd[:60]}" - ) - except Exception: - pass - + # ── Block evaluation (BEFORE checkpoint preflight) ─────────── + # We must know whether the tool will execute before touching + # checkpoint state (dedup slot, real snapshots). block_result = None blocked_by_guardrail = False if _ts_scope_block is not None: @@ -224,6 +205,30 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe block_result = agent._guardrail_block_result(guardrail_decision) blocked_by_guardrail = True + # ── Checkpoint preflight (only for tools that will execute) ── + if block_result is None: + # Checkpoint for file-mutating tools + if function_name in {"write_file", "patch"} and agent._checkpoint_mgr.enabled: + try: + file_path = function_args.get("path", "") + if file_path: + work_dir = agent._checkpoint_mgr.get_working_dir_for_path(file_path) + agent._checkpoint_mgr.ensure_checkpoint(work_dir, f"before {function_name}") + except Exception: + pass + + # Checkpoint before destructive terminal commands + if function_name == "terminal" and agent._checkpoint_mgr.enabled: + try: + cmd = function_args.get("command", "") + if _is_destructive_command(cmd): + cwd = function_args.get("workdir") or os.getenv("TERMINAL_CWD", os.getcwd()) + agent._checkpoint_mgr.ensure_checkpoint( + cwd, f"before terminal: {cmd[:60]}" + ) + except Exception: + pass + parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail)) # ── Logging / callbacks ────────────────────────────────────────── diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 1653dc0d4ad..2bef65887da 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2543,6 +2543,122 @@ class TestConcurrentToolExecution: assert json.loads(result) == {"error": "Blocked"} assert agent._turns_since_memory == 5 + def test_concurrent_blocked_write_skips_checkpoint(self, agent, monkeypatch): + """Concurrent path: blocked write_file should not trigger checkpoint.""" + tc1 = _mock_tool_call(name="write_file", + arguments='{"path":"test.txt","content":"hello"}', + call_id="c1") + tc2 = _mock_tool_call(name="read_file", + arguments='{"path":"other.py"}', + call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked" if args[0] == "write_file" else None, + ) + + agent._checkpoint_mgr.enabled = True + + def fake_handle(name, args, task_id, **kwargs): + return f"result_{name}" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock: + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + cp_mock.assert_not_called() + + def test_concurrent_blocked_patch_skips_checkpoint(self, agent, monkeypatch): + """Concurrent path: blocked patch should not trigger checkpoint.""" + tc1 = _mock_tool_call(name="patch", + arguments='{"path":"f.py","old":"a","new":"b"}', + call_id="c1") + tc2 = _mock_tool_call(name="read_file", + arguments='{"path":"other.py"}', + call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked" if args[0] == "patch" else None, + ) + + agent._checkpoint_mgr.enabled = True + + def fake_handle(name, args, task_id, **kwargs): + return f"result_{name}" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock: + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + cp_mock.assert_not_called() + + def test_concurrent_blocked_terminal_skips_checkpoint(self, agent, monkeypatch): + """Concurrent path: blocked terminal should not trigger checkpoint.""" + tc1 = _mock_tool_call(name="terminal", + arguments='{"command":"rm -rf /tmp/foo"}', + call_id="c1") + tc2 = _mock_tool_call(name="read_file", + arguments='{"path":"other.py"}', + call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked" if args[0] == "terminal" else None, + ) + + agent._checkpoint_mgr.enabled = True + + def fake_handle(name, args, task_id, **kwargs): + return f"result_{name}" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock: + with patch("agent.tool_executor._is_destructive_command", return_value=True): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + cp_mock.assert_not_called() + + def test_concurrent_blocked_write_does_not_steal_slot_from_allowed_write(self, agent, monkeypatch): + """When write_file is blocked, its dedup slot must not be consumed, + so a subsequent allowed write_file for the same path still checkpoints.""" + tc1 = _mock_tool_call(name="write_file", + arguments='{"path":"dup.txt","content":"blocked"}', + call_id="c1") + tc2 = _mock_tool_call(name="write_file", + arguments='{"path":"dup.txt","content":"allowed"}', + call_id="c2") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tc1, tc2]) + messages = [] + + call_count = {"n": 0} + def block_first_only(*args, **kwargs): + call_count["n"] += 1 + return "Blocked" if call_count["n"] == 1 else None + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + block_first_only, + ) + + agent._checkpoint_mgr.enabled = True + + def fake_handle(name, args, task_id, **kwargs): + return f"result_{name}" + + with patch("run_agent.handle_function_call", side_effect=fake_handle): + with patch.object(agent._checkpoint_mgr, "ensure_checkpoint") as cp_mock: + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + # Second (allowed) write must checkpoint even though first was blocked. + cp_mock.assert_called_once() + class TestPathsOverlap: """Unit tests for the _paths_overlap helper."""