Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin

feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
kshitij 2026-06-06 10:46:41 -07:00 committed by GitHub
commit d4a7bfd3aa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 2170 additions and 151 deletions

View file

@ -1620,13 +1620,37 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo
def invoke_tool(agent, function_name: str, function_args: dict, effective_task_id: str,
tool_call_id: Optional[str] = None, messages: list = None,
pre_tool_block_checked: bool = False) -> str:
pre_tool_block_checked: bool = False,
skip_tool_request_middleware: bool = False,
tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None) -> str:
"""Invoke a single tool and return the result string. No display logic.
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
tools. Used by the concurrent execution path; the sequential path retains
its own inline invocation for backward-compatible display handling.
"""
if not isinstance(function_args, dict):
function_args = {}
_tool_middleware_trace = list(tool_request_middleware_trace or [])
try:
from hermes_cli.middleware import apply_tool_request_middleware
if not skip_tool_request_middleware:
_tool_request_mw = apply_tool_request_middleware(
function_name,
function_args,
task_id=effective_task_id or "",
session_id=getattr(agent, "session_id", "") or "",
tool_call_id=tool_call_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
)
function_args = _tool_request_mw.payload
_tool_middleware_trace = _tool_request_mw.trace
except Exception as _mw_err:
logger.debug("tool_request middleware error: %s", _mw_err)
# Check plugin hooks for a block directive before executing anything.
block_message: Optional[str] = None
if not pre_tool_block_checked:
@ -1640,6 +1664,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
tool_call_id=tool_call_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
middleware_trace=list(_tool_middleware_trace),
)
except Exception:
pass
@ -1659,6 +1684,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
status="blocked",
error_type="plugin_block",
error_message=block_message,
middleware_trace=list(_tool_middleware_trace),
)
except Exception:
pass
@ -1666,12 +1692,13 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
tool_start_time = time.monotonic()
def _finish_agent_tool(result: Any) -> Any:
def _finish_agent_tool(result: Any, observed_args: Optional[dict] = None) -> Any:
hook_args = observed_args if isinstance(observed_args, dict) else function_args
try:
from model_tools import _emit_post_tool_call_hook
_emit_post_tool_call_hook(
function_name=function_name,
function_args=function_args,
function_args=hook_args,
result=result,
task_id=effective_task_id or "",
session_id=getattr(agent, "session_id", "") or "",
@ -1679,89 +1706,116 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
duration_ms=int((time.monotonic() - tool_start_time) * 1000),
middleware_trace=list(_tool_middleware_trace),
)
except Exception:
pass
return result
if function_name == "todo":
from tools.todo_tool import todo_tool as _todo_tool
return _finish_agent_tool(
_todo_tool(
todos=function_args.get("todos"),
merge=function_args.get("merge", False),
store=agent._todo_store,
def _execute(next_args: dict) -> Any:
from tools.todo_tool import todo_tool as _todo_tool
return _finish_agent_tool(
_todo_tool(
todos=next_args.get("todos"),
merge=next_args.get("merge", False),
store=agent._todo_store,
),
next_args,
)
)
elif function_name == "session_search":
session_db = agent._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()}))
from tools.session_search_tool import session_search as _session_search
return _finish_agent_tool(
_session_search(
query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3),
session_id=function_args.get("session_id"),
around_message_id=function_args.get("around_message_id"),
window=function_args.get("window", 5),
sort=function_args.get("sort"),
db=session_db,
current_session_id=agent.session_id,
def _execute(next_args: dict) -> Any:
session_db = agent._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()}), next_args)
from tools.session_search_tool import session_search as _session_search
return _finish_agent_tool(
_session_search(
query=next_args.get("query", ""),
role_filter=next_args.get("role_filter"),
limit=next_args.get("limit", 3),
session_id=next_args.get("session_id"),
around_message_id=next_args.get("around_message_id"),
window=next_args.get("window", 5),
sort=next_args.get("sort"),
db=session_db,
current_session_id=agent.session_id,
),
next_args,
)
)
elif function_name == "memory":
target = function_args.get("target", "memory")
from tools.memory_tool import memory_tool as _memory_tool
result = _memory_tool(
action=function_args.get("action"),
target=target,
content=function_args.get("content"),
old_text=function_args.get("old_text"),
store=agent._memory_store,
)
# Bridge: notify external memory provider of built-in memory writes
if agent._memory_manager and function_args.get("action") in {"add", "replace"}:
try:
agent._memory_manager.on_memory_write(
function_args.get("action", ""),
target,
function_args.get("content", ""),
metadata=agent._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=tool_call_id,
),
)
except Exception:
pass
return _finish_agent_tool(result)
elif agent._memory_manager and agent._memory_manager.has_tool(function_name):
return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, function_args))
elif function_name == "clarify":
from tools.clarify_tool import clarify_tool as _clarify_tool
return _finish_agent_tool(
_clarify_tool(
question=function_args.get("question", ""),
choices=function_args.get("choices"),
callback=agent.clarify_callback,
def _execute(next_args: dict) -> Any:
target = next_args.get("target", "memory")
from tools.memory_tool import memory_tool as _memory_tool
result = _memory_tool(
action=next_args.get("action"),
target=target,
content=next_args.get("content"),
old_text=next_args.get("old_text"),
store=agent._memory_store,
)
# Bridge: notify external memory provider of built-in memory writes
if agent._memory_manager and next_args.get("action") in {"add", "replace"}:
try:
agent._memory_manager.on_memory_write(
next_args.get("action", ""),
target,
next_args.get("content", ""),
metadata=agent._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=tool_call_id,
),
)
except Exception:
pass
return _finish_agent_tool(result, next_args)
elif agent._memory_manager and agent._memory_manager.has_tool(function_name):
def _execute(next_args: dict) -> Any:
return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, next_args), next_args)
elif function_name == "clarify":
def _execute(next_args: dict) -> Any:
from tools.clarify_tool import clarify_tool as _clarify_tool
return _finish_agent_tool(
_clarify_tool(
question=next_args.get("question", ""),
choices=next_args.get("choices"),
callback=agent.clarify_callback,
),
next_args,
)
)
elif function_name == "delegate_task":
return _finish_agent_tool(agent._dispatch_delegate_task(function_args))
def _execute(next_args: dict) -> Any:
return _finish_agent_tool(agent._dispatch_delegate_task(next_args), next_args)
else:
return _ra().handle_function_call(
function_name, function_args, effective_task_id,
tool_call_id=tool_call_id,
session_id=agent.session_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
skip_pre_tool_call_hook=True,
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
)
def _execute(next_args: dict) -> Any:
return _ra().handle_function_call(
function_name, next_args, effective_task_id,
tool_call_id=tool_call_id,
session_id=agent.session_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
skip_pre_tool_call_hook=True,
skip_tool_request_middleware=True,
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
tool_request_middleware_trace=list(_tool_middleware_trace),
)
from hermes_cli.middleware import run_tool_execution_middleware
return run_tool_execution_middleware(
function_name,
function_args,
lambda next_args: _execute(next_args if isinstance(next_args, dict) else function_args),
original_args=function_args,
task_id=effective_task_id or "",
session_id=getattr(agent, "session_id", "") or "",
tool_call_id=tool_call_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
)

View file

@ -1239,6 +1239,28 @@ def run_conversation(
_sanitize_structure_non_ascii(api_kwargs)
if agent.api_mode == "codex_responses":
api_kwargs = agent._get_transport().preflight_kwargs(api_kwargs, allow_stream=False)
try:
from hermes_cli.middleware import apply_llm_request_middleware
_llm_request_mw = apply_llm_request_middleware(
api_kwargs,
task_id=effective_task_id,
turn_id=turn_id,
api_request_id=api_request_id,
session_id=agent.session_id or "",
platform=agent.platform or "",
model=agent.model,
provider=agent.provider,
base_url=agent.base_url,
api_mode=agent.api_mode,
api_call_count=api_call_count,
)
api_kwargs = _llm_request_mw.payload
_original_api_kwargs = _llm_request_mw.original_payload
_llm_middleware_trace = _llm_request_mw.trace
except Exception:
_original_api_kwargs = dict(api_kwargs)
_llm_middleware_trace = []
try:
from hermes_cli.plugins import (
@ -1291,6 +1313,7 @@ def run_conversation(
request_char_count=total_chars,
max_tokens=agent.max_tokens,
started_at=api_start_time,
middleware_trace=list(_llm_middleware_trace),
request=_request_payload,
)
except Exception:
@ -1349,7 +1372,24 @@ def run_conversation(
)
return agent._interruptible_api_call(next_api_kwargs)
response = _perform_api_call(api_kwargs)
from hermes_cli.middleware import run_llm_execution_middleware
response = run_llm_execution_middleware(
api_kwargs,
_perform_api_call,
original_request=_original_api_kwargs,
task_id=effective_task_id,
turn_id=turn_id,
api_request_id=api_request_id,
session_id=agent.session_id or "",
platform=agent.platform or "",
model=agent.model,
provider=agent.provider,
base_url=agent.base_url,
api_mode=agent.api_mode,
api_call_count=api_call_count,
middleware_trace=list(_llm_middleware_trace),
)
api_duration = time.time() - api_start_time

View file

@ -70,6 +70,7 @@ def _emit_terminal_post_tool_call(
status: str | None = None,
error_type: str | None = None,
error_message: str | None = None,
middleware_trace: Optional[list[dict[str, Any]]] = None,
) -> None:
try:
from model_tools import _emit_post_tool_call_hook
@ -86,6 +87,7 @@ def _emit_terminal_post_tool_call(
status=status,
error_type=error_type,
error_message=error_message,
middleware_trace=list(middleware_trace or []),
)
except Exception:
pass
@ -111,6 +113,7 @@ def _emit_cancelled_terminal_post_tool_call(
start_time: float,
reason: str = "user interrupt",
error_type: str = "keyboard_interrupt",
middleware_trace: Optional[list[dict[str, Any]]] = None,
) -> str:
result = _cancelled_tool_result(reason)
_emit_terminal_post_tool_call(
@ -124,6 +127,7 @@ def _emit_cancelled_terminal_post_tool_call(
status="cancelled",
error_type=error_type,
error_message=f"Tool execution cancelled by {reason}",
middleware_trace=list(middleware_trace or []),
)
return result
@ -177,6 +181,65 @@ def _tool_search_scoped_names(agent) -> frozenset:
return names
def _apply_tool_request_middleware_for_agent(
agent,
*,
function_name: str,
function_args: dict,
effective_task_id: str,
tool_call_id: str,
) -> tuple[dict, list[dict[str, Any]]]:
try:
from hermes_cli.middleware import apply_tool_request_middleware
result = apply_tool_request_middleware(
function_name,
function_args,
task_id=effective_task_id or "",
session_id=getattr(agent, "session_id", "") or "",
tool_call_id=tool_call_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
)
payload = result.payload if isinstance(result.payload, dict) else function_args
return payload, list(result.trace)
except Exception as exc:
logger.debug("tool_request middleware error: %s", exc)
return function_args, []
def _run_agent_tool_execution_middleware(
agent,
*,
function_name: str,
function_args: dict,
effective_task_id: str,
tool_call_id: str,
execute,
) -> tuple[Any, dict]:
observed_args = function_args
def _execute(next_args: dict) -> Any:
nonlocal observed_args
observed_args = next_args if isinstance(next_args, dict) else function_args
return execute(observed_args)
from hermes_cli.middleware import run_tool_execution_middleware
result = run_tool_execution_middleware(
function_name,
function_args,
_execute,
original_args=function_args,
task_id=effective_task_id or "",
session_id=getattr(agent, "session_id", "") or "",
tool_call_id=tool_call_id or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
)
return result, observed_args
def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
"""Execute multiple tool calls concurrently using a thread pool.
@ -198,7 +261,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
return
# ── Parse args + pre-execution bookkeeping ───────────────────────
parsed_calls = [] # list of (tool_call, function_name, function_args)
parsed_calls = [] # list of (tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail)
for tool_call in tool_calls:
function_name = tool_call.function.name
@ -250,6 +313,14 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
except Exception:
pass
function_args, middleware_trace = _apply_tool_request_middleware_for_agent(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
)
# ── Block evaluation (BEFORE checkpoint preflight) ───────────
# We must know whether the tool will execute before touching
# checkpoint state (dedup slot, real snapshots).
@ -268,6 +339,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
status="blocked",
error_type="tool_scope_block",
error_message=_ts_scope_block,
middleware_trace=list(middleware_trace),
)
else:
try:
@ -280,6 +352,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
tool_call_id=getattr(tool_call, "id", "") or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
middleware_trace=list(middleware_trace),
)
except Exception:
block_message = None
@ -296,6 +369,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
status="blocked",
error_type="plugin_block",
error_message=block_message,
middleware_trace=list(middleware_trace),
)
else:
guardrail_decision = agent._tool_guardrails.before_call(function_name, function_args)
@ -312,6 +386,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
status="blocked",
error_type="guardrail_block",
error_message=getattr(guardrail_decision, "message", None) or "Tool blocked by guardrail policy",
middleware_trace=list(middleware_trace),
)
# ── Checkpoint preflight (only for tools that will execute) ──
@ -338,13 +413,13 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
except Exception:
pass
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
parsed_calls.append((tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail))
# ── Logging / callbacks ──────────────────────────────────────────
tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls)
tool_names_str = ", ".join(name for _, name, _, _, _, _ in parsed_calls)
if not agent.quiet_mode:
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
args_str = json.dumps(args, ensure_ascii=False)
if agent.verbose_logging:
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
@ -353,7 +428,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
args_preview = args_str[:agent.log_prefix_chars] + "..." if len(args_str) > agent.log_prefix_chars else args_str
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls:
if block_result is not None:
continue
if agent.tool_progress_callback:
@ -363,7 +438,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
except Exception as cb_err:
logging.debug(f"Tool progress callback error: {cb_err}")
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls:
if block_result is not None:
continue
if agent.tool_start_callback:
@ -373,18 +448,18 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
logging.debug(f"Tool start callback error: {cb_err}")
# ── Concurrent execution ─────────────────────────────────────────
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag)
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag, middleware_trace)
results = [None] * num_tools
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
if block_result is not None:
results[i] = (name, args, block_result, 0.0, True, True)
results[i] = (name, args, block_result, 0.0, True, True, middleware_trace)
# Touch activity before launching workers so the gateway knows
# we're executing tools (not stuck).
agent._current_tool = tool_names_str
agent._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}")
def _run_tool(index, tool_call, function_name, function_args):
def _run_tool(index, tool_call, function_name, function_args, middleware_trace):
"""Worker function executed in a thread."""
# Register this worker tid so the agent can fan out an interrupt
# to it — see AIAgent.interrupt(). Must happen first thing, and
@ -423,6 +498,8 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
tool_call.id,
messages=messages,
pre_tool_block_checked=True,
skip_tool_request_middleware=True,
tool_request_middleware_trace=list(middleware_trace),
)
except KeyboardInterrupt:
try:
@ -436,10 +513,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
start_time=start,
middleware_trace=list(middleware_trace),
)
duration = time.time() - start
logger.info("tool %s cancelled (%.2fs)", function_name, duration)
results[index] = (function_name, function_args, result, duration, True, False)
results[index] = (function_name, function_args, result, duration, True, False, middleware_trace)
return
except Exception as tool_error:
result = f"Error executing tool '{function_name}': {tool_error}"
@ -450,7 +528,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
else:
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
results[index] = (function_name, function_args, result, duration, is_error, False)
results[index] = (function_name, function_args, result, duration, is_error, False, middleware_trace)
finally:
# Tear down worker-tid tracking. Clear any interrupt bit we may
# have set so the next task scheduled onto this recycled tid
@ -475,7 +553,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
try:
runnable_calls = [
(i, tc, name, args)
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
if block_result is None
]
futures = []
@ -487,7 +565,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
# _approval_session_key) AND thread-local approval/sudo
# callbacks into the worker thread; clears callbacks on exit.
f = executor.submit(
propagate_context_to_thread(_run_tool), i, tc, name, args
propagate_context_to_thread(_run_tool), i, tc, name, args, parsed_calls[i][3]
)
futures.append(f)
@ -545,7 +623,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
spinner.stop(f"{completed}/{num_tools} tools completed in {total_dur:.1f}s total")
# ── Post-execution: display per-tool results ─────────────────────
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
r = results[i]
blocked = False
if r is None:
@ -562,6 +640,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
status="cancelled",
error_type="keyboard_interrupt",
error_message="Tool execution cancelled by user interrupt",
middleware_trace=list(middleware_trace),
)
else:
function_result = f"Error executing tool '{name}': thread did not return a result"
@ -575,10 +654,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
status="error",
error_type="thread_missing_result",
error_message=function_result,
middleware_trace=list(middleware_trace),
)
tool_duration = 0.0
else:
function_name, function_args, function_result, tool_duration, is_error, blocked = r
function_name, function_args, function_result, tool_duration, is_error, blocked, middleware_trace = r
if not blocked:
function_result = agent._append_guardrail_observation(
@ -738,6 +818,14 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
except Exception:
pass
function_args, middleware_trace = _apply_tool_request_middleware_for_agent(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
)
# Check plugin hooks for a block directive before executing.
_block_msg: Optional[str] = None
_block_error_type = "plugin_block"
@ -755,6 +843,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
tool_call_id=getattr(tool_call, "id", "") or "",
turn_id=getattr(agent, "_current_turn_id", "") or "",
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
middleware_trace=list(middleware_trace),
)
except Exception:
pass
@ -853,6 +942,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
status="blocked",
error_type=_block_error_type,
error_message=_block_msg,
middleware_trace=list(middleware_trace),
)
elif _guardrail_block_decision is not None:
# Tool blocked by tool-loop guardrail — synthesize exactly one
@ -869,71 +959,108 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
status="blocked",
error_type="guardrail_block",
error_message=getattr(_guardrail_block_decision, "message", None) or "Tool blocked by guardrail policy",
middleware_trace=list(middleware_trace),
)
elif function_name == "todo":
from tools.todo_tool import todo_tool as _todo_tool
function_result = _todo_tool(
todos=function_args.get("todos"),
merge=function_args.get("merge", False),
store=agent._todo_store,
def _execute(next_args: dict) -> Any:
from tools.todo_tool import todo_tool as _todo_tool
return _todo_tool(
todos=next_args.get("todos"),
merge=next_args.get("merge", False),
store=agent._todo_store,
)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
tool_duration = time.time() - tool_start_time
if agent._should_emit_quiet_tool_messages():
agent._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
elif function_name == "session_search":
session_db = agent._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
else:
def _execute(next_args: dict) -> Any:
session_db = agent._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
return json.dumps({"success": False, "error": format_session_db_unavailable()})
from tools.session_search_tool import session_search as _session_search
function_result = _session_search(
query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3),
session_id=function_args.get("session_id"),
around_message_id=function_args.get("around_message_id"),
window=function_args.get("window", 5),
sort=function_args.get("sort"),
return _session_search(
query=next_args.get("query", ""),
role_filter=next_args.get("role_filter"),
limit=next_args.get("limit", 3),
session_id=next_args.get("session_id"),
around_message_id=next_args.get("around_message_id"),
window=next_args.get("window", 5),
sort=next_args.get("sort"),
db=session_db,
current_session_id=agent.session_id,
)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
tool_duration = time.time() - tool_start_time
if agent._should_emit_quiet_tool_messages():
agent._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}")
elif function_name == "memory":
target = function_args.get("target", "memory")
from tools.memory_tool import memory_tool as _memory_tool
function_result = _memory_tool(
action=function_args.get("action"),
target=target,
content=function_args.get("content"),
old_text=function_args.get("old_text"),
store=agent._memory_store,
def _execute(next_args: dict) -> Any:
target = next_args.get("target", "memory")
from tools.memory_tool import memory_tool as _memory_tool
result = _memory_tool(
action=next_args.get("action"),
target=target,
content=next_args.get("content"),
old_text=next_args.get("old_text"),
store=agent._memory_store,
)
# Bridge: notify external memory provider of built-in memory writes
if agent._memory_manager and next_args.get("action") in {"add", "replace"}:
try:
agent._memory_manager.on_memory_write(
next_args.get("action", ""),
target,
next_args.get("content", ""),
metadata=agent._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", None),
),
)
except Exception:
pass
return result
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
# Bridge: notify external memory provider of built-in memory writes
if agent._memory_manager and function_args.get("action") in {"add", "replace"}:
try:
agent._memory_manager.on_memory_write(
function_args.get("action", ""),
target,
function_args.get("content", ""),
metadata=agent._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", None),
),
)
except Exception:
pass
tool_duration = time.time() - tool_start_time
if agent._should_emit_quiet_tool_messages():
agent._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
elif function_name == "clarify":
from tools.clarify_tool import clarify_tool as _clarify_tool
function_result = _clarify_tool(
question=function_args.get("question", ""),
choices=function_args.get("choices"),
callback=agent.clarify_callback,
def _execute(next_args: dict) -> Any:
from tools.clarify_tool import clarify_tool as _clarify_tool
return _clarify_tool(
question=next_args.get("question", ""),
choices=next_args.get("choices"),
callback=agent.clarify_callback,
)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
tool_duration = time.time() - tool_start_time
if agent._should_emit_quiet_tool_messages():
@ -957,7 +1084,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
agent._delegate_spinner = spinner
_delegate_result = None
try:
function_result = agent._dispatch_delegate_task(function_args)
def _execute(next_args: dict) -> Any:
return agent._dispatch_delegate_task(next_args)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
_delegate_result = function_result
finally:
agent._delegate_spinner = None
@ -978,7 +1114,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
spinner.start()
_ce_result = None
try:
function_result = agent.context_compressor.handle_tool_call(function_name, function_args, messages=messages)
def _execute(next_args: dict) -> Any:
return agent.context_compressor.handle_tool_call(function_name, next_args, messages=messages)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
_ce_result = function_result
except Exception as tool_error:
function_result = json.dumps({"error": f"Context engine tool '{function_name}' failed: {tool_error}"})
@ -1002,7 +1147,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
spinner.start()
_mem_result = None
try:
function_result = agent._memory_manager.handle_tool_call(function_name, function_args)
def _execute(next_args: dict) -> Any:
return agent._memory_manager.handle_tool_call(function_name, next_args)
function_result, function_args = _run_agent_tool_execution_middleware(
agent,
function_name=function_name,
function_args=function_args,
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
execute=_execute,
)
_mem_result = function_result
except Exception as tool_error:
function_result = json.dumps({"error": f"Memory tool '{function_name}' failed: {tool_error}"})
@ -1032,8 +1186,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
skip_pre_tool_call_hook=True,
skip_tool_request_middleware=True,
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
tool_request_middleware_trace=list(middleware_trace),
)
_spinner_result = function_result
except KeyboardInterrupt:
@ -1044,6 +1200,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
start_time=tool_start_time,
middleware_trace=list(middleware_trace),
)
_spinner_result = function_result
try:
@ -1071,8 +1228,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
skip_pre_tool_call_hook=True,
skip_tool_request_middleware=True,
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
tool_request_middleware_trace=list(middleware_trace),
)
except KeyboardInterrupt:
_emit_cancelled_terminal_post_tool_call(
@ -1082,6 +1241,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
start_time=tool_start_time,
middleware_trace=list(middleware_trace),
)
try:
agent.interrupt("keyboard interrupt")
@ -1126,6 +1286,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
effective_task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", "") or "",
duration_ms=int(tool_duration * 1000),
middleware_trace=list(middleware_trace),
)
if not _execution_blocked:
function_result = agent._append_guardrail_observation(

260
docs/middleware/README.md Normal file
View file

@ -0,0 +1,260 @@
# Hermes Middleware
Hermes middleware is the behavior-changing companion to observer hooks.
Observer hooks report what happened. Middleware can change what happens by
rewriting a request before execution or by wrapping the execution callback
itself.
This contract is intentionally backend-neutral. A plugin can use it for local
policy, request shaping, tracing, adaptive routing, cache control, sandbox
selection, or handoff to runtimes such as NeMo Relay without changing Hermes'
planner, model provider adapters, tool registry, memory, or CLI UX.
With middleware enabled, plugins can:
- Rewrite LLM provider request kwargs before Hermes calls the provider.
- Rewrite tool arguments before guardrails, approval checks, hooks, and tool
execution see them.
- Wrap the actual LLM execution callback while preserving Hermes retry,
streaming, interrupt, and hook behavior.
- Wrap the actual tool execution callback while preserving Hermes guardrails,
approval, post-tool hooks, and tool-result transformation.
## Contract
Plugins register middleware from `register(ctx)`:
```python
def register(ctx):
ctx.register_middleware("llm_request", on_llm_request)
ctx.register_middleware("llm_execution", on_llm_execution)
ctx.register_middleware("tool_request", on_tool_request)
ctx.register_middleware("tool_execution", on_tool_execution)
```
Every middleware callback receives:
- `telemetry_schema_version`: currently `hermes.observer.v1`
- `middleware_schema_version`: currently `hermes.middleware.v1`
- Runtime context such as `session_id`, `task_id`, `turn_id`,
`api_request_id`, `provider`, `model`, `api_mode`, `tool_name`, and
`tool_call_id` when applicable.
Supported middleware kinds:
| Kind | Payload | Return shape | Purpose |
| --- | --- | --- | --- |
| `llm_request` | `request`, `original_request` | `{"request": {...}}` | Replace effective provider kwargs before provider execution. |
| `tool_request` | `tool_name`, `args`, `original_args` | `{"args": {...}}` | Replace effective tool args before hooks, guardrails, approvals, and execution. |
| `llm_execution` | `request`, `original_request`, `next_call` | Any provider response | Wrap or replace the actual provider call. |
| `tool_execution` | `tool_name`, `args`, `original_args`, `next_call` | Any tool result | Wrap or replace the actual tool call. |
Request middleware can return optional trace fields:
```python
return {
"request": updated_request,
"source": "my-plugin",
"reason": "selected fallback model",
}
```
Hermes stores those trace entries in later observer hook payloads as
`middleware_trace`.
Execution middleware receives a `next_call` callback. Call it to continue the
chain:
```python
def on_tool_execution(**kwargs):
result = kwargs["next_call"](kwargs["args"])
return result
```
If multiple plugins register the same execution middleware kind, Hermes runs
them as a nested chain in registration order. Middleware failures are fail-open:
Hermes logs a warning and continues with the next middleware or the base
runtime path.
## Execution Order
### LLM Calls
For each provider request, Hermes applies middleware in this order:
1. Build provider kwargs from the current conversation.
2. Apply `llm_request` middleware.
3. Emit `pre_api_request` observer hooks with the effective request.
4. Run provider execution through `llm_execution` middleware.
5. Emit `post_api_request` or `api_request_error` observer hooks.
Request middleware sees the full provider kwargs, including `messages` or
Responses API `input`, model settings, tool definitions, stream options, and
provider-specific options. Execution middleware receives the same effective
request plus `next_call`.
### Tool Calls
For each tool call, Hermes applies middleware in this order:
1. Parse and coerce model-provided tool arguments.
2. Apply `tool_request` middleware.
3. Run the normal Hermes pre-execution path against the effective arguments:
tool availability checks, observer block directives, guardrails, and
approval checks.
4. Run tool execution through `tool_execution` middleware.
5. Emit `post_tool_call` observer hooks.
6. Apply `transform_tool_result` hooks before the result is appended back into
conversation context.
Tool request middleware runs before approval checks. Use it carefully: a
rewritten path, command, or URL is the value downstream policy will evaluate.
## Enablement
Middleware only runs for enabled plugins. For a bundled plugin:
```bash
hermes plugins enable <plugin-name>
```
For isolated local testing, use one `HERMES_HOME` for plugin enablement and the
agent run:
```bash
export HERMES_HOME=/tmp/hermes-middleware-test
mkdir -p "$HERMES_HOME"
hermes plugins enable <plugin-name>
hermes chat --query 'Reply exactly ok'
```
For source checkouts, prefer the source command so the runtime sees plugins and
middleware from the working tree:
```bash
uv sync
uv run hermes plugins enable <plugin-name>
uv run hermes chat --query 'Reply exactly ok'
```
## Generic Plugin Examples
The examples below are intentionally small. They show the middleware contract
shape without depending on NeMo Relay.
### LLM Request Middleware
This plugin tags provider requests and records a middleware trace entry:
```python
def register(ctx):
ctx.register_middleware("llm_request", tag_llm_request)
def tag_llm_request(**kwargs):
request = dict(kwargs["request"])
extra_body = dict(request.get("extra_body") or {})
extra_body.setdefault("metadata", {})["hermes_middleware_demo"] = True
request["extra_body"] = extra_body
return {
"request": request,
"source": "middleware-demo",
"reason": "tagged provider request",
}
```
The effective request is passed to `pre_api_request`, provider execution, and
`post_api_request`.
### Tool Request Middleware
This plugin constrains `terminal` calls to a known working directory:
```python
def register(ctx):
ctx.register_middleware("tool_request", normalize_terminal_workdir)
def normalize_terminal_workdir(**kwargs):
if kwargs.get("tool_name") != "terminal":
return None
args = dict(kwargs["args"])
args.setdefault("workdir", "/tmp/hermes-middleware-demo")
return {
"args": args,
"source": "middleware-demo",
"reason": "defaulted terminal workdir",
}
```
Because this runs before hooks and approvals, downstream telemetry and policy
observe the rewritten `workdir`.
### LLM Execution Middleware
This plugin wraps the provider call and preserves the raw provider response:
```python
import time
def register(ctx):
ctx.register_middleware("llm_execution", time_llm_execution)
def time_llm_execution(**kwargs):
started = time.monotonic()
response = kwargs["next_call"](kwargs["request"])
elapsed_ms = int((time.monotonic() - started) * 1000)
print(f"llm_execution elapsed_ms={elapsed_ms}")
return response
```
Return the same response shape Hermes expects from the provider adapter. Do not
wrap the response in a plugin-specific envelope unless the rest of the runtime
expects that envelope.
### Tool Execution Middleware
This plugin wraps tool execution while preserving the tool result:
```python
def register(ctx):
ctx.register_middleware("tool_execution", annotate_tool_execution)
def annotate_tool_execution(**kwargs):
result = kwargs["next_call"](kwargs["args"])
# Metrics, logging, or external routing can happen here.
return result
```
Execution middleware may call `next_call(modified_args)` to pass a changed
payload to later middleware and the base tool dispatcher.
Plugin-specific examples should live with the plugin that owns the behavior.
For NeMo Relay adaptive execution middleware, see
[`plugins/observability/nemo_relay/README.md`](../../plugins/observability/nemo_relay/README.md).
## Safety Notes
- Middleware should be deterministic for the same input unless it is explicitly
routing to a dynamic external system.
- Request middleware should return complete replacement payloads, not partial
patches.
- Execution middleware should call `next_call(...)` exactly once unless it is
intentionally short-circuiting execution.
- If execution middleware raises before calling `next_call(...)`, Hermes treats
that as middleware failure and continues with the remaining middleware chain
and base execution.
- If execution middleware calls `next_call(...)` successfully and then raises
during post-processing, Hermes preserves the downstream result and does not
run the provider or tool a second time.
- If downstream provider or tool execution fails, middleware may let that error
propagate or translate it deliberately. Hermes does not convert downstream
failure into a successful `None` result.
- Tool request middleware runs before approvals. If it mutates file paths,
commands, URLs, or arguments, the mutated values are what guardrails and
approvals evaluate.
- Observer hooks remain the right place for read-only telemetry. Use middleware
only when a plugin needs to alter or wrap behavior.

313
hermes_cli/middleware.py Normal file
View file

@ -0,0 +1,313 @@
"""Hermes middleware contract helpers.
Observer hooks report what happened. Middleware can change what happens by
rewriting a request or wrapping the actual execution callback. Keep the small
contract helpers here so agent-loop call sites and plugins share one vocabulary.
"""
from __future__ import annotations
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
logger = logging.getLogger(__name__)
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
TOOL_REQUEST_MIDDLEWARE = "tool_request"
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
LLM_REQUEST_MIDDLEWARE = "llm_request"
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
# Back-compat aliases for older PoC branches that used API terminology.
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
VALID_MIDDLEWARE: set[str] = {
TOOL_REQUEST_MIDDLEWARE,
TOOL_EXECUTION_MIDDLEWARE,
LLM_REQUEST_MIDDLEWARE,
LLM_EXECUTION_MIDDLEWARE,
}
@dataclass
class RequestMiddlewareResult:
"""Result of applying request middleware to a mutable payload."""
payload: Any
original_payload: Any
changed: bool = False
trace: List[Dict[str, Any]] = field(default_factory=list)
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
return kwargs
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
return kwargs
def _safe_copy(payload: Any) -> Any:
"""Deep-copy a request payload, tolerating non-deepcopyable members.
Request payloads are normally plain JSON-shaped dicts, but an LLM request
can occasionally carry non-deepcopyable objects (clients, callbacks, file
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
still runs and the original nested objects are shared by reference rather
than corrupting the live payload.
"""
try:
return deepcopy(payload)
except Exception as exc: # pragma: no cover - exercised via fallback test
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
if isinstance(payload, dict):
return dict(payload)
return payload
def apply_llm_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered LLM request middleware.
Middleware may return ``{"request": {...}}`` to replace the effective
provider kwargs before Hermes sends them.
"""
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=request,
original_payload=request,
changed=False,
trace=[],
)
original_request = _safe_copy(request)
current_request = _safe_copy(original_request)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
LLM_REQUEST_MIDDLEWARE,
request=current_request,
original_request=original_request,
**context,
):
if not isinstance(result, dict):
continue
next_request = result.get("request")
if not isinstance(next_request, dict):
continue
current_request = _safe_copy(next_request)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_request,
original_payload=original_request,
changed=bool(trace),
trace=trace,
)
def apply_tool_request_middleware(
tool_name: str,
args: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Apply registered tool request middleware.
Middleware may return ``{"args": {...}}`` to replace the effective tool
arguments before hooks, guardrails, approvals, and execution see them.
"""
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
return RequestMiddlewareResult(
payload=args,
original_payload=args,
changed=False,
trace=[],
)
original_args = _safe_copy(args)
current_args = _safe_copy(original_args)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
TOOL_REQUEST_MIDDLEWARE,
tool_name=tool_name,
args=current_args,
original_args=original_args,
**context,
):
if not isinstance(result, dict):
continue
next_args = result.get("args")
if not isinstance(next_args, dict):
continue
current_args = _safe_copy(next_args)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
payload=current_args,
original_payload=original_args,
changed=bool(trace),
trace=trace,
)
def apply_api_request_middleware(
request: Dict[str, Any],
**context: Any,
) -> RequestMiddlewareResult:
"""Compatibility wrapper for older ``api_request`` naming."""
return apply_llm_request_middleware(request, **context)
def run_llm_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run provider execution through registered LLM execution middleware."""
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(request)
return _run_execution_chain(
LLM_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
request=request,
original_request=context.pop("original_request", request),
**context,
)
def run_tool_execution_middleware(
tool_name: str,
args: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Run tool execution through registered tool execution middleware."""
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
if not callbacks:
return next_call(args)
return _run_execution_chain(
TOOL_EXECUTION_MIDDLEWARE,
callbacks,
next_call,
tool_name=tool_name,
args=args,
original_args=context.pop("original_args", args),
**context,
)
def run_api_execution_middleware(
request: Dict[str, Any],
next_call: Callable[[Dict[str, Any]], Any],
**context: Any,
) -> Any:
"""Compatibility wrapper for older ``api_execution`` naming."""
return run_llm_execution_middleware(request, next_call, **context)
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
from hermes_cli.plugins import invoke_middleware
return invoke_middleware(kind, **middleware_payload(**kwargs))
def _has_middleware(kind: str) -> bool:
from hermes_cli.plugins import has_middleware
return has_middleware(kind)
def _get_middleware_callbacks(kind: str) -> List[Callable]:
from hermes_cli.plugins import get_plugin_manager
return list(get_plugin_manager()._middleware.get(kind, []))
def _run_execution_chain(
kind: str,
callbacks: List[Callable],
terminal_call: Callable[[Any], Any],
**kwargs: Any,
) -> Any:
payload_key = "request" if "request" in kwargs else "args"
class _DownstreamExecutionError(Exception):
def __init__(self, original: BaseException) -> None:
super().__init__(str(original))
self.original = original
def call_at(index: int, payload: Any) -> Any:
if index >= len(callbacks):
return terminal_call(payload)
callback = callbacks[index]
next_called = False
next_succeeded = False
next_result: Any = None
def next_call(next_payload: Any = None) -> Any:
nonlocal next_called, next_succeeded, next_result
# ``next_call`` is single-use per middleware frame. Calling it more
# than once would re-run the downstream provider/tool, so a second
# invocation is a contract violation rather than a retry. Surface it
# instead of silently executing the terminal call twice.
if next_called:
raise RuntimeError(
f"Middleware '{kind}' callback "
f"{getattr(callback, '__name__', repr(callback))} called "
"next_call() more than once; downstream execution is single-use"
)
next_called = True
try:
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
next_succeeded = True
return next_result
except Exception as exc:
raise _DownstreamExecutionError(exc) from exc
call_kwargs = middleware_payload(**kwargs)
call_kwargs[payload_key] = payload
call_kwargs["next_call"] = next_call
try:
return callback(**call_kwargs)
except _DownstreamExecutionError as exc:
raise exc.original
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(callback, "__name__", repr(callback)),
exc,
)
if next_succeeded:
return next_result
if next_called:
raise
return call_at(index + 1, payload)
return call_at(0, kwargs[payload_key])
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
entry: Dict[str, Any] = {}
for key in ("source", "reason", "name"):
value = result.get(key)
if isinstance(value, str) and value:
entry[key] = value
if not entry:
entry["source"] = "plugin"
return entry

View file

@ -49,7 +49,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
from hermes_constants import get_hermes_home
from utils import env_var_enabled
from hermes_cli.config import cfg_get
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
from hermes_cli.middleware import OBSERVER_SCHEMA_VERSION, VALID_MIDDLEWARE
def get_bundled_plugins_dir() -> Path:
@ -277,6 +277,7 @@ class LoadedPlugin:
module: Optional[types.ModuleType] = None
tools_registered: List[str] = field(default_factory=list)
hooks_registered: List[str] = field(default_factory=list)
middleware_registered: List[str] = field(default_factory=list)
commands_registered: List[str] = field(default_factory=list)
enabled: bool = False
error: Optional[str] = None
@ -952,6 +953,27 @@ class PluginContext:
self._manager._hooks.setdefault(hook_name, []).append(callback)
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
# -- middleware registration -------------------------------------------
def register_middleware(self, kind: str, callback: Callable) -> None:
"""Register a behavior-changing middleware callback.
Middleware is separate from observer hooks: request middleware may
rewrite the effective payload, and execution middleware may wrap the
real callback. Unknown kinds are stored for forward compatibility but
warned so plugin authors can catch typos.
"""
if kind not in VALID_MIDDLEWARE:
logger.warning(
"Plugin '%s' registered unknown middleware '%s' "
"(valid: %s)",
self.manifest.name,
kind,
", ".join(sorted(VALID_MIDDLEWARE)),
)
self._manager._middleware.setdefault(kind, []).append(callback)
logger.debug("Plugin %s registered middleware: %s", self.manifest.name, kind)
# -- skill registration -------------------------------------------------
def register_skill(
@ -1010,6 +1032,7 @@ class PluginManager:
def __init__(self) -> None:
self._plugins: Dict[str, LoadedPlugin] = {}
self._hooks: Dict[str, List[Callable]] = {}
self._middleware: Dict[str, List[Callable]] = {}
self._plugin_tool_names: Set[str] = set()
self._plugin_platform_names: Set[str] = set()
self._cli_commands: Dict[str, dict] = {}
@ -1039,6 +1062,7 @@ class PluginManager:
if force:
self._plugins.clear()
self._hooks.clear()
self._middleware.clear()
self._plugin_tool_names.clear()
self._cli_commands.clear()
self._plugin_commands.clear()
@ -1449,15 +1473,28 @@ class PluginManager:
for h in p.hooks_registered
}
)
loaded.middleware_registered = list(
{
kind
for kind, cbs in self._middleware.items()
if cbs
}
- {
kind
for name, p in self._plugins.items()
for kind in p.middleware_registered
}
)
loaded.commands_registered = [
c for c in self._plugin_commands
if self._plugin_commands[c].get("plugin") == manifest.name
]
loaded.enabled = True
logger.debug(
" registered: %d tool(s), %d hook(s), %d slash command(s), %d CLI command(s)",
" registered: %d tool(s), %d hook(s), %d middleware, %d slash command(s), %d CLI command(s)",
len(loaded.tools_registered),
len(loaded.hooks_registered),
len(loaded.middleware_registered),
len(loaded.commands_registered),
sum(
1 for c in self._cli_commands
@ -1575,6 +1612,33 @@ class PluginManager:
"""Return True when at least one callback is registered for a hook."""
return bool(self._hooks.get(hook_name))
def has_middleware(self, kind: str) -> bool:
"""Return True when at least one callback is registered for middleware."""
return bool(self._middleware.get(kind))
def invoke_middleware(self, kind: str, **kwargs: Any) -> List[Any]:
"""Call registered middleware callbacks for *kind*.
Each callback is isolated so one plugin cannot break the base runtime
path. Middleware that wants to change behavior must return the shape
documented by the caller-specific contract.
"""
callbacks = self._middleware.get(kind, [])
results: List[Any] = []
for cb in callbacks:
try:
ret = cb(**kwargs)
if ret is not None:
results.append(ret)
except Exception as exc:
logger.warning(
"Middleware '%s' callback %s raised: %s",
kind,
getattr(cb, "__name__", repr(cb)),
exc,
)
return results
# -----------------------------------------------------------------------
# Introspection
# -----------------------------------------------------------------------
@ -1594,6 +1658,7 @@ class PluginManager:
"enabled": loaded.enabled,
"tools": len(loaded.tools_registered),
"hooks": len(loaded.hooks_registered),
"middleware": len(loaded.middleware_registered),
"commands": len(loaded.commands_registered),
"error": loaded.error,
}
@ -1655,6 +1720,23 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
return get_plugin_manager().invoke_hook(hook_name, **kwargs)
def invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
"""Invoke registered middleware callbacks.
Returns a list of non-``None`` return values from middleware callbacks.
"""
return get_plugin_manager().invoke_middleware(kind, **kwargs)
def has_middleware(kind: str) -> bool:
"""Return True when middleware callbacks are registered for ``kind``."""
manager = get_plugin_manager()
method = getattr(manager, "has_middleware", None)
if callable(method):
return bool(method(kind))
return bool(getattr(manager, "_middleware", {}).get(kind))
def has_hook(hook_name: str) -> bool:
"""Return True when a hook has registered callbacks."""
return get_plugin_manager().has_hook(hook_name)
@ -1683,6 +1765,7 @@ def get_pre_tool_call_block_message(
tool_call_id: str = "",
turn_id: str = "",
api_request_id: str = "",
middleware_trace: Optional[List[Dict[str, Any]]] = None,
) -> Optional[str]:
"""Check ``pre_tool_call`` hooks for a blocking directive.
@ -1709,6 +1792,7 @@ def get_pre_tool_call_block_message(
tool_call_id=tool_call_id,
turn_id=turn_id,
api_request_id=api_request_id,
middleware_trace=list(middleware_trace or []),
)
for result in hook_results:

View file

@ -823,6 +823,7 @@ def _emit_post_tool_call_hook(
status: Optional[str] = None,
error_type: Optional[str] = None,
error_message: Optional[str] = None,
middleware_trace: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Emit the ``post_tool_call`` observer hook.
@ -853,6 +854,7 @@ def _emit_post_tool_call_hook(
status=status,
error_type=error_type,
error_message=error_message,
middleware_trace=list(middleware_trace or []),
)
except Exception as _hook_err:
logger.debug("post_tool_call hook error: %s", _hook_err)
@ -869,6 +871,8 @@ def handle_function_call(
user_task: Optional[str] = None,
enabled_tools: Optional[List[str]] = None,
skip_pre_tool_call_hook: bool = False,
skip_tool_request_middleware: bool = False,
tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None,
enabled_toolsets: Optional[List[str]] = None,
disabled_toolsets: Optional[List[str]] = None,
) -> str:
@ -900,6 +904,7 @@ def handle_function_call(
function_args = coerce_tool_args(function_name, function_args)
if not isinstance(function_args, dict):
function_args = {}
_tool_middleware_trace = list(tool_request_middleware_trace or [])
# ── Tool Search bridge dispatch ──────────────────────────────────
# tool_search and tool_describe are pure catalog reads — handle them
@ -970,10 +975,32 @@ def handle_function_call(
user_task=user_task,
enabled_tools=enabled_tools,
skip_pre_tool_call_hook=skip_pre_tool_call_hook,
skip_tool_request_middleware=skip_tool_request_middleware,
tool_request_middleware_trace=list(_tool_middleware_trace),
enabled_toolsets=enabled_toolsets,
disabled_toolsets=disabled_toolsets,
)
_tool_original_args = dict(function_args)
if not skip_tool_request_middleware:
try:
from hermes_cli.middleware import apply_tool_request_middleware
_tool_request_mw = apply_tool_request_middleware(
function_name,
function_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
turn_id=turn_id or "",
api_request_id=api_request_id or "",
)
function_args = _tool_request_mw.payload
_tool_original_args = _tool_request_mw.original_payload
_tool_middleware_trace = _tool_request_mw.trace
except Exception as _mw_err:
logger.debug("tool_request middleware error: %s", _mw_err)
try:
if function_name in _AGENT_LOOP_TOOLS:
return json.dumps({"error": f"{function_name} must be handled by the agent loop"})
@ -1000,6 +1027,7 @@ def handle_function_call(
tool_call_id=tool_call_id or "",
turn_id=turn_id or "",
api_request_id=api_request_id or "",
middleware_trace=list(_tool_middleware_trace),
)
except Exception as _hook_err:
logger.debug("pre_tool_call hook error: %s", _hook_err)
@ -1018,6 +1046,7 @@ def handle_function_call(
status="blocked",
error_type="plugin_block",
error_message=block_message,
middleware_trace=list(_tool_middleware_trace),
)
return result
@ -1082,7 +1111,19 @@ def handle_function_call(
task_id=task_id,
user_task=user_task,
)
result = _dispatch(function_args)
from hermes_cli.middleware import run_tool_execution_middleware
result = run_tool_execution_middleware(
function_name,
function_args,
_dispatch,
original_args=_tool_original_args,
task_id=task_id or "",
session_id=session_id or "",
tool_call_id=tool_call_id or "",
turn_id=turn_id or "",
api_request_id=api_request_id or "",
)
finally:
if _approval_tokens is not None and reset_current_observability_context is not None:
try:
@ -1101,6 +1142,7 @@ def handle_function_call(
turn_id=turn_id,
api_request_id=api_request_id,
duration_ms=duration_ms,
middleware_trace=list(_tool_middleware_trace),
)
# Generic tool-result canonicalization seam: plugins receive the

View file

@ -165,6 +165,28 @@ When `HERMES_NEMO_RELAY_PLUGINS_TOML` is set and initializes successfully, NeMo
Relay owns exporter lifecycle through that config. The direct
`HERMES_NEMO_RELAY_ATOF_*` fallback setup is skipped.
To enable NeMo Relay managed execution intercepts for provider and tool calls,
include an adaptive component in the same `plugins.toml`:
```toml
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
```
When the adaptive component is enabled and the installed NeMo Relay runtime
exposes `llm.execute(...)` / `tools.execute(...)`, Hermes routes LLM and tool
execution through those middleware boundaries. The observer hooks still emit
session, turn, approval, and subagent marks; the plugin skips its manual
`llm.call` and `tools.call` spans for executions that are already managed by
NeMo Relay.
For the full generic Hermes middleware contract, see
[`docs/middleware/README.md`](../../../docs/middleware/README.md).
## Canonical Local Examples
The examples below use the official `nemo-relay==0.3` distribution and a local
@ -366,3 +388,166 @@ subagent IDs, role/status fields when present, and derived
`parent_trajectory_id` / `child_trajectory_id` values. This keeps the ATOF
stream lossless for later ATIF conversion that can compact subagents into
separate trajectories.
## Adaptive Middleware Example
The `observability/nemo_relay` plugin uses Hermes execution middleware to hand
LLM and tool calls to NeMo Relay managed execution when an adaptive component is
enabled.
Minimal `plugins.toml`:
```toml
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
```
Enable it for Hermes:
```bash
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/plugins.toml
```
When the adaptive component is enabled and the installed NeMo Relay runtime
exposes `llm.execute(...)` and `tools.execute(...)`, Hermes routes execution
through these boundaries:
```text
Hermes provider call
-> llm_execution middleware
-> nemo_relay.llm.execute(...)
-> Hermes provider adapter next_call(...)
Hermes tool call
-> tool_execution middleware
-> nemo_relay.tools.execute(...)
-> Hermes tool dispatcher next_call(...)
```
The plugin still emits observer marks for sessions, turns, approvals, and
subagents. When adaptive managed execution is active, it skips manual
`llm.call` and `tools.call` observer spans to avoid duplicate LLM/tool events
for the same execution.
### Local Adaptive E2E
This example enables both NeMo Relay observability export and adaptive execution
middleware for a local Hermes run.
```bash
pip install "nemo-relay==0.3"
export HERMES_HOME=/tmp/hermes-middleware-test/hermes-home
mkdir -p "$HERMES_HOME" /tmp/hermes-middleware-test/nemo-relay
cat > "$HERMES_HOME/config.yaml" <<'YAML'
model:
provider: custom
default: qwen3.6:35b
base_url: http://127.0.0.1:11434/v1
api_key: ollama
plugins:
enabled:
- observability/nemo_relay
YAML
cat > /tmp/hermes-middleware-test/nemo-relay/plugins.toml <<'TOML'
version = 1
[[components]]
kind = "observability"
enabled = true
[components.config]
version = 1
[components.config.atof]
enabled = true
output_directory = "/tmp/hermes-middleware-test/atof"
filename = "middleware-events.jsonl"
mode = "overwrite"
[components.config.atif]
enabled = true
output_directory = "/tmp/hermes-middleware-test/atif"
filename_template = "middleware-trajectory-{session_id}.json"
agent_name = "Hermes Middleware E2E"
agent_version = "local"
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
TOML
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/nemo-relay/plugins.toml
hermes chat \
--query 'Use the terminal tool exactly once to run printf middleware_execution_ok. Then reply with exactly the command output.' \
--provider custom \
--model qwen3.6:35b \
--toolsets terminal \
--max-turns 4 \
--quiet \
--accept-hooks
```
Expected CLI output:
```text
session_id: middleware-demo-session
middleware_execution_ok
```
Expected ATOF shape:
```jsonl
{"kind":"scope","category":"llm","name":"custom","scope_category":"start","metadata":{"session_id":"middleware-demo-session"},"data":{"mode":"route"}}
{"kind":"scope","category":"tool","name":"terminal","scope_category":"start","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal"},"data":{"mode":"route"}}
{"kind":"scope","category":"tool","name":"terminal","scope_category":"end","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal","status":"ok"},"data":"{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"}
```
Expected ATIF shape:
```json
{
"schema_version": "ATIF-v1.7",
"session_id": "middleware-demo-session",
"agent": {
"name": "Hermes Middleware E2E",
"version": "local",
"model_name": "qwen3.6:35b"
},
"steps": [
{
"source": "agent",
"tool_calls": [
{
"function_name": "terminal",
"arguments": {"command": "printf middleware_execution_ok"}
}
],
"observation": {
"results": [
{
"source_call_id": "call_terminal",
"content": "{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"
}
]
}
},
{
"source": "agent",
"message": "middleware_execution_ok"
}
]
}
```

View file

@ -42,6 +42,9 @@ class _SubagentParent:
@dataclass
class _Settings:
plugins_toml_path: str = ""
plugins_config: dict[str, Any] | None = None
adaptive_enabled: bool = False
adaptive_mode: str = "observe"
atof_enabled: bool = False
atof_output_directory: str = ""
atof_filename: str = "hermes-atof.jsonl"
@ -67,17 +70,15 @@ class _Runtime:
self._configure_atof()
def _configure_plugins_toml(self) -> bool:
if not self.settings.plugins_toml_path:
if not self.settings.plugins_config:
return False
plugin_mod = getattr(self.nemo_relay, "plugin", None)
initialize = getattr(plugin_mod, "initialize", None)
if not callable(initialize):
return False
config_path = Path(self.settings.plugins_toml_path)
try:
config = tomllib.loads(config_path.read_text(encoding="utf-8"))
self._ensure_plugin_config_output_dirs(config)
result = initialize(config)
self._ensure_plugin_config_output_dirs(self.settings.plugins_config)
result = initialize(self.settings.plugins_config)
if inspect.isawaitable(result):
asyncio.run(result)
return True
@ -221,6 +222,100 @@ class _Runtime:
self.subagent_parents.pop(child_session_id, None)
self.mark("hermes.subagent.stop", kwargs)
def managed_llm_enabled(self) -> bool:
return (
self.settings.adaptive_enabled
and callable(getattr(getattr(self.nemo_relay, "llm", None), "execute", None))
and callable(getattr(self.nemo_relay, "LLMRequest", None))
)
def managed_tool_enabled(self) -> bool:
return (
self.settings.adaptive_enabled
and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None))
)
def execute_llm(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
request_body = _jsonable(kwargs.get("request") or {})
request = self.nemo_relay.LLMRequest({}, request_body)
next_call = kwargs.get("next_call")
if not callable(next_call):
return request_body
raw_response: dict[str, Any] = {"set": False, "value": None}
def _impl(next_request: Any) -> Any:
next_body = getattr(next_request, "content", next_request)
raw = next_call(next_body if isinstance(next_body, dict) else request_body)
raw_response["set"] = True
raw_response["value"] = raw
return _llm_response_payload(raw)
async def _managed_execute() -> Any:
result = self.nemo_relay.llm.execute(
str(kwargs.get("provider") or "llm"),
request,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"api_call_count": kwargs.get("api_call_count"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
model_name=str(kwargs.get("model") or ""),
)
if inspect.isawaitable(result):
return await result
return result
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
def execute_tool(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
tool_name = str(kwargs.get("tool_name") or "tool")
args = _jsonable(kwargs.get("args") or {})
next_call = kwargs.get("next_call")
if not callable(next_call):
return args
raw_response: dict[str, Any] = {"set": False, "value": None}
def _impl(next_args: Any) -> Any:
effective_args = next_args if isinstance(next_args, dict) else args
raw = next_call(effective_args)
raw_response["set"] = True
raw_response["value"] = raw
return _jsonable(raw)
async def _managed_execute() -> Any:
result = self.nemo_relay.tools.execute(
tool_name,
args,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"tool_call_id": kwargs.get("tool_call_id"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
)
if inspect.isawaitable(result):
return await result
return result
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
def register(ctx) -> None:
ctx.register_hook("on_session_start", on_session_start)
@ -238,6 +333,8 @@ def register(ctx) -> None:
ctx.register_hook("post_approval_response", on_post_approval_response)
ctx.register_hook("subagent_start", on_subagent_start)
ctx.register_hook("subagent_stop", on_subagent_stop)
ctx.register_middleware("llm_execution", on_llm_execution_middleware)
ctx.register_middleware("tool_execution", on_tool_execution_middleware)
def on_session_start(**kwargs: Any) -> None:
@ -280,6 +377,8 @@ def on_pre_api_request(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -303,6 +402,8 @@ def on_post_api_request(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -324,6 +425,8 @@ def on_api_request_error(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_llm_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -345,6 +448,8 @@ def on_pre_tool_call(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_tool_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -365,6 +470,8 @@ def on_post_tool_call(**kwargs: Any) -> None:
runtime = _get_runtime()
if runtime is None:
return
if runtime.managed_tool_enabled():
return
def _record() -> None:
state = runtime.ensure_session(kwargs)
@ -406,6 +513,28 @@ def on_subagent_stop(**kwargs: Any) -> None:
_safe(lambda: runtime.mark_subagent_stop(kwargs))
def on_llm_execution_middleware(**kwargs: Any) -> Any:
runtime = _get_runtime()
next_call = kwargs.get("next_call")
request = kwargs.get("request") or {}
if runtime is not None and runtime.managed_llm_enabled():
return runtime.execute_llm(kwargs)
if callable(next_call):
return next_call(request)
return request
def on_tool_execution_middleware(**kwargs: Any) -> Any:
runtime = _get_runtime()
next_call = kwargs.get("next_call")
args = kwargs.get("args") or {}
if runtime is not None and runtime.managed_tool_enabled():
return runtime.execute_tool(kwargs)
if callable(next_call):
return next_call(args)
return args
def _get_runtime() -> Optional[_Runtime]:
global _RUNTIME
with _LOCK:
@ -429,8 +558,14 @@ def _get_runtime() -> Optional[_Runtime]:
def _load_settings() -> _Settings:
plugins_toml_path = _env("HERMES_NEMO_RELAY_PLUGINS_TOML")
plugins_config = _load_plugins_config(plugins_toml_path)
adaptive_config = _enabled_component_config(plugins_config, "adaptive")
return _Settings(
plugins_toml_path=_env("HERMES_NEMO_RELAY_PLUGINS_TOML"),
plugins_toml_path=plugins_toml_path,
plugins_config=plugins_config,
adaptive_enabled=adaptive_config is not None,
adaptive_mode=_adaptive_mode(adaptive_config),
atof_enabled=_env_bool("HERMES_NEMO_RELAY_ATOF_ENABLED"),
atof_output_directory=_env("HERMES_NEMO_RELAY_ATOF_OUTPUT_DIRECTORY"),
atof_filename=_env("HERMES_NEMO_RELAY_ATOF_FILENAME") or "hermes-atof.jsonl",
@ -445,6 +580,44 @@ def _load_settings() -> _Settings:
)
def _load_plugins_config(path: str) -> dict[str, Any] | None:
if not path:
return None
try:
return tomllib.loads(Path(path).read_text(encoding="utf-8"))
except Exception as exc:
logger.debug("NeMo Relay plugins.toml load failed: %s", exc, exc_info=True)
return None
def _enabled_component_config(
plugins_config: dict[str, Any] | None,
kind: str,
) -> dict[str, Any] | None:
if not isinstance(plugins_config, dict):
return None
components = plugins_config.get("components")
if not isinstance(components, list):
return None
for component in components:
if not isinstance(component, dict):
continue
if component.get("kind") != kind or not component.get("enabled", True):
continue
config = component.get("config")
return config if isinstance(config, dict) else {}
return None
def _adaptive_mode(config: dict[str, Any] | None) -> str:
if not isinstance(config, dict):
return "observe"
mode = config.get("mode")
if isinstance(mode, str) and mode.strip():
return mode.strip()
return "observe"
def _env(name: str) -> str:
return os.environ.get(name, "").strip()
@ -549,12 +722,78 @@ def _jsonable(value: Any) -> Any:
return _jsonable(value.model_dump(mode="json"))
except Exception:
pass
try:
if hasattr(value, "__dict__"):
return _jsonable(vars(value))
except Exception:
pass
try:
return json.loads(json.dumps(value, default=str))
except Exception:
return str(value)
def _value(obj: Any, key: str, default: Any = None) -> Any:
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)
def _llm_response_payload(response: Any) -> Any:
"""Return the LLM response shape NeMo Relay's ATIF conversion expects."""
payload = _jsonable(response)
if isinstance(payload, dict) and "assistant_message" in payload:
return payload
choices = _value(response, "choices")
if choices is None and isinstance(payload, dict):
choices = payload.get("choices")
first_choice = choices[0] if isinstance(choices, list) and choices else None
message = _value(first_choice, "message")
finish_reason = _value(first_choice, "finish_reason")
assistant_message: dict[str, Any] = {"role": "assistant", "content": ""}
if message is not None:
assistant_message["role"] = _value(message, "role", "assistant") or "assistant"
content = _value(message, "content")
if content is not None:
assistant_message["content"] = _jsonable(content)
tool_calls = _tool_calls_payload(_value(message, "tool_calls"))
if tool_calls:
assistant_message["tool_calls"] = tool_calls
reasoning = _value(message, "reasoning_content")
if reasoning is not None:
assistant_message["reasoning_content"] = _jsonable(reasoning)
elif isinstance(payload, dict):
assistant_message["content"] = payload.get("content") or payload.get("output_text") or ""
return {
"model": _value(response, "model", payload.get("model") if isinstance(payload, dict) else None),
"assistant_message": assistant_message,
"finish_reason": finish_reason,
"usage": _jsonable(_value(response, "usage", payload.get("usage") if isinstance(payload, dict) else None)),
}
def _tool_calls_payload(tool_calls: Any) -> list[dict[str, Any]]:
if not isinstance(tool_calls, list):
return []
normalized: list[dict[str, Any]] = []
for call in tool_calls:
function = _value(call, "function")
normalized.append(
{
"id": _value(call, "id"),
"type": _value(call, "type", "function") or "function",
"function": {
"name": _value(function, "name"),
"arguments": _value(function, "arguments"),
},
}
)
return normalized
def _safe(fn) -> None:
try:
fn()
@ -562,6 +801,35 @@ def _safe(fn) -> None:
logger.debug("NeMo Relay hook handling failed: %s", exc, exc_info=True)
def _resolve_awaitable(value: Any) -> Any:
if not inspect.isawaitable(value):
return value
try:
asyncio.get_running_loop()
except RuntimeError:
return asyncio.run(value)
result: dict[str, Any] = {}
error: dict[str, BaseException] = {}
def _runner() -> None:
try:
result["value"] = asyncio.run(value)
except BaseException as exc: # pragma: no cover - re-raised below
error["exc"] = exc
thread = threading.Thread(
target=_runner,
name="hermes-nemo-relay-awaitable",
daemon=True,
)
thread.start()
thread.join()
if "exc" in error:
raise error["exc"]
return result.get("value")
def reset_for_tests() -> None:
global _RUNTIME
with _LOCK:

View file

@ -4955,10 +4955,22 @@ class AIAgent:
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
tool_call_id: Optional[str] = None, messages: list = None,
pre_tool_block_checked: bool = False) -> str:
pre_tool_block_checked: bool = False,
skip_tool_request_middleware: bool = False,
tool_request_middleware_trace: Optional[list[dict[str, Any]]] = None) -> str:
"""Forwarder — see ``agent.agent_runtime_helpers.invoke_tool``."""
from agent.agent_runtime_helpers import invoke_tool
return invoke_tool(self, function_name, function_args, effective_task_id, tool_call_id, messages, pre_tool_block_checked)
return invoke_tool(
self,
function_name,
function_args,
effective_task_id,
tool_call_id,
messages,
pre_tool_block_checked,
skip_tool_request_middleware,
tool_request_middleware_trace,
)
@staticmethod
def _wrap_verbose(label: str, text: str, indent: str = " ") -> str:

View file

@ -18,8 +18,15 @@ from hermes_cli.plugins import (
get_plugin_command_handler,
get_plugin_commands,
get_pre_tool_call_block_message,
has_middleware,
resolve_plugin_command_result,
)
from hermes_cli.middleware import (
VALID_MIDDLEWARE,
apply_llm_request_middleware,
apply_tool_request_middleware,
run_tool_execution_middleware,
)
# ── Helpers ────────────────────────────────────────────────────────────────
@ -96,6 +103,223 @@ class TestPluginDiscovery:
assert "hello_plugin" in mgr._plugins
assert mgr._plugins["hello_plugin"].enabled
def test_plugin_can_register_and_invoke_middleware(self, tmp_path, monkeypatch):
plugins_dir = tmp_path / "hermes_test" / "plugins"
_make_plugin_dir(
plugins_dir,
"mw_plugin",
register_body=(
"ctx.register_middleware('llm_request', "
"lambda **kw: {'request': {**kw['request'], 'mw': True}})\n"
" ctx.register_middleware('tool_request', "
"lambda **kw: {'args': {**kw['args'], 'mw': True}})"
),
)
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
mgr = PluginManager()
mgr.discover_and_load()
assert "llm_request" in VALID_MIDDLEWARE
assert "tool_request" in VALID_MIDDLEWARE
assert set(mgr._plugins["mw_plugin"].middleware_registered) == {"llm_request", "tool_request"}
assert mgr.invoke_middleware("llm_request", request={"messages": []}) == [
{"request": {"messages": [], "mw": True}}
]
assert mgr.invoke_middleware("tool_request", args={"path": "README.md"}) == [
{"args": {"path": "README.md", "mw": True}}
]
assert mgr.has_middleware("llm_request") is True
def test_execution_middleware_does_not_retry_downstream_failure(self, monkeypatch):
calls = []
def middleware(**kwargs):
return kwargs["next_call"](kwargs["args"])
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise RuntimeError("tool failed")
with pytest.raises(RuntimeError, match="tool failed"):
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
assert calls == [{"command": "false"}]
def test_middleware_helpers_skip_no_listener_work(self, monkeypatch):
manager = types.SimpleNamespace(_middleware={})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
request = {"messages": []}
args = {"path": "README.md"}
llm_result = apply_llm_request_middleware(request)
tool_result = apply_tool_request_middleware("read_file", args)
assert llm_result.payload is request
assert llm_result.original_payload is request
assert llm_result.changed is False
assert llm_result.trace == []
assert tool_result.payload is args
assert tool_result.original_payload is args
assert tool_result.changed is False
assert tool_result.trace == []
assert run_tool_execution_middleware("terminal", args, lambda payload: payload) is args
assert has_middleware("tool_request") is False
def test_request_middleware_changed_tracks_trace_not_deep_equality(self, monkeypatch):
def same_payload_middleware(**kwargs):
return {"args": kwargs["args"], "source": "same-payload"}
manager = types.SimpleNamespace(
_middleware={"tool_request": [same_payload_middleware]},
invoke_middleware=lambda kind, **kwargs: [same_payload_middleware(**kwargs)],
)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
args = {"path": "README.md"}
result = apply_tool_request_middleware("read_file", args)
assert result.payload == args
assert result.original_payload == args
assert result.changed is True
assert result.trace == [{"source": "same-payload"}]
def test_execution_middleware_post_next_call_error_does_not_retry(self, monkeypatch):
calls = []
def middleware(**kwargs):
result = kwargs["next_call"](kwargs["args"])
raise RuntimeError(f"post-processing failed after {result}")
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
return "terminal-result"
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == "terminal-result"
assert calls == [{"command": "printf ok"}]
def test_execution_middleware_pre_next_call_error_fails_open_to_remaining_chain(self, monkeypatch):
calls = []
def failing_middleware(**kwargs):
calls.append("failing")
raise RuntimeError("middleware setup failed")
def downstream_middleware(**kwargs):
calls.append("downstream")
return kwargs["next_call"]({**kwargs["args"], "rewritten": True})
manager = types.SimpleNamespace(_middleware={"tool_execution": [failing_middleware, downstream_middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(("terminal", args))
return args
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == {"command": "printf ok", "rewritten": True}
assert calls == ["failing", "downstream", ("terminal", {"command": "printf ok", "rewritten": True})]
def test_execution_middleware_translated_downstream_failure_is_not_masked(self, monkeypatch):
calls = []
def middleware(**kwargs):
try:
return kwargs["next_call"](kwargs["args"])
except Exception as exc:
raise RuntimeError(f"translated downstream failure: {exc}") from exc
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise RuntimeError("terminal failed")
with pytest.raises(RuntimeError, match="translated downstream failure: terminal failed"):
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
assert calls == [{"command": "false"}]
def test_execution_middleware_downstream_base_exception_is_not_wrapped(self, monkeypatch):
calls = []
def middleware(**kwargs):
try:
return kwargs["next_call"](kwargs["args"])
except Exception as exc:
raise RuntimeError(f"middleware should not catch base exception: {exc}") from exc
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
raise KeyboardInterrupt()
with pytest.raises(KeyboardInterrupt):
run_tool_execution_middleware("terminal", {"command": "interrupt"}, terminal)
assert calls == [{"command": "interrupt"}]
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
calls = []
def middleware(**kwargs):
first = kwargs["next_call"](kwargs["args"])
# Deliberate misuse: a second next_call() must not re-run the
# downstream tool. The chain surfaces it as an error and preserves
# the first (successful) downstream result.
kwargs["next_call"](kwargs["args"])
return first
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
return "terminal-result"
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == "terminal-result"
assert calls == [{"command": "printf ok"}]
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
import threading
recorded = {}
def middleware(**kwargs):
recorded["args"] = kwargs["args"]
return None
manager = types.SimpleNamespace(
_middleware={"tool_request": [middleware]},
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
args = {"command": "noop", "lock": threading.Lock()}
result = apply_tool_request_middleware("terminal", args)
# Middleware ran (payload was copied via the shallow fallback) and the
# non-deepcopyable member is shared by reference rather than aborting.
assert recorded["args"]["command"] == "noop"
assert result.payload["command"] == "noop"
assert result.payload["lock"] is args["lock"]
def test_discover_project_plugins(self, tmp_path, monkeypatch):
"""Plugins in ./.hermes/plugins/ are discovered."""
project_dir = tmp_path / "project"

View file

@ -27,8 +27,16 @@ class _FakeNemoRelay:
pop=self._scope_pop,
event=self._scope_event,
)
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
self.llm = SimpleNamespace(
call=self._llm_call,
call_end=self._llm_call_end,
execute=self._llm_execute,
)
self.tools = SimpleNamespace(
call=self._tool_call,
call_end=self._tool_call_end,
execute=self._tool_execute,
)
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
self.LLMRequest = _FakeLLMRequest
self.AtofExporterConfig = _FakeAtofExporterConfig
@ -55,6 +63,12 @@ class _FakeNemoRelay:
def _llm_call_end(self, handle, response, **kwargs):
self.events.append(("llm.call_end", handle, response, kwargs))
def _llm_execute(self, name, request, func, **kwargs):
self.events.append(("llm.execute.start", name, request.content, kwargs))
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
self.events.append(("llm.execute.end", name, result, kwargs))
return result
def _tool_call(self, name, args, **kwargs):
handle = ("tool", name)
self.events.append(("tool.call", name, args, kwargs))
@ -63,6 +77,12 @@ class _FakeNemoRelay:
def _tool_call_end(self, handle, result, **kwargs):
self.events.append(("tool.call_end", handle, result, kwargs))
def _tool_execute(self, name, args, func, **kwargs):
self.events.append(("tool.execute.start", name, args, kwargs))
result = func({"intercepted": True, **args})
self.events.append(("tool.execute.end", name, result, kwargs))
return result
def _make_atof_exporter(self, config):
return _FakeAtofExporter(self.events, config)
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
assert atif_dir.is_dir()
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_request = {}
raw_choice = SimpleNamespace(
message=SimpleNamespace(
role="assistant",
content=None,
tool_calls=[
SimpleNamespace(
id="tool-1",
type="function",
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
)
],
reasoning_content="need a tool",
),
finish_reason="tool_calls",
)
def next_call(request):
seen_request.update(request)
return SimpleNamespace(
id="resp-1",
model="demo-model",
choices=[raw_choice],
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
)
response = plugin.on_llm_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
provider="anthropic",
model="demo-model",
api_call_count=1,
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert response.model == "demo-model"
assert response.choices == [raw_choice]
assert seen_request["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
assert execute_end[2] == {
"model": "demo-model",
"assistant_message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "tool-1",
"type": "function",
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
}
],
"reasoning_content": "need a tool",
},
"finish_reason": "tool_calls",
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
}
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": []},
next_call=lambda request: {"raw": request},
)
assert response == {"raw": {"messages": []}}
assert not any(event[0] == "llm.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
seen_args = {}
def next_call(args):
seen_args.update(args)
return {"raw": True, "args": args}
response = plugin.on_tool_execution_middleware(
session_id="s1",
task_id="t1",
turn_id="turn-1",
api_request_id="api-1",
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=next_call,
)
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
assert seen_args["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
response = plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
assert response == {"raw": {"command": "pwd"}}
assert not any(event[0] == "tool.execute.start" for event in fake.events)
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
base = {
"session_id": "s1",
"task_id": "t1",
"turn_id": "turn-1",
"api_request_id": "api-1",
}
plugin.on_pre_api_request(
**base,
provider="anthropic",
model="demo-model",
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
)
plugin.on_post_api_request(**base, response={"ok": True})
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
plugin.on_llm_execution_middleware(
**base,
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=lambda request: {"raw": request},
)
plugin.on_tool_execution_middleware(
**base,
tool_name="terminal",
tool_call_id="tool-1",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
event_names = [event[0] for event in fake.events]
assert "llm.call" not in event_names
assert "llm.call_end" not in event_names
assert "tool.call" not in event_names
assert "tool.call_end" not in event_names
assert "llm.execute.start" in event_names
assert "tool.execute.start" in event_names
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
sys.modules.pop("plugins.observability.nemo_relay", None)

View file

@ -2466,8 +2466,10 @@ class TestConcurrentToolExecution:
api_request_id="",
enabled_tools=list(agent.valid_tool_names),
skip_pre_tool_call_hook=True,
skip_tool_request_middleware=True,
enabled_toolsets=agent.enabled_toolsets,
disabled_toolsets=agent.disabled_toolsets,
tool_request_middleware_trace=[],
)
assert result == "result"
@ -2647,6 +2649,89 @@ class TestConcurrentToolExecution:
assert post_call[1]["result"] == '{"ok":true}'
assert post_call[1]["status"] == "ok"
def test_sequential_agent_level_tool_execution_middleware_wraps_inline_dispatch(self, agent, monkeypatch):
"""Sequential built-in tool paths should expose the adaptive execution boundary."""
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
hook_calls = []
seen = {}
def request_middleware(**kwargs):
return {
"args": {**kwargs["args"], "request_rewritten": True},
"source": "request-test",
}
def execution_middleware(**kwargs):
seen["middleware_args"] = kwargs["args"]
return kwargs["next_call"]({**kwargs["args"], "merge": True})
manager = SimpleNamespace(_middleware={
"tool_request": [request_middleware],
"tool_execution": [execution_middleware],
})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_middleware",
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
)
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
assert seen["middleware_args"] == {"todos": [], "request_rewritten": True}
mock_todo.assert_called_once_with(todos=[], merge=True, store=agent._todo_store)
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert post_call[1]["tool_name"] == "todo"
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True, "merge": True}
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
def test_concurrent_agent_level_tool_preserves_request_middleware_trace(self, agent, monkeypatch):
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
messages = []
hook_calls = []
def request_middleware(**kwargs):
return {
"args": {**kwargs["args"], "request_rewritten": True},
"source": "request-test",
}
manager = SimpleNamespace(_middleware={"tool_request": [request_middleware], "tool_execution": []})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_middleware",
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
)
monkeypatch.setattr(
"hermes_cli.plugins.get_pre_tool_call_block_message",
lambda *args, **kwargs: None,
)
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}'):
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert post_call[1]["tool_name"] == "todo"
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True}
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
def test_agent_runtime_post_hook_ownership_predicate_covers_agent_tools(self, agent):
"""Sequential and concurrent agent-level paths share post-hook ownership."""
from agent.agent_runtime_helpers import agent_runtime_owns_post_tool_hook

View file

@ -64,6 +64,7 @@ class TestHandleFunctionCall:
tool_call_id="call-1",
turn_id="",
api_request_id="",
middleware_trace=[],
),
call(
"post_tool_call",
@ -79,6 +80,7 @@ class TestHandleFunctionCall:
status="ok",
error_type=None,
error_message=None,
middleware_trace=[],
),
call(
"transform_tool_result",
@ -145,6 +147,60 @@ class TestHandleFunctionCall:
assert "post_tool_call" not in fired
assert "transform_tool_result" not in fired
def test_tool_request_and_execution_middleware_wrap_registry_dispatch(self, monkeypatch):
seen = {}
def fake_invoke_middleware(kind, **kwargs):
if kind == "tool_request":
return [{
"args": {**kwargs["args"], "rewritten": True},
"source": "test-middleware",
"reason": "rewrite",
}]
return []
def execution_middleware(**kwargs):
seen["execution_args"] = kwargs["args"]
return kwargs["next_call"]({**kwargs["args"], "wrapped": True})
def fake_dispatch(tool_name, args, **kwargs):
seen["dispatch"] = (tool_name, args, kwargs)
return json.dumps({"ok": True, "args": args})
manager = type(
"Manager",
(),
{"_middleware": {"tool_request": [fake_invoke_middleware], "tool_execution": [execution_middleware]}},
)()
monkeypatch.setattr("hermes_cli.plugins.invoke_middleware", fake_invoke_middleware)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
hook_calls = []
monkeypatch.setattr(
"hermes_cli.plugins.invoke_hook",
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
)
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
result = json.loads(
handle_function_call(
"web_search",
{"q": "test"},
task_id="task-1",
tool_call_id="tool-1",
session_id="session-1",
)
)
assert seen["execution_args"] == {"q": "test", "rewritten": True}
assert seen["dispatch"][1] == {"q": "test", "rewritten": True, "wrapped": True}
assert result["args"] == {"q": "test", "rewritten": True, "wrapped": True}
expected_trace = [{"source": "test-middleware", "reason": "rewrite"}]
pre_call = next(call for call in hook_calls if call[0] == "pre_tool_call")
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
assert pre_call[1]["middleware_trace"] == expected_trace
assert post_call[1]["middleware_trace"] == expected_trace
# =========================================================================
# Agent loop tools