mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Add request-scoped plugin lifecycle hooks
This commit is contained in:
parent
dce5f51c7c
commit
9e820dda37
6 changed files with 169 additions and 6 deletions
|
|
@ -56,6 +56,8 @@ VALID_HOOKS: Set[str] = {
|
|||
"post_tool_call",
|
||||
"pre_llm_call",
|
||||
"post_llm_call",
|
||||
"pre_llm_request",
|
||||
"post_llm_request",
|
||||
"on_session_start",
|
||||
"on_session_end",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -460,6 +460,8 @@ def handle_function_call(
|
|||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
task_id: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
user_task: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
|
|
@ -497,7 +499,14 @@ def handle_function_call(
|
|||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("pre_tool_call", tool_name=function_name, args=function_args, task_id=task_id or "")
|
||||
invoke_hook(
|
||||
"pre_tool_call",
|
||||
tool_name=function_name,
|
||||
args=function_args,
|
||||
task_id=task_id or "",
|
||||
session_id=session_id or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
@ -519,7 +528,15 @@ def handle_function_call(
|
|||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook("post_tool_call", tool_name=function_name, args=function_args, result=result, task_id=task_id or "")
|
||||
invoke_hook(
|
||||
"post_tool_call",
|
||||
tool_name=function_name,
|
||||
args=function_args,
|
||||
result=result,
|
||||
task_id=task_id or "",
|
||||
session_id=session_id or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
|
|
|||
53
run_agent.py
53
run_agent.py
|
|
@ -5965,7 +5965,8 @@ class AIAgent:
|
|||
finally:
|
||||
self._executing_tools = False
|
||||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str:
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
||||
tool_call_id: Optional[str] = None) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
|
|
@ -6033,6 +6034,8 @@ class AIAgent:
|
|||
else:
|
||||
return handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=self.session_id or "",
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
)
|
||||
|
||||
|
|
@ -6134,7 +6137,7 @@ class AIAgent:
|
|||
"""Worker function executed in a thread."""
|
||||
start = time.time()
|
||||
try:
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id)
|
||||
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id)
|
||||
except Exception as tool_error:
|
||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||
|
|
@ -6452,6 +6455,8 @@ class AIAgent:
|
|||
try:
|
||||
function_result = handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
tool_call_id=tool_call.id,
|
||||
session_id=self.session_id or "",
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
)
|
||||
_spinner_result = function_result
|
||||
|
|
@ -6469,6 +6474,8 @@ class AIAgent:
|
|||
try:
|
||||
function_result = handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
tool_call_id=tool_call.id,
|
||||
session_id=self.session_id or "",
|
||||
enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None,
|
||||
)
|
||||
except Exception as tool_error:
|
||||
|
|
@ -7273,7 +7280,26 @@ class AIAgent:
|
|||
if self.api_mode == "codex_responses":
|
||||
api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False)
|
||||
|
||||
if env_var_enabled("HERMES_DUMP_REQUESTS"):
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook(
|
||||
"pre_llm_request",
|
||||
task_id=effective_task_id,
|
||||
session_id=self.session_id or "",
|
||||
platform=self.platform or "",
|
||||
model=self.model,
|
||||
provider=self.provider,
|
||||
base_url=self.base_url,
|
||||
api_mode=self.api_mode,
|
||||
api_call_count=api_call_count,
|
||||
messages=api_messages,
|
||||
max_tokens=self.max_tokens,
|
||||
tools=self.tools or [],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}:
|
||||
self._dump_api_request_debug(api_kwargs, reason="preflight")
|
||||
|
||||
# Always prefer the streaming path — even without stream
|
||||
|
|
@ -8359,6 +8385,27 @@ class AIAgent:
|
|||
else:
|
||||
assistant_message.content = str(raw)
|
||||
|
||||
try:
|
||||
from hermes_cli.plugins import invoke_hook
|
||||
invoke_hook(
|
||||
"post_llm_request",
|
||||
task_id=effective_task_id,
|
||||
session_id=self.session_id or "",
|
||||
platform=self.platform or "",
|
||||
model=self.model,
|
||||
provider=self.provider,
|
||||
base_url=self.base_url,
|
||||
api_mode=self.api_mode,
|
||||
api_call_count=api_call_count,
|
||||
api_duration=api_duration,
|
||||
finish_reason=finish_reason,
|
||||
messages=api_messages,
|
||||
response=response,
|
||||
assistant_message=assistant_message,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Handle assistant response
|
||||
if assistant_message.content and not self.quiet_mode:
|
||||
if self.verbose_logging:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets."""
|
||||
|
||||
import json
|
||||
from unittest.mock import call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from model_tools import (
|
||||
|
|
@ -38,6 +40,40 @@ class TestHandleFunctionCall:
|
|||
assert len(parsed["error"]) > 0
|
||||
assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower()
|
||||
|
||||
def test_tool_hooks_receive_session_and_tool_call_ids(self):
|
||||
with (
|
||||
patch("model_tools.registry.dispatch", return_value='{"ok":true}'),
|
||||
patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook,
|
||||
):
|
||||
result = handle_function_call(
|
||||
"web_search",
|
||||
{"q": "test"},
|
||||
task_id="task-1",
|
||||
tool_call_id="call-1",
|
||||
session_id="session-1",
|
||||
)
|
||||
|
||||
assert result == '{"ok":true}'
|
||||
assert mock_invoke_hook.call_args_list == [
|
||||
call(
|
||||
"pre_tool_call",
|
||||
tool_name="web_search",
|
||||
args={"q": "test"},
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
),
|
||||
call(
|
||||
"post_tool_call",
|
||||
tool_name="web_search",
|
||||
args={"q": "test"},
|
||||
result='{"ok":true}',
|
||||
task_id="task-1",
|
||||
session_id="session-1",
|
||||
tool_call_id="call-1",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop tools
|
||||
|
|
|
|||
|
|
@ -196,6 +196,10 @@ class TestPluginLoading:
|
|||
class TestPluginHooks:
|
||||
"""Tests for lifecycle hook registration and invocation."""
|
||||
|
||||
def test_valid_hooks_include_request_scoped_llm_hooks(self):
|
||||
assert "pre_llm_request" in VALID_HOOKS
|
||||
assert "post_llm_request" in VALID_HOOKS
|
||||
|
||||
def test_register_and_invoke_hook(self, tmp_path, monkeypatch):
|
||||
"""Registered hooks are called on invoke_hook()."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
|
|
@ -262,6 +266,28 @@ class TestPluginHooks:
|
|||
user_message="hi", assistant_response="bye", model="test")
|
||||
assert results == []
|
||||
|
||||
def test_request_hooks_are_invokeable(self, tmp_path, monkeypatch):
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir, "request_hook",
|
||||
register_body='ctx.register_hook("pre_llm_request", lambda **kw: {"seen": kw.get("api_call_count")})',
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
results = mgr.invoke_hook(
|
||||
"pre_llm_request",
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
model="test",
|
||||
api_call_count=2,
|
||||
messages=[],
|
||||
tools=[],
|
||||
)
|
||||
assert results == [{"seen": 2}]
|
||||
|
||||
def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog):
|
||||
"""Registering an unknown hook name logs a warning."""
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
|
|
|
|||
|
|
@ -1258,6 +1258,8 @@ class TestConcurrentToolExecution:
|
|||
result = agent._invoke_tool("web_search", {"q": "test"}, "task-1")
|
||||
mock_hfc.assert_called_once_with(
|
||||
"web_search", {"q": "test"}, "task-1",
|
||||
tool_call_id=None,
|
||||
session_id=agent.session_id,
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
|
||||
)
|
||||
|
|
@ -1441,7 +1443,7 @@ class TestRunConversation:
|
|||
resp2 = _mock_response(content="Done searching", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [resp1, resp2]
|
||||
with (
|
||||
patch("run_agent.handle_function_call", return_value="search result"),
|
||||
patch("run_agent.handle_function_call", return_value="search result") as mock_handle_function_call,
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
|
|
@ -1449,6 +1451,39 @@ class TestRunConversation:
|
|||
result = agent.run_conversation("search something")
|
||||
assert result["final_response"] == "Done searching"
|
||||
assert result["api_calls"] == 2
|
||||
assert mock_handle_function_call.call_args.kwargs["tool_call_id"] == "c1"
|
||||
assert mock_handle_function_call.call_args.kwargs["session_id"] == agent.session_id
|
||||
|
||||
def test_request_scoped_llm_hooks_fire_for_each_api_call(self, agent):
|
||||
self._setup_agent(agent)
|
||||
tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1")
|
||||
resp1 = _mock_response(content="", finish_reason="tool_calls", tool_calls=[tc])
|
||||
resp2 = _mock_response(content="Done searching", finish_reason="stop")
|
||||
agent.client.chat.completions.create.side_effect = [resp1, resp2]
|
||||
|
||||
hook_calls = []
|
||||
|
||||
def _record_hook(name, **kwargs):
|
||||
hook_calls.append((name, kwargs))
|
||||
return []
|
||||
|
||||
with (
|
||||
patch("run_agent.handle_function_call", return_value="search result"),
|
||||
patch("hermes_cli.plugins.invoke_hook", side_effect=_record_hook),
|
||||
patch.object(agent, "_persist_session"),
|
||||
patch.object(agent, "_save_trajectory"),
|
||||
patch.object(agent, "_cleanup_task_resources"),
|
||||
):
|
||||
result = agent.run_conversation("search something")
|
||||
|
||||
assert result["final_response"] == "Done searching"
|
||||
pre_request_calls = [kw for name, kw in hook_calls if name == "pre_llm_request"]
|
||||
post_request_calls = [kw for name, kw in hook_calls if name == "post_llm_request"]
|
||||
assert len(pre_request_calls) == 2
|
||||
assert len(post_request_calls) == 2
|
||||
assert [call["api_call_count"] for call in pre_request_calls] == [1, 2]
|
||||
assert [call["api_call_count"] for call in post_request_calls] == [1, 2]
|
||||
assert all(call["session_id"] == agent.session_id for call in pre_request_calls)
|
||||
|
||||
def test_interrupt_breaks_loop(self, agent):
|
||||
self._setup_agent(agent)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue