diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 13a31b2a87..fbe6422d50 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -584,6 +584,45 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]: +def get_pre_tool_call_block_message( + tool_name: str, + args: Optional[Dict[str, Any]], + task_id: str = "", + session_id: str = "", + tool_call_id: str = "", +) -> Optional[str]: + """Check ``pre_tool_call`` hooks for a blocking directive. + + Plugins that need to enforce policy (rate limiting, security + restrictions, approval workflows) can return:: + + {"action": "block", "message": "Reason the tool was blocked"} + + from their ``pre_tool_call`` callback. The first valid block + directive wins. Invalid or irrelevant hook return values are + silently ignored so existing observer-only hooks are unaffected. + """ + hook_results = invoke_hook( + "pre_tool_call", + tool_name=tool_name, + args=args if isinstance(args, dict) else {}, + task_id=task_id, + session_id=session_id, + tool_call_id=tool_call_id, + ) + + for result in hook_results: + if not isinstance(result, dict): + continue + if result.get("action") != "block": + continue + message = result.get("message") + if isinstance(message, str) and message: + return message + + return None + + def get_plugin_context_engine(): """Return the plugin-registered context engine, or None.""" return get_plugin_manager()._context_engine diff --git a/model_tools.py b/model_tools.py index c37007c413..1924b25168 100644 --- a/model_tools.py +++ b/model_tools.py @@ -464,6 +464,7 @@ def handle_function_call( session_id: Optional[str] = None, user_task: Optional[str] = None, enabled_tools: Optional[List[str]] = None, + skip_pre_tool_call_hook: bool = False, ) -> str: """ Main function call dispatcher that routes calls to the tool registry. @@ -484,31 +485,53 @@ def handle_function_call( # Coerce string arguments to their schema-declared types (e.g. "42"→42) function_args = coerce_tool_args(function_name, function_args) - # Notify the read-loop tracker when a non-read/search tool runs, - # so the *consecutive* counter resets (reads after other work are fine). - if function_name not in _READ_SEARCH_TOOLS: - try: - from tools.file_tools import notify_other_tool_call - notify_other_tool_call(task_id or "default") - except Exception: - pass # file_tools may not be loaded yet - try: if function_name in _AGENT_LOOP_TOOLS: return json.dumps({"error": f"{function_name} must be handled by the agent loop"}) - try: - from hermes_cli.plugins import invoke_hook - invoke_hook( - "pre_tool_call", - tool_name=function_name, - args=function_args, - task_id=task_id or "", - session_id=session_id or "", - tool_call_id=tool_call_id or "", - ) - except Exception: - pass + # Check plugin hooks for a block directive (unless caller already + # checked — e.g. run_agent._invoke_tool passes skip=True to + # avoid double-firing the hook). + if not skip_pre_tool_call_hook: + block_message: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, + function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) + except Exception: + pass + + if block_message is not None: + return json.dumps({"error": block_message}, ensure_ascii=False) + else: + # Still fire the hook for observers — just don't check for blocking + # (the caller already did that). + try: + from hermes_cli.plugins import invoke_hook + invoke_hook( + "pre_tool_call", + tool_name=function_name, + args=function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) + except Exception: + pass + + # Notify the read-loop tracker when a non-read/search tool runs, + # so the *consecutive* counter resets (reads after other work are fine). + if function_name not in _READ_SEARCH_TOOLS: + try: + from tools.file_tools import notify_other_tool_call + notify_other_tool_call(task_id or "default") + except Exception: + pass # file_tools may not be loaded yet if function_name == "execute_code": # Prefer the caller-provided list so subagents can't overwrite diff --git a/run_agent.py b/run_agent.py index 5005153b3b..5922534646 100644 --- a/run_agent.py +++ b/run_agent.py @@ -6890,6 +6890,18 @@ class AIAgent: tools. Used by the concurrent execution path; the sequential path retains its own inline invocation for backward-compatible display handling. """ + # Check plugin hooks for a block directive before executing anything. + block_message: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + block_message = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + pass + if block_message is not None: + return json.dumps({"error": block_message}, ensure_ascii=False) + if function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool return _todo_tool( @@ -6954,6 +6966,7 @@ class AIAgent: tool_call_id=tool_call_id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: @@ -7184,12 +7197,6 @@ class AIAgent: function_name = tool_call.function.name - # Reset nudge counters when the relevant tool is actually used - if function_name == "memory": - self._turns_since_memory = 0 - elif function_name == "skill_manage": - self._iters_since_skill = 0 - try: function_args = json.loads(tool_call.function.arguments) except json.JSONDecodeError as e: @@ -7198,6 +7205,27 @@ class AIAgent: if not isinstance(function_args, dict): function_args = {} + # Check plugin hooks for a block directive before executing. + _block_msg: Optional[str] = None + try: + from hermes_cli.plugins import get_pre_tool_call_block_message + _block_msg = get_pre_tool_call_block_message( + function_name, function_args, task_id=effective_task_id or "", + ) + except Exception: + pass + + if _block_msg is not None: + # Tool blocked by plugin policy — skip counter resets. + # Execution is handled below in the tool dispatch chain. + pass + else: + # Reset nudge counters when the relevant tool is actually used + if function_name == "memory": + self._turns_since_memory = 0 + elif function_name == "skill_manage": + self._iters_since_skill = 0 + if not self.quiet_mode: args_str = json.dumps(function_args, ensure_ascii=False) if self.verbose_logging: @@ -7207,33 +7235,35 @@ class AIAgent: args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}") - self._current_tool = function_name - self._touch_activity(f"executing tool: {function_name}") + if _block_msg is None: + self._current_tool = function_name + self._touch_activity(f"executing tool: {function_name}") # Set activity callback for long-running tool execution (terminal # commands, etc.) so the gateway's inactivity monitor doesn't kill # the agent while a command is running. - try: - from tools.environments.base import set_activity_callback - set_activity_callback(self._touch_activity) - except Exception: - pass + if _block_msg is None: + try: + from tools.environments.base import set_activity_callback + set_activity_callback(self._touch_activity) + except Exception: + pass - if self.tool_progress_callback: + if _block_msg is None and self.tool_progress_callback: try: preview = _build_tool_preview(function_name, function_args) self.tool_progress_callback("tool.started", function_name, preview, function_args) except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") - if self.tool_start_callback: + if _block_msg is None and self.tool_start_callback: try: self.tool_start_callback(tool_call.id, function_name, function_args) except Exception as cb_err: logging.debug(f"Tool start callback error: {cb_err}") # Checkpoint: snapshot working dir before file-mutating tools - if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: + if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled: try: file_path = function_args.get("path", "") if file_path: @@ -7245,7 +7275,7 @@ class AIAgent: pass # never block tool execution # Checkpoint before destructive terminal commands - if function_name == "terminal" and self._checkpoint_mgr.enabled: + if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled: try: cmd = function_args.get("command", "") if _is_destructive_command(cmd): @@ -7258,7 +7288,11 @@ class AIAgent: tool_start_time = time.time() - if function_name == "todo": + if _block_msg is not None: + # Tool blocked by plugin policy — return error without executing. + function_result = json.dumps({"error": _block_msg}, ensure_ascii=False) + tool_duration = 0.0 + elif function_name == "todo": from tools.todo_tool import todo_tool as _todo_tool function_result = _todo_tool( todos=function_args.get("todos"), @@ -7401,6 +7435,7 @@ class AIAgent: tool_call_id=tool_call.id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) _spinner_result = function_result except Exception as tool_error: @@ -7420,6 +7455,7 @@ class AIAgent: tool_call_id=tool_call.id, session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, + skip_pre_tool_call_hook=True, ) except Exception as tool_error: function_result = f"Error executing tool '{function_name}': {tool_error}" diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index ec29a4e90a..7be1be6179 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -18,6 +18,7 @@ from hermes_cli.plugins import ( PluginManager, PluginManifest, get_plugin_manager, + get_pre_tool_call_block_message, discover_plugins, invoke_hook, ) @@ -310,6 +311,50 @@ class TestPluginHooks: assert any("on_banana" in record.message for record in caplog.records) +class TestPreToolCallBlocking: + """Tests for the pre_tool_call block directive helper.""" + + def test_block_message_returned_for_valid_directive(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [{"action": "block", "message": "blocked by plugin"}], + ) + assert get_pre_tool_call_block_message("todo", {}, task_id="t1") == "blocked by plugin" + + def test_invalid_returns_are_ignored(self, monkeypatch): + """Various malformed hook returns should not trigger a block.""" + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [ + "block", # not a dict + 123, # not a dict + {"action": "block"}, # missing message + {"action": "deny", "message": "nope"}, # wrong action + {"message": "missing action"}, # no action key + {"action": "block", "message": 123}, # message not str + ], + ) + assert get_pre_tool_call_block_message("todo", {}, task_id="t1") is None + + def test_none_when_no_hooks(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [], + ) + assert get_pre_tool_call_block_message("web_search", {"q": "test"}) is None + + def test_first_valid_block_wins(self, monkeypatch): + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: [ + {"action": "allow"}, + {"action": "block", "message": "first blocker"}, + {"action": "block", "message": "second blocker"}, + ], + ) + assert get_pre_tool_call_block_message("terminal", {}) == "first blocker" + + # ── TestPluginContext ────────────────────────────────────────────────────── diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index 568077fd7b..d71e6a6255 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -1442,7 +1442,7 @@ class TestConcurrentToolExecution: tool_call_id=None, session_id=agent.session_id, enabled_tools=list(agent.valid_tool_names), - + skip_pre_tool_call_hook=True, ) assert result == "result" @@ -1489,6 +1489,73 @@ class TestConcurrentToolExecution: mock_todo.assert_called_once() assert "ok" in result + def test_invoke_tool_blocked_returns_error_and_skips_execution(self, agent, monkeypatch): + """_invoke_tool should return error JSON when a plugin blocks the tool.""" + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked by test policy", + ) + with patch("tools.todo_tool.todo_tool", side_effect=AssertionError("should not run")) as mock_todo: + result = agent._invoke_tool("todo", {"todos": []}, "task-1") + + assert json.loads(result) == {"error": "Blocked by test policy"} + mock_todo.assert_not_called() + + def test_invoke_tool_blocked_skips_handle_function_call(self, agent, monkeypatch): + """Blocked registry tools should not reach handle_function_call.""" + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked", + ) + with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")): + result = agent._invoke_tool("web_search", {"q": "test"}, "task-1") + + assert json.loads(result) == {"error": "Blocked"} + + def test_sequential_blocked_tool_skips_checkpoints_and_callbacks(self, agent, monkeypatch): + """Sequential path: blocked tool should not trigger checkpoints or start callbacks.""" + tool_call = _mock_tool_call(name="write_file", + arguments='{"path":"test.txt","content":"hello"}', + call_id="c1") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call]) + messages = [] + + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked by policy", + ) + agent._checkpoint_mgr.enabled = True + agent._checkpoint_mgr.ensure_checkpoint = MagicMock( + side_effect=AssertionError("checkpoint should not run") + ) + + starts = [] + agent.tool_start_callback = lambda *a: starts.append(a) + + with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")): + agent._execute_tool_calls_sequential(mock_msg, messages, "task-1") + + agent._checkpoint_mgr.ensure_checkpoint.assert_not_called() + assert starts == [] + assert len(messages) == 1 + assert messages[0]["role"] == "tool" + assert json.loads(messages[0]["content"]) == {"error": "Blocked by policy"} + + def test_blocked_memory_tool_does_not_reset_counter(self, agent, monkeypatch): + """Blocked memory tool should not reset the nudge counter.""" + agent._turns_since_memory = 5 + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: "Blocked", + ) + with patch("tools.memory_tool.memory_tool", side_effect=AssertionError("should not run")): + result = agent._invoke_tool( + "memory", {"action": "add", "target": "memory", "content": "x"}, "task-1", + ) + + assert json.loads(result) == {"error": "Blocked"} + assert agent._turns_since_memory == 5 + class TestPathsOverlap: """Unit tests for the _paths_overlap helper.""" diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index 5e3b1d6ce1..bb8a79ab0b 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -91,6 +91,91 @@ class TestAgentLoopTools: assert "terminal" not in _AGENT_LOOP_TOOLS +# ========================================================================= +# Pre-tool-call blocking via plugin hooks +# ========================================================================= + +class TestPreToolCallBlocking: + """Verify that pre_tool_call hooks can block tool execution.""" + + def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch): + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [{"action": "block", "message": "Blocked by policy"}] + return [] + + dispatch_called = False + _orig_dispatch = None + + def fake_dispatch(*args, **kwargs): + nonlocal dispatch_called + dispatch_called = True + raise AssertionError("dispatch should not run when blocked") + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch) + + result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) + assert result == {"error": "Blocked by policy"} + assert not dispatch_called + + def test_blocked_tool_skips_read_loop_notification(self, monkeypatch): + notifications = [] + + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [{"action": "block", "message": "Blocked"}] + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run"))) + monkeypatch.setattr("tools.file_tools.notify_other_tool_call", + lambda task_id: notifications.append(task_id)) + + result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1")) + assert result == {"error": "Blocked"} + assert notifications == [] + + def test_invalid_hook_returns_do_not_block(self, monkeypatch): + """Malformed hook returns should be ignored — tool executes normally.""" + def fake_invoke_hook(hook_name, **kwargs): + if hook_name == "pre_tool_call": + return [ + "block", + {"action": "block"}, # missing message + {"action": "deny", "message": "nope"}, + ] + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: json.dumps({"ok": True})) + + result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1")) + assert result == {"ok": True} + + def test_skip_flag_prevents_double_block_check(self, monkeypatch): + """When skip_pre_tool_call_hook=True, blocking is not checked (caller did it).""" + hook_calls = [] + + def fake_invoke_hook(hook_name, **kwargs): + hook_calls.append(hook_name) + return [] + + monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook) + monkeypatch.setattr("model_tools.registry.dispatch", + lambda *a, **kw: json.dumps({"ok": True})) + + handle_function_call("web_search", {"q": "test"}, task_id="t1", + skip_pre_tool_call_hook=True) + + # Hook still fires for observer notification, but get_pre_tool_call_block_message + # is not called — invoke_hook fires directly in the skip=True branch. + assert "pre_tool_call" in hook_calls + assert "post_tool_call" in hook_calls + + # ========================================================================= # Legacy toolset map # =========================================================================