Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin

feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
kshitij 2026-06-06 10:46:41 -07:00 committed by GitHub
commit d4a7bfd3aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2170 additions and 151 deletions

View file

@ -165,6 +165,28 @@ When `HERMES_NEMO_RELAY_PLUGINS_TOML` is set and initializes successfully, NeMo
Relay owns exporter lifecycle through that config. The direct
`HERMES_NEMO_RELAY_ATOF_*` fallback setup is skipped.
To enable NeMo Relay managed execution intercepts for provider and tool calls,
include an adaptive component in the same `plugins.toml`:
```toml
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
```
When the adaptive component is enabled and the installed NeMo Relay runtime
exposes `llm.execute(...)` / `tools.execute(...)`, Hermes routes LLM and tool
execution through those middleware boundaries. The observer hooks still emit
session, turn, approval, and subagent marks; the plugin skips its manual
`llm.call` and `tools.call` spans for executions that are already managed by
NeMo Relay.
For the full generic Hermes middleware contract, see
[`docs/middleware/README.md`](../../../docs/middleware/README.md).
## Canonical Local Examples
The examples below use the official `nemo-relay==0.3` distribution and a local
@ -366,3 +388,166 @@ subagent IDs, role/status fields when present, and derived
`parent_trajectory_id` / `child_trajectory_id` values. This keeps the ATOF
stream lossless for later ATIF conversion that can compact subagents into
separate trajectories.
## Adaptive Middleware Example
The `observability/nemo_relay` plugin uses Hermes execution middleware to hand
LLM and tool calls to NeMo Relay managed execution when an adaptive component is
enabled.
Minimal `plugins.toml`:
```toml
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
```
Enable it for Hermes:
```bash
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/plugins.toml
```
When the adaptive component is enabled and the installed NeMo Relay runtime
exposes `llm.execute(...)` and `tools.execute(...)`, Hermes routes execution
through these boundaries:
```text
Hermes provider call
-> llm_execution middleware
-> nemo_relay.llm.execute(...)
-> Hermes provider adapter next_call(...)
Hermes tool call
-> tool_execution middleware
-> nemo_relay.tools.execute(...)
-> Hermes tool dispatcher next_call(...)
```
The plugin still emits observer marks for sessions, turns, approvals, and
subagents. When adaptive managed execution is active, it skips manual
`llm.call` and `tools.call` observer spans to avoid duplicate LLM/tool events
for the same execution.
### Local Adaptive E2E
This example enables both NeMo Relay observability export and adaptive execution
middleware for a local Hermes run.
```bash
pip install "nemo-relay==0.3"
export HERMES_HOME=/tmp/hermes-middleware-test/hermes-home
mkdir -p "$HERMES_HOME" /tmp/hermes-middleware-test/nemo-relay
cat > "$HERMES_HOME/config.yaml" <<'YAML'
model:
provider: custom
default: qwen3.6:35b
base_url: http://127.0.0.1:11434/v1
api_key: ollama
plugins:
enabled:
- observability/nemo_relay
YAML
cat > /tmp/hermes-middleware-test/nemo-relay/plugins.toml <<'TOML'
version = 1
[[components]]
kind = "observability"
enabled = true
[components.config]
version = 1
[components.config.atof]
enabled = true
output_directory = "/tmp/hermes-middleware-test/atof"
filename = "middleware-events.jsonl"
mode = "overwrite"
[components.config.atif]
enabled = true
output_directory = "/tmp/hermes-middleware-test/atif"
filename_template = "middleware-trajectory-{session_id}.json"
agent_name = "Hermes Middleware E2E"
agent_version = "local"
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
TOML
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/nemo-relay/plugins.toml
hermes chat \
--query 'Use the terminal tool exactly once to run printf middleware_execution_ok. Then reply with exactly the command output.' \
--provider custom \
--model qwen3.6:35b \
--toolsets terminal \
--max-turns 4 \
--quiet \
--accept-hooks
```
Expected CLI output:
```text
session_id: middleware-demo-session
middleware_execution_ok
```
Expected ATOF shape:
```jsonl
{"kind":"scope","category":"llm","name":"custom","scope_category":"start","metadata":{"session_id":"middleware-demo-session"},"data":{"mode":"route"}}
{"kind":"scope","category":"tool","name":"terminal","scope_category":"start","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal"},"data":{"mode":"route"}}
{"kind":"scope","category":"tool","name":"terminal","scope_category":"end","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal","status":"ok"},"data":"{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"}
```
Expected ATIF shape:
```json
{
"schema_version": "ATIF-v1.7",
"session_id": "middleware-demo-session",
"agent": {
"name": "Hermes Middleware E2E",
"version": "local",
"model_name": "qwen3.6:35b"
},
"steps": [
{
"source": "agent",
"tool_calls": [
{
"function_name": "terminal",
"arguments": {"command": "printf middleware_execution_ok"}
}
],
"observation": {
"results": [
{
"source_call_id": "call_terminal",
"content": "{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"
}
]
}
},
{
"source": "agent",
"message": "middleware_execution_ok"
}
]
}
```

View file

@ -42,6 +42,9 @@ class _SubagentParent:
@dataclass
class _Settings:
plugins_toml_path: str = ""
plugins_config: dict[str, Any] | None = None
adaptive_enabled: bool = False
adaptive_mode: str = "observe"
atof_enabled: bool = False
atof_output_directory: str = ""
atof_filename: str = "hermes-atof.jsonl"
@ -67,17 +70,15 @@ class _Runtime:
self._configure_atof()
def _configure_plugins_toml(self) -> bool:
if not self.settings.plugins_toml_path:
if not self.settings.plugins_config:
return False
plugin_mod = getattr(self.nemo_relay, "plugin", None)
initialize = getattr(plugin_mod, "initialize", None)
if not callable(initialize):
return False
config_path = Path(self.settings.plugins_toml_path)
try:
config = tomllib.loads(config_path.read_text(encoding="utf-8"))
self._ensure_plugin_config_output_dirs(config)
result = initialize(config)
self._ensure_plugin_config_output_dirs(self.settings.plugins_config)
result = initialize(self.settings.plugins_config)
if inspect.isawaitable(result):
asyncio.run(result)
return True
@ -221,6 +222,100 @@ class _Runtime:
self.subagent_parents.pop(child_session_id, None)
self.mark("hermes.subagent.stop", kwargs)
def managed_llm_enabled(self) -> bool:
return (
self.settings.adaptive_enabled
and callable(getattr(getattr(self.nemo_relay, "llm", None), "execute", None))
and callable(getattr(self.nemo_relay, "LLMRequest", None))
)
def managed_tool_enabled(self) -> bool:
return (
self.settings.adaptive_enabled
and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None))
)
def execute_llm(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
request_body = _jsonable(kwargs.get("request") or {})
request = self.nemo_relay.LLMRequest({}, request_body)
next_call = kwargs.get("next_call")
if not callable(next_call):
return request_body
raw_response: dict[str, Any] = {"set": False, "value": None}
def _impl(next_request: Any) -> Any:
next_body = getattr(next_request, "content", next_request)
raw = next_call(next_body if isinstance(next_body, dict) else request_body)
raw_response["set"] = True
raw_response["value"] = raw
return _llm_response_payload(raw)
async def _managed_execute() -> Any:
result = self.nemo_relay.llm.execute(
str(kwargs.get("provider") or "llm"),
request,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"api_call_count": kwargs.get("api_call_count"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
model_name=str(kwargs.get("model") or ""),
)
if inspect.isawaitable(result):
return await result
return result
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
def execute_tool(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
tool_name = str(kwargs.get("tool_name") or "tool")
args = _jsonable(kwargs.get("args") or {})
next_call = kwargs.get("next_call")
if not callable(next_call):
return args
raw_response: dict[str, Any] = {"set": False, "value": None}
def _impl(next_args: Any) -> Any:
effective_args = next_args if isinstance(next_args, dict) else args
raw = next_call(effective_args)
raw_response["set"] = True
raw_response["value"] = raw
return _jsonable(raw)
async def _managed_execute() -> Any:
result = self.nemo_relay.tools.execute(
tool_name,
args,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"tool_call_id": kwargs.get("tool_call_id"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
)
if inspect.isawaitable(result):
return await result
return result
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
def register(ctx) -> None:
ctx.register_hook("on_session_start", on_session_start)
@ -238,6 +333,8 @@ def register(ctx) -> None:
ctx.register_hook("post_approval_response", on_post_approval_response)
ctx.register_hook("subagent_start", on_subagent_start)
ctx.register_hook("subagent_stop", on_subagent_stop)
ctx.register_middleware("llm_execution", on_llm_execution_middleware)
ctx.register_middleware("tool_execution", on_tool_execution_middleware)
def on_session_start(**kwargs: Any) -> None:
@ -280,6 +377,8 @@ def on_pre_api_request(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -303,6 +402,8 @@ def on_post_api_request(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -324,6 +425,8 @@ def on_api_request_error(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -345,6 +448,8 @@ def on_pre_tool_call(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_tool_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -365,6 +470,8 @@ def on_post_tool_call(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_tool_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -406,6 +513,28 @@ def on_subagent_stop(**kwargs: Any) -> None:
_safe(lambda: runtime.mark_subagent_stop(kwargs))
def on_llm_execution_middleware(**kwargs: Any) -> Any:
runtime = _get_runtime()
next_call = kwargs.get("next_call")
request = kwargs.get("request") or {}
if runtime is not None and runtime.managed_llm_enabled():
return runtime.execute_llm(kwargs)
if callable(next_call):
return next_call(request)
return request
def on_tool_execution_middleware(**kwargs: Any) -> Any:
runtime = _get_runtime()
next_call = kwargs.get("next_call")
args = kwargs.get("args") or {}
if runtime is not None and runtime.managed_tool_enabled():
return runtime.execute_tool(kwargs)
if callable(next_call):
return next_call(args)
return args
def _get_runtime() -> Optional[_Runtime]:
global _RUNTIME
with _LOCK:
@ -429,8 +558,14 @@ def _get_runtime() -> Optional[_Runtime]:
def _load_settings() -> _Settings:
plugins_toml_path = _env("HERMES_NEMO_RELAY_PLUGINS_TOML")
plugins_config = _load_plugins_config(plugins_toml_path)
adaptive_config = _enabled_component_config(plugins_config, "adaptive")
return _Settings(
plugins_toml_path=_env("HERMES_NEMO_RELAY_PLUGINS_TOML"),
plugins_toml_path=plugins_toml_path,
plugins_config=plugins_config,
adaptive_enabled=adaptive_config is not None,
adaptive_mode=_adaptive_mode(adaptive_config),
atof_enabled=_env_bool("HERMES_NEMO_RELAY_ATOF_ENABLED"),
atof_output_directory=_env("HERMES_NEMO_RELAY_ATOF_OUTPUT_DIRECTORY"),
atof_filename=_env("HERMES_NEMO_RELAY_ATOF_FILENAME") or "hermes-atof.jsonl",
@ -445,6 +580,44 @@ def _load_settings() -> _Settings:
)
def _load_plugins_config(path: str) -> dict[str, Any] | None:
if not path:
return None
try:
return tomllib.loads(Path(path).read_text(encoding="utf-8"))
except Exception as exc:
logger.debug("NeMo Relay plugins.toml load failed: %s", exc, exc_info=True)
return None
def _enabled_component_config(
plugins_config: dict[str, Any] | None,
kind: str,
) -> dict[str, Any] | None:
if not isinstance(plugins_config, dict):
return None
components = plugins_config.get("components")
if not isinstance(components, list):
return None
for component in components:
if not isinstance(component, dict):
continue
if component.get("kind") != kind or not component.get("enabled", True):
continue
config = component.get("config")
return config if isinstance(config, dict) else {}
return None
def _adaptive_mode(config: dict[str, Any] | None) -> str:
if not isinstance(config, dict):
return "observe"
mode = config.get("mode")
if isinstance(mode, str) and mode.strip():
return mode.strip()
return "observe"
def _env(name: str) -> str:
return os.environ.get(name, "").strip()
@ -549,12 +722,78 @@ def _jsonable(value: Any) -> Any:
return _jsonable(value.model_dump(mode="json"))
except Exception:
pass
try:
if hasattr(value, "__dict__"):
return _jsonable(vars(value))
except Exception:
pass
try:
return json.loads(json.dumps(value, default=str))
except Exception:
return str(value)
def _value(obj: Any, key: str, default: Any = None) -> Any:
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _llm_response_payload(response: Any) -> Any:
"""Return the LLM response shape NeMo Relay's ATIF conversion expects."""
payload = _jsonable(response)
if isinstance(payload, dict) and "assistant_message" in payload:
return payload
choices = _value(response, "choices")
if choices is None and isinstance(payload, dict):
choices = payload.get("choices")
first_choice = choices[0] if isinstance(choices, list) and choices else None
message = _value(first_choice, "message")
finish_reason = _value(first_choice, "finish_reason")
assistant_message: dict[str, Any] = {"role": "assistant", "content": ""}
if message is not None:
assistant_message["role"] = _value(message, "role", "assistant") or "assistant"
content = _value(message, "content")
if content is not None:
assistant_message["content"] = _jsonable(content)
tool_calls = _tool_calls_payload(_value(message, "tool_calls"))
if tool_calls:
assistant_message["tool_calls"] = tool_calls
reasoning = _value(message, "reasoning_content")
if reasoning is not None:
assistant_message["reasoning_content"] = _jsonable(reasoning)
elif isinstance(payload, dict):
assistant_message["content"] = payload.get("content") or payload.get("output_text") or ""
return {
"model": _value(response, "model", payload.get("model") if isinstance(payload, dict) else None),
"assistant_message": assistant_message,
"finish_reason": finish_reason,
"usage": _jsonable(_value(response, "usage", payload.get("usage") if isinstance(payload, dict) else None)),
}
def _tool_calls_payload(tool_calls: Any) -> list[dict[str, Any]]:
if not isinstance(tool_calls, list):
return []
normalized: list[dict[str, Any]] = []
for call in tool_calls:
function = _value(call, "function")
normalized.append(
{
"id": _value(call, "id"),
"type": _value(call, "type", "function") or "function",
"function": {
"name": _value(function, "name"),
"arguments": _value(function, "arguments"),
},
}
)
return normalized
def _safe(fn) -> None:
try:
fn()
@ -562,6 +801,35 @@ def _safe(fn) -> None:
logger.debug("NeMo Relay hook handling failed: %s", exc, exc_info=True)
def _resolve_awaitable(value: Any) -> Any:
if not inspect.isawaitable(value):
return value
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(value)
result: dict[str, Any] = {}
error: dict[str, BaseException] = {}
def _runner() -> None:
try:
result["value"] = asyncio.run(value)
except BaseException as exc: # pragma: no cover - re-raised below
error["exc"] = exc
thread = threading.Thread(
target=_runner,
name="hermes-nemo-relay-awaitable",
daemon=True,
)
thread.start()
thread.join()
if "exc" in error:
raise error["exc"]
return result.get("value")
def reset_for_tests() -> None:
global _RUNTIME
with _LOCK: