Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin

feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
kshitij 2026-06-06 10:46:41 -07:00 committed by GitHub
commit d4a7bfd3aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2170 additions and 151 deletions

View file

@ -27,8 +27,16 @@ class _FakeNemoRelay:
pop=self._scope_pop,
event=self._scope_event,
)
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
self.llm = SimpleNamespace(
call=self._llm_call,
call_end=self._llm_call_end,
execute=self._llm_execute,
)
self.tools = SimpleNamespace(
call=self._tool_call,
call_end=self._tool_call_end,
execute=self._tool_execute,
)
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
self.LLMRequest = _FakeLLMRequest
self.AtofExporterConfig = _FakeAtofExporterConfig
@ -55,6 +63,12 @@ class _FakeNemoRelay:
def _llm_call_end(self, handle, response, **kwargs):
self.events.append(("llm.call_end", handle, response, kwargs))
def _llm_execute(self, name, request, func, **kwargs):
self.events.append(("llm.execute.start", name, request.content, kwargs))
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
self.events.append(("llm.execute.end", name, result, kwargs))
return result
def _tool_call(self, name, args, **kwargs):
handle = ("tool", name)
self.events.append(("tool.call", name, args, kwargs))
@ -63,6 +77,12 @@ class _FakeNemoRelay:
def _tool_call_end(self, handle, result, **kwargs):
self.events.append(("tool.call_end", handle, result, kwargs))
def _tool_execute(self, name, args, func, **kwargs):
self.events.append(("tool.execute.start", name, args, kwargs))
result = func({"intercepted": True, **args})
self.events.append(("tool.execute.end", name, result, kwargs))
return result
def _make_atof_exporter(self, config):
return _FakeAtofExporter(self.events, config)
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
assert atif_dir.is_dir()
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_request = {}
raw_choice = SimpleNamespace(
message=SimpleNamespace(
role="assistant",
content=None,
tool_calls=[
SimpleNamespace(
id="tool-1",
type="function",
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
)
],
reasoning_content="need a tool",
),
finish_reason="tool_calls",
)
def next_call(request):
seen_request.update(request)
return SimpleNamespace(
id="resp-1",
model="demo-model",
choices=[raw_choice],
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
)
response = plugin.on_llm_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
provider="anthropic",
model="demo-model",
api_call_count=1,
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert response.model == "demo-model"
assert response.choices == [raw_choice]
assert seen_request["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
assert execute_end[2] == {
"model": "demo-model",
"assistant_message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "tool-1",
"type": "function",
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
}
],
"reasoning_content": "need a tool",
},
"finish_reason": "tool_calls",
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
}
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": []},
next_call=lambda request: {"raw": request},
)
assert response == {"raw": {"messages": []}}
assert not any(event[0] == "llm.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_args = {}
def next_call(args):
seen_args.update(args)
return {"raw": True, "args": args}
response = plugin.on_tool_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=next_call,
)
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
assert seen_args["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
assert response == {"raw": {"command": "pwd"}}
assert not any(event[0] == "tool.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
base = {
"session_id": "s1",
"task_id": "t1",
"turn_id": "turn-1",
"api_request_id": "api-1",
}
plugin.on_pre_api_request(
**base,
provider="anthropic",
model="demo-model",
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
)
plugin.on_post_api_request(**base, response={"ok": True})
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
plugin.on_llm_execution_middleware(
**base,
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=lambda request: {"raw": request},
)
plugin.on_tool_execution_middleware(
**base,
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
event_names = [event[0] for event in fake.events]
assert "llm.call" not in event_names
assert "llm.call_end" not in event_names
assert "tool.call" not in event_names
assert "tool.call_end" not in event_names
assert "llm.execute.start" in event_names
assert "tool.execute.start" in event_names
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
sys.modules.pop("plugins.observability.nemo_relay", None)