From 2e0c9083db8425d6e087ba0af6024406aa513d05 Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Wed, 3 Jun 2026 11:22:06 -0700 Subject: [PATCH 1/3] feat(middleware): add adaptive execution intercepts Signed-off-by: Bryan Bednarski --- agent/agent_runtime_helpers.py | 200 ++++++++----- agent/conversation_loop.py | 42 ++- agent/tool_executor.py | 289 +++++++++++++++---- docs/middleware/README.md | 251 ++++++++++++++++ hermes_cli/middleware.py | 280 ++++++++++++++++++ hermes_cli/plugins.py | 88 +++++- model_tools.py | 44 ++- plugins/observability/nemo_relay/README.md | 185 ++++++++++++ plugins/observability/nemo_relay/__init__.py | 280 +++++++++++++++++- run_agent.py | 16 +- tests/hermes_cli/test_plugins.py | 111 +++++++ tests/plugins/test_nemo_relay_plugin.py | 239 ++++++++++++++- tests/run_agent/test_run_agent.py | 85 ++++++ tests/test_model_tools.py | 56 ++++ 14 files changed, 2015 insertions(+), 151 deletions(-) create mode 100644 docs/middleware/README.md create mode 100644 hermes_cli/middleware.py diff --git a/agent/agent_runtime_helpers.py b/agent/agent_runtime_helpers.py index 09eccef5f34..2c223345055 100644 --- a/agent/agent_runtime_helpers.py +++ b/agent/agent_runtime_helpers.py @@ -1619,13 +1619,37 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo def invoke_tool(agent, function_name: str, function_args: dict, effective_task_id: str, tool_call_id: Optional[str] = None, messages: list = None, - pre_tool_block_checked: bool = False) -> str: + pre_tool_block_checked: bool = False, + skip_tool_request_middleware: bool = False, + tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None) -> str: """Invoke a single tool and return the result string. No display logic. Handles both agent-level tools (todo, memory, etc.) and registry-dispatched tools. Used by the concurrent execution path; the sequential path retains its own inline invocation for backward-compatible display handling. """ + if not isinstance(function_args, dict): + function_args = {} + + _tool_middleware_trace = list(tool_request_middleware_trace or []) + try: + from hermes_cli.middleware import apply_tool_request_middleware + + if not skip_tool_request_middleware: + _tool_request_mw = apply_tool_request_middleware( + function_name, + function_args, + task_id=effective_task_id or "", + session_id=getattr(agent, "session_id", "") or "", + tool_call_id=tool_call_id or "", + turn_id=getattr(agent, "_current_turn_id", "") or "", + api_request_id=getattr(agent, "_current_api_request_id", "") or "", + ) + function_args = _tool_request_mw.payload + _tool_middleware_trace = _tool_request_mw.trace + except Exception as _mw_err: + logger.debug("tool_request middleware error: %s", _mw_err) + # Check plugin hooks for a block directive before executing anything. block_message: Optional[str] = None if not pre_tool_block_checked: @@ -1639,6 +1663,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i tool_call_id=tool_call_id or "", turn_id=getattr(agent, "_current_turn_id", "") or "", api_request_id=getattr(agent, "_current_api_request_id", "") or "", + middleware_trace=list(_tool_middleware_trace), ) except Exception: pass @@ -1658,6 +1683,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i status="blocked", error_type="plugin_block", error_message=block_message, + middleware_trace=list(_tool_middleware_trace), ) except Exception: pass @@ -1665,12 +1691,13 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i tool_start_time = time.monotonic() - def _finish_agent_tool(result: Any) -> Any: + def _finish_agent_tool(result: Any, observed_args: Optional[dict] = None) -> Any: + hook_args = observed_args if isinstance(observed_args, dict) else function_args try: from model_tools import _emit_post_tool_call_hook _emit_post_tool_call_hook( function_name=function_name, - function_args=function_args, + function_args=hook_args, result=result, task_id=effective_task_id or "", session_id=getattr(agent, "session_id", "") or "", @@ -1678,89 +1705,116 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i turn_id=getattr(agent, "_current_turn_id", "") or "", api_request_id=getattr(agent, "_current_api_request_id", "") or "", duration_ms=int((time.monotonic() - tool_start_time) * 1000), + middleware_trace=list(_tool_middleware_trace), ) except Exception: pass return result if function_name == "todo": - from tools.todo_tool import todo_tool as _todo_tool - return _finish_agent_tool( - _todo_tool( - todos=function_args.get("todos"), - merge=function_args.get("merge", False), - store=agent._todo_store, + def _execute(next_args: dict) -> Any: + from tools.todo_tool import todo_tool as _todo_tool + return _finish_agent_tool( + _todo_tool( + todos=next_args.get("todos"), + merge=next_args.get("merge", False), + store=agent._todo_store, + ), + next_args, ) - ) elif function_name == "session_search": - session_db = agent._get_session_db_for_recall() - if not session_db: - from hermes_state import format_session_db_unavailable - return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()})) - from tools.session_search_tool import session_search as _session_search - return _finish_agent_tool( - _session_search( - query=function_args.get("query", ""), - role_filter=function_args.get("role_filter"), - limit=function_args.get("limit", 3), - session_id=function_args.get("session_id"), - around_message_id=function_args.get("around_message_id"), - window=function_args.get("window", 5), - sort=function_args.get("sort"), - db=session_db, - current_session_id=agent.session_id, + def _execute(next_args: dict) -> Any: + session_db = agent._get_session_db_for_recall() + if not session_db: + from hermes_state import format_session_db_unavailable + return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()}), next_args) + from tools.session_search_tool import session_search as _session_search + return _finish_agent_tool( + _session_search( + query=next_args.get("query", ""), + role_filter=next_args.get("role_filter"), + limit=next_args.get("limit", 3), + session_id=next_args.get("session_id"), + around_message_id=next_args.get("around_message_id"), + window=next_args.get("window", 5), + sort=next_args.get("sort"), + db=session_db, + current_session_id=agent.session_id, + ), + next_args, ) - ) elif function_name == "memory": - target = function_args.get("target", "memory") - from tools.memory_tool import memory_tool as _memory_tool - result = _memory_tool( - action=function_args.get("action"), - target=target, - content=function_args.get("content"), - old_text=function_args.get("old_text"), - store=agent._memory_store, - ) - # Bridge: notify external memory provider of built-in memory writes - if agent._memory_manager and function_args.get("action") in {"add", "replace"}: - try: - agent._memory_manager.on_memory_write( - function_args.get("action", ""), - target, - function_args.get("content", ""), - metadata=agent._build_memory_write_metadata( - task_id=effective_task_id, - tool_call_id=tool_call_id, - ), - ) - except Exception: - pass - return _finish_agent_tool(result) - elif agent._memory_manager and agent._memory_manager.has_tool(function_name): - return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, function_args)) - elif function_name == "clarify": - from tools.clarify_tool import clarify_tool as _clarify_tool - return _finish_agent_tool( - _clarify_tool( - question=function_args.get("question", ""), - choices=function_args.get("choices"), - callback=agent.clarify_callback, + def _execute(next_args: dict) -> Any: + target = next_args.get("target", "memory") + from tools.memory_tool import memory_tool as _memory_tool + result = _memory_tool( + action=next_args.get("action"), + target=target, + content=next_args.get("content"), + old_text=next_args.get("old_text"), + store=agent._memory_store, + ) + # Bridge: notify external memory provider of built-in memory writes + if agent._memory_manager and next_args.get("action") in {"add", "replace"}: + try: + agent._memory_manager.on_memory_write( + next_args.get("action", ""), + target, + next_args.get("content", ""), + metadata=agent._build_memory_write_metadata( + task_id=effective_task_id, + tool_call_id=tool_call_id, + ), + ) + except Exception: + pass + return _finish_agent_tool(result, next_args) + elif agent._memory_manager and agent._memory_manager.has_tool(function_name): + def _execute(next_args: dict) -> Any: + return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, next_args), next_args) + elif function_name == "clarify": + def _execute(next_args: dict) -> Any: + from tools.clarify_tool import clarify_tool as _clarify_tool + return _finish_agent_tool( + _clarify_tool( + question=next_args.get("question", ""), + choices=next_args.get("choices"), + callback=agent.clarify_callback, + ), + next_args, ) - ) elif function_name == "delegate_task": - return _finish_agent_tool(agent._dispatch_delegate_task(function_args)) + def _execute(next_args: dict) -> Any: + return _finish_agent_tool(agent._dispatch_delegate_task(next_args), next_args) else: - return _ra().handle_function_call( - function_name, function_args, effective_task_id, - tool_call_id=tool_call_id, - session_id=agent.session_id or "", - turn_id=getattr(agent, "_current_turn_id", "") or "", - api_request_id=getattr(agent, "_current_api_request_id", "") or "", - enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None, - skip_pre_tool_call_hook=True, - enabled_toolsets=getattr(agent, "enabled_toolsets", None), - disabled_toolsets=getattr(agent, "disabled_toolsets", None), - ) + def _execute(next_args: dict) -> Any: + return _ra().handle_function_call( + function_name, next_args, effective_task_id, + tool_call_id=tool_call_id, + session_id=agent.session_id or "", + turn_id=getattr(agent, "_current_turn_id", "") or "", + api_request_id=getattr(agent, "_current_api_request_id", "") or "", + enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None, + skip_pre_tool_call_hook=True, + skip_tool_request_middleware=True, + enabled_toolsets=getattr(agent, "enabled_toolsets", None), + disabled_toolsets=getattr(agent, "disabled_toolsets", None), + tool_request_middleware_trace=list(_tool_middleware_trace), + ) + + from hermes_cli.middleware import run_tool_execution_middleware + + return run_tool_execution_middleware( + function_name, + function_args, + lambda next_args: _execute(next_args if isinstance(next_args, dict) else function_args), + original_args=function_args, + task_id=effective_task_id or "", + session_id=getattr(agent, "session_id", "") or "", + tool_call_id=tool_call_id or "", + turn_id=getattr(agent, "_current_turn_id", "") or "", + api_request_id=getattr(agent, "_current_api_request_id", "") or "", + ) diff --git a/agent/conversation_loop.py b/agent/conversation_loop.py index c52b9b72d7d..b66a6615ae0 100644 --- a/agent/conversation_loop.py +++ b/agent/conversation_loop.py @@ -1225,6 +1225,28 @@ def run_conversation( _sanitize_structure_non_ascii(api_kwargs) if agent.api_mode == "codex_responses": api_kwargs = agent._get_transport().preflight_kwargs(api_kwargs, allow_stream=False) + try: + from hermes_cli.middleware import apply_llm_request_middleware + + _llm_request_mw = apply_llm_request_middleware( + api_kwargs, + task_id=effective_task_id, + turn_id=turn_id, + api_request_id=api_request_id, + session_id=agent.session_id or "", + platform=agent.platform or "", + model=agent.model, + provider=agent.provider, + base_url=agent.base_url, + api_mode=agent.api_mode, + api_call_count=api_call_count, + ) + api_kwargs = _llm_request_mw.payload + _original_api_kwargs = _llm_request_mw.original_payload + _llm_middleware_trace = _llm_request_mw.trace + except Exception: + _original_api_kwargs = dict(api_kwargs) + _llm_middleware_trace = [] try: from hermes_cli.plugins import ( @@ -1277,6 +1299,7 @@ def run_conversation( request_char_count=total_chars, max_tokens=agent.max_tokens, started_at=api_start_time, + middleware_trace=list(_llm_middleware_trace), request=_request_payload, ) except Exception: @@ -1335,7 +1358,24 @@ def run_conversation( ) return agent._interruptible_api_call(next_api_kwargs) - response = _perform_api_call(api_kwargs) + from hermes_cli.middleware import run_llm_execution_middleware + + response = run_llm_execution_middleware( + api_kwargs, + _perform_api_call, + original_request=_original_api_kwargs, + task_id=effective_task_id, + turn_id=turn_id, + api_request_id=api_request_id, + session_id=agent.session_id or "", + platform=agent.platform or "", + model=agent.model, + provider=agent.provider, + base_url=agent.base_url, + api_mode=agent.api_mode, + api_call_count=api_call_count, + middleware_trace=list(_llm_middleware_trace), + ) api_duration = time.time() - api_start_time diff --git a/agent/tool_executor.py b/agent/tool_executor.py index fc3667edb50..f908aedb806 100644 --- a/agent/tool_executor.py +++ b/agent/tool_executor.py @@ -70,6 +70,7 @@ def _emit_terminal_post_tool_call( status: str | None = None, error_type: str | None = None, error_message: str | None = None, + middleware_trace: Optional[list[dict[str, Any]]] = None, ) -> None: try: from model_tools import _emit_post_tool_call_hook @@ -86,6 +87,7 @@ def _emit_terminal_post_tool_call( status=status, error_type=error_type, error_message=error_message, + middleware_trace=list(middleware_trace or []), ) except Exception: pass @@ -111,6 +113,7 @@ def _emit_cancelled_terminal_post_tool_call( start_time: float, reason: str = "user interrupt", error_type: str = "keyboard_interrupt", + middleware_trace: Optional[list[dict[str, Any]]] = None, ) -> str: result = _cancelled_tool_result(reason) _emit_terminal_post_tool_call( @@ -124,6 +127,7 @@ def _emit_cancelled_terminal_post_tool_call( status="cancelled", error_type=error_type, error_message=f"Tool execution cancelled by {reason}", + middleware_trace=list(middleware_trace or []), ) return result @@ -177,6 +181,65 @@ def _tool_search_scoped_names(agent) -> frozenset: return names +def _apply_tool_request_middleware_for_agent( + agent, + *, + function_name: str, + function_args: dict, + effective_task_id: str, + tool_call_id: str, +) -> tuple[dict, list[dict[str, Any]]]: + try: + from hermes_cli.middleware import apply_tool_request_middleware + + result = apply_tool_request_middleware( + function_name, + function_args, + task_id=effective_task_id or "", + session_id=getattr(agent, "session_id", "") or "", + tool_call_id=tool_call_id or "", + turn_id=getattr(agent, "_current_turn_id", "") or "", + api_request_id=getattr(agent, "_current_api_request_id", "") or "", + ) + payload = result.payload if isinstance(result.payload, dict) else function_args + return payload, list(result.trace) + except Exception as exc: + logger.debug("tool_request middleware error: %s", exc) + return function_args, [] + + +def _run_agent_tool_execution_middleware( + agent, + *, + function_name: str, + function_args: dict, + effective_task_id: str, + tool_call_id: str, + execute, +) -> tuple[Any, dict]: + observed_args = function_args + + def _execute(next_args: dict) -> Any: + nonlocal observed_args + observed_args = next_args if isinstance(next_args, dict) else function_args + return execute(observed_args) + + from hermes_cli.middleware import run_tool_execution_middleware + + result = run_tool_execution_middleware( + function_name, + function_args, + _execute, + original_args=function_args, + task_id=effective_task_id or "", + session_id=getattr(agent, "session_id", "") or "", + tool_call_id=tool_call_id or "", + turn_id=getattr(agent, "_current_turn_id", "") or "", + api_request_id=getattr(agent, "_current_api_request_id", "") or "", + ) + return result, observed_args + + def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None: """Execute multiple tool calls concurrently using a thread pool. @@ -198,7 +261,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe return # ── Parse args + pre-execution bookkeeping ─────────────────────── - parsed_calls = [] # list of (tool_call, function_name, function_args) + parsed_calls = [] # list of (tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail) for tool_call in tool_calls: function_name = tool_call.function.name @@ -250,6 +313,14 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe except Exception: pass + function_args, middleware_trace = _apply_tool_request_middleware_for_agent( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + ) + # ── Block evaluation (BEFORE checkpoint preflight) ─────────── # We must know whether the tool will execute before touching # checkpoint state (dedup slot, real snapshots). @@ -268,6 +339,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe status="blocked", error_type="tool_scope_block", error_message=_ts_scope_block, + middleware_trace=list(middleware_trace), ) else: try: @@ -280,6 +352,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe tool_call_id=getattr(tool_call, "id", "") or "", turn_id=getattr(agent, "_current_turn_id", "") or "", api_request_id=getattr(agent, "_current_api_request_id", "") or "", + middleware_trace=list(middleware_trace), ) except Exception: block_message = None @@ -296,6 +369,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe status="blocked", error_type="plugin_block", error_message=block_message, + middleware_trace=list(middleware_trace), ) else: guardrail_decision = agent._tool_guardrails.before_call(function_name, function_args) @@ -312,6 +386,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe status="blocked", error_type="guardrail_block", error_message=getattr(guardrail_decision, "message", None) or "Tool blocked by guardrail policy", + middleware_trace=list(middleware_trace), ) # ── Checkpoint preflight (only for tools that will execute) ── @@ -338,13 +413,13 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe except Exception: pass - parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail)) + parsed_calls.append((tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail)) # ── Logging / callbacks ────────────────────────────────────────── - tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls) + tool_names_str = ", ".join(name for _, name, _, _, _, _ in parsed_calls) if not agent.quiet_mode: print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}") - for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1): + for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1): args_str = json.dumps(args, ensure_ascii=False) if agent.verbose_logging: print(f" 📞 Tool {i}: {name}({list(args.keys())})") @@ -353,7 +428,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe args_preview = args_str[:agent.log_prefix_chars] + "..." if len(args_str) > agent.log_prefix_chars else args_str print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}") - for tc, name, args, block_result, blocked_by_guardrail in parsed_calls: + for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls: if block_result is not None: continue if agent.tool_progress_callback: @@ -363,7 +438,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe except Exception as cb_err: logging.debug(f"Tool progress callback error: {cb_err}") - for tc, name, args, block_result, blocked_by_guardrail in parsed_calls: + for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls: if block_result is not None: continue if agent.tool_start_callback: @@ -373,18 +448,18 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe logging.debug(f"Tool start callback error: {cb_err}") # ── Concurrent execution ───────────────────────────────────────── - # Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag) + # Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag, middleware_trace) results = [None] * num_tools - for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls): + for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls): if block_result is not None: - results[i] = (name, args, block_result, 0.0, True, True) + results[i] = (name, args, block_result, 0.0, True, True, middleware_trace) # Touch activity before launching workers so the gateway knows # we're executing tools (not stuck). agent._current_tool = tool_names_str agent._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}") - def _run_tool(index, tool_call, function_name, function_args): + def _run_tool(index, tool_call, function_name, function_args, middleware_trace): """Worker function executed in a thread.""" # Register this worker tid so the agent can fan out an interrupt # to it — see AIAgent.interrupt(). Must happen first thing, and @@ -423,6 +498,8 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe tool_call.id, messages=messages, pre_tool_block_checked=True, + skip_tool_request_middleware=True, + tool_request_middleware_trace=list(middleware_trace), ) except KeyboardInterrupt: try: @@ -436,10 +513,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe effective_task_id=effective_task_id, tool_call_id=getattr(tool_call, "id", "") or "", start_time=start, + middleware_trace=list(middleware_trace), ) duration = time.time() - start logger.info("tool %s cancelled (%.2fs)", function_name, duration) - results[index] = (function_name, function_args, result, duration, True, False) + results[index] = (function_name, function_args, result, duration, True, False, middleware_trace) return except Exception as tool_error: result = f"Error executing tool '{function_name}': {tool_error}" @@ -450,7 +528,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200]) else: logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result)) - results[index] = (function_name, function_args, result, duration, is_error, False) + results[index] = (function_name, function_args, result, duration, is_error, False, middleware_trace) finally: # Tear down worker-tid tracking. Clear any interrupt bit we may # have set so the next task scheduled onto this recycled tid @@ -475,7 +553,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe try: runnable_calls = [ (i, tc, name, args) - for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls) + for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls) if block_result is None ] futures = [] @@ -487,7 +565,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe # _approval_session_key) AND thread-local approval/sudo # callbacks into the worker thread; clears callbacks on exit. f = executor.submit( - propagate_context_to_thread(_run_tool), i, tc, name, args + propagate_context_to_thread(_run_tool), i, tc, name, args, parsed_calls[i][3] ) futures.append(f) @@ -545,7 +623,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total") # ── Post-execution: display per-tool results ───────────────────── - for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls): + for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls): r = results[i] blocked = False if r is None: @@ -562,6 +640,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe status="cancelled", error_type="keyboard_interrupt", error_message="Tool execution cancelled by user interrupt", + middleware_trace=list(middleware_trace), ) else: function_result = f"Error executing tool '{name}': thread did not return a result" @@ -575,10 +654,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe status="error", error_type="thread_missing_result", error_message=function_result, + middleware_trace=list(middleware_trace), ) tool_duration = 0.0 else: - function_name, function_args, function_result, tool_duration, is_error, blocked = r + function_name, function_args, function_result, tool_duration, is_error, blocked, middleware_trace = r if not blocked: function_result = agent._append_guardrail_observation( @@ -738,6 +818,14 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe except Exception: pass + function_args, middleware_trace = _apply_tool_request_middleware_for_agent( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + ) + # Check plugin hooks for a block directive before executing. _block_msg: Optional[str] = None _block_error_type = "plugin_block" @@ -755,6 +843,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe tool_call_id=getattr(tool_call, "id", "") or "", turn_id=getattr(agent, "_current_turn_id", "") or "", api_request_id=getattr(agent, "_current_api_request_id", "") or "", + middleware_trace=list(middleware_trace), ) except Exception: pass @@ -853,6 +942,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe status="blocked", error_type=_block_error_type, error_message=_block_msg, + middleware_trace=list(middleware_trace), ) elif _guardrail_block_decision is not None: # Tool blocked by tool-loop guardrail — synthesize exactly one @@ -869,71 +959,108 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe status="blocked", error_type="guardrail_block", error_message=getattr(_guardrail_block_decision, "message", None) or "Tool blocked by guardrail policy", + middleware_trace=list(middleware_trace), ) elif function_name == "todo": - from tools.todo_tool import todo_tool as _todo_tool - function_result = _todo_tool( - todos=function_args.get("todos"), - merge=function_args.get("merge", False), - store=agent._todo_store, + def _execute(next_args: dict) -> Any: + from tools.todo_tool import todo_tool as _todo_tool + return _todo_tool( + todos=next_args.get("todos"), + merge=next_args.get("merge", False), + store=agent._todo_store, + ) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, ) tool_duration = time.time() - tool_start_time if agent._should_emit_quiet_tool_messages(): agent._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}") elif function_name == "session_search": - session_db = agent._get_session_db_for_recall() - if not session_db: - from hermes_state import format_session_db_unavailable - function_result = json.dumps({"success": False, "error": format_session_db_unavailable()}) - else: + def _execute(next_args: dict) -> Any: + session_db = agent._get_session_db_for_recall() + if not session_db: + from hermes_state import format_session_db_unavailable + return json.dumps({"success": False, "error": format_session_db_unavailable()}) from tools.session_search_tool import session_search as _session_search - function_result = _session_search( - query=function_args.get("query", ""), - role_filter=function_args.get("role_filter"), - limit=function_args.get("limit", 3), - session_id=function_args.get("session_id"), - around_message_id=function_args.get("around_message_id"), - window=function_args.get("window", 5), - sort=function_args.get("sort"), + return _session_search( + query=next_args.get("query", ""), + role_filter=next_args.get("role_filter"), + limit=next_args.get("limit", 3), + session_id=next_args.get("session_id"), + around_message_id=next_args.get("around_message_id"), + window=next_args.get("window", 5), + sort=next_args.get("sort"), db=session_db, current_session_id=agent.session_id, ) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, + ) tool_duration = time.time() - tool_start_time if agent._should_emit_quiet_tool_messages(): agent._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}") elif function_name == "memory": - target = function_args.get("target", "memory") - from tools.memory_tool import memory_tool as _memory_tool - function_result = _memory_tool( - action=function_args.get("action"), - target=target, - content=function_args.get("content"), - old_text=function_args.get("old_text"), - store=agent._memory_store, + def _execute(next_args: dict) -> Any: + target = next_args.get("target", "memory") + from tools.memory_tool import memory_tool as _memory_tool + result = _memory_tool( + action=next_args.get("action"), + target=target, + content=next_args.get("content"), + old_text=next_args.get("old_text"), + store=agent._memory_store, + ) + # Bridge: notify external memory provider of built-in memory writes + if agent._memory_manager and next_args.get("action") in {"add", "replace"}: + try: + agent._memory_manager.on_memory_write( + next_args.get("action", ""), + target, + next_args.get("content", ""), + metadata=agent._build_memory_write_metadata( + task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", None), + ), + ) + except Exception: + pass + return result + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, ) - # Bridge: notify external memory provider of built-in memory writes - if agent._memory_manager and function_args.get("action") in {"add", "replace"}: - try: - agent._memory_manager.on_memory_write( - function_args.get("action", ""), - target, - function_args.get("content", ""), - metadata=agent._build_memory_write_metadata( - task_id=effective_task_id, - tool_call_id=getattr(tool_call, "id", None), - ), - ) - except Exception: - pass tool_duration = time.time() - tool_start_time if agent._should_emit_quiet_tool_messages(): agent._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}") elif function_name == "clarify": - from tools.clarify_tool import clarify_tool as _clarify_tool - function_result = _clarify_tool( - question=function_args.get("question", ""), - choices=function_args.get("choices"), - callback=agent.clarify_callback, + def _execute(next_args: dict) -> Any: + from tools.clarify_tool import clarify_tool as _clarify_tool + return _clarify_tool( + question=next_args.get("question", ""), + choices=next_args.get("choices"), + callback=agent.clarify_callback, + ) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, ) tool_duration = time.time() - tool_start_time if agent._should_emit_quiet_tool_messages(): @@ -957,7 +1084,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe agent._delegate_spinner = spinner _delegate_result = None try: - function_result = agent._dispatch_delegate_task(function_args) + def _execute(next_args: dict) -> Any: + return agent._dispatch_delegate_task(next_args) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, + ) _delegate_result = function_result finally: agent._delegate_spinner = None @@ -978,7 +1114,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe spinner.start() _ce_result = None try: - function_result = agent.context_compressor.handle_tool_call(function_name, function_args, messages=messages) + def _execute(next_args: dict) -> Any: + return agent.context_compressor.handle_tool_call(function_name, next_args, messages=messages) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, + ) _ce_result = function_result except Exception as tool_error: function_result = json.dumps({"error": f"Context engine tool '{function_name}' failed: {tool_error}"}) @@ -1002,7 +1147,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe spinner.start() _mem_result = None try: - function_result = agent._memory_manager.handle_tool_call(function_name, function_args) + def _execute(next_args: dict) -> Any: + return agent._memory_manager.handle_tool_call(function_name, next_args) + function_result, function_args = _run_agent_tool_execution_middleware( + agent, + function_name=function_name, + function_args=function_args, + effective_task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", "") or "", + execute=_execute, + ) _mem_result = function_result except Exception as tool_error: function_result = json.dumps({"error": f"Memory tool '{function_name}' failed: {tool_error}"}) @@ -1032,8 +1186,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe api_request_id=getattr(agent, "_current_api_request_id", "") or "", enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None, skip_pre_tool_call_hook=True, + skip_tool_request_middleware=True, enabled_toolsets=getattr(agent, "enabled_toolsets", None), disabled_toolsets=getattr(agent, "disabled_toolsets", None), + tool_request_middleware_trace=list(middleware_trace), ) _spinner_result = function_result except KeyboardInterrupt: @@ -1044,6 +1200,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe effective_task_id=effective_task_id, tool_call_id=getattr(tool_call, "id", "") or "", start_time=tool_start_time, + middleware_trace=list(middleware_trace), ) _spinner_result = function_result try: @@ -1071,8 +1228,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe api_request_id=getattr(agent, "_current_api_request_id", "") or "", enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None, skip_pre_tool_call_hook=True, + skip_tool_request_middleware=True, enabled_toolsets=getattr(agent, "enabled_toolsets", None), disabled_toolsets=getattr(agent, "disabled_toolsets", None), + tool_request_middleware_trace=list(middleware_trace), ) except KeyboardInterrupt: _emit_cancelled_terminal_post_tool_call( @@ -1082,6 +1241,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe effective_task_id=effective_task_id, tool_call_id=getattr(tool_call, "id", "") or "", start_time=tool_start_time, + middleware_trace=list(middleware_trace), ) try: agent.interrupt("keyboard interrupt") @@ -1126,6 +1286,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe effective_task_id=effective_task_id, tool_call_id=getattr(tool_call, "id", "") or "", duration_ms=int(tool_duration * 1000), + middleware_trace=list(middleware_trace), ) if not _execution_blocked: function_result = agent._append_guardrail_observation( diff --git a/docs/middleware/README.md b/docs/middleware/README.md new file mode 100644 index 00000000000..b385b87eb29 --- /dev/null +++ b/docs/middleware/README.md @@ -0,0 +1,251 @@ +# Hermes Middleware + +Hermes middleware is the behavior-changing companion to observer hooks. +Observer hooks report what happened. Middleware can change what happens by +rewriting a request before execution or by wrapping the execution callback +itself. + +This contract is intentionally backend-neutral. A plugin can use it for local +policy, request shaping, tracing, adaptive routing, cache control, sandbox +selection, or handoff to runtimes such as NeMo Relay without changing Hermes' +planner, model provider adapters, tool registry, memory, or CLI UX. + +With middleware enabled, plugins can: + +- Rewrite LLM provider request kwargs before Hermes calls the provider. +- Rewrite tool arguments before guardrails, approval checks, hooks, and tool + execution see them. +- Wrap the actual LLM execution callback while preserving Hermes retry, + streaming, interrupt, and hook behavior. +- Wrap the actual tool execution callback while preserving Hermes guardrails, + approval, post-tool hooks, and tool-result transformation. + +## Contract + +Plugins register middleware from `register(ctx)`: + +```python +def register(ctx): + ctx.register_middleware("llm_request", on_llm_request) + ctx.register_middleware("llm_execution", on_llm_execution) + ctx.register_middleware("tool_request", on_tool_request) + ctx.register_middleware("tool_execution", on_tool_execution) +``` + +Every middleware callback receives: + +- `telemetry_schema_version`: currently `hermes.observer.v1` +- `middleware_schema_version`: currently `hermes.middleware.v1` +- Runtime context such as `session_id`, `task_id`, `turn_id`, + `api_request_id`, `provider`, `model`, `api_mode`, `tool_name`, and + `tool_call_id` when applicable. + +Supported middleware kinds: + +| Kind | Payload | Return shape | Purpose | +| --- | --- | --- | --- | +| `llm_request` | `request`, `original_request` | `{"request": {...}}` | Replace effective provider kwargs before provider execution. | +| `tool_request` | `tool_name`, `args`, `original_args` | `{"args": {...}}` | Replace effective tool args before hooks, guardrails, approvals, and execution. | +| `llm_execution` | `request`, `original_request`, `next_call` | Any provider response | Wrap or replace the actual provider call. | +| `tool_execution` | `tool_name`, `args`, `original_args`, `next_call` | Any tool result | Wrap or replace the actual tool call. | + +Request middleware can return optional trace fields: + +```python +return { + "request": updated_request, + "source": "my-plugin", + "reason": "selected fallback model", +} +``` + +Hermes stores those trace entries in later observer hook payloads as +`middleware_trace`. + +Execution middleware receives a `next_call` callback. Call it to continue the +chain: + +```python +def on_tool_execution(**kwargs): + result = kwargs["next_call"](kwargs["args"]) + return result +``` + +If multiple plugins register the same execution middleware kind, Hermes runs +them as a nested chain in registration order. Middleware failures are fail-open: +Hermes logs a warning and continues with the next middleware or the base +runtime path. + +## Execution Order + +### LLM Calls + +For each provider request, Hermes applies middleware in this order: + +1. Build provider kwargs from the current conversation. +2. Apply `llm_request` middleware. +3. Emit `pre_api_request` observer hooks with the effective request. +4. Run provider execution through `llm_execution` middleware. +5. Emit `post_api_request` or `api_request_error` observer hooks. + +Request middleware sees the full provider kwargs, including `messages` or +Responses API `input`, model settings, tool definitions, stream options, and +provider-specific options. Execution middleware receives the same effective +request plus `next_call`. + +### Tool Calls + +For each tool call, Hermes applies middleware in this order: + +1. Parse and coerce model-provided tool arguments. +2. Apply `tool_request` middleware. +3. Run the normal Hermes pre-execution path against the effective arguments: + tool availability checks, observer block directives, guardrails, and + approval checks. +4. Run tool execution through `tool_execution` middleware. +5. Emit `post_tool_call` observer hooks. +6. Apply `transform_tool_result` hooks before the result is appended back into + conversation context. + +Tool request middleware runs before approval checks. Use it carefully: a +rewritten path, command, or URL is the value downstream policy will evaluate. + +## Enablement + +Middleware only runs for enabled plugins. For a bundled plugin: + +```bash +hermes plugins enable +``` + +For isolated local testing, use one `HERMES_HOME` for plugin enablement and the +agent run: + +```bash +export HERMES_HOME=/tmp/hermes-middleware-test +mkdir -p "$HERMES_HOME" +hermes plugins enable +hermes chat --query 'Reply exactly ok' +``` + +For source checkouts, prefer the source command so the runtime sees plugins and +middleware from the working tree: + +```bash +uv sync +uv run hermes plugins enable +uv run hermes chat --query 'Reply exactly ok' +``` + +## Generic Plugin Examples + +The examples below are intentionally small. They show the middleware contract +shape without depending on NeMo Relay. + +### LLM Request Middleware + +This plugin tags provider requests and records a middleware trace entry: + +```python +def register(ctx): + ctx.register_middleware("llm_request", tag_llm_request) + + +def tag_llm_request(**kwargs): + request = dict(kwargs["request"]) + extra_body = dict(request.get("extra_body") or {}) + extra_body.setdefault("metadata", {})["hermes_middleware_demo"] = True + request["extra_body"] = extra_body + return { + "request": request, + "source": "middleware-demo", + "reason": "tagged provider request", + } +``` + +The effective request is passed to `pre_api_request`, provider execution, and +`post_api_request`. + +### Tool Request Middleware + +This plugin constrains `terminal` calls to a known working directory: + +```python +def register(ctx): + ctx.register_middleware("tool_request", normalize_terminal_workdir) + + +def normalize_terminal_workdir(**kwargs): + if kwargs.get("tool_name") != "terminal": + return None + args = dict(kwargs["args"]) + args.setdefault("workdir", "/tmp/hermes-middleware-demo") + return { + "args": args, + "source": "middleware-demo", + "reason": "defaulted terminal workdir", + } +``` + +Because this runs before hooks and approvals, downstream telemetry and policy +observe the rewritten `workdir`. + +### LLM Execution Middleware + +This plugin wraps the provider call and preserves the raw provider response: + +```python +import time + + +def register(ctx): + ctx.register_middleware("llm_execution", time_llm_execution) + + +def time_llm_execution(**kwargs): + started = time.monotonic() + response = kwargs["next_call"](kwargs["request"]) + elapsed_ms = int((time.monotonic() - started) * 1000) + print(f"llm_execution elapsed_ms={elapsed_ms}") + return response +``` + +Return the same response shape Hermes expects from the provider adapter. Do not +wrap the response in a plugin-specific envelope unless the rest of the runtime +expects that envelope. + +### Tool Execution Middleware + +This plugin wraps tool execution while preserving the tool result: + +```python +def register(ctx): + ctx.register_middleware("tool_execution", annotate_tool_execution) + + +def annotate_tool_execution(**kwargs): + result = kwargs["next_call"](kwargs["args"]) + # Metrics, logging, or external routing can happen here. + return result +``` + +Execution middleware may call `next_call(modified_args)` to pass a changed +payload to later middleware and the base tool dispatcher. + +Plugin-specific examples should live with the plugin that owns the behavior. +For NeMo Relay adaptive execution middleware, see +[`plugins/observability/nemo_relay/README.md`](../../plugins/observability/nemo_relay/README.md). + +## Safety Notes + +- Middleware should be deterministic for the same input unless it is explicitly + routing to a dynamic external system. +- Request middleware should return complete replacement payloads, not partial + patches. +- Execution middleware should call `next_call(...)` exactly once unless it is + intentionally short-circuiting execution. +- Tool request middleware runs before approvals. If it mutates file paths, + commands, URLs, or arguments, the mutated values are what guardrails and + approvals evaluate. +- Observer hooks remain the right place for read-only telemetry. Use middleware + only when a plugin needs to alter or wrap behavior. diff --git a/hermes_cli/middleware.py b/hermes_cli/middleware.py new file mode 100644 index 00000000000..938bffcf172 --- /dev/null +++ b/hermes_cli/middleware.py @@ -0,0 +1,280 @@ +"""Hermes middleware contract helpers. + +Observer hooks report what happened. Middleware can change what happens by +rewriting a request or wrapping the actual execution callback. Keep the small +contract helpers here so agent-loop call sites and plugins share one vocabulary. +""" + +from __future__ import annotations + +import logging +from copy import deepcopy +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional + +logger = logging.getLogger(__name__) + +OBSERVER_SCHEMA_VERSION = "hermes.observer.v1" +MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1" + +TOOL_REQUEST_MIDDLEWARE = "tool_request" +TOOL_EXECUTION_MIDDLEWARE = "tool_execution" +LLM_REQUEST_MIDDLEWARE = "llm_request" +LLM_EXECUTION_MIDDLEWARE = "llm_execution" + +# Back-compat aliases for older PoC branches that used API terminology. +API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE +API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE + +VALID_MIDDLEWARE: set[str] = { + TOOL_REQUEST_MIDDLEWARE, + TOOL_EXECUTION_MIDDLEWARE, + LLM_REQUEST_MIDDLEWARE, + LLM_EXECUTION_MIDDLEWARE, +} + + +@dataclass +class RequestMiddlewareResult: + """Result of applying request middleware to a mutable payload.""" + + payload: Any + original_payload: Any + changed: bool = False + trace: List[Dict[str, Any]] = field(default_factory=list) + + +def observer_payload(**kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION) + return kwargs + + +def middleware_payload(**kwargs: Any) -> Dict[str, Any]: + kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION) + kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION) + return kwargs + + +def apply_llm_request_middleware( + request: Dict[str, Any], + **context: Any, +) -> RequestMiddlewareResult: + """Apply registered LLM request middleware. + + Middleware may return ``{"request": {...}}`` to replace the effective + provider kwargs before Hermes sends them. + """ + if not _has_middleware(LLM_REQUEST_MIDDLEWARE): + return RequestMiddlewareResult( + payload=request, + original_payload=request, + changed=False, + trace=[], + ) + + original_request = deepcopy(request) + current_request = deepcopy(original_request) + trace: List[Dict[str, Any]] = [] + + for result in _invoke_middleware( + LLM_REQUEST_MIDDLEWARE, + request=current_request, + original_request=original_request, + **context, + ): + if not isinstance(result, dict): + continue + next_request = result.get("request") + if not isinstance(next_request, dict): + continue + current_request = deepcopy(next_request) + trace.append(_trace_entry(result)) + + return RequestMiddlewareResult( + payload=current_request, + original_payload=original_request, + changed=bool(trace), + trace=trace, + ) + + +def apply_tool_request_middleware( + tool_name: str, + args: Dict[str, Any], + **context: Any, +) -> RequestMiddlewareResult: + """Apply registered tool request middleware. + + Middleware may return ``{"args": {...}}`` to replace the effective tool + arguments before hooks, guardrails, approvals, and execution see them. + """ + if not _has_middleware(TOOL_REQUEST_MIDDLEWARE): + return RequestMiddlewareResult( + payload=args, + original_payload=args, + changed=False, + trace=[], + ) + + original_args = deepcopy(args) + current_args = deepcopy(original_args) + trace: List[Dict[str, Any]] = [] + + for result in _invoke_middleware( + TOOL_REQUEST_MIDDLEWARE, + tool_name=tool_name, + args=current_args, + original_args=original_args, + **context, + ): + if not isinstance(result, dict): + continue + next_args = result.get("args") + if not isinstance(next_args, dict): + continue + current_args = deepcopy(next_args) + trace.append(_trace_entry(result)) + + return RequestMiddlewareResult( + payload=current_args, + original_payload=original_args, + changed=bool(trace), + trace=trace, + ) + + +def apply_api_request_middleware( + request: Dict[str, Any], + **context: Any, +) -> RequestMiddlewareResult: + """Compatibility wrapper for older ``api_request`` naming.""" + return apply_llm_request_middleware(request, **context) + + +def run_llm_execution_middleware( + request: Dict[str, Any], + next_call: Callable[[Dict[str, Any]], Any], + **context: Any, +) -> Any: + """Run provider execution through registered LLM execution middleware.""" + callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE) + if not callbacks: + return next_call(request) + return _run_execution_chain( + LLM_EXECUTION_MIDDLEWARE, + callbacks, + next_call, + request=request, + original_request=context.pop("original_request", request), + **context, + ) + + +def run_tool_execution_middleware( + tool_name: str, + args: Dict[str, Any], + next_call: Callable[[Dict[str, Any]], Any], + **context: Any, +) -> Any: + """Run tool execution through registered tool execution middleware.""" + callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE) + if not callbacks: + return next_call(args) + return _run_execution_chain( + TOOL_EXECUTION_MIDDLEWARE, + callbacks, + next_call, + tool_name=tool_name, + args=args, + original_args=context.pop("original_args", args), + **context, + ) + + +def run_api_execution_middleware( + request: Dict[str, Any], + next_call: Callable[[Dict[str, Any]], Any], + **context: Any, +) -> Any: + """Compatibility wrapper for older ``api_execution`` naming.""" + return run_llm_execution_middleware(request, next_call, **context) + + +def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]: + from hermes_cli.plugins import invoke_middleware + + return invoke_middleware(kind, **middleware_payload(**kwargs)) + + +def _has_middleware(kind: str) -> bool: + from hermes_cli.plugins import has_middleware + + return has_middleware(kind) + + +def _get_middleware_callbacks(kind: str) -> List[Callable]: + from hermes_cli.plugins import get_plugin_manager + + return list(get_plugin_manager()._middleware.get(kind, [])) + + +def _run_execution_chain( + kind: str, + callbacks: List[Callable], + terminal_call: Callable[[Any], Any], + **kwargs: Any, +) -> Any: + payload_key = "request" if "request" in kwargs else "args" + + class _DownstreamExecutionError(Exception): + def __init__(self, original: BaseException) -> None: + super().__init__(str(original)) + self.original = original + + def call_at(index: int, payload: Any) -> Any: + if index >= len(callbacks): + return terminal_call(payload) + + callback = callbacks[index] + next_called = False + next_result: Any = None + + def next_call(next_payload: Any = None) -> Any: + nonlocal next_called, next_result + next_called = True + try: + next_result = call_at(index + 1, payload if next_payload is None else next_payload) + return next_result + except BaseException as exc: + raise _DownstreamExecutionError(exc) from exc + + call_kwargs = middleware_payload(**kwargs) + call_kwargs[payload_key] = payload + call_kwargs["next_call"] = next_call + try: + return callback(**call_kwargs) + except _DownstreamExecutionError as exc: + raise exc.original + except Exception as exc: + logger.warning( + "Middleware '%s' callback %s raised: %s", + kind, + getattr(callback, "__name__", repr(callback)), + exc, + ) + if next_called: + return next_result + return call_at(index + 1, payload) + + return call_at(0, kwargs[payload_key]) + + +def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]: + entry: Dict[str, Any] = {} + for key in ("source", "reason", "name"): + value = result.get(key) + if isinstance(value, str) and value: + entry[key] = value + if not entry: + entry["source"] = "plugin" + return entry diff --git a/hermes_cli/plugins.py b/hermes_cli/plugins.py index fd449fc27a4..d5cb7e8fe01 100644 --- a/hermes_cli/plugins.py +++ b/hermes_cli/plugins.py @@ -49,7 +49,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union from hermes_constants import get_hermes_home from utils import env_var_enabled from hermes_cli.config import cfg_get -OBSERVER_SCHEMA_VERSION = "hermes.observer.v1" +from hermes_cli.middleware import OBSERVER_SCHEMA_VERSION, VALID_MIDDLEWARE def get_bundled_plugins_dir() -> Path: @@ -277,6 +277,7 @@ class LoadedPlugin: module: Optional[types.ModuleType] = None tools_registered: List[str] = field(default_factory=list) hooks_registered: List[str] = field(default_factory=list) + middleware_registered: List[str] = field(default_factory=list) commands_registered: List[str] = field(default_factory=list) enabled: bool = False error: Optional[str] = None @@ -952,6 +953,27 @@ class PluginContext: self._manager._hooks.setdefault(hook_name, []).append(callback) logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name) + # -- middleware registration ------------------------------------------- + + def register_middleware(self, kind: str, callback: Callable) -> None: + """Register a behavior-changing middleware callback. + + Middleware is separate from observer hooks: request middleware may + rewrite the effective payload, and execution middleware may wrap the + real callback. Unknown kinds are stored for forward compatibility but + warned so plugin authors can catch typos. + """ + if kind not in VALID_MIDDLEWARE: + logger.warning( + "Plugin '%s' registered unknown middleware '%s' " + "(valid: %s)", + self.manifest.name, + kind, + ", ".join(sorted(VALID_MIDDLEWARE)), + ) + self._manager._middleware.setdefault(kind, []).append(callback) + logger.debug("Plugin %s registered middleware: %s", self.manifest.name, kind) + # -- skill registration ------------------------------------------------- def register_skill( @@ -1010,6 +1032,7 @@ class PluginManager: def __init__(self) -> None: self._plugins: Dict[str, LoadedPlugin] = {} self._hooks: Dict[str, List[Callable]] = {} + self._middleware: Dict[str, List[Callable]] = {} self._plugin_tool_names: Set[str] = set() self._plugin_platform_names: Set[str] = set() self._cli_commands: Dict[str, dict] = {} @@ -1039,6 +1062,7 @@ class PluginManager: if force: self._plugins.clear() self._hooks.clear() + self._middleware.clear() self._plugin_tool_names.clear() self._cli_commands.clear() self._plugin_commands.clear() @@ -1449,15 +1473,28 @@ class PluginManager: for h in p.hooks_registered } ) + loaded.middleware_registered = list( + { + kind + for kind, cbs in self._middleware.items() + if cbs + } + - { + kind + for name, p in self._plugins.items() + for kind in p.middleware_registered + } + ) loaded.commands_registered = [ c for c in self._plugin_commands if self._plugin_commands[c].get("plugin") == manifest.name ] loaded.enabled = True logger.debug( - " registered: %d tool(s), %d hook(s), %d slash command(s), %d CLI command(s)", + " registered: %d tool(s), %d hook(s), %d middleware, %d slash command(s), %d CLI command(s)", len(loaded.tools_registered), len(loaded.hooks_registered), + len(loaded.middleware_registered), len(loaded.commands_registered), sum( 1 for c in self._cli_commands @@ -1575,6 +1612,33 @@ class PluginManager: """Return True when at least one callback is registered for a hook.""" return bool(self._hooks.get(hook_name)) + def has_middleware(self, kind: str) -> bool: + """Return True when at least one callback is registered for middleware.""" + return bool(self._middleware.get(kind)) + + def invoke_middleware(self, kind: str, **kwargs: Any) -> List[Any]: + """Call registered middleware callbacks for *kind*. + + Each callback is isolated so one plugin cannot break the base runtime + path. Middleware that wants to change behavior must return the shape + documented by the caller-specific contract. + """ + callbacks = self._middleware.get(kind, []) + results: List[Any] = [] + for cb in callbacks: + try: + ret = cb(**kwargs) + if ret is not None: + results.append(ret) + except Exception as exc: + logger.warning( + "Middleware '%s' callback %s raised: %s", + kind, + getattr(cb, "__name__", repr(cb)), + exc, + ) + return results + # ----------------------------------------------------------------------- # Introspection # ----------------------------------------------------------------------- @@ -1594,6 +1658,7 @@ class PluginManager: "enabled": loaded.enabled, "tools": len(loaded.tools_registered), "hooks": len(loaded.hooks_registered), + "middleware": len(loaded.middleware_registered), "commands": len(loaded.commands_registered), "error": loaded.error, } @@ -1655,6 +1720,23 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]: return get_plugin_manager().invoke_hook(hook_name, **kwargs) +def invoke_middleware(kind: str, **kwargs: Any) -> List[Any]: + """Invoke registered middleware callbacks. + + Returns a list of non-``None`` return values from middleware callbacks. + """ + return get_plugin_manager().invoke_middleware(kind, **kwargs) + + +def has_middleware(kind: str) -> bool: + """Return True when middleware callbacks are registered for ``kind``.""" + manager = get_plugin_manager() + method = getattr(manager, "has_middleware", None) + if callable(method): + return bool(method(kind)) + return bool(getattr(manager, "_middleware", {}).get(kind)) + + def has_hook(hook_name: str) -> bool: """Return True when a hook has registered callbacks.""" return get_plugin_manager().has_hook(hook_name) @@ -1683,6 +1765,7 @@ def get_pre_tool_call_block_message( tool_call_id: str = "", turn_id: str = "", api_request_id: str = "", + middleware_trace: Optional[List[Dict[str, Any]]] = None, ) -> Optional[str]: """Check ``pre_tool_call`` hooks for a blocking directive. @@ -1709,6 +1792,7 @@ def get_pre_tool_call_block_message( tool_call_id=tool_call_id, turn_id=turn_id, api_request_id=api_request_id, + middleware_trace=list(middleware_trace or []), ) for result in hook_results: diff --git a/model_tools.py b/model_tools.py index c3a9c98c60c..9d04ada2d75 100644 --- a/model_tools.py +++ b/model_tools.py @@ -823,6 +823,7 @@ def _emit_post_tool_call_hook( status: Optional[str] = None, error_type: Optional[str] = None, error_message: Optional[str] = None, + middleware_trace: Optional[List[Dict[str, Any]]] = None, ) -> None: """Emit the ``post_tool_call`` observer hook. @@ -853,6 +854,7 @@ def _emit_post_tool_call_hook( status=status, error_type=error_type, error_message=error_message, + middleware_trace=list(middleware_trace or []), ) except Exception as _hook_err: logger.debug("post_tool_call hook error: %s", _hook_err) @@ -869,6 +871,8 @@ def handle_function_call( user_task: Optional[str] = None, enabled_tools: Optional[List[str]] = None, skip_pre_tool_call_hook: bool = False, + skip_tool_request_middleware: bool = False, + tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None, enabled_toolsets: Optional[List[str]] = None, disabled_toolsets: Optional[List[str]] = None, ) -> str: @@ -900,6 +904,7 @@ def handle_function_call( function_args = coerce_tool_args(function_name, function_args) if not isinstance(function_args, dict): function_args = {} + _tool_middleware_trace = list(tool_request_middleware_trace or []) # ── Tool Search bridge dispatch ────────────────────────────────── # tool_search and tool_describe are pure catalog reads — handle them @@ -970,10 +975,32 @@ def handle_function_call( user_task=user_task, enabled_tools=enabled_tools, skip_pre_tool_call_hook=skip_pre_tool_call_hook, + skip_tool_request_middleware=skip_tool_request_middleware, + tool_request_middleware_trace=list(_tool_middleware_trace), enabled_toolsets=enabled_toolsets, disabled_toolsets=disabled_toolsets, ) + _tool_original_args = dict(function_args) + if not skip_tool_request_middleware: + try: + from hermes_cli.middleware import apply_tool_request_middleware + + _tool_request_mw = apply_tool_request_middleware( + function_name, + function_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + turn_id=turn_id or "", + api_request_id=api_request_id or "", + ) + function_args = _tool_request_mw.payload + _tool_original_args = _tool_request_mw.original_payload + _tool_middleware_trace = _tool_request_mw.trace + except Exception as _mw_err: + logger.debug("tool_request middleware error: %s", _mw_err) + try: if function_name in _AGENT_LOOP_TOOLS: return json.dumps({"error": f"{function_name} must be handled by the agent loop"}) @@ -1000,6 +1027,7 @@ def handle_function_call( tool_call_id=tool_call_id or "", turn_id=turn_id or "", api_request_id=api_request_id or "", + middleware_trace=list(_tool_middleware_trace), ) except Exception as _hook_err: logger.debug("pre_tool_call hook error: %s", _hook_err) @@ -1018,6 +1046,7 @@ def handle_function_call( status="blocked", error_type="plugin_block", error_message=block_message, + middleware_trace=list(_tool_middleware_trace), ) return result @@ -1082,7 +1111,19 @@ def handle_function_call( task_id=task_id, user_task=user_task, ) - result = _dispatch(function_args) + from hermes_cli.middleware import run_tool_execution_middleware + + result = run_tool_execution_middleware( + function_name, + function_args, + _dispatch, + original_args=_tool_original_args, + task_id=task_id or "", + session_id=session_id or "", + tool_call_id=tool_call_id or "", + turn_id=turn_id or "", + api_request_id=api_request_id or "", + ) finally: if _approval_tokens is not None and reset_current_observability_context is not None: try: @@ -1101,6 +1142,7 @@ def handle_function_call( turn_id=turn_id, api_request_id=api_request_id, duration_ms=duration_ms, + middleware_trace=list(_tool_middleware_trace), ) # Generic tool-result canonicalization seam: plugins receive the diff --git a/plugins/observability/nemo_relay/README.md b/plugins/observability/nemo_relay/README.md index f1a2c3b7dc3..b5376696213 100644 --- a/plugins/observability/nemo_relay/README.md +++ b/plugins/observability/nemo_relay/README.md @@ -165,6 +165,28 @@ When `HERMES_NEMO_RELAY_PLUGINS_TOML` is set and initializes successfully, NeMo Relay owns exporter lifecycle through that config. The direct `HERMES_NEMO_RELAY_ATOF_*` fallback setup is skipped. +To enable NeMo Relay managed execution intercepts for provider and tool calls, +include an adaptive component in the same `plugins.toml`: + +```toml +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +``` + +When the adaptive component is enabled and the installed NeMo Relay runtime +exposes `llm.execute(...)` / `tools.execute(...)`, Hermes routes LLM and tool +execution through those middleware boundaries. The observer hooks still emit +session, turn, approval, and subagent marks; the plugin skips its manual +`llm.call` and `tools.call` spans for executions that are already managed by +NeMo Relay. + +For the full generic Hermes middleware contract, see +[`docs/middleware/README.md`](../../../docs/middleware/README.md). + ## Canonical Local Examples The examples below use the official `nemo-relay==0.3` distribution and a local @@ -366,3 +388,166 @@ subagent IDs, role/status fields when present, and derived `parent_trajectory_id` / `child_trajectory_id` values. This keeps the ATOF stream lossless for later ATIF conversion that can compact subagents into separate trajectories. + +## Adaptive Middleware Example + +The `observability/nemo_relay` plugin uses Hermes execution middleware to hand +LLM and tool calls to NeMo Relay managed execution when an adaptive component is +enabled. + +Minimal `plugins.toml`: + +```toml +version = 1 + +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +``` + +Enable it for Hermes: + +```bash +export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/plugins.toml +``` + +When the adaptive component is enabled and the installed NeMo Relay runtime +exposes `llm.execute(...)` and `tools.execute(...)`, Hermes routes execution +through these boundaries: + +```text +Hermes provider call + -> llm_execution middleware + -> nemo_relay.llm.execute(...) + -> Hermes provider adapter next_call(...) + +Hermes tool call + -> tool_execution middleware + -> nemo_relay.tools.execute(...) + -> Hermes tool dispatcher next_call(...) +``` + +The plugin still emits observer marks for sessions, turns, approvals, and +subagents. When adaptive managed execution is active, it skips manual +`llm.call` and `tools.call` observer spans to avoid duplicate LLM/tool events +for the same execution. + +### Local Adaptive E2E + +This example enables both NeMo Relay observability export and adaptive execution +middleware for a local Hermes run. + +```bash +pip install "nemo-relay==0.3" + +export HERMES_HOME=/tmp/hermes-middleware-test/hermes-home +mkdir -p "$HERMES_HOME" /tmp/hermes-middleware-test/nemo-relay + +cat > "$HERMES_HOME/config.yaml" <<'YAML' +model: + provider: custom + default: qwen3.6:35b + base_url: http://127.0.0.1:11434/v1 + api_key: ollama +plugins: + enabled: + - observability/nemo_relay +YAML + +cat > /tmp/hermes-middleware-test/nemo-relay/plugins.toml <<'TOML' +version = 1 + +[[components]] +kind = "observability" +enabled = true + +[components.config] +version = 1 + +[components.config.atof] +enabled = true +output_directory = "/tmp/hermes-middleware-test/atof" +filename = "middleware-events.jsonl" +mode = "overwrite" + +[components.config.atif] +enabled = true +output_directory = "/tmp/hermes-middleware-test/atif" +filename_template = "middleware-trajectory-{session_id}.json" +agent_name = "Hermes Middleware E2E" +agent_version = "local" + +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +TOML + +export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/nemo-relay/plugins.toml + +hermes chat \ + --query 'Use the terminal tool exactly once to run printf middleware_execution_ok. Then reply with exactly the command output.' \ + --provider custom \ + --model qwen3.6:35b \ + --toolsets terminal \ + --max-turns 4 \ + --quiet \ + --accept-hooks +``` + +Expected CLI output: + +```text +session_id: middleware-demo-session +middleware_execution_ok +``` + +Expected ATOF shape: + +```jsonl +{"kind":"scope","category":"llm","name":"custom","scope_category":"start","metadata":{"session_id":"middleware-demo-session"},"data":{"mode":"route"}} +{"kind":"scope","category":"tool","name":"terminal","scope_category":"start","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal"},"data":{"mode":"route"}} +{"kind":"scope","category":"tool","name":"terminal","scope_category":"end","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal","status":"ok"},"data":"{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"} +``` + +Expected ATIF shape: + +```json +{ + "schema_version": "ATIF-v1.7", + "session_id": "middleware-demo-session", + "agent": { + "name": "Hermes Middleware E2E", + "version": "local", + "model_name": "qwen3.6:35b" + }, + "steps": [ + { + "source": "agent", + "tool_calls": [ + { + "function_name": "terminal", + "arguments": {"command": "printf middleware_execution_ok"} + } + ], + "observation": { + "results": [ + { + "source_call_id": "call_terminal", + "content": "{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}" + } + ] + } + }, + { + "source": "agent", + "message": "middleware_execution_ok" + } + ] +} +``` diff --git a/plugins/observability/nemo_relay/__init__.py b/plugins/observability/nemo_relay/__init__.py index 25078a21c6d..cd1587fdab0 100644 --- a/plugins/observability/nemo_relay/__init__.py +++ b/plugins/observability/nemo_relay/__init__.py @@ -42,6 +42,9 @@ class _SubagentParent: @dataclass class _Settings: plugins_toml_path: str = "" + plugins_config: dict[str, Any] | None = None + adaptive_enabled: bool = False + adaptive_mode: str = "observe" atof_enabled: bool = False atof_output_directory: str = "" atof_filename: str = "hermes-atof.jsonl" @@ -67,17 +70,15 @@ class _Runtime: self._configure_atof() def _configure_plugins_toml(self) -> bool: - if not self.settings.plugins_toml_path: + if not self.settings.plugins_config: return False plugin_mod = getattr(self.nemo_relay, "plugin", None) initialize = getattr(plugin_mod, "initialize", None) if not callable(initialize): return False - config_path = Path(self.settings.plugins_toml_path) try: - config = tomllib.loads(config_path.read_text(encoding="utf-8")) - self._ensure_plugin_config_output_dirs(config) - result = initialize(config) + self._ensure_plugin_config_output_dirs(self.settings.plugins_config) + result = initialize(self.settings.plugins_config) if inspect.isawaitable(result): asyncio.run(result) return True @@ -221,6 +222,100 @@ class _Runtime: self.subagent_parents.pop(child_session_id, None) self.mark("hermes.subagent.stop", kwargs) + def managed_llm_enabled(self) -> bool: + return ( + self.settings.adaptive_enabled + and callable(getattr(getattr(self.nemo_relay, "llm", None), "execute", None)) + and callable(getattr(self.nemo_relay, "LLMRequest", None)) + ) + + def managed_tool_enabled(self) -> bool: + return ( + self.settings.adaptive_enabled + and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None)) + ) + + def execute_llm(self, kwargs: dict[str, Any]) -> Any: + state = self.ensure_session(kwargs) + request_body = _jsonable(kwargs.get("request") or {}) + request = self.nemo_relay.LLMRequest({}, request_body) + next_call = kwargs.get("next_call") + if not callable(next_call): + return request_body + + raw_response: dict[str, Any] = {"set": False, "value": None} + + def _impl(next_request: Any) -> Any: + next_body = getattr(next_request, "content", next_request) + raw = next_call(next_body if isinstance(next_body, dict) else request_body) + raw_response["set"] = True + raw_response["value"] = raw + return _llm_response_payload(raw) + + async def _managed_execute() -> Any: + result = self.nemo_relay.llm.execute( + str(kwargs.get("provider") or "llm"), + request, + _impl, + handle=state.handle, + data=_jsonable( + { + "turn_id": kwargs.get("turn_id"), + "api_request_id": kwargs.get("api_request_id"), + "api_call_count": kwargs.get("api_call_count"), + "mode": self.settings.adaptive_mode, + } + ), + metadata=_metadata(kwargs), + model_name=str(kwargs.get("model") or ""), + ) + if inspect.isawaitable(result): + return await result + return result + + managed_result = _resolve_awaitable(_managed_execute()) + return raw_response["value"] if raw_response["set"] else managed_result + + def execute_tool(self, kwargs: dict[str, Any]) -> Any: + state = self.ensure_session(kwargs) + tool_name = str(kwargs.get("tool_name") or "tool") + args = _jsonable(kwargs.get("args") or {}) + next_call = kwargs.get("next_call") + if not callable(next_call): + return args + + raw_response: dict[str, Any] = {"set": False, "value": None} + + def _impl(next_args: Any) -> Any: + effective_args = next_args if isinstance(next_args, dict) else args + raw = next_call(effective_args) + raw_response["set"] = True + raw_response["value"] = raw + return _jsonable(raw) + + async def _managed_execute() -> Any: + result = self.nemo_relay.tools.execute( + tool_name, + args, + _impl, + handle=state.handle, + data=_jsonable( + { + "turn_id": kwargs.get("turn_id"), + "api_request_id": kwargs.get("api_request_id"), + "tool_call_id": kwargs.get("tool_call_id"), + "mode": self.settings.adaptive_mode, + } + ), + metadata=_metadata(kwargs), + ) + if inspect.isawaitable(result): + return await result + return result + + managed_result = _resolve_awaitable(_managed_execute()) + return raw_response["value"] if raw_response["set"] else managed_result + def register(ctx) -> None: ctx.register_hook("on_session_start", on_session_start) @@ -238,6 +333,8 @@ def register(ctx) -> None: ctx.register_hook("post_approval_response", on_post_approval_response) ctx.register_hook("subagent_start", on_subagent_start) ctx.register_hook("subagent_stop", on_subagent_stop) + ctx.register_middleware("llm_execution", on_llm_execution_middleware) + ctx.register_middleware("tool_execution", on_tool_execution_middleware) def on_session_start(**kwargs: Any) -> None: @@ -280,6 +377,8 @@ def on_pre_api_request(**kwargs: Any) -> None: runtime = _get_runtime() if runtime is None: return + if runtime.managed_llm_enabled(): + return def _record() -> None: state = runtime.ensure_session(kwargs) @@ -303,6 +402,8 @@ def on_post_api_request(**kwargs: Any) -> None: runtime = _get_runtime() if runtime is None: return + if runtime.managed_llm_enabled(): + return def _record() -> None: state = runtime.ensure_session(kwargs) @@ -324,6 +425,8 @@ def on_api_request_error(**kwargs: Any) -> None: runtime = _get_runtime() if runtime is None: return + if runtime.managed_llm_enabled(): + return def _record() -> None: state = runtime.ensure_session(kwargs) @@ -345,6 +448,8 @@ def on_pre_tool_call(**kwargs: Any) -> None: runtime = _get_runtime() if runtime is None: return + if runtime.managed_tool_enabled(): + return def _record() -> None: state = runtime.ensure_session(kwargs) @@ -365,6 +470,8 @@ def on_post_tool_call(**kwargs: Any) -> None: runtime = _get_runtime() if runtime is None: return + if runtime.managed_tool_enabled(): + return def _record() -> None: state = runtime.ensure_session(kwargs) @@ -406,6 +513,28 @@ def on_subagent_stop(**kwargs: Any) -> None: _safe(lambda: runtime.mark_subagent_stop(kwargs)) +def on_llm_execution_middleware(**kwargs: Any) -> Any: + runtime = _get_runtime() + next_call = kwargs.get("next_call") + request = kwargs.get("request") or {} + if runtime is not None and runtime.managed_llm_enabled(): + return runtime.execute_llm(kwargs) + if callable(next_call): + return next_call(request) + return request + + +def on_tool_execution_middleware(**kwargs: Any) -> Any: + runtime = _get_runtime() + next_call = kwargs.get("next_call") + args = kwargs.get("args") or {} + if runtime is not None and runtime.managed_tool_enabled(): + return runtime.execute_tool(kwargs) + if callable(next_call): + return next_call(args) + return args + + def _get_runtime() -> Optional[_Runtime]: global _RUNTIME with _LOCK: @@ -429,8 +558,14 @@ def _get_runtime() -> Optional[_Runtime]: def _load_settings() -> _Settings: + plugins_toml_path = _env("HERMES_NEMO_RELAY_PLUGINS_TOML") + plugins_config = _load_plugins_config(plugins_toml_path) + adaptive_config = _enabled_component_config(plugins_config, "adaptive") return _Settings( - plugins_toml_path=_env("HERMES_NEMO_RELAY_PLUGINS_TOML"), + plugins_toml_path=plugins_toml_path, + plugins_config=plugins_config, + adaptive_enabled=adaptive_config is not None, + adaptive_mode=_adaptive_mode(adaptive_config), atof_enabled=_env_bool("HERMES_NEMO_RELAY_ATOF_ENABLED"), atof_output_directory=_env("HERMES_NEMO_RELAY_ATOF_OUTPUT_DIRECTORY"), atof_filename=_env("HERMES_NEMO_RELAY_ATOF_FILENAME") or "hermes-atof.jsonl", @@ -445,6 +580,44 @@ def _load_settings() -> _Settings: ) +def _load_plugins_config(path: str) -> dict[str, Any] | None: + if not path: + return None + try: + return tomllib.loads(Path(path).read_text(encoding="utf-8")) + except Exception as exc: + logger.debug("NeMo Relay plugins.toml load failed: %s", exc, exc_info=True) + return None + + +def _enabled_component_config( + plugins_config: dict[str, Any] | None, + kind: str, +) -> dict[str, Any] | None: + if not isinstance(plugins_config, dict): + return None + components = plugins_config.get("components") + if not isinstance(components, list): + return None + for component in components: + if not isinstance(component, dict): + continue + if component.get("kind") != kind or not component.get("enabled", True): + continue + config = component.get("config") + return config if isinstance(config, dict) else {} + return None + + +def _adaptive_mode(config: dict[str, Any] | None) -> str: + if not isinstance(config, dict): + return "observe" + mode = config.get("mode") + if isinstance(mode, str) and mode.strip(): + return mode.strip() + return "observe" + + def _env(name: str) -> str: return os.environ.get(name, "").strip() @@ -549,12 +722,78 @@ def _jsonable(value: Any) -> Any: return _jsonable(value.model_dump(mode="json")) except Exception: pass + try: + if hasattr(value, "__dict__"): + return _jsonable(vars(value)) + except Exception: + pass try: return json.loads(json.dumps(value, default=str)) except Exception: return str(value) +def _value(obj: Any, key: str, default: Any = None) -> Any: + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _llm_response_payload(response: Any) -> Any: + """Return the LLM response shape NeMo Relay's ATIF conversion expects.""" + payload = _jsonable(response) + if isinstance(payload, dict) and "assistant_message" in payload: + return payload + + choices = _value(response, "choices") + if choices is None and isinstance(payload, dict): + choices = payload.get("choices") + first_choice = choices[0] if isinstance(choices, list) and choices else None + message = _value(first_choice, "message") + finish_reason = _value(first_choice, "finish_reason") + + assistant_message: dict[str, Any] = {"role": "assistant", "content": ""} + if message is not None: + assistant_message["role"] = _value(message, "role", "assistant") or "assistant" + content = _value(message, "content") + if content is not None: + assistant_message["content"] = _jsonable(content) + tool_calls = _tool_calls_payload(_value(message, "tool_calls")) + if tool_calls: + assistant_message["tool_calls"] = tool_calls + reasoning = _value(message, "reasoning_content") + if reasoning is not None: + assistant_message["reasoning_content"] = _jsonable(reasoning) + elif isinstance(payload, dict): + assistant_message["content"] = payload.get("content") or payload.get("output_text") or "" + + return { + "model": _value(response, "model", payload.get("model") if isinstance(payload, dict) else None), + "assistant_message": assistant_message, + "finish_reason": finish_reason, + "usage": _jsonable(_value(response, "usage", payload.get("usage") if isinstance(payload, dict) else None)), + } + + +def _tool_calls_payload(tool_calls: Any) -> list[dict[str, Any]]: + if not isinstance(tool_calls, list): + return [] + normalized: list[dict[str, Any]] = [] + for call in tool_calls: + function = _value(call, "function") + normalized.append( + { + "id": _value(call, "id"), + "type": _value(call, "type", "function") or "function", + "function": { + "name": _value(function, "name"), + "arguments": _value(function, "arguments"), + }, + } + ) + return normalized + + def _safe(fn) -> None: try: fn() @@ -562,6 +801,35 @@ def _safe(fn) -> None: logger.debug("NeMo Relay hook handling failed: %s", exc, exc_info=True) +def _resolve_awaitable(value: Any) -> Any: + if not inspect.isawaitable(value): + return value + try: + asyncio.get_running_loop() + except RuntimeError: + return asyncio.run(value) + + result: dict[str, Any] = {} + error: dict[str, BaseException] = {} + + def _runner() -> None: + try: + result["value"] = asyncio.run(value) + except BaseException as exc: # pragma: no cover - re-raised below + error["exc"] = exc + + thread = threading.Thread( + target=_runner, + name="hermes-nemo-relay-awaitable", + daemon=True, + ) + thread.start() + thread.join() + if "exc" in error: + raise error["exc"] + return result.get("value") + + def reset_for_tests() -> None: global _RUNTIME with _LOCK: diff --git a/run_agent.py b/run_agent.py index d0d0293439d..c0b3619896c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -4775,10 +4775,22 @@ class AIAgent: def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str, tool_call_id: Optional[str] = None, messages: list = None, - pre_tool_block_checked: bool = False) -> str: + pre_tool_block_checked: bool = False, + skip_tool_request_middleware: bool = False, + tool_request_middleware_trace: Optional[list[dict[str, Any]]] = None) -> str: """Forwarder — see ``agent.agent_runtime_helpers.invoke_tool``.""" from agent.agent_runtime_helpers import invoke_tool - return invoke_tool(self, function_name, function_args, effective_task_id, tool_call_id, messages, pre_tool_block_checked) + return invoke_tool( + self, + function_name, + function_args, + effective_task_id, + tool_call_id, + messages, + pre_tool_block_checked, + skip_tool_request_middleware, + tool_request_middleware_trace, + ) @staticmethod def _wrap_verbose(label: str, text: str, indent: str = " ") -> str: diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index baf7f92fcfb..6bff7b6d87d 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -18,8 +18,15 @@ from hermes_cli.plugins import ( get_plugin_command_handler, get_plugin_commands, get_pre_tool_call_block_message, + has_middleware, resolve_plugin_command_result, ) +from hermes_cli.middleware import ( + VALID_MIDDLEWARE, + apply_llm_request_middleware, + apply_tool_request_middleware, + run_tool_execution_middleware, +) # ── Helpers ──────────────────────────────────────────────────────────────── @@ -96,6 +103,110 @@ class TestPluginDiscovery: assert "hello_plugin" in mgr._plugins assert mgr._plugins["hello_plugin"].enabled + def test_plugin_can_register_and_invoke_middleware(self, tmp_path, monkeypatch): + plugins_dir = tmp_path / "hermes_test" / "plugins" + _make_plugin_dir( + plugins_dir, + "mw_plugin", + register_body=( + "ctx.register_middleware('llm_request', " + "lambda **kw: {'request': {**kw['request'], 'mw': True}})\n" + " ctx.register_middleware('tool_request', " + "lambda **kw: {'args': {**kw['args'], 'mw': True}})" + ), + ) + monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test")) + + mgr = PluginManager() + mgr.discover_and_load() + + assert "llm_request" in VALID_MIDDLEWARE + assert "tool_request" in VALID_MIDDLEWARE + assert set(mgr._plugins["mw_plugin"].middleware_registered) == {"llm_request", "tool_request"} + assert mgr.invoke_middleware("llm_request", request={"messages": []}) == [ + {"request": {"messages": [], "mw": True}} + ] + assert mgr.invoke_middleware("tool_request", args={"path": "README.md"}) == [ + {"args": {"path": "README.md", "mw": True}} + ] + assert mgr.has_middleware("llm_request") is True + + def test_execution_middleware_does_not_retry_downstream_failure(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + return kwargs["next_call"](kwargs["args"]) + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + raise RuntimeError("tool failed") + + with pytest.raises(RuntimeError, match="tool failed"): + run_tool_execution_middleware("terminal", {"command": "false"}, terminal) + + assert calls == [{"command": "false"}] + + def test_middleware_helpers_skip_no_listener_work(self, monkeypatch): + manager = types.SimpleNamespace(_middleware={}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + request = {"messages": []} + args = {"path": "README.md"} + + llm_result = apply_llm_request_middleware(request) + tool_result = apply_tool_request_middleware("read_file", args) + + assert llm_result.payload is request + assert llm_result.original_payload is request + assert llm_result.changed is False + assert llm_result.trace == [] + assert tool_result.payload is args + assert tool_result.original_payload is args + assert tool_result.changed is False + assert tool_result.trace == [] + assert run_tool_execution_middleware("terminal", args, lambda payload: payload) is args + assert has_middleware("tool_request") is False + + def test_request_middleware_changed_tracks_trace_not_deep_equality(self, monkeypatch): + def same_payload_middleware(**kwargs): + return {"args": kwargs["args"], "source": "same-payload"} + + manager = types.SimpleNamespace( + _middleware={"tool_request": [same_payload_middleware]}, + invoke_middleware=lambda kind, **kwargs: [same_payload_middleware(**kwargs)], + ) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + args = {"path": "README.md"} + result = apply_tool_request_middleware("read_file", args) + + assert result.payload == args + assert result.original_payload == args + assert result.changed is True + assert result.trace == [{"source": "same-payload"}] + + def test_execution_middleware_post_next_call_error_does_not_retry(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + result = kwargs["next_call"](kwargs["args"]) + raise RuntimeError(f"post-processing failed after {result}") + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + return "terminal-result" + + result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal) + + assert result == "terminal-result" + assert calls == [{"command": "printf ok"}] + def test_discover_project_plugins(self, tmp_path, monkeypatch): """Plugins in ./.hermes/plugins/ are discovered.""" project_dir = tmp_path / "project" diff --git a/tests/plugins/test_nemo_relay_plugin.py b/tests/plugins/test_nemo_relay_plugin.py index 7c18493fd39..c4970bf2415 100644 --- a/tests/plugins/test_nemo_relay_plugin.py +++ b/tests/plugins/test_nemo_relay_plugin.py @@ -27,8 +27,16 @@ class _FakeNemoRelay: pop=self._scope_pop, event=self._scope_event, ) - self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end) - self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end) + self.llm = SimpleNamespace( + call=self._llm_call, + call_end=self._llm_call_end, + execute=self._llm_execute, + ) + self.tools = SimpleNamespace( + call=self._tool_call, + call_end=self._tool_call_end, + execute=self._tool_execute, + ) self.plugin = SimpleNamespace(initialize=self._plugin_initialize) self.LLMRequest = _FakeLLMRequest self.AtofExporterConfig = _FakeAtofExporterConfig @@ -55,6 +63,12 @@ class _FakeNemoRelay: def _llm_call_end(self, handle, response, **kwargs): self.events.append(("llm.call_end", handle, response, kwargs)) + def _llm_execute(self, name, request, func, **kwargs): + self.events.append(("llm.execute.start", name, request.content, kwargs)) + result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content})) + self.events.append(("llm.execute.end", name, result, kwargs)) + return result + def _tool_call(self, name, args, **kwargs): handle = ("tool", name) self.events.append(("tool.call", name, args, kwargs)) @@ -63,6 +77,12 @@ class _FakeNemoRelay: def _tool_call_end(self, handle, result, **kwargs): self.events.append(("tool.call_end", handle, result, kwargs)) + def _tool_execute(self, name, args, func, **kwargs): + self.events.append(("tool.execute.start", name, args, kwargs)) + result = func({"intercepted": True, **args}) + self.events.append(("tool.execute.end", name, result, kwargs)) + return result + def _make_atof_exporter(self, config): return _FakeAtofExporter(self.events, config) @@ -425,6 +445,221 @@ output_directory = "{atif_dir}" assert atif_dir.is_dir() +def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + plugin = _fresh_plugin(monkeypatch, fake) + plugins_toml = tmp_path / "plugins.toml" + plugins_toml.write_text( + """ +version = 1 + +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +""", + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml)) + + seen_request = {} + raw_choice = SimpleNamespace( + message=SimpleNamespace( + role="assistant", + content=None, + tool_calls=[ + SimpleNamespace( + id="tool-1", + type="function", + function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'), + ) + ], + reasoning_content="need a tool", + ), + finish_reason="tool_calls", + ) + + def next_call(request): + seen_request.update(request) + return SimpleNamespace( + id="resp-1", + model="demo-model", + choices=[raw_choice], + usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8), + ) + + response = plugin.on_llm_execution_middleware( + session_id="s1", + task_id="t1", + turn_id="turn-1", + api_request_id="api-1", + provider="anthropic", + model="demo-model", + api_call_count=1, + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=next_call, + ) + + assert response.model == "demo-model" + assert response.choices == [raw_choice] + assert seen_request["intercepted"] is True + execute_start = next(event for event in fake.events if event[0] == "llm.execute.start") + assert execute_start[3]["data"]["mode"] == "route" + execute_end = next(event for event in fake.events if event[0] == "llm.execute.end") + assert execute_end[2] == { + "model": "demo-model", + "assistant_message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "tool-1", + "type": "function", + "function": {"name": "terminal", "arguments": '{"command":"pwd"}'}, + } + ], + "reasoning_content": "need a tool", + }, + "finish_reason": "tool_calls", + "usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8}, + } + + +def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch): + fake = _FakeNemoRelay() + plugin = _fresh_plugin(monkeypatch, fake) + + response = plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": []}, + next_call=lambda request: {"raw": request}, + ) + + assert response == {"raw": {"messages": []}} + assert not any(event[0] == "llm.execute.start" for event in fake.events) + + +def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + plugin = _fresh_plugin(monkeypatch, fake) + plugins_toml = tmp_path / "plugins.toml" + plugins_toml.write_text( + """ +version = 1 + +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +""", + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml)) + + seen_args = {} + + def next_call(args): + seen_args.update(args) + return {"raw": True, "args": args} + + response = plugin.on_tool_execution_middleware( + session_id="s1", + task_id="t1", + turn_id="turn-1", + api_request_id="api-1", + tool_name="terminal", + tool_call_id="tool-1", + args={"command": "pwd"}, + next_call=next_call, + ) + + assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}} + assert seen_args["intercepted"] is True + execute_start = next(event for event in fake.events if event[0] == "tool.execute.start") + assert execute_start[3]["data"]["mode"] == "route" + assert execute_start[3]["data"]["tool_call_id"] == "tool-1" + + +def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch): + fake = _FakeNemoRelay() + plugin = _fresh_plugin(monkeypatch, fake) + + response = plugin.on_tool_execution_middleware( + session_id="s1", + tool_name="terminal", + args={"command": "pwd"}, + next_call=lambda args: {"raw": args}, + ) + + assert response == {"raw": {"command": "pwd"}} + assert not any(event[0] == "tool.execute.start" for event in fake.events) + + +def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + plugin = _fresh_plugin(monkeypatch, fake) + plugins_toml = tmp_path / "plugins.toml" + plugins_toml.write_text( + """ +version = 1 + +[[components]] +kind = "adaptive" +enabled = true + +[components.config] +mode = "route" +""", + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml)) + + base = { + "session_id": "s1", + "task_id": "t1", + "turn_id": "turn-1", + "api_request_id": "api-1", + } + plugin.on_pre_api_request( + **base, + provider="anthropic", + model="demo-model", + request={"body": {"messages": [{"role": "user", "content": "hi"}]}}, + ) + plugin.on_post_api_request(**base, response={"ok": True}) + plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"}) + plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True}) + + plugin.on_llm_execution_middleware( + **base, + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=lambda request: {"raw": request}, + ) + plugin.on_tool_execution_middleware( + **base, + tool_name="terminal", + tool_call_id="tool-1", + args={"command": "pwd"}, + next_call=lambda args: {"raw": args}, + ) + + event_names = [event[0] for event in fake.events] + assert "llm.call" not in event_names + assert "llm.call_end" not in event_names + assert "tool.call" not in event_names + assert "tool.call_end" not in event_names + assert "llm.execute.start" in event_names + assert "tool.execute.start" in event_names + + def test_nemo_relay_plugin_noops_without_dependency(monkeypatch): monkeypatch.delitem(sys.modules, "nemo_relay", raising=False) sys.modules.pop("plugins.observability.nemo_relay", None) diff --git a/tests/run_agent/test_run_agent.py b/tests/run_agent/test_run_agent.py index e9e7011dd1e..aef73c665f3 100644 --- a/tests/run_agent/test_run_agent.py +++ b/tests/run_agent/test_run_agent.py @@ -2466,8 +2466,10 @@ class TestConcurrentToolExecution: api_request_id="", enabled_tools=list(agent.valid_tool_names), skip_pre_tool_call_hook=True, + skip_tool_request_middleware=True, enabled_toolsets=agent.enabled_toolsets, disabled_toolsets=agent.disabled_toolsets, + tool_request_middleware_trace=[], ) assert result == "result" @@ -2647,6 +2649,89 @@ class TestConcurrentToolExecution: assert post_call[1]["result"] == '{"ok":true}' assert post_call[1]["status"] == "ok" + def test_sequential_agent_level_tool_execution_middleware_wraps_inline_dispatch(self, agent, monkeypatch): + """Sequential built-in tool paths should expose the adaptive execution boundary.""" + tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call]) + messages = [] + hook_calls = [] + seen = {} + + def request_middleware(**kwargs): + return { + "args": {**kwargs["args"], "request_rewritten": True}, + "source": "request-test", + } + + def execution_middleware(**kwargs): + seen["middleware_args"] = kwargs["args"] + return kwargs["next_call"]({**kwargs["args"], "merge": True}) + + manager = SimpleNamespace(_middleware={ + "tool_request": [request_middleware], + "tool_execution": [execution_middleware], + }) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + monkeypatch.setattr( + "hermes_cli.plugins.invoke_middleware", + lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [], + ) + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: None, + ) + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [], + ) + monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True) + + with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo: + agent._execute_tool_calls_sequential(mock_msg, messages, "task-1") + + assert seen["middleware_args"] == {"todos": [], "request_rewritten": True} + mock_todo.assert_called_once_with(todos=[], merge=True, store=agent._todo_store) + post_call = next(call for call in hook_calls if call[0] == "post_tool_call") + assert post_call[1]["tool_name"] == "todo" + assert post_call[1]["args"] == {"todos": [], "request_rewritten": True, "merge": True} + assert post_call[1]["middleware_trace"] == [{"source": "request-test"}] + + def test_concurrent_agent_level_tool_preserves_request_middleware_trace(self, agent, monkeypatch): + tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1") + mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call]) + messages = [] + hook_calls = [] + + def request_middleware(**kwargs): + return { + "args": {**kwargs["args"], "request_rewritten": True}, + "source": "request-test", + } + + manager = SimpleNamespace(_middleware={"tool_request": [request_middleware], "tool_execution": []}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + monkeypatch.setattr( + "hermes_cli.plugins.invoke_middleware", + lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [], + ) + monkeypatch.setattr( + "hermes_cli.plugins.get_pre_tool_call_block_message", + lambda *args, **kwargs: None, + ) + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [], + ) + monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True) + + with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}'): + agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1") + + post_call = next(call for call in hook_calls if call[0] == "post_tool_call") + assert post_call[1]["tool_name"] == "todo" + assert post_call[1]["args"] == {"todos": [], "request_rewritten": True} + assert post_call[1]["middleware_trace"] == [{"source": "request-test"}] + def test_agent_runtime_post_hook_ownership_predicate_covers_agent_tools(self, agent): """Sequential and concurrent agent-level paths share post-hook ownership.""" from agent.agent_runtime_helpers import agent_runtime_owns_post_tool_hook diff --git a/tests/test_model_tools.py b/tests/test_model_tools.py index f1a5b510cb2..91e7103aac7 100644 --- a/tests/test_model_tools.py +++ b/tests/test_model_tools.py @@ -64,6 +64,7 @@ class TestHandleFunctionCall: tool_call_id="call-1", turn_id="", api_request_id="", + middleware_trace=[], ), call( "post_tool_call", @@ -79,6 +80,7 @@ class TestHandleFunctionCall: status="ok", error_type=None, error_message=None, + middleware_trace=[], ), call( "transform_tool_result", @@ -145,6 +147,60 @@ class TestHandleFunctionCall: assert "post_tool_call" not in fired assert "transform_tool_result" not in fired + def test_tool_request_and_execution_middleware_wrap_registry_dispatch(self, monkeypatch): + seen = {} + + def fake_invoke_middleware(kind, **kwargs): + if kind == "tool_request": + return [{ + "args": {**kwargs["args"], "rewritten": True}, + "source": "test-middleware", + "reason": "rewrite", + }] + return [] + + def execution_middleware(**kwargs): + seen["execution_args"] = kwargs["args"] + return kwargs["next_call"]({**kwargs["args"], "wrapped": True}) + + def fake_dispatch(tool_name, args, **kwargs): + seen["dispatch"] = (tool_name, args, kwargs) + return json.dumps({"ok": True, "args": args}) + + manager = type( + "Manager", + (), + {"_middleware": {"tool_request": [fake_invoke_middleware], "tool_execution": [execution_middleware]}}, + )() + monkeypatch.setattr("hermes_cli.plugins.invoke_middleware", fake_invoke_middleware) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + hook_calls = [] + monkeypatch.setattr( + "hermes_cli.plugins.invoke_hook", + lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [], + ) + monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True) + monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch) + + result = json.loads( + handle_function_call( + "web_search", + {"q": "test"}, + task_id="task-1", + tool_call_id="tool-1", + session_id="session-1", + ) + ) + + assert seen["execution_args"] == {"q": "test", "rewritten": True} + assert seen["dispatch"][1] == {"q": "test", "rewritten": True, "wrapped": True} + assert result["args"] == {"q": "test", "rewritten": True, "wrapped": True} + expected_trace = [{"source": "test-middleware", "reason": "rewrite"}] + pre_call = next(call for call in hook_calls if call[0] == "pre_tool_call") + post_call = next(call for call in hook_calls if call[0] == "post_tool_call") + assert pre_call[1]["middleware_trace"] == expected_trace + assert post_call[1]["middleware_trace"] == expected_trace + # ========================================================================= # Agent loop tools From 5abe45674dc7eaf72190d97785738ea2ea8b607b Mon Sep 17 00:00:00 2001 From: Bryan Bednarski Date: Sat, 6 Jun 2026 09:26:18 -0700 Subject: [PATCH 2/3] fix(middleware): preserve translated downstream failures Track successful next_call completion separately from invocation so execution middleware that catches and translates a downstream provider/tool failure does not accidentally convert that failure into a successful None result. Also avoid wrapping BaseException from downstream execution, and document the execution middleware error semantics. Tests cover: - pre-next_call middleware failures fail open to the remaining chain - post-next_call middleware failures preserve the downstream result - translated downstream failures propagate instead of returning None - downstream BaseException is not wrapped Signed-off-by: Bryan Bednarski --- docs/middleware/README.md | 9 +++++ hermes_cli/middleware.py | 10 +++-- tests/hermes_cli/test_plugins.py | 65 ++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 3 deletions(-) diff --git a/docs/middleware/README.md b/docs/middleware/README.md index b385b87eb29..4a5c06f8cbe 100644 --- a/docs/middleware/README.md +++ b/docs/middleware/README.md @@ -244,6 +244,15 @@ For NeMo Relay adaptive execution middleware, see patches. - Execution middleware should call `next_call(...)` exactly once unless it is intentionally short-circuiting execution. +- If execution middleware raises before calling `next_call(...)`, Hermes treats + that as middleware failure and continues with the remaining middleware chain + and base execution. +- If execution middleware calls `next_call(...)` successfully and then raises + during post-processing, Hermes preserves the downstream result and does not + run the provider or tool a second time. +- If downstream provider or tool execution fails, middleware may let that error + propagate or translate it deliberately. Hermes does not convert downstream + failure into a successful `None` result. - Tool request middleware runs before approvals. If it mutates file paths, commands, URLs, or arguments, the mutated values are what guardrails and approvals evaluate. diff --git a/hermes_cli/middleware.py b/hermes_cli/middleware.py index 938bffcf172..277368dffb3 100644 --- a/hermes_cli/middleware.py +++ b/hermes_cli/middleware.py @@ -237,15 +237,17 @@ def _run_execution_chain( callback = callbacks[index] next_called = False + next_succeeded = False next_result: Any = None def next_call(next_payload: Any = None) -> Any: - nonlocal next_called, next_result + nonlocal next_called, next_succeeded, next_result next_called = True try: next_result = call_at(index + 1, payload if next_payload is None else next_payload) + next_succeeded = True return next_result - except BaseException as exc: + except Exception as exc: raise _DownstreamExecutionError(exc) from exc call_kwargs = middleware_payload(**kwargs) @@ -262,8 +264,10 @@ def _run_execution_chain( getattr(callback, "__name__", repr(callback)), exc, ) - if next_called: + if next_succeeded: return next_result + if next_called: + raise return call_at(index + 1, payload) return call_at(0, kwargs[payload_key]) diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index 6bff7b6d87d..ddd1dab56e4 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -207,6 +207,71 @@ class TestPluginDiscovery: assert result == "terminal-result" assert calls == [{"command": "printf ok"}] + def test_execution_middleware_pre_next_call_error_fails_open_to_remaining_chain(self, monkeypatch): + calls = [] + + def failing_middleware(**kwargs): + calls.append("failing") + raise RuntimeError("middleware setup failed") + + def downstream_middleware(**kwargs): + calls.append("downstream") + return kwargs["next_call"]({**kwargs["args"], "rewritten": True}) + + manager = types.SimpleNamespace(_middleware={"tool_execution": [failing_middleware, downstream_middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(("terminal", args)) + return args + + result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal) + + assert result == {"command": "printf ok", "rewritten": True} + assert calls == ["failing", "downstream", ("terminal", {"command": "printf ok", "rewritten": True})] + + def test_execution_middleware_translated_downstream_failure_is_not_masked(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + try: + return kwargs["next_call"](kwargs["args"]) + except Exception as exc: + raise RuntimeError(f"translated downstream failure: {exc}") from exc + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + raise RuntimeError("terminal failed") + + with pytest.raises(RuntimeError, match="translated downstream failure: terminal failed"): + run_tool_execution_middleware("terminal", {"command": "false"}, terminal) + + assert calls == [{"command": "false"}] + + def test_execution_middleware_downstream_base_exception_is_not_wrapped(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + try: + return kwargs["next_call"](kwargs["args"]) + except Exception as exc: + raise RuntimeError(f"middleware should not catch base exception: {exc}") from exc + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + raise KeyboardInterrupt() + + with pytest.raises(KeyboardInterrupt): + run_tool_execution_middleware("terminal", {"command": "interrupt"}, terminal) + + assert calls == [{"command": "interrupt"}] + def test_discover_project_plugins(self, tmp_path, monkeypatch): """Plugins in ./.hermes/plugins/ are discovered.""" project_dir = tmp_path / "project" From c4c5548eb4800068ff3dd1ac8361d3b4ee23a06b Mon Sep 17 00:00:00 2001 From: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> Date: Sat, 6 Jun 2026 23:07:25 +0530 Subject: [PATCH 3/3] fix(middleware): single-use next_call guard + deepcopy-safe request copies Address the two non-blocking follow-ups from review: - next_call is now single-use per middleware frame. A second invocation raises instead of silently re-running the downstream provider/tool, so the terminal call cannot execute twice via the chain. The error surfaces through the existing handler, which preserves the first downstream result. - Request-middleware payload copies go through _safe_copy(), which falls back to a shallow dict copy when deepcopy() fails on a non-deepcopyable member (clients, callbacks, file handles) instead of aborting the pass. Adds regression coverage for both: double next_call() keeps the terminal single-run, and a non-deepcopyable (threading.Lock) request payload still runs middleware via the shallow fallback. --- hermes_cli/middleware.py | 41 +++++++++++++++++++++++---- tests/hermes_cli/test_plugins.py | 48 ++++++++++++++++++++++++++++++++ 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/hermes_cli/middleware.py b/hermes_cli/middleware.py index 277368dffb3..8795952a2b7 100644 --- a/hermes_cli/middleware.py +++ b/hermes_cli/middleware.py @@ -55,6 +55,25 @@ def middleware_payload(**kwargs: Any) -> Dict[str, Any]: return kwargs +def _safe_copy(payload: Any) -> Any: + """Deep-copy a request payload, tolerating non-deepcopyable members. + + Request payloads are normally plain JSON-shaped dicts, but an LLM request + can occasionally carry non-deepcopyable objects (clients, callbacks, file + handles). A hard ``deepcopy`` failure there would otherwise abort the whole + request-middleware pass. Fall back to a shallow ``dict`` copy so middleware + still runs and the original nested objects are shared by reference rather + than corrupting the live payload. + """ + try: + return deepcopy(payload) + except Exception as exc: # pragma: no cover - exercised via fallback test + logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc) + if isinstance(payload, dict): + return dict(payload) + return payload + + def apply_llm_request_middleware( request: Dict[str, Any], **context: Any, @@ -72,8 +91,8 @@ def apply_llm_request_middleware( trace=[], ) - original_request = deepcopy(request) - current_request = deepcopy(original_request) + original_request = _safe_copy(request) + current_request = _safe_copy(original_request) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( @@ -87,7 +106,7 @@ def apply_llm_request_middleware( next_request = result.get("request") if not isinstance(next_request, dict): continue - current_request = deepcopy(next_request) + current_request = _safe_copy(next_request) trace.append(_trace_entry(result)) return RequestMiddlewareResult( @@ -116,8 +135,8 @@ def apply_tool_request_middleware( trace=[], ) - original_args = deepcopy(args) - current_args = deepcopy(original_args) + original_args = _safe_copy(args) + current_args = _safe_copy(original_args) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( @@ -132,7 +151,7 @@ def apply_tool_request_middleware( next_args = result.get("args") if not isinstance(next_args, dict): continue - current_args = deepcopy(next_args) + current_args = _safe_copy(next_args) trace.append(_trace_entry(result)) return RequestMiddlewareResult( @@ -242,6 +261,16 @@ def _run_execution_chain( def next_call(next_payload: Any = None) -> Any: nonlocal next_called, next_succeeded, next_result + # ``next_call`` is single-use per middleware frame. Calling it more + # than once would re-run the downstream provider/tool, so a second + # invocation is a contract violation rather than a retry. Surface it + # instead of silently executing the terminal call twice. + if next_called: + raise RuntimeError( + f"Middleware '{kind}' callback " + f"{getattr(callback, '__name__', repr(callback))} called " + "next_call() more than once; downstream execution is single-use" + ) next_called = True try: next_result = call_at(index + 1, payload if next_payload is None else next_payload) diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index ddd1dab56e4..bb889450d00 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -272,6 +272,54 @@ class TestPluginDiscovery: assert calls == [{"command": "interrupt"}] + def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + first = kwargs["next_call"](kwargs["args"]) + # Deliberate misuse: a second next_call() must not re-run the + # downstream tool. The chain surfaces it as an error and preserves + # the first (successful) downstream result. + kwargs["next_call"](kwargs["args"]) + return first + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + return "terminal-result" + + result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal) + + assert result == "terminal-result" + assert calls == [{"command": "printf ok"}] + + def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch): + import threading + + recorded = {} + + def middleware(**kwargs): + recorded["args"] = kwargs["args"] + return None + + manager = types.SimpleNamespace( + _middleware={"tool_request": [middleware]}, + invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)], + ) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + # threading.Lock is not deepcopyable; a hard deepcopy would raise. + args = {"command": "noop", "lock": threading.Lock()} + result = apply_tool_request_middleware("terminal", args) + + # Middleware ran (payload was copied via the shallow fallback) and the + # non-deepcopyable member is shared by reference rather than aborting. + assert recorded["args"]["command"] == "noop" + assert result.payload["command"] == "noop" + assert result.payload["lock"] is args["lock"] + def test_discover_project_plugins(self, tmp_path, monkeypatch): """Plugins in ./.hermes/plugins/ are discovered.""" project_dir = tmp_path / "project"