mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-15 09:21:36 +00:00
Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin
feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
commit
d4a7bfd3aa
14 changed files with 2170 additions and 151 deletions
313
hermes_cli/middleware.py
Normal file
313
hermes_cli/middleware.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Hermes middleware contract helpers.
|
||||
|
||||
Observer hooks report what happened. Middleware can change what happens by
|
||||
rewriting a request or wrapping the actual execution callback. Keep the small
|
||||
contract helpers here so agent-loop call sites and plugins share one vocabulary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
||||
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
|
||||
|
||||
TOOL_REQUEST_MIDDLEWARE = "tool_request"
|
||||
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
|
||||
LLM_REQUEST_MIDDLEWARE = "llm_request"
|
||||
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
|
||||
|
||||
# Back-compat aliases for older PoC branches that used API terminology.
|
||||
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
|
||||
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
|
||||
|
||||
VALID_MIDDLEWARE: set[str] = {
|
||||
TOOL_REQUEST_MIDDLEWARE,
|
||||
TOOL_EXECUTION_MIDDLEWARE,
|
||||
LLM_REQUEST_MIDDLEWARE,
|
||||
LLM_EXECUTION_MIDDLEWARE,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMiddlewareResult:
|
||||
"""Result of applying request middleware to a mutable payload."""
|
||||
|
||||
payload: Any
|
||||
original_payload: Any
|
||||
changed: bool = False
|
||||
trace: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
||||
return kwargs
|
||||
|
||||
|
||||
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
||||
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _safe_copy(payload: Any) -> Any:
|
||||
"""Deep-copy a request payload, tolerating non-deepcopyable members.
|
||||
|
||||
Request payloads are normally plain JSON-shaped dicts, but an LLM request
|
||||
can occasionally carry non-deepcopyable objects (clients, callbacks, file
|
||||
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
|
||||
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
|
||||
still runs and the original nested objects are shared by reference rather
|
||||
than corrupting the live payload.
|
||||
"""
|
||||
try:
|
||||
return deepcopy(payload)
|
||||
except Exception as exc: # pragma: no cover - exercised via fallback test
|
||||
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
return payload
|
||||
|
||||
|
||||
def apply_llm_request_middleware(
|
||||
request: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Apply registered LLM request middleware.
|
||||
|
||||
Middleware may return ``{"request": {...}}`` to replace the effective
|
||||
provider kwargs before Hermes sends them.
|
||||
"""
|
||||
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
|
||||
return RequestMiddlewareResult(
|
||||
payload=request,
|
||||
original_payload=request,
|
||||
changed=False,
|
||||
trace=[],
|
||||
)
|
||||
|
||||
original_request = _safe_copy(request)
|
||||
current_request = _safe_copy(original_request)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
LLM_REQUEST_MIDDLEWARE,
|
||||
request=current_request,
|
||||
original_request=original_request,
|
||||
**context,
|
||||
):
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
next_request = result.get("request")
|
||||
if not isinstance(next_request, dict):
|
||||
continue
|
||||
current_request = _safe_copy(next_request)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
payload=current_request,
|
||||
original_payload=original_request,
|
||||
changed=bool(trace),
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
|
||||
def apply_tool_request_middleware(
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Apply registered tool request middleware.
|
||||
|
||||
Middleware may return ``{"args": {...}}`` to replace the effective tool
|
||||
arguments before hooks, guardrails, approvals, and execution see them.
|
||||
"""
|
||||
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
|
||||
return RequestMiddlewareResult(
|
||||
payload=args,
|
||||
original_payload=args,
|
||||
changed=False,
|
||||
trace=[],
|
||||
)
|
||||
|
||||
original_args = _safe_copy(args)
|
||||
current_args = _safe_copy(original_args)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
TOOL_REQUEST_MIDDLEWARE,
|
||||
tool_name=tool_name,
|
||||
args=current_args,
|
||||
original_args=original_args,
|
||||
**context,
|
||||
):
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
next_args = result.get("args")
|
||||
if not isinstance(next_args, dict):
|
||||
continue
|
||||
current_args = _safe_copy(next_args)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
payload=current_args,
|
||||
original_payload=original_args,
|
||||
changed=bool(trace),
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
|
||||
def apply_api_request_middleware(
|
||||
request: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Compatibility wrapper for older ``api_request`` naming."""
|
||||
return apply_llm_request_middleware(request, **context)
|
||||
|
||||
|
||||
def run_llm_execution_middleware(
|
||||
request: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Run provider execution through registered LLM execution middleware."""
|
||||
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
|
||||
if not callbacks:
|
||||
return next_call(request)
|
||||
return _run_execution_chain(
|
||||
LLM_EXECUTION_MIDDLEWARE,
|
||||
callbacks,
|
||||
next_call,
|
||||
request=request,
|
||||
original_request=context.pop("original_request", request),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
def run_tool_execution_middleware(
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Run tool execution through registered tool execution middleware."""
|
||||
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
|
||||
if not callbacks:
|
||||
return next_call(args)
|
||||
return _run_execution_chain(
|
||||
TOOL_EXECUTION_MIDDLEWARE,
|
||||
callbacks,
|
||||
next_call,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
original_args=context.pop("original_args", args),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
def run_api_execution_middleware(
|
||||
request: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Compatibility wrapper for older ``api_execution`` naming."""
|
||||
return run_llm_execution_middleware(request, next_call, **context)
|
||||
|
||||
|
||||
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
||||
from hermes_cli.plugins import invoke_middleware
|
||||
|
||||
return invoke_middleware(kind, **middleware_payload(**kwargs))
|
||||
|
||||
|
||||
def _has_middleware(kind: str) -> bool:
|
||||
from hermes_cli.plugins import has_middleware
|
||||
|
||||
return has_middleware(kind)
|
||||
|
||||
|
||||
def _get_middleware_callbacks(kind: str) -> List[Callable]:
|
||||
from hermes_cli.plugins import get_plugin_manager
|
||||
|
||||
return list(get_plugin_manager()._middleware.get(kind, []))
|
||||
|
||||
|
||||
def _run_execution_chain(
|
||||
kind: str,
|
||||
callbacks: List[Callable],
|
||||
terminal_call: Callable[[Any], Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
payload_key = "request" if "request" in kwargs else "args"
|
||||
|
||||
class _DownstreamExecutionError(Exception):
|
||||
def __init__(self, original: BaseException) -> None:
|
||||
super().__init__(str(original))
|
||||
self.original = original
|
||||
|
||||
def call_at(index: int, payload: Any) -> Any:
|
||||
if index >= len(callbacks):
|
||||
return terminal_call(payload)
|
||||
|
||||
callback = callbacks[index]
|
||||
next_called = False
|
||||
next_succeeded = False
|
||||
next_result: Any = None
|
||||
|
||||
def next_call(next_payload: Any = None) -> Any:
|
||||
nonlocal next_called, next_succeeded, next_result
|
||||
# ``next_call`` is single-use per middleware frame. Calling it more
|
||||
# than once would re-run the downstream provider/tool, so a second
|
||||
# invocation is a contract violation rather than a retry. Surface it
|
||||
# instead of silently executing the terminal call twice.
|
||||
if next_called:
|
||||
raise RuntimeError(
|
||||
f"Middleware '{kind}' callback "
|
||||
f"{getattr(callback, '__name__', repr(callback))} called "
|
||||
"next_call() more than once; downstream execution is single-use"
|
||||
)
|
||||
next_called = True
|
||||
try:
|
||||
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
|
||||
next_succeeded = True
|
||||
return next_result
|
||||
except Exception as exc:
|
||||
raise _DownstreamExecutionError(exc) from exc
|
||||
|
||||
call_kwargs = middleware_payload(**kwargs)
|
||||
call_kwargs[payload_key] = payload
|
||||
call_kwargs["next_call"] = next_call
|
||||
try:
|
||||
return callback(**call_kwargs)
|
||||
except _DownstreamExecutionError as exc:
|
||||
raise exc.original
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Middleware '%s' callback %s raised: %s",
|
||||
kind,
|
||||
getattr(callback, "__name__", repr(callback)),
|
||||
exc,
|
||||
)
|
||||
if next_succeeded:
|
||||
return next_result
|
||||
if next_called:
|
||||
raise
|
||||
return call_at(index + 1, payload)
|
||||
|
||||
return call_at(0, kwargs[payload_key])
|
||||
|
||||
|
||||
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
entry: Dict[str, Any] = {}
|
||||
for key in ("source", "reason", "name"):
|
||||
value = result.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
entry[key] = value
|
||||
if not entry:
|
||||
entry["source"] = "plugin"
|
||||
return entry
|
||||
|
|
@ -49,7 +49,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
|
|||
from hermes_constants import get_hermes_home
|
||||
from utils import env_var_enabled
|
||||
from hermes_cli.config import cfg_get
|
||||
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
||||
from hermes_cli.middleware import OBSERVER_SCHEMA_VERSION, VALID_MIDDLEWARE
|
||||
|
||||
|
||||
def get_bundled_plugins_dir() -> Path:
|
||||
|
|
@ -277,6 +277,7 @@ class LoadedPlugin:
|
|||
module: Optional[types.ModuleType] = None
|
||||
tools_registered: List[str] = field(default_factory=list)
|
||||
hooks_registered: List[str] = field(default_factory=list)
|
||||
middleware_registered: List[str] = field(default_factory=list)
|
||||
commands_registered: List[str] = field(default_factory=list)
|
||||
enabled: bool = False
|
||||
error: Optional[str] = None
|
||||
|
|
@ -952,6 +953,27 @@ class PluginContext:
|
|||
self._manager._hooks.setdefault(hook_name, []).append(callback)
|
||||
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
|
||||
|
||||
# -- middleware registration -------------------------------------------
|
||||
|
||||
def register_middleware(self, kind: str, callback: Callable) -> None:
|
||||
"""Register a behavior-changing middleware callback.
|
||||
|
||||
Middleware is separate from observer hooks: request middleware may
|
||||
rewrite the effective payload, and execution middleware may wrap the
|
||||
real callback. Unknown kinds are stored for forward compatibility but
|
||||
warned so plugin authors can catch typos.
|
||||
"""
|
||||
if kind not in VALID_MIDDLEWARE:
|
||||
logger.warning(
|
||||
"Plugin '%s' registered unknown middleware '%s' "
|
||||
"(valid: %s)",
|
||||
self.manifest.name,
|
||||
kind,
|
||||
", ".join(sorted(VALID_MIDDLEWARE)),
|
||||
)
|
||||
self._manager._middleware.setdefault(kind, []).append(callback)
|
||||
logger.debug("Plugin %s registered middleware: %s", self.manifest.name, kind)
|
||||
|
||||
# -- skill registration -------------------------------------------------
|
||||
|
||||
def register_skill(
|
||||
|
|
@ -1010,6 +1032,7 @@ class PluginManager:
|
|||
def __init__(self) -> None:
|
||||
self._plugins: Dict[str, LoadedPlugin] = {}
|
||||
self._hooks: Dict[str, List[Callable]] = {}
|
||||
self._middleware: Dict[str, List[Callable]] = {}
|
||||
self._plugin_tool_names: Set[str] = set()
|
||||
self._plugin_platform_names: Set[str] = set()
|
||||
self._cli_commands: Dict[str, dict] = {}
|
||||
|
|
@ -1039,6 +1062,7 @@ class PluginManager:
|
|||
if force:
|
||||
self._plugins.clear()
|
||||
self._hooks.clear()
|
||||
self._middleware.clear()
|
||||
self._plugin_tool_names.clear()
|
||||
self._cli_commands.clear()
|
||||
self._plugin_commands.clear()
|
||||
|
|
@ -1449,15 +1473,28 @@ class PluginManager:
|
|||
for h in p.hooks_registered
|
||||
}
|
||||
)
|
||||
loaded.middleware_registered = list(
|
||||
{
|
||||
kind
|
||||
for kind, cbs in self._middleware.items()
|
||||
if cbs
|
||||
}
|
||||
- {
|
||||
kind
|
||||
for name, p in self._plugins.items()
|
||||
for kind in p.middleware_registered
|
||||
}
|
||||
)
|
||||
loaded.commands_registered = [
|
||||
c for c in self._plugin_commands
|
||||
if self._plugin_commands[c].get("plugin") == manifest.name
|
||||
]
|
||||
loaded.enabled = True
|
||||
logger.debug(
|
||||
" registered: %d tool(s), %d hook(s), %d slash command(s), %d CLI command(s)",
|
||||
" registered: %d tool(s), %d hook(s), %d middleware, %d slash command(s), %d CLI command(s)",
|
||||
len(loaded.tools_registered),
|
||||
len(loaded.hooks_registered),
|
||||
len(loaded.middleware_registered),
|
||||
len(loaded.commands_registered),
|
||||
sum(
|
||||
1 for c in self._cli_commands
|
||||
|
|
@ -1575,6 +1612,33 @@ class PluginManager:
|
|||
"""Return True when at least one callback is registered for a hook."""
|
||||
return bool(self._hooks.get(hook_name))
|
||||
|
||||
def has_middleware(self, kind: str) -> bool:
|
||||
"""Return True when at least one callback is registered for middleware."""
|
||||
return bool(self._middleware.get(kind))
|
||||
|
||||
def invoke_middleware(self, kind: str, **kwargs: Any) -> List[Any]:
|
||||
"""Call registered middleware callbacks for *kind*.
|
||||
|
||||
Each callback is isolated so one plugin cannot break the base runtime
|
||||
path. Middleware that wants to change behavior must return the shape
|
||||
documented by the caller-specific contract.
|
||||
"""
|
||||
callbacks = self._middleware.get(kind, [])
|
||||
results: List[Any] = []
|
||||
for cb in callbacks:
|
||||
try:
|
||||
ret = cb(**kwargs)
|
||||
if ret is not None:
|
||||
results.append(ret)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Middleware '%s' callback %s raised: %s",
|
||||
kind,
|
||||
getattr(cb, "__name__", repr(cb)),
|
||||
exc,
|
||||
)
|
||||
return results
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Introspection
|
||||
# -----------------------------------------------------------------------
|
||||
|
|
@ -1594,6 +1658,7 @@ class PluginManager:
|
|||
"enabled": loaded.enabled,
|
||||
"tools": len(loaded.tools_registered),
|
||||
"hooks": len(loaded.hooks_registered),
|
||||
"middleware": len(loaded.middleware_registered),
|
||||
"commands": len(loaded.commands_registered),
|
||||
"error": loaded.error,
|
||||
}
|
||||
|
|
@ -1655,6 +1720,23 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
|
|||
return get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
|
||||
|
||||
def invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
||||
"""Invoke registered middleware callbacks.
|
||||
|
||||
Returns a list of non-``None`` return values from middleware callbacks.
|
||||
"""
|
||||
return get_plugin_manager().invoke_middleware(kind, **kwargs)
|
||||
|
||||
|
||||
def has_middleware(kind: str) -> bool:
|
||||
"""Return True when middleware callbacks are registered for ``kind``."""
|
||||
manager = get_plugin_manager()
|
||||
method = getattr(manager, "has_middleware", None)
|
||||
if callable(method):
|
||||
return bool(method(kind))
|
||||
return bool(getattr(manager, "_middleware", {}).get(kind))
|
||||
|
||||
|
||||
def has_hook(hook_name: str) -> bool:
|
||||
"""Return True when a hook has registered callbacks."""
|
||||
return get_plugin_manager().has_hook(hook_name)
|
||||
|
|
@ -1683,6 +1765,7 @@ def get_pre_tool_call_block_message(
|
|||
tool_call_id: str = "",
|
||||
turn_id: str = "",
|
||||
api_request_id: str = "",
|
||||
middleware_trace: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Check ``pre_tool_call`` hooks for a blocking directive.
|
||||
|
||||
|
|
@ -1709,6 +1792,7 @@ def get_pre_tool_call_block_message(
|
|||
tool_call_id=tool_call_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
middleware_trace=list(middleware_trace or []),
|
||||
)
|
||||
|
||||
for result in hook_results:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue