feat(plugins): let pre_tool_call hooks block tool execution

Plugins can now return {"action": "block", "message": "reason"} from
their pre_tool_call hook to prevent a tool from executing. The error
message is returned to the model as a tool result so it can adjust.

Covers both execution paths: handle_function_call (model_tools.py) and
agent-level tools (run_agent.py _invoke_tool + sequential/concurrent).
Blocked tools skip all side effects (counter resets, checkpoints,
callbacks, read-loop tracker).

Adds skip_pre_tool_call_hook flag to avoid double-firing the hook when
run_agent.py already checked and then calls handle_function_call.

Salvaged from PR #5385 (gianfrancopiana) and PR #4610 (oredsecurity).
This commit is contained in:
Gianfranco Piana 2026-04-13 21:15:25 -07:00 committed by Teknium
parent ea74f61d98
commit eabc0a2f66
6 changed files with 335 additions and 40 deletions

View file

@ -584,6 +584,45 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
def get_pre_tool_call_block_message(
tool_name: str,
args: Optional[Dict[str, Any]],
task_id: str = "",
session_id: str = "",
tool_call_id: str = "",
) -> Optional[str]:
"""Check ``pre_tool_call`` hooks for a blocking directive.
Plugins that need to enforce policy (rate limiting, security
restrictions, approval workflows) can return::
{"action": "block", "message": "Reason the tool was blocked"}
from their ``pre_tool_call`` callback. The first valid block
directive wins. Invalid or irrelevant hook return values are
silently ignored so existing observer-only hooks are unaffected.
"""
hook_results = invoke_hook(
"pre_tool_call",
tool_name=tool_name,
args=args if isinstance(args, dict) else {},
task_id=task_id,
session_id=session_id,
tool_call_id=tool_call_id,
)
for result in hook_results:
if not isinstance(result, dict):
continue
if result.get("action") != "block":
continue
message = result.get("message")
if isinstance(message, str) and message:
return message
return None
def get_plugin_context_engine():
"""Return the plugin-registered context engine, or None."""
return get_plugin_manager()._context_engine

View file

@ -464,6 +464,7 @@ def handle_function_call(
session_id: Optional[str] = None,
user_task: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
skip_pre_tool_call_hook: bool = False,
) -> str:
"""
Main function call dispatcher that routes calls to the tool registry.
@ -484,31 +485,53 @@ def handle_function_call(
# Coerce string arguments to their schema-declared types (e.g. "42"→42)
function_args = coerce_tool_args(function_name, function_args)
# Notify the read-loop tracker when a non-read/search tool runs,
# so the *consecutive* counter resets (reads after other work are fine).
if function_name not in _READ_SEARCH_TOOLS:
try:
from tools.file_tools import notify_other_tool_call
notify_other_tool_call(task_id or "default")
except Exception:
pass # file_tools may not be loaded yet
try:
if function_name in _AGENT_LOOP_TOOLS:
return json.dumps({"error": f"{function_name} must be handled by the agent loop"})
try:
from hermes_cli.plugins import invoke_hook
invoke_hook(
"pre_tool_call",
tool_name=function_name,
args=function_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
)
except Exception:
pass
# Check plugin hooks for a block directive (unless caller already
# checked — e.g. run_agent._invoke_tool passes skip=True to
# avoid double-firing the hook).
if not skip_pre_tool_call_hook:
block_message: Optional[str] = None
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
block_message = get_pre_tool_call_block_message(
function_name,
function_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
)
except Exception:
pass
if block_message is not None:
return json.dumps({"error": block_message}, ensure_ascii=False)
else:
# Still fire the hook for observers — just don't check for blocking
# (the caller already did that).
try:
from hermes_cli.plugins import invoke_hook
invoke_hook(
"pre_tool_call",
tool_name=function_name,
args=function_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
)
except Exception:
pass
# Notify the read-loop tracker when a non-read/search tool runs,
# so the *consecutive* counter resets (reads after other work are fine).
if function_name not in _READ_SEARCH_TOOLS:
try:
from tools.file_tools import notify_other_tool_call
notify_other_tool_call(task_id or "default")
except Exception:
pass # file_tools may not be loaded yet
if function_name == "execute_code":
# Prefer the caller-provided list so subagents can't overwrite

View file

@ -6890,6 +6890,18 @@ class AIAgent:
tools. Used by the concurrent execution path; the sequential path retains
its own inline invocation for backward-compatible display handling.
"""
# Check plugin hooks for a block directive before executing anything.
block_message: Optional[str] = None
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
block_message = get_pre_tool_call_block_message(
function_name, function_args, task_id=effective_task_id or "",
)
except Exception:
pass
if block_message is not None:
return json.dumps({"error": block_message}, ensure_ascii=False)
if function_name == "todo":
from tools.todo_tool import todo_tool as _todo_tool
return _todo_tool(
@ -6954,6 +6966,7 @@ class AIAgent:
tool_call_id=tool_call_id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
skip_pre_tool_call_hook=True,
)
def _execute_tool_calls_concurrent(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
@ -7184,12 +7197,6 @@ class AIAgent:
function_name = tool_call.function.name
# Reset nudge counters when the relevant tool is actually used
if function_name == "memory":
self._turns_since_memory = 0
elif function_name == "skill_manage":
self._iters_since_skill = 0
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
@ -7198,6 +7205,27 @@ class AIAgent:
if not isinstance(function_args, dict):
function_args = {}
# Check plugin hooks for a block directive before executing.
_block_msg: Optional[str] = None
try:
from hermes_cli.plugins import get_pre_tool_call_block_message
_block_msg = get_pre_tool_call_block_message(
function_name, function_args, task_id=effective_task_id or "",
)
except Exception:
pass
if _block_msg is not None:
# Tool blocked by plugin policy — skip counter resets.
# Execution is handled below in the tool dispatch chain.
pass
else:
# Reset nudge counters when the relevant tool is actually used
if function_name == "memory":
self._turns_since_memory = 0
elif function_name == "skill_manage":
self._iters_since_skill = 0
if not self.quiet_mode:
args_str = json.dumps(function_args, ensure_ascii=False)
if self.verbose_logging:
@ -7207,33 +7235,35 @@ class AIAgent:
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
self._current_tool = function_name
self._touch_activity(f"executing tool: {function_name}")
if _block_msg is None:
self._current_tool = function_name
self._touch_activity(f"executing tool: {function_name}")
# Set activity callback for long-running tool execution (terminal
# commands, etc.) so the gateway's inactivity monitor doesn't kill
# the agent while a command is running.
try:
from tools.environments.base import set_activity_callback
set_activity_callback(self._touch_activity)
except Exception:
pass
if _block_msg is None:
try:
from tools.environments.base import set_activity_callback
set_activity_callback(self._touch_activity)
except Exception:
pass
if self.tool_progress_callback:
if _block_msg is None and self.tool_progress_callback:
try:
preview = _build_tool_preview(function_name, function_args)
self.tool_progress_callback("tool.started", function_name, preview, function_args)
except Exception as cb_err:
logging.debug(f"Tool progress callback error: {cb_err}")
if self.tool_start_callback:
if _block_msg is None and self.tool_start_callback:
try:
self.tool_start_callback(tool_call.id, function_name, function_args)
except Exception as cb_err:
logging.debug(f"Tool start callback error: {cb_err}")
# Checkpoint: snapshot working dir before file-mutating tools
if function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
try:
file_path = function_args.get("path", "")
if file_path:
@ -7245,7 +7275,7 @@ class AIAgent:
pass # never block tool execution
# Checkpoint before destructive terminal commands
if function_name == "terminal" and self._checkpoint_mgr.enabled:
if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled:
try:
cmd = function_args.get("command", "")
if _is_destructive_command(cmd):
@ -7258,7 +7288,11 @@ class AIAgent:
tool_start_time = time.time()
if function_name == "todo":
if _block_msg is not None:
# Tool blocked by plugin policy — return error without executing.
function_result = json.dumps({"error": _block_msg}, ensure_ascii=False)
tool_duration = 0.0
elif function_name == "todo":
from tools.todo_tool import todo_tool as _todo_tool
function_result = _todo_tool(
todos=function_args.get("todos"),
@ -7401,6 +7435,7 @@ class AIAgent:
tool_call_id=tool_call.id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
skip_pre_tool_call_hook=True,
)
_spinner_result = function_result
except Exception as tool_error:
@ -7420,6 +7455,7 @@ class AIAgent:
tool_call_id=tool_call.id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
skip_pre_tool_call_hook=True,
)
except Exception as tool_error:
function_result = f"Error executing tool '{function_name}': {tool_error}"

View file

@ -18,6 +18,7 @@ from hermes_cli.plugins import (
PluginManager,
PluginManifest,
get_plugin_manager,
get_pre_tool_call_block_message,
discover_plugins,
invoke_hook,
)
@ -310,6 +311,50 @@ class TestPluginHooks:
assert any("on_banana" in record.message for record in caplog.records)
class TestPreToolCallBlocking:
"""Tests for the pre_tool_call block directive helper."""
def test_block_message_returned_for_valid_directive(self, monkeypatch):
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: [{"action": "block", "message": "blocked by plugin"}],
)
assert get_pre_tool_call_block_message("todo", {}, task_id="t1") == "blocked by plugin"
def test_invalid_returns_are_ignored(self, monkeypatch):
"""Various malformed hook returns should not trigger a block."""
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: [
"block", # not a dict
123, # not a dict
{"action": "block"}, # missing message
{"action": "deny", "message": "nope"}, # wrong action
{"message": "missing action"}, # no action key
{"action": "block", "message": 123}, # message not str
],
)
assert get_pre_tool_call_block_message("todo", {}, task_id="t1") is None
def test_none_when_no_hooks(self, monkeypatch):
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: [],
)
assert get_pre_tool_call_block_message("web_search", {"q": "test"}) is None
def test_first_valid_block_wins(self, monkeypatch):
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: [
{"action": "allow"},
{"action": "block", "message": "first blocker"},
{"action": "block", "message": "second blocker"},
],
)
assert get_pre_tool_call_block_message("terminal", {}) == "first blocker"
# ── TestPluginContext ──────────────────────────────────────────────────────

View file

@ -1442,7 +1442,7 @@ class TestConcurrentToolExecution:
tool_call_id=None,
session_id=agent.session_id,
enabled_tools=list(agent.valid_tool_names),
skip_pre_tool_call_hook=True,
)
assert result == "result"
@ -1489,6 +1489,73 @@ class TestConcurrentToolExecution:
mock_todo.assert_called_once()
assert "ok" in result
def test_invoke_tool_blocked_returns_error_and_skips_execution(self, agent, monkeypatch):
"""_invoke_tool should return error JSON when a plugin blocks the tool."""
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: "Blocked by test policy",
)
with patch("tools.todo_tool.todo_tool", side_effect=AssertionError("should not run")) as mock_todo:
result = agent._invoke_tool("todo", {"todos": []}, "task-1")
assert json.loads(result) == {"error": "Blocked by test policy"}
mock_todo.assert_not_called()
def test_invoke_tool_blocked_skips_handle_function_call(self, agent, monkeypatch):
"""Blocked registry tools should not reach handle_function_call."""
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: "Blocked",
)
with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")):
result = agent._invoke_tool("web_search", {"q": "test"}, "task-1")
assert json.loads(result) == {"error": "Blocked"}
def test_sequential_blocked_tool_skips_checkpoints_and_callbacks(self, agent, monkeypatch):
"""Sequential path: blocked tool should not trigger checkpoints or start callbacks."""
tool_call = _mock_tool_call(name="write_file",
arguments='{"path":"test.txt","content":"hello"}',
call_id="c1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: "Blocked by policy",
)
agent._checkpoint_mgr.enabled = True
agent._checkpoint_mgr.ensure_checkpoint = MagicMock(
side_effect=AssertionError("checkpoint should not run")
)
starts = []
agent.tool_start_callback = lambda *a: starts.append(a)
with patch("run_agent.handle_function_call", side_effect=AssertionError("should not run")):
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
agent._checkpoint_mgr.ensure_checkpoint.assert_not_called()
assert starts == []
assert len(messages) == 1
assert messages[0]["role"] == "tool"
assert json.loads(messages[0]["content"]) == {"error": "Blocked by policy"}
def test_blocked_memory_tool_does_not_reset_counter(self, agent, monkeypatch):
"""Blocked memory tool should not reset the nudge counter."""
agent._turns_since_memory = 5
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: "Blocked",
)
with patch("tools.memory_tool.memory_tool", side_effect=AssertionError("should not run")):
result = agent._invoke_tool(
"memory", {"action": "add", "target": "memory", "content": "x"}, "task-1",
)
assert json.loads(result) == {"error": "Blocked"}
assert agent._turns_since_memory == 5
class TestPathsOverlap:
"""Unit tests for the _paths_overlap helper."""

View file

@ -91,6 +91,91 @@ class TestAgentLoopTools:
assert "terminal" not in _AGENT_LOOP_TOOLS
# =========================================================================
# Pre-tool-call blocking via plugin hooks
# =========================================================================
class TestPreToolCallBlocking:
"""Verify that pre_tool_call hooks can block tool execution."""
def test_blocked_tool_returns_error_and_skips_dispatch(self, monkeypatch):
def fake_invoke_hook(hook_name, **kwargs):
if hook_name == "pre_tool_call":
return [{"action": "block", "message": "Blocked by policy"}]
return []
dispatch_called = False
_orig_dispatch = None
def fake_dispatch(*args, **kwargs):
nonlocal dispatch_called
dispatch_called = True
raise AssertionError("dispatch should not run when blocked")
monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1"))
assert result == {"error": "Blocked by policy"}
assert not dispatch_called
def test_blocked_tool_skips_read_loop_notification(self, monkeypatch):
notifications = []
def fake_invoke_hook(hook_name, **kwargs):
if hook_name == "pre_tool_call":
return [{"action": "block", "message": "Blocked"}]
return []
monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
monkeypatch.setattr("model_tools.registry.dispatch",
lambda *a, **kw: (_ for _ in ()).throw(AssertionError("should not run")))
monkeypatch.setattr("tools.file_tools.notify_other_tool_call",
lambda task_id: notifications.append(task_id))
result = json.loads(handle_function_call("web_search", {"q": "test"}, task_id="t1"))
assert result == {"error": "Blocked"}
assert notifications == []
def test_invalid_hook_returns_do_not_block(self, monkeypatch):
"""Malformed hook returns should be ignored — tool executes normally."""
def fake_invoke_hook(hook_name, **kwargs):
if hook_name == "pre_tool_call":
return [
"block",
{"action": "block"}, # missing message
{"action": "deny", "message": "nope"},
]
return []
monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
monkeypatch.setattr("model_tools.registry.dispatch",
lambda *a, **kw: json.dumps({"ok": True}))
result = json.loads(handle_function_call("read_file", {"path": "test.txt"}, task_id="t1"))
assert result == {"ok": True}
def test_skip_flag_prevents_double_block_check(self, monkeypatch):
"""When skip_pre_tool_call_hook=True, blocking is not checked (caller did it)."""
hook_calls = []
def fake_invoke_hook(hook_name, **kwargs):
hook_calls.append(hook_name)
return []
monkeypatch.setattr("hermes_cli.plugins.invoke_hook", fake_invoke_hook)
monkeypatch.setattr("model_tools.registry.dispatch",
lambda *a, **kw: json.dumps({"ok": True}))
handle_function_call("web_search", {"q": "test"}, task_id="t1",
skip_pre_tool_call_hook=True)
# Hook still fires for observer notification, but get_pre_tool_call_block_message
# is not called — invoke_hook fires directly in the skip=True branch.
assert "pre_tool_call" in hook_calls
assert "post_tool_call" in hook_calls
# =========================================================================
# Legacy toolset map
# =========================================================================