hermes-agent/hermes_cli/middleware.py
kshitijk4poor c4c5548eb4 fix(middleware): single-use next_call guard + deepcopy-safe request copies
Address the two non-blocking follow-ups from review:

- next_call is now single-use per middleware frame. A second invocation
  raises instead of silently re-running the downstream provider/tool, so
  the terminal call cannot execute twice via the chain. The error surfaces
  through the existing handler, which preserves the first downstream result.
- Request-middleware payload copies go through _safe_copy(), which falls
  back to a shallow dict copy when deepcopy() fails on a non-deepcopyable
  member (clients, callbacks, file handles) instead of aborting the pass.

Adds regression coverage for both: double next_call() keeps the terminal
single-run, and a non-deepcopyable (threading.Lock) request payload still
runs middleware via the shallow fallback.
2026-06-06 23:07:25 +05:30

313 lines
9.8 KiB
Python

"""Hermes middleware contract helpers.
Observer hooks report what happened. Middleware can change what happens by
rewriting a request or wrapping the actual execution callback. Keep the small
contract helpers here so agent-loop call sites and plugins share one vocabulary.
"""
from __future__ import annotations
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
logger = logging.getLogger(__name__)
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
TOOL_REQUEST_MIDDLEWARE = "tool_request"
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
LLM_REQUEST_MIDDLEWARE = "llm_request"
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
# Back-compat aliases for older PoC branches that used API terminology.
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
VALID_MIDDLEWARE: set[str] = {
TOOL_REQUEST_MIDDLEWARE,
TOOL_EXECUTION_MIDDLEWARE,
LLM_REQUEST_MIDDLEWARE,
LLM_EXECUTION_MIDDLEWARE,
}
@dataclass
class RequestMiddlewareResult:
"""Result of applying request middleware to a mutable payload."""
payload: Any
original_payload: Any
changed: bool = False
trace: List[Dict[str, Any]] = field(default_factory=list)
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
return kwargs
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
return kwargs
def _safe_copy(payload: Any) -> Any:
"""Deep-copy a request payload, tolerating non-deepcopyable members.
Request payloads are normally plain JSON-shaped dicts, but an LLM request
can occasionally carry non-deepcopyable objects (clients, callbacks, file
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
still runs and the original nested objects are shared by reference rather
than corrupting the live payload.
"""
try:
return deepcopy(payload)
except Exception as exc: # pragma: no cover - exercised via fallback test
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
if isinstance(payload, dict):
return dict(payload)
return payload
def apply_llm_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered LLM request middleware.
Middleware may return ``{"request": {...}}`` to replace the effective
provider kwargs before Hermes sends them.
"""
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=request,
original_payload=request,
changed=False,
trace=[],
)
original_request = _safe_copy(request)
current_request = _safe_copy(original_request)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
LLM_REQUEST_MIDDLEWARE,
request=current_request,
original_request=original_request,
**context,
):
if not isinstance(result, dict):
continue
next_request = result.get("request")
if not isinstance(next_request, dict):
continue
current_request = _safe_copy(next_request)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_request,
original_payload=original_request,
changed=bool(trace),
trace=trace,
)
def apply_tool_request_middleware(
tool_name: str,
args: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered tool request middleware.
Middleware may return ``{"args": {...}}`` to replace the effective tool
arguments before hooks, guardrails, approvals, and execution see them.
"""
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=args,
original_payload=args,
changed=False,
trace=[],
)
original_args = _safe_copy(args)
current_args = _safe_copy(original_args)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
TOOL_REQUEST_MIDDLEWARE,
tool_name=tool_name,
args=current_args,
original_args=original_args,
**context,
):
if not isinstance(result, dict):
continue
next_args = result.get("args")
if not isinstance(next_args, dict):
continue
current_args = _safe_copy(next_args)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_args,
original_payload=original_args,
changed=bool(trace),
trace=trace,
)
def apply_api_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Compatibility wrapper for older ``api_request`` naming."""
return apply_llm_request_middleware(request, **context)
def run_llm_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run provider execution through registered LLM execution middleware."""
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(request)
return _run_execution_chain(
LLM_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
request=request,
original_request=context.pop("original_request", request),
**context,
)
def run_tool_execution_middleware(
tool_name: str,
args: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run tool execution through registered tool execution middleware."""
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(args)
return _run_execution_chain(
TOOL_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
tool_name=tool_name,
args=args,
original_args=context.pop("original_args", args),
**context,
)
def run_api_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Compatibility wrapper for older ``api_execution`` naming."""
return run_llm_execution_middleware(request, next_call, **context)
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
from hermes_cli.plugins import invoke_middleware
return invoke_middleware(kind, **middleware_payload(**kwargs))
def _has_middleware(kind: str) -> bool:
from hermes_cli.plugins import has_middleware
return has_middleware(kind)
def _get_middleware_callbacks(kind: str) -> List[Callable]:
from hermes_cli.plugins import get_plugin_manager
return list(get_plugin_manager()._middleware.get(kind, []))
def _run_execution_chain(
kind: str,
callbacks: List[Callable],
terminal_call: Callable[[Any], Any],
**kwargs: Any,
) -> Any:
payload_key = "request" if "request" in kwargs else "args"
class _DownstreamExecutionError(Exception):
def __init__(self, original: BaseException) -> None:
super().__init__(str(original))
self.original = original
def call_at(index: int, payload: Any) -> Any:
if index >= len(callbacks):
return terminal_call(payload)
callback = callbacks[index]
next_called = False
next_succeeded = False
next_result: Any = None
def next_call(next_payload: Any = None) -> Any:
nonlocal next_called, next_succeeded, next_result
# ``next_call`` is single-use per middleware frame. Calling it more
# than once would re-run the downstream provider/tool, so a second
# invocation is a contract violation rather than a retry. Surface it
# instead of silently executing the terminal call twice.
if next_called:
raise RuntimeError(
f"Middleware '{kind}' callback "
f"{getattr(callback, '__name__', repr(callback))} called "
"next_call() more than once; downstream execution is single-use"
)
next_called = True
try:
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
next_succeeded = True
return next_result
except Exception as exc:
raise _DownstreamExecutionError(exc) from exc
call_kwargs = middleware_payload(**kwargs)
call_kwargs[payload_key] = payload
call_kwargs["next_call"] = next_call
try:
return callback(**call_kwargs)
except _DownstreamExecutionError as exc:
raise exc.original
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(callback, "__name__", repr(callback)),
exc,
)
if next_succeeded:
return next_result
if next_called:
raise
return call_at(index + 1, payload)
return call_at(0, kwargs[payload_key])
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
entry: Dict[str, Any] = {}
for key in ("source", "reason", "name"):
value = result.get(key)
if isinstance(value, str) and value:
entry[key] = value
if not entry:
entry["source"] = "plugin"
return entry