Add request-scoped plugin lifecycle hooks

This commit is contained in:
kshitijk4poor 2026-03-29 12:26:44 +05:30 committed by Teknium
parent dce5f51c7c
commit 9e820dda37
6 changed files with 169 additions and 6 deletions

View file

@ -56,6 +56,8 @@ VALID_HOOKS: Set[str] = {
"post_tool_call",
"pre_llm_call",
"post_llm_call",
"pre_llm_request",
"post_llm_request",
"on_session_start",
"on_session_end",
}

View file

@ -460,6 +460,8 @@ def handle_function_call(
function_name: str,
function_args: Dict[str, Any],
task_id: Optional[str] = None,
tool_call_id: Optional[str] = None,
session_id: Optional[str] = None,
user_task: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
) -> str:
@ -497,7 +499,14 @@ def handle_function_call(
try:
from hermes_cli.plugins import invoke_hook
invoke_hook("pre_tool_call", tool_name=function_name, args=function_args, task_id=task_id or "")
invoke_hook(
"pre_tool_call",
tool_name=function_name,
args=function_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
)
except Exception:
pass
@ -519,7 +528,15 @@ def handle_function_call(
try:
from hermes_cli.plugins import invoke_hook
invoke_hook("post_tool_call", tool_name=function_name, args=function_args, result=result, task_id=task_id or "")
invoke_hook(
"post_tool_call",
tool_name=function_name,
args=function_args,
result=result,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
)
except Exception:
pass

View file

@ -5965,7 +5965,8 @@ class AIAgent:
finally:
self._executing_tools = False
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str:
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
tool_call_id: Optional[str] = None) -> str:
"""Invoke a single tool and return the result string. No display logic.
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
@ -6033,6 +6034,8 @@ class AIAgent:
else:
return handle_function_call(
function_name, function_args, effective_task_id,
tool_call_id=tool_call_id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
)
@ -6134,7 +6137,7 @@ class AIAgent:
"""Worker function executed in a thread."""
start = time.time()
try:
result = self._invoke_tool(function_name, function_args, effective_task_id)
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id)
except Exception as tool_error:
result = f"Error executing tool '{function_name}': {tool_error}"
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
@ -6452,6 +6455,8 @@ class AIAgent:
try:
function_result = handle_function_call(
function_name, function_args, effective_task_id,
tool_call_id=tool_call.id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
)
_spinner_result = function_result
@ -6469,6 +6474,8 @@ class AIAgent:
try:
function_result = handle_function_call(
function_name, function_args, effective_task_id,
tool_call_id=tool_call.id,
session_id=self.session_id or "",
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
)
except Exception as tool_error:
@ -7273,7 +7280,26 @@ class AIAgent:
if self.api_mode == "codex_responses":
api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False)
if env_var_enabled("HERMES_DUMP_REQUESTS"):
try:
from hermes_cli.plugins import invoke_hook
invoke_hook(
"pre_llm_request",
task_id=effective_task_id,
session_id=self.session_id or "",
platform=self.platform or "",
model=self.model,
provider=self.provider,
base_url=self.base_url,
api_mode=self.api_mode,
api_call_count=api_call_count,
messages=api_messages,
max_tokens=self.max_tokens,
tools=self.tools or [],
)
except Exception:
pass
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
self._dump_api_request_debug(api_kwargs, reason="preflight")
# Always prefer the streaming path — even without stream
@ -8359,6 +8385,27 @@ class AIAgent:
else:
assistant_message.content = str(raw)
try:
from hermes_cli.plugins import invoke_hook
invoke_hook(
"post_llm_request",
task_id=effective_task_id,
session_id=self.session_id or "",
platform=self.platform or "",
model=self.model,
provider=self.provider,
base_url=self.base_url,
api_mode=self.api_mode,
api_call_count=api_call_count,
api_duration=api_duration,
finish_reason=finish_reason,
messages=api_messages,
response=response,
assistant_message=assistant_message,
)
except Exception:
pass
# Handle assistant response
if assistant_message.content and not self.quiet_mode:
if self.verbose_logging:

View file

@ -1,6 +1,8 @@
"""Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets."""
import json
from unittest.mock import call, patch
import pytest
from model_tools import (
@ -38,6 +40,40 @@ class TestHandleFunctionCall:
assert len(parsed["error"]) > 0
assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower()
def test_tool_hooks_receive_session_and_tool_call_ids(self):
with (
patch("model_tools.registry.dispatch", return_value='{"ok":true}'),
patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook,
):
result = handle_function_call(
"web_search",
{"q": "test"},
task_id="task-1",
tool_call_id="call-1",
session_id="session-1",
)
assert result == '{"ok":true}'
assert mock_invoke_hook.call_args_list == [
call(
"pre_tool_call",
tool_name="web_search",
args={"q": "test"},
task_id="task-1",
session_id="session-1",
tool_call_id="call-1",
),
call(
"post_tool_call",
tool_name="web_search",
args={"q": "test"},
result='{"ok":true}',
task_id="task-1",
session_id="session-1",
tool_call_id="call-1",
),
]
# =========================================================================
# Agent loop tools

View file

@ -196,6 +196,10 @@ class TestPluginLoading:
class TestPluginHooks:
"""Tests for lifecycle hook registration and invocation."""
def test_valid_hooks_include_request_scoped_llm_hooks(self):
assert "pre_llm_request" in VALID_HOOKS
assert "post_llm_request" in VALID_HOOKS
def test_register_and_invoke_hook(self, tmp_path, monkeypatch):
"""Registered hooks are called on invoke_hook()."""
plugins_dir = tmp_path / "hermes_test" / "plugins"
@ -262,6 +266,28 @@ class TestPluginHooks:
user_message="hi", assistant_response="bye", model="test")
assert results == []
def test_request_hooks_are_invokeable(self, tmp_path, monkeypatch):
plugins_dir = tmp_path / "hermes_test" / "plugins"
_make_plugin_dir(
plugins_dir, "request_hook",
register_body='ctx.register_hook("pre_llm_request", lambda **kw: {"seen": kw.get("api_call_count")})',
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
mgr = PluginManager()
mgr.discover_and_load()
results = mgr.invoke_hook(
"pre_llm_request",
session_id="s1",
task_id="t1",
model="test",
api_call_count=2,
messages=[],
tools=[],
)
assert results == [{"seen": 2}]
def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog):
"""Registering an unknown hook name logs a warning."""
plugins_dir = tmp_path / "hermes_test" / "plugins"

View file

@ -1258,6 +1258,8 @@ class TestConcurrentToolExecution:
result = agent._invoke_tool("web_search", {"q": "test"}, "task-1")
mock_hfc.assert_called_once_with(
"web_search", {"q": "test"}, "task-1",
tool_call_id=None,
session_id=agent.session_id,
enabled_tools=list(agent.valid_tool_names),
)
@ -1441,7 +1443,7 @@ class TestRunConversation:
resp2 = _mock_response(content="Done searching", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [resp1, resp2]
with (
patch("run_agent.handle_function_call", return_value="search result"),
patch("run_agent.handle_function_call", return_value="search result") as mock_handle_function_call,
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
@ -1449,6 +1451,39 @@ class TestRunConversation:
result = agent.run_conversation("search something")
assert result["final_response"] == "Done searching"
assert result["api_calls"] == 2
assert mock_handle_function_call.call_args.kwargs["tool_call_id"] == "c1"
assert mock_handle_function_call.call_args.kwargs["session_id"] == agent.session_id
def test_request_scoped_llm_hooks_fire_for_each_api_call(self, agent):
self._setup_agent(agent)
tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1")
resp1 = _mock_response(content="", finish_reason="tool_calls", tool_calls=[tc])
resp2 = _mock_response(content="Done searching", finish_reason="stop")
agent.client.chat.completions.create.side_effect = [resp1, resp2]
hook_calls = []
def _record_hook(name, **kwargs):
hook_calls.append((name, kwargs))
return []
with (
patch("run_agent.handle_function_call", return_value="search result"),
patch("hermes_cli.plugins.invoke_hook", side_effect=_record_hook),
patch.object(agent, "_persist_session"),
patch.object(agent, "_save_trajectory"),
patch.object(agent, "_cleanup_task_resources"),
):
result = agent.run_conversation("search something")
assert result["final_response"] == "Done searching"
pre_request_calls = [kw for name, kw in hook_calls if name == "pre_llm_request"]
post_request_calls = [kw for name, kw in hook_calls if name == "post_llm_request"]
assert len(pre_request_calls) == 2
assert len(post_request_calls) == 2
assert [call["api_call_count"] for call in pre_request_calls] == [1, 2]
assert [call["api_call_count"] for call in post_request_calls] == [1, 2]
assert all(call["session_id"] == agent.session_id for call in pre_request_calls)
def test_interrupt_breaks_loop(self, agent):
self._setup_agent(agent)