Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin

feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
kshitij 2026-06-06 10:46:41 -07:00 committed by GitHub
commit d4a7bfd3aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2170 additions and 151 deletions

View file

@ -18,8 +18,15 @@ from hermes_cli.plugins import (
get_plugin_command_handler,
get_plugin_commands,
get_pre_tool_call_block_message,
has_middleware,
resolve_plugin_command_result,
)
from hermes_cli.middleware import (
VALID_MIDDLEWARE,
apply_llm_request_middleware,
apply_tool_request_middleware,
run_tool_execution_middleware,
)
# ── Helpers ────────────────────────────────────────────────────────────────
@ -96,6 +103,223 @@ class TestPluginDiscovery:
assert "hello_plugin" in mgr._plugins
assert mgr._plugins["hello_plugin"].enabled
def test_plugin_can_register_and_invoke_middleware(self, tmp_path, monkeypatch):
plugins_dir = tmp_path / "hermes_test" / "plugins"
_make_plugin_dir(
plugins_dir,
"mw_plugin",
register_body=(
"ctx.register_middleware('llm_request', "
"lambda **kw: {'request': {**kw['request'], 'mw': True}})\n"
" ctx.register_middleware('tool_request', "
"lambda **kw: {'args': {**kw['args'], 'mw': True}})"
),
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
mgr = PluginManager()
mgr.discover_and_load()
assert "llm_request" in VALID_MIDDLEWARE
assert "tool_request" in VALID_MIDDLEWARE
assert set(mgr._plugins["mw_plugin"].middleware_registered) == {"llm_request", "tool_request"}
assert mgr.invoke_middleware("llm_request", request={"messages": []}) == [
{"request": {"messages": [], "mw": True}}
]
assert mgr.invoke_middleware("tool_request", args={"path": "README.md"}) == [
{"args": {"path": "README.md", "mw": True}}
]
assert mgr.has_middleware("llm_request") is True
def test_execution_middleware_does_not_retry_downstream_failure(self, monkeypatch):
calls = []
def middleware(**kwargs):
return kwargs["next_call"](kwargs["args"])
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise RuntimeError("tool failed")
with pytest.raises(RuntimeError, match="tool failed"):
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
assert calls == [{"command": "false"}]
def test_middleware_helpers_skip_no_listener_work(self, monkeypatch):
manager = types.SimpleNamespace(_middleware={})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
request = {"messages": []}
args = {"path": "README.md"}
llm_result = apply_llm_request_middleware(request)
tool_result = apply_tool_request_middleware("read_file", args)
assert llm_result.payload is request
assert llm_result.original_payload is request
assert llm_result.changed is False
assert llm_result.trace == []
assert tool_result.payload is args
assert tool_result.original_payload is args
assert tool_result.changed is False
assert tool_result.trace == []
assert run_tool_execution_middleware("terminal", args, lambda payload: payload) is args
assert has_middleware("tool_request") is False
def test_request_middleware_changed_tracks_trace_not_deep_equality(self, monkeypatch):
def same_payload_middleware(**kwargs):
return {"args": kwargs["args"], "source": "same-payload"}
manager = types.SimpleNamespace(
_middleware={"tool_request": [same_payload_middleware]},
invoke_middleware=lambda kind, **kwargs: [same_payload_middleware(**kwargs)],
)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
args = {"path": "README.md"}
result = apply_tool_request_middleware("read_file", args)
assert result.payload == args
assert result.original_payload == args
assert result.changed is True
assert result.trace == [{"source": "same-payload"}]
def test_execution_middleware_post_next_call_error_does_not_retry(self, monkeypatch):
calls = []
def middleware(**kwargs):
result = kwargs["next_call"](kwargs["args"])
raise RuntimeError(f"post-processing failed after {result}")
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
return "terminal-result"
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == "terminal-result"
assert calls == [{"command": "printf ok"}]
def test_execution_middleware_pre_next_call_error_fails_open_to_remaining_chain(self, monkeypatch):
calls = []
def failing_middleware(**kwargs):
calls.append("failing")
raise RuntimeError("middleware setup failed")
def downstream_middleware(**kwargs):
calls.append("downstream")
return kwargs["next_call"]({**kwargs["args"], "rewritten": True})
manager = types.SimpleNamespace(_middleware={"tool_execution": [failing_middleware, downstream_middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(("terminal", args))
return args
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == {"command": "printf ok", "rewritten": True}
assert calls == ["failing", "downstream", ("terminal", {"command": "printf ok", "rewritten": True})]
def test_execution_middleware_translated_downstream_failure_is_not_masked(self, monkeypatch):
calls = []
def middleware(**kwargs):
try:
return kwargs["next_call"](kwargs["args"])
except Exception as exc:
raise RuntimeError(f"translated downstream failure: {exc}") from exc
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise RuntimeError("terminal failed")
with pytest.raises(RuntimeError, match="translated downstream failure: terminal failed"):
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
assert calls == [{"command": "false"}]
def test_execution_middleware_downstream_base_exception_is_not_wrapped(self, monkeypatch):
calls = []
def middleware(**kwargs):
try:
return kwargs["next_call"](kwargs["args"])
except Exception as exc:
raise RuntimeError(f"middleware should not catch base exception: {exc}") from exc
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
run_tool_execution_middleware("terminal", {"command": "interrupt"}, terminal)
assert calls == [{"command": "interrupt"}]
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
calls = []
def middleware(**kwargs):
first = kwargs["next_call"](kwargs["args"])
# Deliberate misuse: a second next_call() must not re-run the
# downstream tool. The chain surfaces it as an error and preserves
# the first (successful) downstream result.
kwargs["next_call"](kwargs["args"])
return first
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
return "terminal-result"
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == "terminal-result"
assert calls == [{"command": "printf ok"}]
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
import threading
recorded = {}
def middleware(**kwargs):
recorded["args"] = kwargs["args"]
return None
manager = types.SimpleNamespace(
_middleware={"tool_request": [middleware]},
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
args = {"command": "noop", "lock": threading.Lock()}
result = apply_tool_request_middleware("terminal", args)
# Middleware ran (payload was copied via the shallow fallback) and the
# non-deepcopyable member is shared by reference rather than aborting.
assert recorded["args"]["command"] == "noop"
assert result.payload["command"] == "noop"
assert result.payload["lock"] is args["lock"]
def test_discover_project_plugins(self, tmp_path, monkeypatch):
"""Plugins in ./.hermes/plugins/ are discovered."""
project_dir = tmp_path / "project"

View file

@ -27,8 +27,16 @@ class _FakeNemoRelay:
pop=self._scope_pop,
event=self._scope_event,
)
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
self.llm = SimpleNamespace(
call=self._llm_call,
call_end=self._llm_call_end,
execute=self._llm_execute,
)
self.tools = SimpleNamespace(
call=self._tool_call,
call_end=self._tool_call_end,
execute=self._tool_execute,
)
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
self.LLMRequest = _FakeLLMRequest
self.AtofExporterConfig = _FakeAtofExporterConfig
@ -55,6 +63,12 @@ class _FakeNemoRelay:
def _llm_call_end(self, handle, response, **kwargs):
self.events.append(("llm.call_end", handle, response, kwargs))
def _llm_execute(self, name, request, func, **kwargs):
self.events.append(("llm.execute.start", name, request.content, kwargs))
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
self.events.append(("llm.execute.end", name, result, kwargs))
return result
def _tool_call(self, name, args, **kwargs):
handle = ("tool", name)
self.events.append(("tool.call", name, args, kwargs))
@ -63,6 +77,12 @@ class _FakeNemoRelay:
def _tool_call_end(self, handle, result, **kwargs):
self.events.append(("tool.call_end", handle, result, kwargs))
def _tool_execute(self, name, args, func, **kwargs):
self.events.append(("tool.execute.start", name, args, kwargs))
result = func({"intercepted": True, **args})
self.events.append(("tool.execute.end", name, result, kwargs))
return result
def _make_atof_exporter(self, config):
return _FakeAtofExporter(self.events, config)
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
assert atif_dir.is_dir()
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_request = {}
raw_choice = SimpleNamespace(
message=SimpleNamespace(
role="assistant",
content=None,
tool_calls=[
SimpleNamespace(
id="tool-1",
type="function",
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
)
],
reasoning_content="need a tool",
),
finish_reason="tool_calls",
)
def next_call(request):
seen_request.update(request)
return SimpleNamespace(
id="resp-1",
model="demo-model",
choices=[raw_choice],
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
)
response = plugin.on_llm_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
provider="anthropic",
model="demo-model",
api_call_count=1,
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert response.model == "demo-model"
assert response.choices == [raw_choice]
assert seen_request["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
assert execute_end[2] == {
"model": "demo-model",
"assistant_message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "tool-1",
"type": "function",
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
}
],
"reasoning_content": "need a tool",
},
"finish_reason": "tool_calls",
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
}
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": []},
next_call=lambda request: {"raw": request},
)
assert response == {"raw": {"messages": []}}
assert not any(event[0] == "llm.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_args = {}
def next_call(args):
seen_args.update(args)
return {"raw": True, "args": args}
response = plugin.on_tool_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=next_call,
)
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
assert seen_args["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
assert response == {"raw": {"command": "pwd"}}
assert not any(event[0] == "tool.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
base = {
"session_id": "s1",
"task_id": "t1",
"turn_id": "turn-1",
"api_request_id": "api-1",
}
plugin.on_pre_api_request(
**base,
provider="anthropic",
model="demo-model",
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
)
plugin.on_post_api_request(**base, response={"ok": True})
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
plugin.on_llm_execution_middleware(
**base,
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=lambda request: {"raw": request},
)
plugin.on_tool_execution_middleware(
**base,
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
event_names = [event[0] for event in fake.events]
assert "llm.call" not in event_names
assert "llm.call_end" not in event_names
assert "tool.call" not in event_names
assert "tool.call_end" not in event_names
assert "llm.execute.start" in event_names
assert "tool.execute.start" in event_names
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
sys.modules.pop("plugins.observability.nemo_relay", None)

View file

@ -2466,8 +2466,10 @@ class TestConcurrentToolExecution:
api_request_id="",
enabled_tools=list(agent.valid_tool_names),
skip_pre_tool_call_hook=True,
skip_tool_request_middleware=True,
enabled_toolsets=agent.enabled_toolsets,
disabled_toolsets=agent.disabled_toolsets,
tool_request_middleware_trace=[],
)
assert result == "result"
@ -2647,6 +2649,89 @@ class TestConcurrentToolExecution:
assert post_call[1]["result"] == '{"ok":true}'
assert post_call[1]["status"] == "ok"
def test_sequential_agent_level_tool_execution_middleware_wraps_inline_dispatch(self, agent, monkeypatch):
"""Sequential built-in tool paths should expose the adaptive execution boundary."""
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
hook_calls = []
seen = {}
def request_middleware(**kwargs):
return {
"args": {**kwargs["args"], "request_rewritten": True},
"source": "request-test",
}
def execution_middleware(**kwargs):
seen["middleware_args"] = kwargs["args"]
return kwargs["next_call"]({**kwargs["args"], "merge": True})
manager = SimpleNamespace(_middleware={
"tool_request": [request_middleware],
"tool_execution": [execution_middleware],
})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_middleware",
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
)
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
assert seen["middleware_args"] == {"todos": [], "request_rewritten": True}
mock_todo.assert_called_once_with(todos=[], merge=True, store=agent._todo_store)
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert post_call[1]["tool_name"] == "todo"
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True, "merge": True}
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
def test_concurrent_agent_level_tool_preserves_request_middleware_trace(self, agent, monkeypatch):
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
hook_calls = []
def request_middleware(**kwargs):
return {
"args": {**kwargs["args"], "request_rewritten": True},
"source": "request-test",
}
manager = SimpleNamespace(_middleware={"tool_request": [request_middleware], "tool_execution": []})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_middleware",
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
)
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}'):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert post_call[1]["tool_name"] == "todo"
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True}
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
def test_agent_runtime_post_hook_ownership_predicate_covers_agent_tools(self, agent):
"""Sequential and concurrent agent-level paths share post-hook ownership."""
from agent.agent_runtime_helpers import agent_runtime_owns_post_tool_hook

View file

@ -64,6 +64,7 @@ class TestHandleFunctionCall:
tool_call_id="call-1",
turn_id="",
api_request_id="",
middleware_trace=[],
),
call(
"post_tool_call",
@ -79,6 +80,7 @@ class TestHandleFunctionCall:
status="ok",
error_type=None,
error_message=None,
middleware_trace=[],
),
call(
"transform_tool_result",
@ -145,6 +147,60 @@ class TestHandleFunctionCall:
assert "post_tool_call" not in fired
assert "transform_tool_result" not in fired
def test_tool_request_and_execution_middleware_wrap_registry_dispatch(self, monkeypatch):
seen = {}
def fake_invoke_middleware(kind, **kwargs):
if kind == "tool_request":
return [{
"args": {**kwargs["args"], "rewritten": True},
"source": "test-middleware",
"reason": "rewrite",
}]
return []
def execution_middleware(**kwargs):
seen["execution_args"] = kwargs["args"]
return kwargs["next_call"]({**kwargs["args"], "wrapped": True})
def fake_dispatch(tool_name, args, **kwargs):
seen["dispatch"] = (tool_name, args, kwargs)
return json.dumps({"ok": True, "args": args})
manager = type(
"Manager",
(),
{"_middleware": {"tool_request": [fake_invoke_middleware], "tool_execution": [execution_middleware]}},
)()
monkeypatch.setattr("hermes_cli.plugins.invoke_middleware", fake_invoke_middleware)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
hook_calls = []
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
result = json.loads(
handle_function_call(
"web_search",
{"q": "test"},
task_id="task-1",
tool_call_id="tool-1",
session_id="session-1",
)
)
assert seen["execution_args"] == {"q": "test", "rewritten": True}
assert seen["dispatch"][1] == {"q": "test", "rewritten": True, "wrapped": True}
assert result["args"] == {"q": "test", "rewritten": True, "wrapped": True}
expected_trace = [{"source": "test-middleware", "reason": "rewrite"}]
pre_call = next(call for call in hook_calls if call[0] == "pre_tool_call")
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert pre_call[1]["middleware_trace"] == expected_trace
assert post_call[1]["middleware_trace"] == expected_trace
# =========================================================================
# Agent loop tools