Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin

feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
kshitij 2026-06-06 10:46:41 -07:00 committed by GitHub
commit d4a7bfd3aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2170 additions and 151 deletions

313
hermes_cli/middleware.py Normal file
View file

@ -0,0 +1,313 @@
"""Hermes middleware contract helpers.
Observer hooks report what happened. Middleware can change what happens by
rewriting a request or wrapping the actual execution callback. Keep the small
contract helpers here so agent-loop call sites and plugins share one vocabulary.
"""
from __future__ import annotations
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
logger = logging.getLogger(__name__)
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
TOOL_REQUEST_MIDDLEWARE = "tool_request"
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
LLM_REQUEST_MIDDLEWARE = "llm_request"
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
# Back-compat aliases for older PoC branches that used API terminology.
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
VALID_MIDDLEWARE: set[str] = {
TOOL_REQUEST_MIDDLEWARE,
TOOL_EXECUTION_MIDDLEWARE,
LLM_REQUEST_MIDDLEWARE,
LLM_EXECUTION_MIDDLEWARE,
}
@dataclass
class RequestMiddlewareResult:
"""Result of applying request middleware to a mutable payload."""
payload: Any
original_payload: Any
changed: bool = False
trace: List[Dict[str, Any]] = field(default_factory=list)
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
return kwargs
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
return kwargs
def _safe_copy(payload: Any) -> Any:
"""Deep-copy a request payload, tolerating non-deepcopyable members.
Request payloads are normally plain JSON-shaped dicts, but an LLM request
can occasionally carry non-deepcopyable objects (clients, callbacks, file
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
still runs and the original nested objects are shared by reference rather
than corrupting the live payload.
"""
try:
return deepcopy(payload)
except Exception as exc: # pragma: no cover - exercised via fallback test
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
if isinstance(payload, dict):
return dict(payload)
return payload
def apply_llm_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered LLM request middleware.
Middleware may return ``{"request": {...}}`` to replace the effective
provider kwargs before Hermes sends them.
"""
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=request,
original_payload=request,
changed=False,
trace=[],
)
original_request = _safe_copy(request)
current_request = _safe_copy(original_request)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
LLM_REQUEST_MIDDLEWARE,
request=current_request,
original_request=original_request,
**context,
):
if not isinstance(result, dict):
continue
next_request = result.get("request")
if not isinstance(next_request, dict):
continue
current_request = _safe_copy(next_request)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_request,
original_payload=original_request,
changed=bool(trace),
trace=trace,
)
def apply_tool_request_middleware(
tool_name: str,
args: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered tool request middleware.
Middleware may return ``{"args": {...}}`` to replace the effective tool
arguments before hooks, guardrails, approvals, and execution see them.
"""
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=args,
original_payload=args,
changed=False,
trace=[],
)
original_args = _safe_copy(args)
current_args = _safe_copy(original_args)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
TOOL_REQUEST_MIDDLEWARE,
tool_name=tool_name,
args=current_args,
original_args=original_args,
**context,
):
if not isinstance(result, dict):
continue
next_args = result.get("args")
if not isinstance(next_args, dict):
continue
current_args = _safe_copy(next_args)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_args,
original_payload=original_args,
changed=bool(trace),
trace=trace,
)
def apply_api_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Compatibility wrapper for older ``api_request`` naming."""
return apply_llm_request_middleware(request, **context)
def run_llm_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run provider execution through registered LLM execution middleware."""
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(request)
return _run_execution_chain(
LLM_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
request=request,
original_request=context.pop("original_request", request),
**context,
)
def run_tool_execution_middleware(
tool_name: str,
args: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run tool execution through registered tool execution middleware."""
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(args)
return _run_execution_chain(
TOOL_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
tool_name=tool_name,
args=args,
original_args=context.pop("original_args", args),
**context,
)
def run_api_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Compatibility wrapper for older ``api_execution`` naming."""
return run_llm_execution_middleware(request, next_call, **context)
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
from hermes_cli.plugins import invoke_middleware
return invoke_middleware(kind, **middleware_payload(**kwargs))
def _has_middleware(kind: str) -> bool:
from hermes_cli.plugins import has_middleware
return has_middleware(kind)
def _get_middleware_callbacks(kind: str) -> List[Callable]:
from hermes_cli.plugins import get_plugin_manager
return list(get_plugin_manager()._middleware.get(kind, []))
def _run_execution_chain(
kind: str,
callbacks: List[Callable],
terminal_call: Callable[[Any], Any],
**kwargs: Any,
) -> Any:
payload_key = "request" if "request" in kwargs else "args"
class _DownstreamExecutionError(Exception):
def __init__(self, original: BaseException) -> None:
super().__init__(str(original))
self.original = original
def call_at(index: int, payload: Any) -> Any:
if index >= len(callbacks):
return terminal_call(payload)
callback = callbacks[index]
next_called = False
next_succeeded = False
next_result: Any = None
def next_call(next_payload: Any = None) -> Any:
nonlocal next_called, next_succeeded, next_result
# ``next_call`` is single-use per middleware frame. Calling it more
# than once would re-run the downstream provider/tool, so a second
# invocation is a contract violation rather than a retry. Surface it
# instead of silently executing the terminal call twice.
if next_called:
raise RuntimeError(
f"Middleware '{kind}' callback "
f"{getattr(callback, '__name__', repr(callback))} called "
"next_call() more than once; downstream execution is single-use"
)
next_called = True
try:
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
next_succeeded = True
return next_result
except Exception as exc:
raise _DownstreamExecutionError(exc) from exc
call_kwargs = middleware_payload(**kwargs)
call_kwargs[payload_key] = payload
call_kwargs["next_call"] = next_call
try:
return callback(**call_kwargs)
except _DownstreamExecutionError as exc:
raise exc.original
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(callback, "__name__", repr(callback)),
exc,
)
if next_succeeded:
return next_result
if next_called:
raise
return call_at(index + 1, payload)
return call_at(0, kwargs[payload_key])
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
entry: Dict[str, Any] = {}
for key in ("source", "reason", "name"):
value = result.get(key)
if isinstance(value, str) and value:
entry[key] = value
if not entry:
entry["source"] = "plugin"
return entry

View file

@ -49,7 +49,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
from hermes_constants import get_hermes_home
from utils import env_var_enabled
from hermes_cli.config import cfg_get
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
from hermes_cli.middleware import OBSERVER_SCHEMA_VERSION, VALID_MIDDLEWARE
def get_bundled_plugins_dir() -> Path:
@ -277,6 +277,7 @@ class LoadedPlugin:
module: Optional[types.ModuleType] = None
tools_registered: List[str] = field(default_factory=list)
hooks_registered: List[str] = field(default_factory=list)
middleware_registered: List[str] = field(default_factory=list)
commands_registered: List[str] = field(default_factory=list)
enabled: bool = False
error: Optional[str] = None
@ -952,6 +953,27 @@ class PluginContext:
self._manager._hooks.setdefault(hook_name, []).append(callback)
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
# -- middleware registration -------------------------------------------
def register_middleware(self, kind: str, callback: Callable) -> None:
"""Register a behavior-changing middleware callback.
Middleware is separate from observer hooks: request middleware may
rewrite the effective payload, and execution middleware may wrap the
real callback. Unknown kinds are stored for forward compatibility but
warned so plugin authors can catch typos.
"""
if kind not in VALID_MIDDLEWARE:
logger.warning(
"Plugin '%s' registered unknown middleware '%s' "
"(valid: %s)",
self.manifest.name,
kind,
", ".join(sorted(VALID_MIDDLEWARE)),
)
self._manager._middleware.setdefault(kind, []).append(callback)
logger.debug("Plugin %s registered middleware: %s", self.manifest.name, kind)
# -- skill registration -------------------------------------------------
def register_skill(
@ -1010,6 +1032,7 @@ class PluginManager:
def __init__(self) -> None:
self._plugins: Dict[str, LoadedPlugin] = {}
self._hooks: Dict[str, List[Callable]] = {}
self._middleware: Dict[str, List[Callable]] = {}
self._plugin_tool_names: Set[str] = set()
self._plugin_platform_names: Set[str] = set()
self._cli_commands: Dict[str, dict] = {}
@ -1039,6 +1062,7 @@ class PluginManager:
if force:
self._plugins.clear()
self._hooks.clear()
self._middleware.clear()
self._plugin_tool_names.clear()
self._cli_commands.clear()
self._plugin_commands.clear()
@ -1449,15 +1473,28 @@ class PluginManager:
for h in p.hooks_registered
}
)
loaded.middleware_registered = list(
{
kind
for kind, cbs in self._middleware.items()
if cbs
}
- {
kind
for name, p in self._plugins.items()
for kind in p.middleware_registered
}
)
loaded.commands_registered = [
c for c in self._plugin_commands
if self._plugin_commands[c].get("plugin") == manifest.name
]
loaded.enabled = True
logger.debug(
" registered: %d tool(s), %d hook(s), %d slash command(s), %d CLI command(s)",
" registered: %d tool(s), %d hook(s), %d middleware, %d slash command(s), %d CLI command(s)",
len(loaded.tools_registered),
len(loaded.hooks_registered),
len(loaded.middleware_registered),
len(loaded.commands_registered),
sum(
1 for c in self._cli_commands
@ -1575,6 +1612,33 @@ class PluginManager:
"""Return True when at least one callback is registered for a hook."""
return bool(self._hooks.get(hook_name))
def has_middleware(self, kind: str) -> bool:
"""Return True when at least one callback is registered for middleware."""
return bool(self._middleware.get(kind))
def invoke_middleware(self, kind: str, **kwargs: Any) -> List[Any]:
"""Call registered middleware callbacks for *kind*.
Each callback is isolated so one plugin cannot break the base runtime
path. Middleware that wants to change behavior must return the shape
documented by the caller-specific contract.
"""
callbacks = self._middleware.get(kind, [])
results: List[Any] = []
for cb in callbacks:
try:
ret = cb(**kwargs)
if ret is not None:
results.append(ret)
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(cb, "__name__", repr(cb)),
exc,
)
return results
# -----------------------------------------------------------------------
# Introspection
# -----------------------------------------------------------------------
@ -1594,6 +1658,7 @@ class PluginManager:
"enabled": loaded.enabled,
"tools": len(loaded.tools_registered),
"hooks": len(loaded.hooks_registered),
"middleware": len(loaded.middleware_registered),
"commands": len(loaded.commands_registered),
"error": loaded.error,
}
@ -1655,6 +1720,23 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
return get_plugin_manager().invoke_hook(hook_name, **kwargs)
def invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
"""Invoke registered middleware callbacks.
Returns a list of non-``None`` return values from middleware callbacks.
"""
return get_plugin_manager().invoke_middleware(kind, **kwargs)
def has_middleware(kind: str) -> bool:
"""Return True when middleware callbacks are registered for ``kind``."""
manager = get_plugin_manager()
method = getattr(manager, "has_middleware", None)
if callable(method):
return bool(method(kind))
return bool(getattr(manager, "_middleware", {}).get(kind))
def has_hook(hook_name: str) -> bool:
"""Return True when a hook has registered callbacks."""
return get_plugin_manager().has_hook(hook_name)
@ -1683,6 +1765,7 @@ def get_pre_tool_call_block_message(
tool_call_id: str = "",
turn_id: str = "",
api_request_id: str = "",
middleware_trace: Optional[List[Dict[str, Any]]] = None,
) -> Optional[str]:
"""Check ``pre_tool_call`` hooks for a blocking directive.
@ -1709,6 +1792,7 @@ def get_pre_tool_call_block_message(
tool_call_id=tool_call_id,
turn_id=turn_id,
api_request_id=api_request_id,
middleware_trace=list(middleware_trace or []),
)
for result in hook_results: