mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-16 09:31:37 +00:00
Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin
feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
commit
d4a7bfd3aa
14 changed files with 2170 additions and 151 deletions
|
|
@ -18,8 +18,15 @@ from hermes_cli.plugins import (
|
|||
get_plugin_command_handler,
|
||||
get_plugin_commands,
|
||||
get_pre_tool_call_block_message,
|
||||
has_middleware,
|
||||
resolve_plugin_command_result,
|
||||
)
|
||||
from hermes_cli.middleware import (
|
||||
VALID_MIDDLEWARE,
|
||||
apply_llm_request_middleware,
|
||||
apply_tool_request_middleware,
|
||||
run_tool_execution_middleware,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -96,6 +103,223 @@ class TestPluginDiscovery:
|
|||
assert "hello_plugin" in mgr._plugins
|
||||
assert mgr._plugins["hello_plugin"].enabled
|
||||
|
||||
def test_plugin_can_register_and_invoke_middleware(self, tmp_path, monkeypatch):
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir,
|
||||
"mw_plugin",
|
||||
register_body=(
|
||||
"ctx.register_middleware('llm_request', "
|
||||
"lambda **kw: {'request': {**kw['request'], 'mw': True}})\n"
|
||||
" ctx.register_middleware('tool_request', "
|
||||
"lambda **kw: {'args': {**kw['args'], 'mw': True}})"
|
||||
),
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "llm_request" in VALID_MIDDLEWARE
|
||||
assert "tool_request" in VALID_MIDDLEWARE
|
||||
assert set(mgr._plugins["mw_plugin"].middleware_registered) == {"llm_request", "tool_request"}
|
||||
assert mgr.invoke_middleware("llm_request", request={"messages": []}) == [
|
||||
{"request": {"messages": [], "mw": True}}
|
||||
]
|
||||
assert mgr.invoke_middleware("tool_request", args={"path": "README.md"}) == [
|
||||
{"args": {"path": "README.md", "mw": True}}
|
||||
]
|
||||
assert mgr.has_middleware("llm_request") is True
|
||||
|
||||
def test_execution_middleware_does_not_retry_downstream_failure(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise RuntimeError("tool failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="tool failed"):
|
||||
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
|
||||
|
||||
assert calls == [{"command": "false"}]
|
||||
|
||||
def test_middleware_helpers_skip_no_listener_work(self, monkeypatch):
|
||||
manager = types.SimpleNamespace(_middleware={})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
request = {"messages": []}
|
||||
args = {"path": "README.md"}
|
||||
|
||||
llm_result = apply_llm_request_middleware(request)
|
||||
tool_result = apply_tool_request_middleware("read_file", args)
|
||||
|
||||
assert llm_result.payload is request
|
||||
assert llm_result.original_payload is request
|
||||
assert llm_result.changed is False
|
||||
assert llm_result.trace == []
|
||||
assert tool_result.payload is args
|
||||
assert tool_result.original_payload is args
|
||||
assert tool_result.changed is False
|
||||
assert tool_result.trace == []
|
||||
assert run_tool_execution_middleware("terminal", args, lambda payload: payload) is args
|
||||
assert has_middleware("tool_request") is False
|
||||
|
||||
def test_request_middleware_changed_tracks_trace_not_deep_equality(self, monkeypatch):
|
||||
def same_payload_middleware(**kwargs):
|
||||
return {"args": kwargs["args"], "source": "same-payload"}
|
||||
|
||||
manager = types.SimpleNamespace(
|
||||
_middleware={"tool_request": [same_payload_middleware]},
|
||||
invoke_middleware=lambda kind, **kwargs: [same_payload_middleware(**kwargs)],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
args = {"path": "README.md"}
|
||||
result = apply_tool_request_middleware("read_file", args)
|
||||
|
||||
assert result.payload == args
|
||||
assert result.original_payload == args
|
||||
assert result.changed is True
|
||||
assert result.trace == [{"source": "same-payload"}]
|
||||
|
||||
def test_execution_middleware_post_next_call_error_does_not_retry(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
result = kwargs["next_call"](kwargs["args"])
|
||||
raise RuntimeError(f"post-processing failed after {result}")
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
return "terminal-result"
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == "terminal-result"
|
||||
assert calls == [{"command": "printf ok"}]
|
||||
|
||||
def test_execution_middleware_pre_next_call_error_fails_open_to_remaining_chain(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def failing_middleware(**kwargs):
|
||||
calls.append("failing")
|
||||
raise RuntimeError("middleware setup failed")
|
||||
|
||||
def downstream_middleware(**kwargs):
|
||||
calls.append("downstream")
|
||||
return kwargs["next_call"]({**kwargs["args"], "rewritten": True})
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [failing_middleware, downstream_middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(("terminal", args))
|
||||
return args
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == {"command": "printf ok", "rewritten": True}
|
||||
assert calls == ["failing", "downstream", ("terminal", {"command": "printf ok", "rewritten": True})]
|
||||
|
||||
def test_execution_middleware_translated_downstream_failure_is_not_masked(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
try:
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"translated downstream failure: {exc}") from exc
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise RuntimeError("terminal failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="translated downstream failure: terminal failed"):
|
||||
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
|
||||
|
||||
assert calls == [{"command": "false"}]
|
||||
|
||||
def test_execution_middleware_downstream_base_exception_is_not_wrapped(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
try:
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"middleware should not catch base exception: {exc}") from exc
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
run_tool_execution_middleware("terminal", {"command": "interrupt"}, terminal)
|
||||
|
||||
assert calls == [{"command": "interrupt"}]
|
||||
|
||||
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
first = kwargs["next_call"](kwargs["args"])
|
||||
# Deliberate misuse: a second next_call() must not re-run the
|
||||
# downstream tool. The chain surfaces it as an error and preserves
|
||||
# the first (successful) downstream result.
|
||||
kwargs["next_call"](kwargs["args"])
|
||||
return first
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
return "terminal-result"
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == "terminal-result"
|
||||
assert calls == [{"command": "printf ok"}]
|
||||
|
||||
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
|
||||
import threading
|
||||
|
||||
recorded = {}
|
||||
|
||||
def middleware(**kwargs):
|
||||
recorded["args"] = kwargs["args"]
|
||||
return None
|
||||
|
||||
manager = types.SimpleNamespace(
|
||||
_middleware={"tool_request": [middleware]},
|
||||
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
|
||||
args = {"command": "noop", "lock": threading.Lock()}
|
||||
result = apply_tool_request_middleware("terminal", args)
|
||||
|
||||
# Middleware ran (payload was copied via the shallow fallback) and the
|
||||
# non-deepcopyable member is shared by reference rather than aborting.
|
||||
assert recorded["args"]["command"] == "noop"
|
||||
assert result.payload["command"] == "noop"
|
||||
assert result.payload["lock"] is args["lock"]
|
||||
|
||||
def test_discover_project_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ./.hermes/plugins/ are discovered."""
|
||||
project_dir = tmp_path / "project"
|
||||
|
|
|
|||
|
|
@ -27,8 +27,16 @@ class _FakeNemoRelay:
|
|||
pop=self._scope_pop,
|
||||
event=self._scope_event,
|
||||
)
|
||||
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
|
||||
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
|
||||
self.llm = SimpleNamespace(
|
||||
call=self._llm_call,
|
||||
call_end=self._llm_call_end,
|
||||
execute=self._llm_execute,
|
||||
)
|
||||
self.tools = SimpleNamespace(
|
||||
call=self._tool_call,
|
||||
call_end=self._tool_call_end,
|
||||
execute=self._tool_execute,
|
||||
)
|
||||
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
|
||||
self.LLMRequest = _FakeLLMRequest
|
||||
self.AtofExporterConfig = _FakeAtofExporterConfig
|
||||
|
|
@ -55,6 +63,12 @@ class _FakeNemoRelay:
|
|||
def _llm_call_end(self, handle, response, **kwargs):
|
||||
self.events.append(("llm.call_end", handle, response, kwargs))
|
||||
|
||||
def _llm_execute(self, name, request, func, **kwargs):
|
||||
self.events.append(("llm.execute.start", name, request.content, kwargs))
|
||||
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
self.events.append(("llm.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _tool_call(self, name, args, **kwargs):
|
||||
handle = ("tool", name)
|
||||
self.events.append(("tool.call", name, args, kwargs))
|
||||
|
|
@ -63,6 +77,12 @@ class _FakeNemoRelay:
|
|||
def _tool_call_end(self, handle, result, **kwargs):
|
||||
self.events.append(("tool.call_end", handle, result, kwargs))
|
||||
|
||||
def _tool_execute(self, name, args, func, **kwargs):
|
||||
self.events.append(("tool.execute.start", name, args, kwargs))
|
||||
result = func({"intercepted": True, **args})
|
||||
self.events.append(("tool.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _make_atof_exporter(self, config):
|
||||
return _FakeAtofExporter(self.events, config)
|
||||
|
||||
|
|
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
|
|||
assert atif_dir.is_dir()
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_request = {}
|
||||
raw_choice = SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="tool-1",
|
||||
type="function",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
|
||||
)
|
||||
],
|
||||
reasoning_content="need a tool",
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
|
||||
def next_call(request):
|
||||
seen_request.update(request)
|
||||
return SimpleNamespace(
|
||||
id="resp-1",
|
||||
model="demo-model",
|
||||
choices=[raw_choice],
|
||||
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
|
||||
)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
api_call_count=1,
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response.model == "demo-model"
|
||||
assert response.choices == [raw_choice]
|
||||
assert seen_request["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
|
||||
assert execute_end[2] == {
|
||||
"model": "demo-model",
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
|
||||
}
|
||||
],
|
||||
"reasoning_content": "need a tool",
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
|
||||
}
|
||||
|
||||
|
||||
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": []},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"messages": []}}
|
||||
assert not any(event[0] == "llm.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_args = {}
|
||||
|
||||
def next_call(args):
|
||||
seen_args.update(args)
|
||||
return {"raw": True, "args": args}
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
|
||||
assert seen_args["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
|
||||
|
||||
|
||||
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"command": "pwd"}}
|
||||
assert not any(event[0] == "tool.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
base = {
|
||||
"session_id": "s1",
|
||||
"task_id": "t1",
|
||||
"turn_id": "turn-1",
|
||||
"api_request_id": "api-1",
|
||||
}
|
||||
plugin.on_pre_api_request(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
|
||||
)
|
||||
plugin.on_post_api_request(**base, response={"ok": True})
|
||||
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
|
||||
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
|
||||
|
||||
plugin.on_llm_execution_middleware(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
plugin.on_tool_execution_middleware(
|
||||
**base,
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
event_names = [event[0] for event in fake.events]
|
||||
assert "llm.call" not in event_names
|
||||
assert "llm.call_end" not in event_names
|
||||
assert "tool.call" not in event_names
|
||||
assert "tool.call_end" not in event_names
|
||||
assert "llm.execute.start" in event_names
|
||||
assert "tool.execute.start" in event_names
|
||||
|
||||
|
||||
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
|
||||
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
|
||||
sys.modules.pop("plugins.observability.nemo_relay", None)
|
||||
|
|
|
|||
|
|
@ -2466,8 +2466,10 @@ class TestConcurrentToolExecution:
|
|||
api_request_id="",
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
skip_pre_tool_call_hook=True,
|
||||
skip_tool_request_middleware=True,
|
||||
enabled_toolsets=agent.enabled_toolsets,
|
||||
disabled_toolsets=agent.disabled_toolsets,
|
||||
tool_request_middleware_trace=[],
|
||||
)
|
||||
assert result == "result"
|
||||
|
||||
|
|
@ -2647,6 +2649,89 @@ class TestConcurrentToolExecution:
|
|||
assert post_call[1]["result"] == '{"ok":true}'
|
||||
assert post_call[1]["status"] == "ok"
|
||||
|
||||
def test_sequential_agent_level_tool_execution_middleware_wraps_inline_dispatch(self, agent, monkeypatch):
|
||||
"""Sequential built-in tool paths should expose the adaptive execution boundary."""
|
||||
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
hook_calls = []
|
||||
seen = {}
|
||||
|
||||
def request_middleware(**kwargs):
|
||||
return {
|
||||
"args": {**kwargs["args"], "request_rewritten": True},
|
||||
"source": "request-test",
|
||||
}
|
||||
|
||||
def execution_middleware(**kwargs):
|
||||
seen["middleware_args"] = kwargs["args"]
|
||||
return kwargs["next_call"]({**kwargs["args"], "merge": True})
|
||||
|
||||
manager = SimpleNamespace(_middleware={
|
||||
"tool_request": [request_middleware],
|
||||
"tool_execution": [execution_middleware],
|
||||
})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_middleware",
|
||||
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
|
||||
|
||||
assert seen["middleware_args"] == {"todos": [], "request_rewritten": True}
|
||||
mock_todo.assert_called_once_with(todos=[], merge=True, store=agent._todo_store)
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert post_call[1]["tool_name"] == "todo"
|
||||
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True, "merge": True}
|
||||
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
|
||||
|
||||
def test_concurrent_agent_level_tool_preserves_request_middleware_trace(self, agent, monkeypatch):
|
||||
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
hook_calls = []
|
||||
|
||||
def request_middleware(**kwargs):
|
||||
return {
|
||||
"args": {**kwargs["args"], "request_rewritten": True},
|
||||
"source": "request-test",
|
||||
}
|
||||
|
||||
manager = SimpleNamespace(_middleware={"tool_request": [request_middleware], "tool_execution": []})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_middleware",
|
||||
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}'):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert post_call[1]["tool_name"] == "todo"
|
||||
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True}
|
||||
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
|
||||
|
||||
def test_agent_runtime_post_hook_ownership_predicate_covers_agent_tools(self, agent):
|
||||
"""Sequential and concurrent agent-level paths share post-hook ownership."""
|
||||
from agent.agent_runtime_helpers import agent_runtime_owns_post_tool_hook
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class TestHandleFunctionCall:
|
|||
tool_call_id="call-1",
|
||||
turn_id="",
|
||||
api_request_id="",
|
||||
middleware_trace=[],
|
||||
),
|
||||
call(
|
||||
"post_tool_call",
|
||||
|
|
@ -79,6 +80,7 @@ class TestHandleFunctionCall:
|
|||
status="ok",
|
||||
error_type=None,
|
||||
error_message=None,
|
||||
middleware_trace=[],
|
||||
),
|
||||
call(
|
||||
"transform_tool_result",
|
||||
|
|
@ -145,6 +147,60 @@ class TestHandleFunctionCall:
|
|||
assert "post_tool_call" not in fired
|
||||
assert "transform_tool_result" not in fired
|
||||
|
||||
def test_tool_request_and_execution_middleware_wrap_registry_dispatch(self, monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_invoke_middleware(kind, **kwargs):
|
||||
if kind == "tool_request":
|
||||
return [{
|
||||
"args": {**kwargs["args"], "rewritten": True},
|
||||
"source": "test-middleware",
|
||||
"reason": "rewrite",
|
||||
}]
|
||||
return []
|
||||
|
||||
def execution_middleware(**kwargs):
|
||||
seen["execution_args"] = kwargs["args"]
|
||||
return kwargs["next_call"]({**kwargs["args"], "wrapped": True})
|
||||
|
||||
def fake_dispatch(tool_name, args, **kwargs):
|
||||
seen["dispatch"] = (tool_name, args, kwargs)
|
||||
return json.dumps({"ok": True, "args": args})
|
||||
|
||||
manager = type(
|
||||
"Manager",
|
||||
(),
|
||||
{"_middleware": {"tool_request": [fake_invoke_middleware], "tool_execution": [execution_middleware]}},
|
||||
)()
|
||||
monkeypatch.setattr("hermes_cli.plugins.invoke_middleware", fake_invoke_middleware)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
hook_calls = []
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
|
||||
|
||||
result = json.loads(
|
||||
handle_function_call(
|
||||
"web_search",
|
||||
{"q": "test"},
|
||||
task_id="task-1",
|
||||
tool_call_id="tool-1",
|
||||
session_id="session-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert seen["execution_args"] == {"q": "test", "rewritten": True}
|
||||
assert seen["dispatch"][1] == {"q": "test", "rewritten": True, "wrapped": True}
|
||||
assert result["args"] == {"q": "test", "rewritten": True, "wrapped": True}
|
||||
expected_trace = [{"source": "test-middleware", "reason": "rewrite"}]
|
||||
pre_call = next(call for call in hook_calls if call[0] == "pre_tool_call")
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert pre_call[1]["middleware_trace"] == expected_trace
|
||||
assert post_call[1]["middleware_trace"] == expected_trace
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue