hermes-agent/hermes_cli/middleware.py
Bryan Bednarski 5abe45674d
fix(middleware): preserve translated downstream failures
Track successful next_call completion separately from invocation so execution
  middleware that catches and translates a downstream provider/tool failure does
  not accidentally convert that failure into a successful None result.

  Also avoid wrapping BaseException from downstream execution, and document the
  execution middleware error semantics.

  Tests cover:
  - pre-next_call middleware failures fail open to the remaining chain
  - post-next_call middleware failures preserve the downstream result
  - translated downstream failures propagate instead of returning None
  - downstream BaseException is not wrapped

Signed-off-by: Bryan Bednarski <bbednarski@nvidia.com>
2026-06-06 09:26:18 -07:00

284 lines
8.3 KiB
Python

"""Hermes middleware contract helpers.
Observer hooks report what happened. Middleware can change what happens by
rewriting a request or wrapping the actual execution callback. Keep the small
contract helpers here so agent-loop call sites and plugins share one vocabulary.
"""
from __future__ import annotations
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
logger = logging.getLogger(__name__)
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
TOOL_REQUEST_MIDDLEWARE = "tool_request"
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
LLM_REQUEST_MIDDLEWARE = "llm_request"
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
# Back-compat aliases for older PoC branches that used API terminology.
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
VALID_MIDDLEWARE: set[str] = {
TOOL_REQUEST_MIDDLEWARE,
TOOL_EXECUTION_MIDDLEWARE,
LLM_REQUEST_MIDDLEWARE,
LLM_EXECUTION_MIDDLEWARE,
}
@dataclass
class RequestMiddlewareResult:
"""Result of applying request middleware to a mutable payload."""
payload: Any
original_payload: Any
changed: bool = False
trace: List[Dict[str, Any]] = field(default_factory=list)
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
return kwargs
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
return kwargs
def apply_llm_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered LLM request middleware.
Middleware may return ``{"request": {...}}`` to replace the effective
provider kwargs before Hermes sends them.
"""
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=request,
original_payload=request,
changed=False,
trace=[],
)
original_request = deepcopy(request)
current_request = deepcopy(original_request)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
LLM_REQUEST_MIDDLEWARE,
request=current_request,
original_request=original_request,
**context,
):
if not isinstance(result, dict):
continue
next_request = result.get("request")
if not isinstance(next_request, dict):
continue
current_request = deepcopy(next_request)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_request,
original_payload=original_request,
changed=bool(trace),
trace=trace,
)
def apply_tool_request_middleware(
tool_name: str,
args: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered tool request middleware.
Middleware may return ``{"args": {...}}`` to replace the effective tool
arguments before hooks, guardrails, approvals, and execution see them.
"""
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=args,
original_payload=args,
changed=False,
trace=[],
)
original_args = deepcopy(args)
current_args = deepcopy(original_args)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
TOOL_REQUEST_MIDDLEWARE,
tool_name=tool_name,
args=current_args,
original_args=original_args,
**context,
):
if not isinstance(result, dict):
continue
next_args = result.get("args")
if not isinstance(next_args, dict):
continue
current_args = deepcopy(next_args)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_args,
original_payload=original_args,
changed=bool(trace),
trace=trace,
)
def apply_api_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Compatibility wrapper for older ``api_request`` naming."""
return apply_llm_request_middleware(request, **context)
def run_llm_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run provider execution through registered LLM execution middleware."""
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(request)
return _run_execution_chain(
LLM_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
request=request,
original_request=context.pop("original_request", request),
**context,
)
def run_tool_execution_middleware(
tool_name: str,
args: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run tool execution through registered tool execution middleware."""
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(args)
return _run_execution_chain(
TOOL_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
tool_name=tool_name,
args=args,
original_args=context.pop("original_args", args),
**context,
)
def run_api_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Compatibility wrapper for older ``api_execution`` naming."""
return run_llm_execution_middleware(request, next_call, **context)
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
from hermes_cli.plugins import invoke_middleware
return invoke_middleware(kind, **middleware_payload(**kwargs))
def _has_middleware(kind: str) -> bool:
from hermes_cli.plugins import has_middleware
return has_middleware(kind)
def _get_middleware_callbacks(kind: str) -> List[Callable]:
from hermes_cli.plugins import get_plugin_manager
return list(get_plugin_manager()._middleware.get(kind, []))
def _run_execution_chain(
kind: str,
callbacks: List[Callable],
terminal_call: Callable[[Any], Any],
**kwargs: Any,
) -> Any:
payload_key = "request" if "request" in kwargs else "args"
class _DownstreamExecutionError(Exception):
def __init__(self, original: BaseException) -> None:
super().__init__(str(original))
self.original = original
def call_at(index: int, payload: Any) -> Any:
if index >= len(callbacks):
return terminal_call(payload)
callback = callbacks[index]
next_called = False
next_succeeded = False
next_result: Any = None
def next_call(next_payload: Any = None) -> Any:
nonlocal next_called, next_succeeded, next_result
next_called = True
try:
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
next_succeeded = True
return next_result
except Exception as exc:
raise _DownstreamExecutionError(exc) from exc
call_kwargs = middleware_payload(**kwargs)
call_kwargs[payload_key] = payload
call_kwargs["next_call"] = next_call
try:
return callback(**call_kwargs)
except _DownstreamExecutionError as exc:
raise exc.original
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(callback, "__name__", repr(callback)),
exc,
)
if next_succeeded:
return next_result
if next_called:
raise
return call_at(index + 1, payload)
return call_at(0, kwargs[payload_key])
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
entry: Dict[str, Any] = {}
for key in ("source", "reason", "name"):
value = result.get(key)
if isinstance(value, str) and value:
entry[key] = value
if not entry:
entry["source"] = "plugin"
return entry