diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index 98dacf131..efe760e69 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -56,6 +56,8 @@ VALID_HOOKS: Set[str] = { "post_tool_call", "pre_llm_call", "post_llm_call", + "pre_llm_request", + "post_llm_request", "on_session_start", "on_session_end", } diff --git a/model_tools.py b/model_tools.py index edea2315d..da5ba7154 100644 --- a/model_tools.py +++ b/model_tools.py @@ -460,6 +460,8 @@ def handle_function_call( function_name: str, function_args: Dict[str, Any], task_id: Optional[str] = None, + tool_call_id: Optional[str] = None, + session_id: Optional[str] = None, user_task: Optional[str] = None, enabled_tools: Optional[List[str]] = None, ) -> str: @@ -497,7 +499,14 @@ def handle_function_call( try: from hermes_cli.plugins import invoke_hook - invoke_hook("pre_tool_call", tool_name=function_name, args=function_args, task_id=task_id or "") + invoke_hook( + "pre_tool_call", + tool_name=function_name, + args=function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) except Exception: pass @@ -519,7 +528,15 @@ def handle_function_call( try: from hermes_cli.plugins import invoke_hook - invoke_hook("post_tool_call", tool_name=function_name, args=function_args, result=result, task_id=task_id or "") + invoke_hook( + "post_tool_call", + tool_name=function_name, + args=function_args, + result=result, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + ) except Exception: pass diff --git a/run_agent.py b/run_agent.py index 47a8f11d6..b125b3a16 100644 --- a/run_agent.py +++ b/run_agent.py @@ -5965,7 +5965,8 @@ class AIAgent: finally: self._executing_tools = False - def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str) -> str: + def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str, + tool_call_id: Optional[str] = None) -> str: """Invoke a single tool and return the result string. No display logic. Handles both agent-level tools (todo, memory, etc.) and registry-dispatched @@ -6033,6 +6034,8 @@ class AIAgent: else: return handle_function_call( function_name, function_args, effective_task_id, + tool_call_id=tool_call_id, + session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, ) @@ -6134,7 +6137,7 @@ class AIAgent: """Worker function executed in a thread.""" start = time.time() try: - result = self._invoke_tool(function_name, function_args, effective_task_id) + result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id) except Exception as tool_error: result = f"Error executing tool '{function_name}': {tool_error}" logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True) @@ -6452,6 +6455,8 @@ class AIAgent: try: function_result = handle_function_call( function_name, function_args, effective_task_id, + tool_call_id=tool_call.id, + session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, ) _spinner_result = function_result @@ -6469,6 +6474,8 @@ class AIAgent: try: function_result = handle_function_call( function_name, function_args, effective_task_id, + tool_call_id=tool_call.id, + session_id=self.session_id or "", enabled_tools=list(self.valid_tool_names) if self.valid_tool_names else None, ) except Exception as tool_error: @@ -7273,7 +7280,26 @@ class AIAgent: if self.api_mode == "codex_responses": api_kwargs = self._preflight_codex_api_kwargs(api_kwargs, allow_stream=False) - if env_var_enabled("HERMES_DUMP_REQUESTS"): + try: + from hermes_cli.plugins import invoke_hook + invoke_hook( + "pre_llm_request", + task_id=effective_task_id, + session_id=self.session_id or "", + platform=self.platform or "", + model=self.model, + provider=self.provider, + base_url=self.base_url, + api_mode=self.api_mode, + api_call_count=api_call_count, + messages=api_messages, + max_tokens=self.max_tokens, + tools=self.tools or [], + ) + except Exception: + pass + + if os.getenv("HERMES_DUMP_REQUESTS", "").strip().lower() in {"1", "true", "yes", "on"}: self._dump_api_request_debug(api_kwargs, reason="preflight") # Always prefer the streaming path — even without stream @@ -8359,6 +8385,27 @@ class AIAgent: else: assistant_message.content = str(raw) + try: + from hermes_cli.plugins import invoke_hook + invoke_hook( + "post_llm_request", + task_id=effective_task_id, + session_id=self.session_id or "", + platform=self.platform or "", + model=self.model, + provider=self.provider, + base_url=self.base_url, + api_mode=self.api_mode, + api_call_count=api_call_count, + api_duration=api_duration, + finish_reason=finish_reason, + messages=api_messages, + response=response, + assistant_message=assistant_message, + ) + except Exception: + pass + # Handle assistant response if assistant_message.content and not self.quiet_mode: if self.verbose_logging: diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index 8c2f8e6f7..5e3b1d6ce 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -1,6 +1,8 @@ """Tests for model_tools.py — function call dispatch, agent-loop interception, legacy toolsets.""" import json +from unittest.mock import call, patch + import pytest from model_tools import ( @@ -38,6 +40,40 @@ class TestHandleFunctionCall: assert len(parsed["error"]) > 0 assert "error" in parsed["error"].lower() or "failed" in parsed["error"].lower() + def test_tool_hooks_receive_session_and_tool_call_ids(self): + with ( + patch("model_tools.registry.dispatch", return_value='{"ok":true}'), + patch("hermes_cli.plugins.invoke_hook") as mock_invoke_hook, + ): + result = handle_function_call( + "web_search", + {"q": "test"}, + task_id="task-1", + tool_call_id="call-1", + session_id="session-1", + ) + + assert result == '{"ok":true}' + assert mock_invoke_hook.call_args_list == [ + call( + "pre_tool_call", + tool_name="web_search", + args={"q": "test"}, + task_id="task-1", + session_id="session-1", + tool_call_id="call-1", + ), + call( + "post_tool_call", + tool_name="web_search", + args={"q": "test"}, + result='{"ok":true}', + task_id="task-1", + session_id="session-1", + tool_call_id="call-1", + ), + ] + # ========================================================================= # Agent loop tools diff --git a/tests/test_plugins.py b/tests/test_plugins.py index cba1a777d..f0576b1cb 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -196,6 +196,10 @@ class TestPluginLoading: class TestPluginHooks: """Tests for lifecycle hook registration and invocation.""" + def test_valid_hooks_include_request_scoped_llm_hooks(self): + assert "pre_llm_request" in VALID_HOOKS + assert "post_llm_request" in VALID_HOOKS + def test_register_and_invoke_hook(self, tmp_path, monkeypatch): """Registered hooks are called on invoke_hook().""" plugins_dir = tmp_path / "hermes_test" / "plugins" @@ -262,6 +266,28 @@ class TestPluginHooks: user_message="hi", assistant_response="bye", model="test") assert results == [] + def test_request_hooks_are_invokeable(self, tmp_path, monkeypatch): + plugins_dir = tmp_path / "hermes_test" / "plugins" + _make_plugin_dir( + plugins_dir, "request_hook", + register_body='ctx.register_hook("pre_llm_request", lambda **kw: {"seen": kw.get("api_call_count")})', + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test")) + + mgr = PluginManager() + mgr.discover_and_load() + + results = mgr.invoke_hook( + "pre_llm_request", + session_id="s1", + task_id="t1", + model="test", + api_call_count=2, + messages=[], + tools=[], + ) + assert results == [{"seen": 2}] + def test_invalid_hook_name_warns(self, tmp_path, monkeypatch, caplog): """Registering an unknown hook name logs a warning.""" plugins_dir = tmp_path / "hermes_test" / "plugins" diff --git a/tests/test_run_agent.py b/tests/test_run_agent.py index a407d27a9..9ab12bf59 100644 --- a/tests/test_run_agent.py +++ b/tests/test_run_agent.py @@ -1258,6 +1258,8 @@ class TestConcurrentToolExecution: result = agent._invoke_tool("web_search", {"q": "test"}, "task-1") mock_hfc.assert_called_once_with( "web_search", {"q": "test"}, "task-1", + tool_call_id=None, + session_id=agent.session_id, enabled_tools=list(agent.valid_tool_names), ) @@ -1441,7 +1443,7 @@ class TestRunConversation: resp2 = _mock_response(content="Done searching", finish_reason="stop") agent.client.chat.completions.create.side_effect = [resp1, resp2] with ( - patch("run_agent.handle_function_call", return_value="search result"), + patch("run_agent.handle_function_call", return_value="search result") as mock_handle_function_call, patch.object(agent, "_persist_session"), patch.object(agent, "_save_trajectory"), patch.object(agent, "_cleanup_task_resources"), @@ -1449,6 +1451,39 @@ class TestRunConversation: result = agent.run_conversation("search something") assert result["final_response"] == "Done searching" assert result["api_calls"] == 2 + assert mock_handle_function_call.call_args.kwargs["tool_call_id"] == "c1" + assert mock_handle_function_call.call_args.kwargs["session_id"] == agent.session_id + + def test_request_scoped_llm_hooks_fire_for_each_api_call(self, agent): + self._setup_agent(agent) + tc = _mock_tool_call(name="web_search", arguments="{}", call_id="c1") + resp1 = _mock_response(content="", finish_reason="tool_calls", tool_calls=[tc]) + resp2 = _mock_response(content="Done searching", finish_reason="stop") + agent.client.chat.completions.create.side_effect = [resp1, resp2] + + hook_calls = [] + + def _record_hook(name, **kwargs): + hook_calls.append((name, kwargs)) + return [] + + with ( + patch("run_agent.handle_function_call", return_value="search result"), + patch("hermes_cli.plugins.invoke_hook", side_effect=_record_hook), + patch.object(agent, "_persist_session"), + patch.object(agent, "_save_trajectory"), + patch.object(agent, "_cleanup_task_resources"), + ): + result = agent.run_conversation("search something") + + assert result["final_response"] == "Done searching" + pre_request_calls = [kw for name, kw in hook_calls if name == "pre_llm_request"] + post_request_calls = [kw for name, kw in hook_calls if name == "post_llm_request"] + assert len(pre_request_calls) == 2 + assert len(post_request_calls) == 2 + assert [call["api_call_count"] for call in pre_request_calls] == [1, 2] + assert [call["api_call_count"] for call in post_request_calls] == [1, 2] + assert all(call["session_id"] == agent.session_id for call in pre_request_calls) def test_interrupt_breaks_loop(self, agent): self._setup_agent(agent)