mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Address the two non-blocking follow-ups from review: - next_call is now single-use per middleware frame. A second invocation raises instead of silently re-running the downstream provider/tool, so the terminal call cannot execute twice via the chain. The error surfaces through the existing handler, which preserves the first downstream result. - Request-middleware payload copies go through _safe_copy(), which falls back to a shallow dict copy when deepcopy() fails on a non-deepcopyable member (clients, callbacks, file handles) instead of aborting the pass. Adds regression coverage for both: double next_call() keeps the terminal single-run, and a non-deepcopyable (threading.Lock) request payload still runs middleware via the shallow fallback.
313 lines
9.8 KiB
Python
313 lines
9.8 KiB
Python
"""Hermes middleware contract helpers.
|
|
|
|
Observer hooks report what happened. Middleware can change what happens by
|
|
rewriting a request or wrapping the actual execution callback. Keep the small
|
|
contract helpers here so agent-loop call sites and plugins share one vocabulary.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
|
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
|
|
|
|
TOOL_REQUEST_MIDDLEWARE = "tool_request"
|
|
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
|
|
LLM_REQUEST_MIDDLEWARE = "llm_request"
|
|
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
|
|
|
|
# Back-compat aliases for older PoC branches that used API terminology.
|
|
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
|
|
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
|
|
|
|
VALID_MIDDLEWARE: set[str] = {
|
|
TOOL_REQUEST_MIDDLEWARE,
|
|
TOOL_EXECUTION_MIDDLEWARE,
|
|
LLM_REQUEST_MIDDLEWARE,
|
|
LLM_EXECUTION_MIDDLEWARE,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RequestMiddlewareResult:
|
|
"""Result of applying request middleware to a mutable payload."""
|
|
|
|
payload: Any
|
|
original_payload: Any
|
|
changed: bool = False
|
|
trace: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
|
|
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
|
return kwargs
|
|
|
|
|
|
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
|
|
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
|
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
|
|
return kwargs
|
|
|
|
|
|
def _safe_copy(payload: Any) -> Any:
|
|
"""Deep-copy a request payload, tolerating non-deepcopyable members.
|
|
|
|
Request payloads are normally plain JSON-shaped dicts, but an LLM request
|
|
can occasionally carry non-deepcopyable objects (clients, callbacks, file
|
|
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
|
|
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
|
|
still runs and the original nested objects are shared by reference rather
|
|
than corrupting the live payload.
|
|
"""
|
|
try:
|
|
return deepcopy(payload)
|
|
except Exception as exc: # pragma: no cover - exercised via fallback test
|
|
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
|
|
if isinstance(payload, dict):
|
|
return dict(payload)
|
|
return payload
|
|
|
|
|
|
def apply_llm_request_middleware(
|
|
request: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Apply registered LLM request middleware.
|
|
|
|
Middleware may return ``{"request": {...}}`` to replace the effective
|
|
provider kwargs before Hermes sends them.
|
|
"""
|
|
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
|
|
return RequestMiddlewareResult(
|
|
payload=request,
|
|
original_payload=request,
|
|
changed=False,
|
|
trace=[],
|
|
)
|
|
|
|
original_request = _safe_copy(request)
|
|
current_request = _safe_copy(original_request)
|
|
trace: List[Dict[str, Any]] = []
|
|
|
|
for result in _invoke_middleware(
|
|
LLM_REQUEST_MIDDLEWARE,
|
|
request=current_request,
|
|
original_request=original_request,
|
|
**context,
|
|
):
|
|
if not isinstance(result, dict):
|
|
continue
|
|
next_request = result.get("request")
|
|
if not isinstance(next_request, dict):
|
|
continue
|
|
current_request = _safe_copy(next_request)
|
|
trace.append(_trace_entry(result))
|
|
|
|
return RequestMiddlewareResult(
|
|
payload=current_request,
|
|
original_payload=original_request,
|
|
changed=bool(trace),
|
|
trace=trace,
|
|
)
|
|
|
|
|
|
def apply_tool_request_middleware(
|
|
tool_name: str,
|
|
args: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Apply registered tool request middleware.
|
|
|
|
Middleware may return ``{"args": {...}}`` to replace the effective tool
|
|
arguments before hooks, guardrails, approvals, and execution see them.
|
|
"""
|
|
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
|
|
return RequestMiddlewareResult(
|
|
payload=args,
|
|
original_payload=args,
|
|
changed=False,
|
|
trace=[],
|
|
)
|
|
|
|
original_args = _safe_copy(args)
|
|
current_args = _safe_copy(original_args)
|
|
trace: List[Dict[str, Any]] = []
|
|
|
|
for result in _invoke_middleware(
|
|
TOOL_REQUEST_MIDDLEWARE,
|
|
tool_name=tool_name,
|
|
args=current_args,
|
|
original_args=original_args,
|
|
**context,
|
|
):
|
|
if not isinstance(result, dict):
|
|
continue
|
|
next_args = result.get("args")
|
|
if not isinstance(next_args, dict):
|
|
continue
|
|
current_args = _safe_copy(next_args)
|
|
trace.append(_trace_entry(result))
|
|
|
|
return RequestMiddlewareResult(
|
|
payload=current_args,
|
|
original_payload=original_args,
|
|
changed=bool(trace),
|
|
trace=trace,
|
|
)
|
|
|
|
|
|
def apply_api_request_middleware(
|
|
request: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Compatibility wrapper for older ``api_request`` naming."""
|
|
return apply_llm_request_middleware(request, **context)
|
|
|
|
|
|
def run_llm_execution_middleware(
|
|
request: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Run provider execution through registered LLM execution middleware."""
|
|
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
|
|
if not callbacks:
|
|
return next_call(request)
|
|
return _run_execution_chain(
|
|
LLM_EXECUTION_MIDDLEWARE,
|
|
callbacks,
|
|
next_call,
|
|
request=request,
|
|
original_request=context.pop("original_request", request),
|
|
**context,
|
|
)
|
|
|
|
|
|
def run_tool_execution_middleware(
|
|
tool_name: str,
|
|
args: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Run tool execution through registered tool execution middleware."""
|
|
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
|
|
if not callbacks:
|
|
return next_call(args)
|
|
return _run_execution_chain(
|
|
TOOL_EXECUTION_MIDDLEWARE,
|
|
callbacks,
|
|
next_call,
|
|
tool_name=tool_name,
|
|
args=args,
|
|
original_args=context.pop("original_args", args),
|
|
**context,
|
|
)
|
|
|
|
|
|
def run_api_execution_middleware(
|
|
request: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Compatibility wrapper for older ``api_execution`` naming."""
|
|
return run_llm_execution_middleware(request, next_call, **context)
|
|
|
|
|
|
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
|
from hermes_cli.plugins import invoke_middleware
|
|
|
|
return invoke_middleware(kind, **middleware_payload(**kwargs))
|
|
|
|
|
|
def _has_middleware(kind: str) -> bool:
|
|
from hermes_cli.plugins import has_middleware
|
|
|
|
return has_middleware(kind)
|
|
|
|
|
|
def _get_middleware_callbacks(kind: str) -> List[Callable]:
|
|
from hermes_cli.plugins import get_plugin_manager
|
|
|
|
return list(get_plugin_manager()._middleware.get(kind, []))
|
|
|
|
|
|
def _run_execution_chain(
|
|
kind: str,
|
|
callbacks: List[Callable],
|
|
terminal_call: Callable[[Any], Any],
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
payload_key = "request" if "request" in kwargs else "args"
|
|
|
|
class _DownstreamExecutionError(Exception):
|
|
def __init__(self, original: BaseException) -> None:
|
|
super().__init__(str(original))
|
|
self.original = original
|
|
|
|
def call_at(index: int, payload: Any) -> Any:
|
|
if index >= len(callbacks):
|
|
return terminal_call(payload)
|
|
|
|
callback = callbacks[index]
|
|
next_called = False
|
|
next_succeeded = False
|
|
next_result: Any = None
|
|
|
|
def next_call(next_payload: Any = None) -> Any:
|
|
nonlocal next_called, next_succeeded, next_result
|
|
# ``next_call`` is single-use per middleware frame. Calling it more
|
|
# than once would re-run the downstream provider/tool, so a second
|
|
# invocation is a contract violation rather than a retry. Surface it
|
|
# instead of silently executing the terminal call twice.
|
|
if next_called:
|
|
raise RuntimeError(
|
|
f"Middleware '{kind}' callback "
|
|
f"{getattr(callback, '__name__', repr(callback))} called "
|
|
"next_call() more than once; downstream execution is single-use"
|
|
)
|
|
next_called = True
|
|
try:
|
|
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
|
|
next_succeeded = True
|
|
return next_result
|
|
except Exception as exc:
|
|
raise _DownstreamExecutionError(exc) from exc
|
|
|
|
call_kwargs = middleware_payload(**kwargs)
|
|
call_kwargs[payload_key] = payload
|
|
call_kwargs["next_call"] = next_call
|
|
try:
|
|
return callback(**call_kwargs)
|
|
except _DownstreamExecutionError as exc:
|
|
raise exc.original
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Middleware '%s' callback %s raised: %s",
|
|
kind,
|
|
getattr(callback, "__name__", repr(callback)),
|
|
exc,
|
|
)
|
|
if next_succeeded:
|
|
return next_result
|
|
if next_called:
|
|
raise
|
|
return call_at(index + 1, payload)
|
|
|
|
return call_at(0, kwargs[payload_key])
|
|
|
|
|
|
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
|
|
entry: Dict[str, Any] = {}
|
|
for key in ("source", "reason", "name"):
|
|
value = result.get(key)
|
|
if isinstance(value, str) and value:
|
|
entry[key] = value
|
|
if not entry:
|
|
entry["source"] = "plugin"
|
|
return entry
|