mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-11 08:42:11 +00:00
Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin
feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
commit
d4a7bfd3aa
14 changed files with 2170 additions and 151 deletions
|
|
@ -27,8 +27,16 @@ class _FakeNemoRelay:
|
|||
pop=self._scope_pop,
|
||||
event=self._scope_event,
|
||||
)
|
||||
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
|
||||
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
|
||||
self.llm = SimpleNamespace(
|
||||
call=self._llm_call,
|
||||
call_end=self._llm_call_end,
|
||||
execute=self._llm_execute,
|
||||
)
|
||||
self.tools = SimpleNamespace(
|
||||
call=self._tool_call,
|
||||
call_end=self._tool_call_end,
|
||||
execute=self._tool_execute,
|
||||
)
|
||||
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
|
||||
self.LLMRequest = _FakeLLMRequest
|
||||
self.AtofExporterConfig = _FakeAtofExporterConfig
|
||||
|
|
@ -55,6 +63,12 @@ class _FakeNemoRelay:
|
|||
def _llm_call_end(self, handle, response, **kwargs):
|
||||
self.events.append(("llm.call_end", handle, response, kwargs))
|
||||
|
||||
def _llm_execute(self, name, request, func, **kwargs):
|
||||
self.events.append(("llm.execute.start", name, request.content, kwargs))
|
||||
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
self.events.append(("llm.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _tool_call(self, name, args, **kwargs):
|
||||
handle = ("tool", name)
|
||||
self.events.append(("tool.call", name, args, kwargs))
|
||||
|
|
@ -63,6 +77,12 @@ class _FakeNemoRelay:
|
|||
def _tool_call_end(self, handle, result, **kwargs):
|
||||
self.events.append(("tool.call_end", handle, result, kwargs))
|
||||
|
||||
def _tool_execute(self, name, args, func, **kwargs):
|
||||
self.events.append(("tool.execute.start", name, args, kwargs))
|
||||
result = func({"intercepted": True, **args})
|
||||
self.events.append(("tool.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _make_atof_exporter(self, config):
|
||||
return _FakeAtofExporter(self.events, config)
|
||||
|
||||
|
|
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
|
|||
assert atif_dir.is_dir()
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_request = {}
|
||||
raw_choice = SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="tool-1",
|
||||
type="function",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
|
||||
)
|
||||
],
|
||||
reasoning_content="need a tool",
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
|
||||
def next_call(request):
|
||||
seen_request.update(request)
|
||||
return SimpleNamespace(
|
||||
id="resp-1",
|
||||
model="demo-model",
|
||||
choices=[raw_choice],
|
||||
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
|
||||
)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
api_call_count=1,
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response.model == "demo-model"
|
||||
assert response.choices == [raw_choice]
|
||||
assert seen_request["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
|
||||
assert execute_end[2] == {
|
||||
"model": "demo-model",
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
|
||||
}
|
||||
],
|
||||
"reasoning_content": "need a tool",
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
|
||||
}
|
||||
|
||||
|
||||
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": []},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"messages": []}}
|
||||
assert not any(event[0] == "llm.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_args = {}
|
||||
|
||||
def next_call(args):
|
||||
seen_args.update(args)
|
||||
return {"raw": True, "args": args}
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
|
||||
assert seen_args["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
|
||||
|
||||
|
||||
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"command": "pwd"}}
|
||||
assert not any(event[0] == "tool.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
base = {
|
||||
"session_id": "s1",
|
||||
"task_id": "t1",
|
||||
"turn_id": "turn-1",
|
||||
"api_request_id": "api-1",
|
||||
}
|
||||
plugin.on_pre_api_request(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
|
||||
)
|
||||
plugin.on_post_api_request(**base, response={"ok": True})
|
||||
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
|
||||
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
|
||||
|
||||
plugin.on_llm_execution_middleware(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
plugin.on_tool_execution_middleware(
|
||||
**base,
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
event_names = [event[0] for event in fake.events]
|
||||
assert "llm.call" not in event_names
|
||||
assert "llm.call_end" not in event_names
|
||||
assert "tool.call" not in event_names
|
||||
assert "tool.call_end" not in event_names
|
||||
assert "llm.execute.start" in event_names
|
||||
assert "tool.execute.start" in event_names
|
||||
|
||||
|
||||
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
|
||||
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
|
||||
sys.modules.pop("plugins.observability.nemo_relay", None)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue