"""Hermes middleware contract helpers. Observer hooks report what happened. Middleware can change what happens by rewriting a request or wrapping the actual execution callback. Keep the small contract helpers here so agent-loop call sites and plugins share one vocabulary. """ from __future__ import annotations import logging from copy import deepcopy from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional logger = logging.getLogger(__name__) OBSERVER_SCHEMA_VERSION = "hermes.observer.v1" MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1" TOOL_REQUEST_MIDDLEWARE = "tool_request" TOOL_EXECUTION_MIDDLEWARE = "tool_execution" LLM_REQUEST_MIDDLEWARE = "llm_request" LLM_EXECUTION_MIDDLEWARE = "llm_execution" # Back-compat aliases for older PoC branches that used API terminology. API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE VALID_MIDDLEWARE: set[str] = { TOOL_REQUEST_MIDDLEWARE, TOOL_EXECUTION_MIDDLEWARE, LLM_REQUEST_MIDDLEWARE, LLM_EXECUTION_MIDDLEWARE, } @dataclass class RequestMiddlewareResult: """Result of applying request middleware to a mutable payload.""" payload: Any original_payload: Any changed: bool = False trace: List[Dict[str, Any]] = field(default_factory=list) def observer_payload(**kwargs: Any) -> Dict[str, Any]: kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION) return kwargs def middleware_payload(**kwargs: Any) -> Dict[str, Any]: kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION) kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION) return kwargs def apply_llm_request_middleware( request: Dict[str, Any], **context: Any, ) -> RequestMiddlewareResult: """Apply registered LLM request middleware. Middleware may return ``{"request": {...}}`` to replace the effective provider kwargs before Hermes sends them. """ if not _has_middleware(LLM_REQUEST_MIDDLEWARE): return RequestMiddlewareResult( payload=request, original_payload=request, changed=False, trace=[], ) original_request = deepcopy(request) current_request = deepcopy(original_request) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( LLM_REQUEST_MIDDLEWARE, request=current_request, original_request=original_request, **context, ): if not isinstance(result, dict): continue next_request = result.get("request") if not isinstance(next_request, dict): continue current_request = deepcopy(next_request) trace.append(_trace_entry(result)) return RequestMiddlewareResult( payload=current_request, original_payload=original_request, changed=bool(trace), trace=trace, ) def apply_tool_request_middleware( tool_name: str, args: Dict[str, Any], **context: Any, ) -> RequestMiddlewareResult: """Apply registered tool request middleware. Middleware may return ``{"args": {...}}`` to replace the effective tool arguments before hooks, guardrails, approvals, and execution see them. """ if not _has_middleware(TOOL_REQUEST_MIDDLEWARE): return RequestMiddlewareResult( payload=args, original_payload=args, changed=False, trace=[], ) original_args = deepcopy(args) current_args = deepcopy(original_args) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( TOOL_REQUEST_MIDDLEWARE, tool_name=tool_name, args=current_args, original_args=original_args, **context, ): if not isinstance(result, dict): continue next_args = result.get("args") if not isinstance(next_args, dict): continue current_args = deepcopy(next_args) trace.append(_trace_entry(result)) return RequestMiddlewareResult( payload=current_args, original_payload=original_args, changed=bool(trace), trace=trace, ) def apply_api_request_middleware( request: Dict[str, Any], **context: Any, ) -> RequestMiddlewareResult: """Compatibility wrapper for older ``api_request`` naming.""" return apply_llm_request_middleware(request, **context) def run_llm_execution_middleware( request: Dict[str, Any], next_call: Callable[[Dict[str, Any]], Any], **context: Any, ) -> Any: """Run provider execution through registered LLM execution middleware.""" callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE) if not callbacks: return next_call(request) return _run_execution_chain( LLM_EXECUTION_MIDDLEWARE, callbacks, next_call, request=request, original_request=context.pop("original_request", request), **context, ) def run_tool_execution_middleware( tool_name: str, args: Dict[str, Any], next_call: Callable[[Dict[str, Any]], Any], **context: Any, ) -> Any: """Run tool execution through registered tool execution middleware.""" callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE) if not callbacks: return next_call(args) return _run_execution_chain( TOOL_EXECUTION_MIDDLEWARE, callbacks, next_call, tool_name=tool_name, args=args, original_args=context.pop("original_args", args), **context, ) def run_api_execution_middleware( request: Dict[str, Any], next_call: Callable[[Dict[str, Any]], Any], **context: Any, ) -> Any: """Compatibility wrapper for older ``api_execution`` naming.""" return run_llm_execution_middleware(request, next_call, **context) def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]: from hermes_cli.plugins import invoke_middleware return invoke_middleware(kind, **middleware_payload(**kwargs)) def _has_middleware(kind: str) -> bool: from hermes_cli.plugins import has_middleware return has_middleware(kind) def _get_middleware_callbacks(kind: str) -> List[Callable]: from hermes_cli.plugins import get_plugin_manager return list(get_plugin_manager()._middleware.get(kind, [])) def _run_execution_chain( kind: str, callbacks: List[Callable], terminal_call: Callable[[Any], Any], **kwargs: Any, ) -> Any: payload_key = "request" if "request" in kwargs else "args" class _DownstreamExecutionError(Exception): def __init__(self, original: BaseException) -> None: super().__init__(str(original)) self.original = original def call_at(index: int, payload: Any) -> Any: if index >= len(callbacks): return terminal_call(payload) callback = callbacks[index] next_called = False next_succeeded = False next_result: Any = None def next_call(next_payload: Any = None) -> Any: nonlocal next_called, next_succeeded, next_result next_called = True try: next_result = call_at(index + 1, payload if next_payload is None else next_payload) next_succeeded = True return next_result except Exception as exc: raise _DownstreamExecutionError(exc) from exc call_kwargs = middleware_payload(**kwargs) call_kwargs[payload_key] = payload call_kwargs["next_call"] = next_call try: return callback(**call_kwargs) except _DownstreamExecutionError as exc: raise exc.original except Exception as exc: logger.warning( "Middleware '%s' callback %s raised: %s", kind, getattr(callback, "__name__", repr(callback)), exc, ) if next_succeeded: return next_result if next_called: raise return call_at(index + 1, payload) return call_at(0, kwargs[payload_key]) def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]: entry: Dict[str, Any] = {} for key in ("source", "reason", "name"): value = result.get(key) if isinstance(value, str) and value: entry[key] = value if not entry: entry["source"] = "plugin" return entry