mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Track successful next_call completion separately from invocation so execution middleware that catches and translates a downstream provider/tool failure does not accidentally convert that failure into a successful None result. Also avoid wrapping BaseException from downstream execution, and document the execution middleware error semantics. Tests cover: - pre-next_call middleware failures fail open to the remaining chain - post-next_call middleware failures preserve the downstream result - translated downstream failures propagate instead of returning None - downstream BaseException is not wrapped Signed-off-by: Bryan Bednarski <bbednarski@nvidia.com>
284 lines
8.3 KiB
Python
284 lines
8.3 KiB
Python
"""Hermes middleware contract helpers.
|
|
|
|
Observer hooks report what happened. Middleware can change what happens by
|
|
rewriting a request or wrapping the actual execution callback. Keep the small
|
|
contract helpers here so agent-loop call sites and plugins share one vocabulary.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from copy import deepcopy
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
|
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
|
|
|
|
TOOL_REQUEST_MIDDLEWARE = "tool_request"
|
|
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
|
|
LLM_REQUEST_MIDDLEWARE = "llm_request"
|
|
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
|
|
|
|
# Back-compat aliases for older PoC branches that used API terminology.
|
|
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
|
|
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
|
|
|
|
VALID_MIDDLEWARE: set[str] = {
|
|
TOOL_REQUEST_MIDDLEWARE,
|
|
TOOL_EXECUTION_MIDDLEWARE,
|
|
LLM_REQUEST_MIDDLEWARE,
|
|
LLM_EXECUTION_MIDDLEWARE,
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RequestMiddlewareResult:
|
|
"""Result of applying request middleware to a mutable payload."""
|
|
|
|
payload: Any
|
|
original_payload: Any
|
|
changed: bool = False
|
|
trace: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
|
|
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
|
return kwargs
|
|
|
|
|
|
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
|
|
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
|
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
|
|
return kwargs
|
|
|
|
|
|
def apply_llm_request_middleware(
|
|
request: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Apply registered LLM request middleware.
|
|
|
|
Middleware may return ``{"request": {...}}`` to replace the effective
|
|
provider kwargs before Hermes sends them.
|
|
"""
|
|
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
|
|
return RequestMiddlewareResult(
|
|
payload=request,
|
|
original_payload=request,
|
|
changed=False,
|
|
trace=[],
|
|
)
|
|
|
|
original_request = deepcopy(request)
|
|
current_request = deepcopy(original_request)
|
|
trace: List[Dict[str, Any]] = []
|
|
|
|
for result in _invoke_middleware(
|
|
LLM_REQUEST_MIDDLEWARE,
|
|
request=current_request,
|
|
original_request=original_request,
|
|
**context,
|
|
):
|
|
if not isinstance(result, dict):
|
|
continue
|
|
next_request = result.get("request")
|
|
if not isinstance(next_request, dict):
|
|
continue
|
|
current_request = deepcopy(next_request)
|
|
trace.append(_trace_entry(result))
|
|
|
|
return RequestMiddlewareResult(
|
|
payload=current_request,
|
|
original_payload=original_request,
|
|
changed=bool(trace),
|
|
trace=trace,
|
|
)
|
|
|
|
|
|
def apply_tool_request_middleware(
|
|
tool_name: str,
|
|
args: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Apply registered tool request middleware.
|
|
|
|
Middleware may return ``{"args": {...}}`` to replace the effective tool
|
|
arguments before hooks, guardrails, approvals, and execution see them.
|
|
"""
|
|
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
|
|
return RequestMiddlewareResult(
|
|
payload=args,
|
|
original_payload=args,
|
|
changed=False,
|
|
trace=[],
|
|
)
|
|
|
|
original_args = deepcopy(args)
|
|
current_args = deepcopy(original_args)
|
|
trace: List[Dict[str, Any]] = []
|
|
|
|
for result in _invoke_middleware(
|
|
TOOL_REQUEST_MIDDLEWARE,
|
|
tool_name=tool_name,
|
|
args=current_args,
|
|
original_args=original_args,
|
|
**context,
|
|
):
|
|
if not isinstance(result, dict):
|
|
continue
|
|
next_args = result.get("args")
|
|
if not isinstance(next_args, dict):
|
|
continue
|
|
current_args = deepcopy(next_args)
|
|
trace.append(_trace_entry(result))
|
|
|
|
return RequestMiddlewareResult(
|
|
payload=current_args,
|
|
original_payload=original_args,
|
|
changed=bool(trace),
|
|
trace=trace,
|
|
)
|
|
|
|
|
|
def apply_api_request_middleware(
|
|
request: Dict[str, Any],
|
|
**context: Any,
|
|
) -> RequestMiddlewareResult:
|
|
"""Compatibility wrapper for older ``api_request`` naming."""
|
|
return apply_llm_request_middleware(request, **context)
|
|
|
|
|
|
def run_llm_execution_middleware(
|
|
request: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Run provider execution through registered LLM execution middleware."""
|
|
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
|
|
if not callbacks:
|
|
return next_call(request)
|
|
return _run_execution_chain(
|
|
LLM_EXECUTION_MIDDLEWARE,
|
|
callbacks,
|
|
next_call,
|
|
request=request,
|
|
original_request=context.pop("original_request", request),
|
|
**context,
|
|
)
|
|
|
|
|
|
def run_tool_execution_middleware(
|
|
tool_name: str,
|
|
args: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Run tool execution through registered tool execution middleware."""
|
|
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
|
|
if not callbacks:
|
|
return next_call(args)
|
|
return _run_execution_chain(
|
|
TOOL_EXECUTION_MIDDLEWARE,
|
|
callbacks,
|
|
next_call,
|
|
tool_name=tool_name,
|
|
args=args,
|
|
original_args=context.pop("original_args", args),
|
|
**context,
|
|
)
|
|
|
|
|
|
def run_api_execution_middleware(
|
|
request: Dict[str, Any],
|
|
next_call: Callable[[Dict[str, Any]], Any],
|
|
**context: Any,
|
|
) -> Any:
|
|
"""Compatibility wrapper for older ``api_execution`` naming."""
|
|
return run_llm_execution_middleware(request, next_call, **context)
|
|
|
|
|
|
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
|
from hermes_cli.plugins import invoke_middleware
|
|
|
|
return invoke_middleware(kind, **middleware_payload(**kwargs))
|
|
|
|
|
|
def _has_middleware(kind: str) -> bool:
|
|
from hermes_cli.plugins import has_middleware
|
|
|
|
return has_middleware(kind)
|
|
|
|
|
|
def _get_middleware_callbacks(kind: str) -> List[Callable]:
|
|
from hermes_cli.plugins import get_plugin_manager
|
|
|
|
return list(get_plugin_manager()._middleware.get(kind, []))
|
|
|
|
|
|
def _run_execution_chain(
|
|
kind: str,
|
|
callbacks: List[Callable],
|
|
terminal_call: Callable[[Any], Any],
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
payload_key = "request" if "request" in kwargs else "args"
|
|
|
|
class _DownstreamExecutionError(Exception):
|
|
def __init__(self, original: BaseException) -> None:
|
|
super().__init__(str(original))
|
|
self.original = original
|
|
|
|
def call_at(index: int, payload: Any) -> Any:
|
|
if index >= len(callbacks):
|
|
return terminal_call(payload)
|
|
|
|
callback = callbacks[index]
|
|
next_called = False
|
|
next_succeeded = False
|
|
next_result: Any = None
|
|
|
|
def next_call(next_payload: Any = None) -> Any:
|
|
nonlocal next_called, next_succeeded, next_result
|
|
next_called = True
|
|
try:
|
|
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
|
|
next_succeeded = True
|
|
return next_result
|
|
except Exception as exc:
|
|
raise _DownstreamExecutionError(exc) from exc
|
|
|
|
call_kwargs = middleware_payload(**kwargs)
|
|
call_kwargs[payload_key] = payload
|
|
call_kwargs["next_call"] = next_call
|
|
try:
|
|
return callback(**call_kwargs)
|
|
except _DownstreamExecutionError as exc:
|
|
raise exc.original
|
|
except Exception as exc:
|
|
logger.warning(
|
|
"Middleware '%s' callback %s raised: %s",
|
|
kind,
|
|
getattr(callback, "__name__", repr(callback)),
|
|
exc,
|
|
)
|
|
if next_succeeded:
|
|
return next_result
|
|
if next_called:
|
|
raise
|
|
return call_at(index + 1, payload)
|
|
|
|
return call_at(0, kwargs[payload_key])
|
|
|
|
|
|
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
|
|
entry: Dict[str, Any] = {}
|
|
for key in ("source", "reason", "name"):
|
|
value = result.get(key)
|
|
if isinstance(value, str) and value:
|
|
entry[key] = value
|
|
if not entry:
|
|
entry["source"] = "plugin"
|
|
return entry
|