mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
Merge pull request #29724 from bbednarski9/bbednarski/nmf-41B-nemoflow-plugin
feat(middleware): add adaptive middleware to hermes-agent, consumed by NeMo-Relay
This commit is contained in:
commit
d4a7bfd3aa
14 changed files with 2170 additions and 151 deletions
|
|
@ -1620,13 +1620,37 @@ def switch_model(agent, new_model, new_provider, api_key='', base_url='', api_mo
|
|||
|
||||
def invoke_tool(agent, function_name: str, function_args: dict, effective_task_id: str,
|
||||
tool_call_id: Optional[str] = None, messages: list = None,
|
||||
pre_tool_block_checked: bool = False) -> str:
|
||||
pre_tool_block_checked: bool = False,
|
||||
skip_tool_request_middleware: bool = False,
|
||||
tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None) -> str:
|
||||
"""Invoke a single tool and return the result string. No display logic.
|
||||
|
||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||
tools. Used by the concurrent execution path; the sequential path retains
|
||||
its own inline invocation for backward-compatible display handling.
|
||||
"""
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
_tool_middleware_trace = list(tool_request_middleware_trace or [])
|
||||
try:
|
||||
from hermes_cli.middleware import apply_tool_request_middleware
|
||||
|
||||
if not skip_tool_request_middleware:
|
||||
_tool_request_mw = apply_tool_request_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
task_id=effective_task_id or "",
|
||||
session_id=getattr(agent, "session_id", "") or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
)
|
||||
function_args = _tool_request_mw.payload
|
||||
_tool_middleware_trace = _tool_request_mw.trace
|
||||
except Exception as _mw_err:
|
||||
logger.debug("tool_request middleware error: %s", _mw_err)
|
||||
|
||||
# Check plugin hooks for a block directive before executing anything.
|
||||
block_message: Optional[str] = None
|
||||
if not pre_tool_block_checked:
|
||||
|
|
@ -1640,6 +1664,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
|
|||
tool_call_id=tool_call_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -1659,6 +1684,7 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
|
|||
status="blocked",
|
||||
error_type="plugin_block",
|
||||
error_message=block_message,
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -1666,12 +1692,13 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
|
|||
|
||||
tool_start_time = time.monotonic()
|
||||
|
||||
def _finish_agent_tool(result: Any) -> Any:
|
||||
def _finish_agent_tool(result: Any, observed_args: Optional[dict] = None) -> Any:
|
||||
hook_args = observed_args if isinstance(observed_args, dict) else function_args
|
||||
try:
|
||||
from model_tools import _emit_post_tool_call_hook
|
||||
_emit_post_tool_call_hook(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
function_args=hook_args,
|
||||
result=result,
|
||||
task_id=effective_task_id or "",
|
||||
session_id=getattr(agent, "session_id", "") or "",
|
||||
|
|
@ -1679,89 +1706,116 @@ def invoke_tool(agent, function_name: str, function_args: dict, effective_task_i
|
|||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
duration_ms=int((time.monotonic() - tool_start_time) * 1000),
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
|
||||
if function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _finish_agent_tool(
|
||||
_todo_tool(
|
||||
todos=function_args.get("todos"),
|
||||
merge=function_args.get("merge", False),
|
||||
store=agent._todo_store,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _finish_agent_tool(
|
||||
_todo_tool(
|
||||
todos=next_args.get("todos"),
|
||||
merge=next_args.get("merge", False),
|
||||
store=agent._todo_store,
|
||||
),
|
||||
next_args,
|
||||
)
|
||||
)
|
||||
elif function_name == "session_search":
|
||||
session_db = agent._get_session_db_for_recall()
|
||||
if not session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()}))
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
return _finish_agent_tool(
|
||||
_session_search(
|
||||
query=function_args.get("query", ""),
|
||||
role_filter=function_args.get("role_filter"),
|
||||
limit=function_args.get("limit", 3),
|
||||
session_id=function_args.get("session_id"),
|
||||
around_message_id=function_args.get("around_message_id"),
|
||||
window=function_args.get("window", 5),
|
||||
sort=function_args.get("sort"),
|
||||
db=session_db,
|
||||
current_session_id=agent.session_id,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
session_db = agent._get_session_db_for_recall()
|
||||
if not session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
return _finish_agent_tool(json.dumps({"success": False, "error": format_session_db_unavailable()}), next_args)
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
return _finish_agent_tool(
|
||||
_session_search(
|
||||
query=next_args.get("query", ""),
|
||||
role_filter=next_args.get("role_filter"),
|
||||
limit=next_args.get("limit", 3),
|
||||
session_id=next_args.get("session_id"),
|
||||
around_message_id=next_args.get("around_message_id"),
|
||||
window=next_args.get("window", 5),
|
||||
sort=next_args.get("sort"),
|
||||
db=session_db,
|
||||
current_session_id=agent.session_id,
|
||||
),
|
||||
next_args,
|
||||
)
|
||||
)
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=function_args.get("action"),
|
||||
target=target,
|
||||
content=function_args.get("content"),
|
||||
old_text=function_args.get("old_text"),
|
||||
store=agent._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if agent._memory_manager and function_args.get("action") in {"add", "replace"}:
|
||||
try:
|
||||
agent._memory_manager.on_memory_write(
|
||||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
metadata=agent._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return _finish_agent_tool(result)
|
||||
elif agent._memory_manager and agent._memory_manager.has_tool(function_name):
|
||||
return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, function_args))
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _finish_agent_tool(
|
||||
_clarify_tool(
|
||||
question=function_args.get("question", ""),
|
||||
choices=function_args.get("choices"),
|
||||
callback=agent.clarify_callback,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
target = next_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=next_args.get("action"),
|
||||
target=target,
|
||||
content=next_args.get("content"),
|
||||
old_text=next_args.get("old_text"),
|
||||
store=agent._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if agent._memory_manager and next_args.get("action") in {"add", "replace"}:
|
||||
try:
|
||||
agent._memory_manager.on_memory_write(
|
||||
next_args.get("action", ""),
|
||||
target,
|
||||
next_args.get("content", ""),
|
||||
metadata=agent._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return _finish_agent_tool(result, next_args)
|
||||
elif agent._memory_manager and agent._memory_manager.has_tool(function_name):
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return _finish_agent_tool(agent._memory_manager.handle_tool_call(function_name, next_args), next_args)
|
||||
elif function_name == "clarify":
|
||||
def _execute(next_args: dict) -> Any:
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _finish_agent_tool(
|
||||
_clarify_tool(
|
||||
question=next_args.get("question", ""),
|
||||
choices=next_args.get("choices"),
|
||||
callback=agent.clarify_callback,
|
||||
),
|
||||
next_args,
|
||||
)
|
||||
)
|
||||
elif function_name == "delegate_task":
|
||||
return _finish_agent_tool(agent._dispatch_delegate_task(function_args))
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return _finish_agent_tool(agent._dispatch_delegate_task(next_args), next_args)
|
||||
else:
|
||||
return _ra().handle_function_call(
|
||||
function_name, function_args, effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=agent.session_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
|
||||
skip_pre_tool_call_hook=True,
|
||||
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
|
||||
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
|
||||
)
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return _ra().handle_function_call(
|
||||
function_name, next_args, effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
session_id=agent.session_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
|
||||
skip_pre_tool_call_hook=True,
|
||||
skip_tool_request_middleware=True,
|
||||
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
|
||||
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
|
||||
tool_request_middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
|
||||
from hermes_cli.middleware import run_tool_execution_middleware
|
||||
|
||||
return run_tool_execution_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
lambda next_args: _execute(next_args if isinstance(next_args, dict) else function_args),
|
||||
original_args=function_args,
|
||||
task_id=effective_task_id or "",
|
||||
session_id=getattr(agent, "session_id", "") or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1239,6 +1239,28 @@ def run_conversation(
|
|||
_sanitize_structure_non_ascii(api_kwargs)
|
||||
if agent.api_mode == "codex_responses":
|
||||
api_kwargs = agent._get_transport().preflight_kwargs(api_kwargs, allow_stream=False)
|
||||
try:
|
||||
from hermes_cli.middleware import apply_llm_request_middleware
|
||||
|
||||
_llm_request_mw = apply_llm_request_middleware(
|
||||
api_kwargs,
|
||||
task_id=effective_task_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
session_id=agent.session_id or "",
|
||||
platform=agent.platform or "",
|
||||
model=agent.model,
|
||||
provider=agent.provider,
|
||||
base_url=agent.base_url,
|
||||
api_mode=agent.api_mode,
|
||||
api_call_count=api_call_count,
|
||||
)
|
||||
api_kwargs = _llm_request_mw.payload
|
||||
_original_api_kwargs = _llm_request_mw.original_payload
|
||||
_llm_middleware_trace = _llm_request_mw.trace
|
||||
except Exception:
|
||||
_original_api_kwargs = dict(api_kwargs)
|
||||
_llm_middleware_trace = []
|
||||
|
||||
try:
|
||||
from hermes_cli.plugins import (
|
||||
|
|
@ -1291,6 +1313,7 @@ def run_conversation(
|
|||
request_char_count=total_chars,
|
||||
max_tokens=agent.max_tokens,
|
||||
started_at=api_start_time,
|
||||
middleware_trace=list(_llm_middleware_trace),
|
||||
request=_request_payload,
|
||||
)
|
||||
except Exception:
|
||||
|
|
@ -1349,7 +1372,24 @@ def run_conversation(
|
|||
)
|
||||
return agent._interruptible_api_call(next_api_kwargs)
|
||||
|
||||
response = _perform_api_call(api_kwargs)
|
||||
from hermes_cli.middleware import run_llm_execution_middleware
|
||||
|
||||
response = run_llm_execution_middleware(
|
||||
api_kwargs,
|
||||
_perform_api_call,
|
||||
original_request=_original_api_kwargs,
|
||||
task_id=effective_task_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
session_id=agent.session_id or "",
|
||||
platform=agent.platform or "",
|
||||
model=agent.model,
|
||||
provider=agent.provider,
|
||||
base_url=agent.base_url,
|
||||
api_mode=agent.api_mode,
|
||||
api_call_count=api_call_count,
|
||||
middleware_trace=list(_llm_middleware_trace),
|
||||
)
|
||||
|
||||
api_duration = time.time() - api_start_time
|
||||
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ def _emit_terminal_post_tool_call(
|
|||
status: str | None = None,
|
||||
error_type: str | None = None,
|
||||
error_message: str | None = None,
|
||||
middleware_trace: Optional[list[dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
try:
|
||||
from model_tools import _emit_post_tool_call_hook
|
||||
|
|
@ -86,6 +87,7 @@ def _emit_terminal_post_tool_call(
|
|||
status=status,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
middleware_trace=list(middleware_trace or []),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -111,6 +113,7 @@ def _emit_cancelled_terminal_post_tool_call(
|
|||
start_time: float,
|
||||
reason: str = "user interrupt",
|
||||
error_type: str = "keyboard_interrupt",
|
||||
middleware_trace: Optional[list[dict[str, Any]]] = None,
|
||||
) -> str:
|
||||
result = _cancelled_tool_result(reason)
|
||||
_emit_terminal_post_tool_call(
|
||||
|
|
@ -124,6 +127,7 @@ def _emit_cancelled_terminal_post_tool_call(
|
|||
status="cancelled",
|
||||
error_type=error_type,
|
||||
error_message=f"Tool execution cancelled by {reason}",
|
||||
middleware_trace=list(middleware_trace or []),
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -177,6 +181,65 @@ def _tool_search_scoped_names(agent) -> frozenset:
|
|||
return names
|
||||
|
||||
|
||||
def _apply_tool_request_middleware_for_agent(
|
||||
agent,
|
||||
*,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
effective_task_id: str,
|
||||
tool_call_id: str,
|
||||
) -> tuple[dict, list[dict[str, Any]]]:
|
||||
try:
|
||||
from hermes_cli.middleware import apply_tool_request_middleware
|
||||
|
||||
result = apply_tool_request_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
task_id=effective_task_id or "",
|
||||
session_id=getattr(agent, "session_id", "") or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
)
|
||||
payload = result.payload if isinstance(result.payload, dict) else function_args
|
||||
return payload, list(result.trace)
|
||||
except Exception as exc:
|
||||
logger.debug("tool_request middleware error: %s", exc)
|
||||
return function_args, []
|
||||
|
||||
|
||||
def _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
*,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
effective_task_id: str,
|
||||
tool_call_id: str,
|
||||
execute,
|
||||
) -> tuple[Any, dict]:
|
||||
observed_args = function_args
|
||||
|
||||
def _execute(next_args: dict) -> Any:
|
||||
nonlocal observed_args
|
||||
observed_args = next_args if isinstance(next_args, dict) else function_args
|
||||
return execute(observed_args)
|
||||
|
||||
from hermes_cli.middleware import run_tool_execution_middleware
|
||||
|
||||
result = run_tool_execution_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
_execute,
|
||||
original_args=function_args,
|
||||
task_id=effective_task_id or "",
|
||||
session_id=getattr(agent, "session_id", "") or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
)
|
||||
return result, observed_args
|
||||
|
||||
|
||||
def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||
"""Execute multiple tool calls concurrently using a thread pool.
|
||||
|
||||
|
|
@ -198,7 +261,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
return
|
||||
|
||||
# ── Parse args + pre-execution bookkeeping ───────────────────────
|
||||
parsed_calls = [] # list of (tool_call, function_name, function_args)
|
||||
parsed_calls = [] # list of (tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail)
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
|
||||
|
|
@ -250,6 +313,14 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
function_args, middleware_trace = _apply_tool_request_middleware_for_agent(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
)
|
||||
|
||||
# ── Block evaluation (BEFORE checkpoint preflight) ───────────
|
||||
# We must know whether the tool will execute before touching
|
||||
# checkpoint state (dedup slot, real snapshots).
|
||||
|
|
@ -268,6 +339,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
status="blocked",
|
||||
error_type="tool_scope_block",
|
||||
error_message=_ts_scope_block,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
|
|
@ -280,6 +352,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
except Exception:
|
||||
block_message = None
|
||||
|
|
@ -296,6 +369,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
status="blocked",
|
||||
error_type="plugin_block",
|
||||
error_message=block_message,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
else:
|
||||
guardrail_decision = agent._tool_guardrails.before_call(function_name, function_args)
|
||||
|
|
@ -312,6 +386,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
status="blocked",
|
||||
error_type="guardrail_block",
|
||||
error_message=getattr(guardrail_decision, "message", None) or "Tool blocked by guardrail policy",
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
|
||||
# ── Checkpoint preflight (only for tools that will execute) ──
|
||||
|
|
@ -338,13 +413,13 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
|
||||
parsed_calls.append((tool_call, function_name, function_args, middleware_trace, block_result, blocked_by_guardrail))
|
||||
|
||||
# ── Logging / callbacks ──────────────────────────────────────────
|
||||
tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls)
|
||||
tool_names_str = ", ".join(name for _, name, _, _, _, _ in parsed_calls)
|
||||
if not agent.quiet_mode:
|
||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
|
||||
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
|
||||
args_str = json.dumps(args, ensure_ascii=False)
|
||||
if agent.verbose_logging:
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
||||
|
|
@ -353,7 +428,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
args_preview = args_str[:agent.log_prefix_chars] + "..." if len(args_str) > agent.log_prefix_chars else args_str
|
||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||
|
||||
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||
for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls:
|
||||
if block_result is not None:
|
||||
continue
|
||||
if agent.tool_progress_callback:
|
||||
|
|
@ -363,7 +438,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
except Exception as cb_err:
|
||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||
|
||||
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||
for tc, name, args, middleware_trace, block_result, blocked_by_guardrail in parsed_calls:
|
||||
if block_result is not None:
|
||||
continue
|
||||
if agent.tool_start_callback:
|
||||
|
|
@ -373,18 +448,18 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
logging.debug(f"Tool start callback error: {cb_err}")
|
||||
|
||||
# ── Concurrent execution ─────────────────────────────────────────
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag)
|
||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag, middleware_trace)
|
||||
results = [None] * num_tools
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
if block_result is not None:
|
||||
results[i] = (name, args, block_result, 0.0, True, True)
|
||||
results[i] = (name, args, block_result, 0.0, True, True, middleware_trace)
|
||||
|
||||
# Touch activity before launching workers so the gateway knows
|
||||
# we're executing tools (not stuck).
|
||||
agent._current_tool = tool_names_str
|
||||
agent._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}")
|
||||
|
||||
def _run_tool(index, tool_call, function_name, function_args):
|
||||
def _run_tool(index, tool_call, function_name, function_args, middleware_trace):
|
||||
"""Worker function executed in a thread."""
|
||||
# Register this worker tid so the agent can fan out an interrupt
|
||||
# to it — see AIAgent.interrupt(). Must happen first thing, and
|
||||
|
|
@ -423,6 +498,8 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
tool_call.id,
|
||||
messages=messages,
|
||||
pre_tool_block_checked=True,
|
||||
skip_tool_request_middleware=True,
|
||||
tool_request_middleware_trace=list(middleware_trace),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
try:
|
||||
|
|
@ -436,10 +513,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
start_time=start,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
duration = time.time() - start
|
||||
logger.info("tool %s cancelled (%.2fs)", function_name, duration)
|
||||
results[index] = (function_name, function_args, result, duration, True, False)
|
||||
results[index] = (function_name, function_args, result, duration, True, False, middleware_trace)
|
||||
return
|
||||
except Exception as tool_error:
|
||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||
|
|
@ -450,7 +528,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
|
||||
else:
|
||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
||||
results[index] = (function_name, function_args, result, duration, is_error, False)
|
||||
results[index] = (function_name, function_args, result, duration, is_error, False, middleware_trace)
|
||||
finally:
|
||||
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
||||
# have set so the next task scheduled onto this recycled tid
|
||||
|
|
@ -475,7 +553,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
try:
|
||||
runnable_calls = [
|
||||
(i, tc, name, args)
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
|
||||
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
|
||||
if block_result is None
|
||||
]
|
||||
futures = []
|
||||
|
|
@ -487,7 +565,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
# _approval_session_key) AND thread-local approval/sudo
|
||||
# callbacks into the worker thread; clears callbacks on exit.
|
||||
f = executor.submit(
|
||||
propagate_context_to_thread(_run_tool), i, tc, name, args
|
||||
propagate_context_to_thread(_run_tool), i, tc, name, args, parsed_calls[i][3]
|
||||
)
|
||||
futures.append(f)
|
||||
|
||||
|
|
@ -545,7 +623,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
||||
|
||||
# ── Post-execution: display per-tool results ─────────────────────
|
||||
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
for i, (tc, name, args, middleware_trace, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||
r = results[i]
|
||||
blocked = False
|
||||
if r is None:
|
||||
|
|
@ -562,6 +640,7 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
status="cancelled",
|
||||
error_type="keyboard_interrupt",
|
||||
error_message="Tool execution cancelled by user interrupt",
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
else:
|
||||
function_result = f"Error executing tool '{name}': thread did not return a result"
|
||||
|
|
@ -575,10 +654,11 @@ def execute_tool_calls_concurrent(agent, assistant_message, messages: list, effe
|
|||
status="error",
|
||||
error_type="thread_missing_result",
|
||||
error_message=function_result,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
tool_duration = 0.0
|
||||
else:
|
||||
function_name, function_args, function_result, tool_duration, is_error, blocked = r
|
||||
function_name, function_args, function_result, tool_duration, is_error, blocked, middleware_trace = r
|
||||
|
||||
if not blocked:
|
||||
function_result = agent._append_guardrail_observation(
|
||||
|
|
@ -738,6 +818,14 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
function_args, middleware_trace = _apply_tool_request_middleware_for_agent(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
)
|
||||
|
||||
# Check plugin hooks for a block directive before executing.
|
||||
_block_msg: Optional[str] = None
|
||||
_block_error_type = "plugin_block"
|
||||
|
|
@ -755,6 +843,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
turn_id=getattr(agent, "_current_turn_id", "") or "",
|
||||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -853,6 +942,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
status="blocked",
|
||||
error_type=_block_error_type,
|
||||
error_message=_block_msg,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
elif _guardrail_block_decision is not None:
|
||||
# Tool blocked by tool-loop guardrail — synthesize exactly one
|
||||
|
|
@ -869,71 +959,108 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
status="blocked",
|
||||
error_type="guardrail_block",
|
||||
error_message=getattr(_guardrail_block_decision, "message", None) or "Tool blocked by guardrail policy",
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
elif function_name == "todo":
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
function_result = _todo_tool(
|
||||
todos=function_args.get("todos"),
|
||||
merge=function_args.get("merge", False),
|
||||
store=agent._todo_store,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
from tools.todo_tool import todo_tool as _todo_tool
|
||||
return _todo_tool(
|
||||
todos=next_args.get("todos"),
|
||||
merge=next_args.get("merge", False),
|
||||
store=agent._todo_store,
|
||||
)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if agent._should_emit_quiet_tool_messages():
|
||||
agent._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "session_search":
|
||||
session_db = agent._get_session_db_for_recall()
|
||||
if not session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
|
||||
else:
|
||||
def _execute(next_args: dict) -> Any:
|
||||
session_db = agent._get_session_db_for_recall()
|
||||
if not session_db:
|
||||
from hermes_state import format_session_db_unavailable
|
||||
return json.dumps({"success": False, "error": format_session_db_unavailable()})
|
||||
from tools.session_search_tool import session_search as _session_search
|
||||
function_result = _session_search(
|
||||
query=function_args.get("query", ""),
|
||||
role_filter=function_args.get("role_filter"),
|
||||
limit=function_args.get("limit", 3),
|
||||
session_id=function_args.get("session_id"),
|
||||
around_message_id=function_args.get("around_message_id"),
|
||||
window=function_args.get("window", 5),
|
||||
sort=function_args.get("sort"),
|
||||
return _session_search(
|
||||
query=next_args.get("query", ""),
|
||||
role_filter=next_args.get("role_filter"),
|
||||
limit=next_args.get("limit", 3),
|
||||
session_id=next_args.get("session_id"),
|
||||
around_message_id=next_args.get("around_message_id"),
|
||||
window=next_args.get("window", 5),
|
||||
sort=next_args.get("sort"),
|
||||
db=session_db,
|
||||
current_session_id=agent.session_id,
|
||||
)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if agent._should_emit_quiet_tool_messages():
|
||||
agent._vprint(f" {_get_cute_tool_message_impl('session_search', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "memory":
|
||||
target = function_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
function_result = _memory_tool(
|
||||
action=function_args.get("action"),
|
||||
target=target,
|
||||
content=function_args.get("content"),
|
||||
old_text=function_args.get("old_text"),
|
||||
store=agent._memory_store,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
target = next_args.get("target", "memory")
|
||||
from tools.memory_tool import memory_tool as _memory_tool
|
||||
result = _memory_tool(
|
||||
action=next_args.get("action"),
|
||||
target=target,
|
||||
content=next_args.get("content"),
|
||||
old_text=next_args.get("old_text"),
|
||||
store=agent._memory_store,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if agent._memory_manager and next_args.get("action") in {"add", "replace"}:
|
||||
try:
|
||||
agent._memory_manager.on_memory_write(
|
||||
next_args.get("action", ""),
|
||||
target,
|
||||
next_args.get("content", ""),
|
||||
metadata=agent._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", None),
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return result
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
# Bridge: notify external memory provider of built-in memory writes
|
||||
if agent._memory_manager and function_args.get("action") in {"add", "replace"}:
|
||||
try:
|
||||
agent._memory_manager.on_memory_write(
|
||||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
metadata=agent._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", None),
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if agent._should_emit_quiet_tool_messages():
|
||||
agent._vprint(f" {_get_cute_tool_message_impl('memory', function_args, tool_duration, result=function_result)}")
|
||||
elif function_name == "clarify":
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
function_result = _clarify_tool(
|
||||
question=function_args.get("question", ""),
|
||||
choices=function_args.get("choices"),
|
||||
callback=agent.clarify_callback,
|
||||
def _execute(next_args: dict) -> Any:
|
||||
from tools.clarify_tool import clarify_tool as _clarify_tool
|
||||
return _clarify_tool(
|
||||
question=next_args.get("question", ""),
|
||||
choices=next_args.get("choices"),
|
||||
callback=agent.clarify_callback,
|
||||
)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
tool_duration = time.time() - tool_start_time
|
||||
if agent._should_emit_quiet_tool_messages():
|
||||
|
|
@ -957,7 +1084,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
agent._delegate_spinner = spinner
|
||||
_delegate_result = None
|
||||
try:
|
||||
function_result = agent._dispatch_delegate_task(function_args)
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return agent._dispatch_delegate_task(next_args)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
_delegate_result = function_result
|
||||
finally:
|
||||
agent._delegate_spinner = None
|
||||
|
|
@ -978,7 +1114,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
spinner.start()
|
||||
_ce_result = None
|
||||
try:
|
||||
function_result = agent.context_compressor.handle_tool_call(function_name, function_args, messages=messages)
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return agent.context_compressor.handle_tool_call(function_name, next_args, messages=messages)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
_ce_result = function_result
|
||||
except Exception as tool_error:
|
||||
function_result = json.dumps({"error": f"Context engine tool '{function_name}' failed: {tool_error}"})
|
||||
|
|
@ -1002,7 +1147,16 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
spinner.start()
|
||||
_mem_result = None
|
||||
try:
|
||||
function_result = agent._memory_manager.handle_tool_call(function_name, function_args)
|
||||
def _execute(next_args: dict) -> Any:
|
||||
return agent._memory_manager.handle_tool_call(function_name, next_args)
|
||||
function_result, function_args = _run_agent_tool_execution_middleware(
|
||||
agent,
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
execute=_execute,
|
||||
)
|
||||
_mem_result = function_result
|
||||
except Exception as tool_error:
|
||||
function_result = json.dumps({"error": f"Memory tool '{function_name}' failed: {tool_error}"})
|
||||
|
|
@ -1032,8 +1186,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
|
||||
skip_pre_tool_call_hook=True,
|
||||
skip_tool_request_middleware=True,
|
||||
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
|
||||
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
|
||||
tool_request_middleware_trace=list(middleware_trace),
|
||||
)
|
||||
_spinner_result = function_result
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -1044,6 +1200,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
start_time=tool_start_time,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
_spinner_result = function_result
|
||||
try:
|
||||
|
|
@ -1071,8 +1228,10 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
api_request_id=getattr(agent, "_current_api_request_id", "") or "",
|
||||
enabled_tools=list(agent.valid_tool_names) if agent.valid_tool_names else None,
|
||||
skip_pre_tool_call_hook=True,
|
||||
skip_tool_request_middleware=True,
|
||||
enabled_toolsets=getattr(agent, "enabled_toolsets", None),
|
||||
disabled_toolsets=getattr(agent, "disabled_toolsets", None),
|
||||
tool_request_middleware_trace=list(middleware_trace),
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
_emit_cancelled_terminal_post_tool_call(
|
||||
|
|
@ -1082,6 +1241,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
start_time=tool_start_time,
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
try:
|
||||
agent.interrupt("keyboard interrupt")
|
||||
|
|
@ -1126,6 +1286,7 @@ def execute_tool_calls_sequential(agent, assistant_message, messages: list, effe
|
|||
effective_task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", "") or "",
|
||||
duration_ms=int(tool_duration * 1000),
|
||||
middleware_trace=list(middleware_trace),
|
||||
)
|
||||
if not _execution_blocked:
|
||||
function_result = agent._append_guardrail_observation(
|
||||
|
|
|
|||
260
docs/middleware/README.md
Normal file
260
docs/middleware/README.md
Normal file
|
|
@ -0,0 +1,260 @@
|
|||
# Hermes Middleware
|
||||
|
||||
Hermes middleware is the behavior-changing companion to observer hooks.
|
||||
Observer hooks report what happened. Middleware can change what happens by
|
||||
rewriting a request before execution or by wrapping the execution callback
|
||||
itself.
|
||||
|
||||
This contract is intentionally backend-neutral. A plugin can use it for local
|
||||
policy, request shaping, tracing, adaptive routing, cache control, sandbox
|
||||
selection, or handoff to runtimes such as NeMo Relay without changing Hermes'
|
||||
planner, model provider adapters, tool registry, memory, or CLI UX.
|
||||
|
||||
With middleware enabled, plugins can:
|
||||
|
||||
- Rewrite LLM provider request kwargs before Hermes calls the provider.
|
||||
- Rewrite tool arguments before guardrails, approval checks, hooks, and tool
|
||||
execution see them.
|
||||
- Wrap the actual LLM execution callback while preserving Hermes retry,
|
||||
streaming, interrupt, and hook behavior.
|
||||
- Wrap the actual tool execution callback while preserving Hermes guardrails,
|
||||
approval, post-tool hooks, and tool-result transformation.
|
||||
|
||||
## Contract
|
||||
|
||||
Plugins register middleware from `register(ctx)`:
|
||||
|
||||
```python
|
||||
def register(ctx):
|
||||
ctx.register_middleware("llm_request", on_llm_request)
|
||||
ctx.register_middleware("llm_execution", on_llm_execution)
|
||||
ctx.register_middleware("tool_request", on_tool_request)
|
||||
ctx.register_middleware("tool_execution", on_tool_execution)
|
||||
```
|
||||
|
||||
Every middleware callback receives:
|
||||
|
||||
- `telemetry_schema_version`: currently `hermes.observer.v1`
|
||||
- `middleware_schema_version`: currently `hermes.middleware.v1`
|
||||
- Runtime context such as `session_id`, `task_id`, `turn_id`,
|
||||
`api_request_id`, `provider`, `model`, `api_mode`, `tool_name`, and
|
||||
`tool_call_id` when applicable.
|
||||
|
||||
Supported middleware kinds:
|
||||
|
||||
| Kind | Payload | Return shape | Purpose |
|
||||
| --- | --- | --- | --- |
|
||||
| `llm_request` | `request`, `original_request` | `{"request": {...}}` | Replace effective provider kwargs before provider execution. |
|
||||
| `tool_request` | `tool_name`, `args`, `original_args` | `{"args": {...}}` | Replace effective tool args before hooks, guardrails, approvals, and execution. |
|
||||
| `llm_execution` | `request`, `original_request`, `next_call` | Any provider response | Wrap or replace the actual provider call. |
|
||||
| `tool_execution` | `tool_name`, `args`, `original_args`, `next_call` | Any tool result | Wrap or replace the actual tool call. |
|
||||
|
||||
Request middleware can return optional trace fields:
|
||||
|
||||
```python
|
||||
return {
|
||||
"request": updated_request,
|
||||
"source": "my-plugin",
|
||||
"reason": "selected fallback model",
|
||||
}
|
||||
```
|
||||
|
||||
Hermes stores those trace entries in later observer hook payloads as
|
||||
`middleware_trace`.
|
||||
|
||||
Execution middleware receives a `next_call` callback. Call it to continue the
|
||||
chain:
|
||||
|
||||
```python
|
||||
def on_tool_execution(**kwargs):
|
||||
result = kwargs["next_call"](kwargs["args"])
|
||||
return result
|
||||
```
|
||||
|
||||
If multiple plugins register the same execution middleware kind, Hermes runs
|
||||
them as a nested chain in registration order. Middleware failures are fail-open:
|
||||
Hermes logs a warning and continues with the next middleware or the base
|
||||
runtime path.
|
||||
|
||||
## Execution Order
|
||||
|
||||
### LLM Calls
|
||||
|
||||
For each provider request, Hermes applies middleware in this order:
|
||||
|
||||
1. Build provider kwargs from the current conversation.
|
||||
2. Apply `llm_request` middleware.
|
||||
3. Emit `pre_api_request` observer hooks with the effective request.
|
||||
4. Run provider execution through `llm_execution` middleware.
|
||||
5. Emit `post_api_request` or `api_request_error` observer hooks.
|
||||
|
||||
Request middleware sees the full provider kwargs, including `messages` or
|
||||
Responses API `input`, model settings, tool definitions, stream options, and
|
||||
provider-specific options. Execution middleware receives the same effective
|
||||
request plus `next_call`.
|
||||
|
||||
### Tool Calls
|
||||
|
||||
For each tool call, Hermes applies middleware in this order:
|
||||
|
||||
1. Parse and coerce model-provided tool arguments.
|
||||
2. Apply `tool_request` middleware.
|
||||
3. Run the normal Hermes pre-execution path against the effective arguments:
|
||||
tool availability checks, observer block directives, guardrails, and
|
||||
approval checks.
|
||||
4. Run tool execution through `tool_execution` middleware.
|
||||
5. Emit `post_tool_call` observer hooks.
|
||||
6. Apply `transform_tool_result` hooks before the result is appended back into
|
||||
conversation context.
|
||||
|
||||
Tool request middleware runs before approval checks. Use it carefully: a
|
||||
rewritten path, command, or URL is the value downstream policy will evaluate.
|
||||
|
||||
## Enablement
|
||||
|
||||
Middleware only runs for enabled plugins. For a bundled plugin:
|
||||
|
||||
```bash
|
||||
hermes plugins enable <plugin-name>
|
||||
```
|
||||
|
||||
For isolated local testing, use one `HERMES_HOME` for plugin enablement and the
|
||||
agent run:
|
||||
|
||||
```bash
|
||||
export HERMES_HOME=/tmp/hermes-middleware-test
|
||||
mkdir -p "$HERMES_HOME"
|
||||
hermes plugins enable <plugin-name>
|
||||
hermes chat --query 'Reply exactly ok'
|
||||
```
|
||||
|
||||
For source checkouts, prefer the source command so the runtime sees plugins and
|
||||
middleware from the working tree:
|
||||
|
||||
```bash
|
||||
uv sync
|
||||
uv run hermes plugins enable <plugin-name>
|
||||
uv run hermes chat --query 'Reply exactly ok'
|
||||
```
|
||||
|
||||
## Generic Plugin Examples
|
||||
|
||||
The examples below are intentionally small. They show the middleware contract
|
||||
shape without depending on NeMo Relay.
|
||||
|
||||
### LLM Request Middleware
|
||||
|
||||
This plugin tags provider requests and records a middleware trace entry:
|
||||
|
||||
```python
|
||||
def register(ctx):
|
||||
ctx.register_middleware("llm_request", tag_llm_request)
|
||||
|
||||
|
||||
def tag_llm_request(**kwargs):
|
||||
request = dict(kwargs["request"])
|
||||
extra_body = dict(request.get("extra_body") or {})
|
||||
extra_body.setdefault("metadata", {})["hermes_middleware_demo"] = True
|
||||
request["extra_body"] = extra_body
|
||||
return {
|
||||
"request": request,
|
||||
"source": "middleware-demo",
|
||||
"reason": "tagged provider request",
|
||||
}
|
||||
```
|
||||
|
||||
The effective request is passed to `pre_api_request`, provider execution, and
|
||||
`post_api_request`.
|
||||
|
||||
### Tool Request Middleware
|
||||
|
||||
This plugin constrains `terminal` calls to a known working directory:
|
||||
|
||||
```python
|
||||
def register(ctx):
|
||||
ctx.register_middleware("tool_request", normalize_terminal_workdir)
|
||||
|
||||
|
||||
def normalize_terminal_workdir(**kwargs):
|
||||
if kwargs.get("tool_name") != "terminal":
|
||||
return None
|
||||
args = dict(kwargs["args"])
|
||||
args.setdefault("workdir", "/tmp/hermes-middleware-demo")
|
||||
return {
|
||||
"args": args,
|
||||
"source": "middleware-demo",
|
||||
"reason": "defaulted terminal workdir",
|
||||
}
|
||||
```
|
||||
|
||||
Because this runs before hooks and approvals, downstream telemetry and policy
|
||||
observe the rewritten `workdir`.
|
||||
|
||||
### LLM Execution Middleware
|
||||
|
||||
This plugin wraps the provider call and preserves the raw provider response:
|
||||
|
||||
```python
|
||||
import time
|
||||
|
||||
|
||||
def register(ctx):
|
||||
ctx.register_middleware("llm_execution", time_llm_execution)
|
||||
|
||||
|
||||
def time_llm_execution(**kwargs):
|
||||
started = time.monotonic()
|
||||
response = kwargs["next_call"](kwargs["request"])
|
||||
elapsed_ms = int((time.monotonic() - started) * 1000)
|
||||
print(f"llm_execution elapsed_ms={elapsed_ms}")
|
||||
return response
|
||||
```
|
||||
|
||||
Return the same response shape Hermes expects from the provider adapter. Do not
|
||||
wrap the response in a plugin-specific envelope unless the rest of the runtime
|
||||
expects that envelope.
|
||||
|
||||
### Tool Execution Middleware
|
||||
|
||||
This plugin wraps tool execution while preserving the tool result:
|
||||
|
||||
```python
|
||||
def register(ctx):
|
||||
ctx.register_middleware("tool_execution", annotate_tool_execution)
|
||||
|
||||
|
||||
def annotate_tool_execution(**kwargs):
|
||||
result = kwargs["next_call"](kwargs["args"])
|
||||
# Metrics, logging, or external routing can happen here.
|
||||
return result
|
||||
```
|
||||
|
||||
Execution middleware may call `next_call(modified_args)` to pass a changed
|
||||
payload to later middleware and the base tool dispatcher.
|
||||
|
||||
Plugin-specific examples should live with the plugin that owns the behavior.
|
||||
For NeMo Relay adaptive execution middleware, see
|
||||
[`plugins/observability/nemo_relay/README.md`](../../plugins/observability/nemo_relay/README.md).
|
||||
|
||||
## Safety Notes
|
||||
|
||||
- Middleware should be deterministic for the same input unless it is explicitly
|
||||
routing to a dynamic external system.
|
||||
- Request middleware should return complete replacement payloads, not partial
|
||||
patches.
|
||||
- Execution middleware should call `next_call(...)` exactly once unless it is
|
||||
intentionally short-circuiting execution.
|
||||
- If execution middleware raises before calling `next_call(...)`, Hermes treats
|
||||
that as middleware failure and continues with the remaining middleware chain
|
||||
and base execution.
|
||||
- If execution middleware calls `next_call(...)` successfully and then raises
|
||||
during post-processing, Hermes preserves the downstream result and does not
|
||||
run the provider or tool a second time.
|
||||
- If downstream provider or tool execution fails, middleware may let that error
|
||||
propagate or translate it deliberately. Hermes does not convert downstream
|
||||
failure into a successful `None` result.
|
||||
- Tool request middleware runs before approvals. If it mutates file paths,
|
||||
commands, URLs, or arguments, the mutated values are what guardrails and
|
||||
approvals evaluate.
|
||||
- Observer hooks remain the right place for read-only telemetry. Use middleware
|
||||
only when a plugin needs to alter or wrap behavior.
|
||||
313
hermes_cli/middleware.py
Normal file
313
hermes_cli/middleware.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
"""Hermes middleware contract helpers.
|
||||
|
||||
Observer hooks report what happened. Middleware can change what happens by
|
||||
rewriting a request or wrapping the actual execution callback. Keep the small
|
||||
contract helpers here so agent-loop call sites and plugins share one vocabulary.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
||||
MIDDLEWARE_SCHEMA_VERSION = "hermes.middleware.v1"
|
||||
|
||||
TOOL_REQUEST_MIDDLEWARE = "tool_request"
|
||||
TOOL_EXECUTION_MIDDLEWARE = "tool_execution"
|
||||
LLM_REQUEST_MIDDLEWARE = "llm_request"
|
||||
LLM_EXECUTION_MIDDLEWARE = "llm_execution"
|
||||
|
||||
# Back-compat aliases for older PoC branches that used API terminology.
|
||||
API_REQUEST_MIDDLEWARE = LLM_REQUEST_MIDDLEWARE
|
||||
API_EXECUTION_MIDDLEWARE = LLM_EXECUTION_MIDDLEWARE
|
||||
|
||||
VALID_MIDDLEWARE: set[str] = {
|
||||
TOOL_REQUEST_MIDDLEWARE,
|
||||
TOOL_EXECUTION_MIDDLEWARE,
|
||||
LLM_REQUEST_MIDDLEWARE,
|
||||
LLM_EXECUTION_MIDDLEWARE,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestMiddlewareResult:
|
||||
"""Result of applying request middleware to a mutable payload."""
|
||||
|
||||
payload: Any
|
||||
original_payload: Any
|
||||
changed: bool = False
|
||||
trace: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def observer_payload(**kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
||||
return kwargs
|
||||
|
||||
|
||||
def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs.setdefault("telemetry_schema_version", OBSERVER_SCHEMA_VERSION)
|
||||
kwargs.setdefault("middleware_schema_version", MIDDLEWARE_SCHEMA_VERSION)
|
||||
return kwargs
|
||||
|
||||
|
||||
def _safe_copy(payload: Any) -> Any:
|
||||
"""Deep-copy a request payload, tolerating non-deepcopyable members.
|
||||
|
||||
Request payloads are normally plain JSON-shaped dicts, but an LLM request
|
||||
can occasionally carry non-deepcopyable objects (clients, callbacks, file
|
||||
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
|
||||
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
|
||||
still runs and the original nested objects are shared by reference rather
|
||||
than corrupting the live payload.
|
||||
"""
|
||||
try:
|
||||
return deepcopy(payload)
|
||||
except Exception as exc: # pragma: no cover - exercised via fallback test
|
||||
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
return payload
|
||||
|
||||
|
||||
def apply_llm_request_middleware(
|
||||
request: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Apply registered LLM request middleware.
|
||||
|
||||
Middleware may return ``{"request": {...}}`` to replace the effective
|
||||
provider kwargs before Hermes sends them.
|
||||
"""
|
||||
if not _has_middleware(LLM_REQUEST_MIDDLEWARE):
|
||||
return RequestMiddlewareResult(
|
||||
payload=request,
|
||||
original_payload=request,
|
||||
changed=False,
|
||||
trace=[],
|
||||
)
|
||||
|
||||
original_request = _safe_copy(request)
|
||||
current_request = _safe_copy(original_request)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
LLM_REQUEST_MIDDLEWARE,
|
||||
request=current_request,
|
||||
original_request=original_request,
|
||||
**context,
|
||||
):
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
next_request = result.get("request")
|
||||
if not isinstance(next_request, dict):
|
||||
continue
|
||||
current_request = _safe_copy(next_request)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
payload=current_request,
|
||||
original_payload=original_request,
|
||||
changed=bool(trace),
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
|
||||
def apply_tool_request_middleware(
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Apply registered tool request middleware.
|
||||
|
||||
Middleware may return ``{"args": {...}}`` to replace the effective tool
|
||||
arguments before hooks, guardrails, approvals, and execution see them.
|
||||
"""
|
||||
if not _has_middleware(TOOL_REQUEST_MIDDLEWARE):
|
||||
return RequestMiddlewareResult(
|
||||
payload=args,
|
||||
original_payload=args,
|
||||
changed=False,
|
||||
trace=[],
|
||||
)
|
||||
|
||||
original_args = _safe_copy(args)
|
||||
current_args = _safe_copy(original_args)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
TOOL_REQUEST_MIDDLEWARE,
|
||||
tool_name=tool_name,
|
||||
args=current_args,
|
||||
original_args=original_args,
|
||||
**context,
|
||||
):
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
next_args = result.get("args")
|
||||
if not isinstance(next_args, dict):
|
||||
continue
|
||||
current_args = _safe_copy(next_args)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
payload=current_args,
|
||||
original_payload=original_args,
|
||||
changed=bool(trace),
|
||||
trace=trace,
|
||||
)
|
||||
|
||||
|
||||
def apply_api_request_middleware(
|
||||
request: Dict[str, Any],
|
||||
**context: Any,
|
||||
) -> RequestMiddlewareResult:
|
||||
"""Compatibility wrapper for older ``api_request`` naming."""
|
||||
return apply_llm_request_middleware(request, **context)
|
||||
|
||||
|
||||
def run_llm_execution_middleware(
|
||||
request: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Run provider execution through registered LLM execution middleware."""
|
||||
callbacks = _get_middleware_callbacks(LLM_EXECUTION_MIDDLEWARE)
|
||||
if not callbacks:
|
||||
return next_call(request)
|
||||
return _run_execution_chain(
|
||||
LLM_EXECUTION_MIDDLEWARE,
|
||||
callbacks,
|
||||
next_call,
|
||||
request=request,
|
||||
original_request=context.pop("original_request", request),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
def run_tool_execution_middleware(
|
||||
tool_name: str,
|
||||
args: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Run tool execution through registered tool execution middleware."""
|
||||
callbacks = _get_middleware_callbacks(TOOL_EXECUTION_MIDDLEWARE)
|
||||
if not callbacks:
|
||||
return next_call(args)
|
||||
return _run_execution_chain(
|
||||
TOOL_EXECUTION_MIDDLEWARE,
|
||||
callbacks,
|
||||
next_call,
|
||||
tool_name=tool_name,
|
||||
args=args,
|
||||
original_args=context.pop("original_args", args),
|
||||
**context,
|
||||
)
|
||||
|
||||
|
||||
def run_api_execution_middleware(
|
||||
request: Dict[str, Any],
|
||||
next_call: Callable[[Dict[str, Any]], Any],
|
||||
**context: Any,
|
||||
) -> Any:
|
||||
"""Compatibility wrapper for older ``api_execution`` naming."""
|
||||
return run_llm_execution_middleware(request, next_call, **context)
|
||||
|
||||
|
||||
def _invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
||||
from hermes_cli.plugins import invoke_middleware
|
||||
|
||||
return invoke_middleware(kind, **middleware_payload(**kwargs))
|
||||
|
||||
|
||||
def _has_middleware(kind: str) -> bool:
|
||||
from hermes_cli.plugins import has_middleware
|
||||
|
||||
return has_middleware(kind)
|
||||
|
||||
|
||||
def _get_middleware_callbacks(kind: str) -> List[Callable]:
|
||||
from hermes_cli.plugins import get_plugin_manager
|
||||
|
||||
return list(get_plugin_manager()._middleware.get(kind, []))
|
||||
|
||||
|
||||
def _run_execution_chain(
|
||||
kind: str,
|
||||
callbacks: List[Callable],
|
||||
terminal_call: Callable[[Any], Any],
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
payload_key = "request" if "request" in kwargs else "args"
|
||||
|
||||
class _DownstreamExecutionError(Exception):
|
||||
def __init__(self, original: BaseException) -> None:
|
||||
super().__init__(str(original))
|
||||
self.original = original
|
||||
|
||||
def call_at(index: int, payload: Any) -> Any:
|
||||
if index >= len(callbacks):
|
||||
return terminal_call(payload)
|
||||
|
||||
callback = callbacks[index]
|
||||
next_called = False
|
||||
next_succeeded = False
|
||||
next_result: Any = None
|
||||
|
||||
def next_call(next_payload: Any = None) -> Any:
|
||||
nonlocal next_called, next_succeeded, next_result
|
||||
# ``next_call`` is single-use per middleware frame. Calling it more
|
||||
# than once would re-run the downstream provider/tool, so a second
|
||||
# invocation is a contract violation rather than a retry. Surface it
|
||||
# instead of silently executing the terminal call twice.
|
||||
if next_called:
|
||||
raise RuntimeError(
|
||||
f"Middleware '{kind}' callback "
|
||||
f"{getattr(callback, '__name__', repr(callback))} called "
|
||||
"next_call() more than once; downstream execution is single-use"
|
||||
)
|
||||
next_called = True
|
||||
try:
|
||||
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
|
||||
next_succeeded = True
|
||||
return next_result
|
||||
except Exception as exc:
|
||||
raise _DownstreamExecutionError(exc) from exc
|
||||
|
||||
call_kwargs = middleware_payload(**kwargs)
|
||||
call_kwargs[payload_key] = payload
|
||||
call_kwargs["next_call"] = next_call
|
||||
try:
|
||||
return callback(**call_kwargs)
|
||||
except _DownstreamExecutionError as exc:
|
||||
raise exc.original
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Middleware '%s' callback %s raised: %s",
|
||||
kind,
|
||||
getattr(callback, "__name__", repr(callback)),
|
||||
exc,
|
||||
)
|
||||
if next_succeeded:
|
||||
return next_result
|
||||
if next_called:
|
||||
raise
|
||||
return call_at(index + 1, payload)
|
||||
|
||||
return call_at(0, kwargs[payload_key])
|
||||
|
||||
|
||||
def _trace_entry(result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
entry: Dict[str, Any] = {}
|
||||
for key in ("source", "reason", "name"):
|
||||
value = result.get(key)
|
||||
if isinstance(value, str) and value:
|
||||
entry[key] = value
|
||||
if not entry:
|
||||
entry["source"] = "plugin"
|
||||
return entry
|
||||
|
|
@ -49,7 +49,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
|
|||
from hermes_constants import get_hermes_home
|
||||
from utils import env_var_enabled
|
||||
from hermes_cli.config import cfg_get
|
||||
OBSERVER_SCHEMA_VERSION = "hermes.observer.v1"
|
||||
from hermes_cli.middleware import OBSERVER_SCHEMA_VERSION, VALID_MIDDLEWARE
|
||||
|
||||
|
||||
def get_bundled_plugins_dir() -> Path:
|
||||
|
|
@ -277,6 +277,7 @@ class LoadedPlugin:
|
|||
module: Optional[types.ModuleType] = None
|
||||
tools_registered: List[str] = field(default_factory=list)
|
||||
hooks_registered: List[str] = field(default_factory=list)
|
||||
middleware_registered: List[str] = field(default_factory=list)
|
||||
commands_registered: List[str] = field(default_factory=list)
|
||||
enabled: bool = False
|
||||
error: Optional[str] = None
|
||||
|
|
@ -952,6 +953,27 @@ class PluginContext:
|
|||
self._manager._hooks.setdefault(hook_name, []).append(callback)
|
||||
logger.debug("Plugin %s registered hook: %s", self.manifest.name, hook_name)
|
||||
|
||||
# -- middleware registration -------------------------------------------
|
||||
|
||||
def register_middleware(self, kind: str, callback: Callable) -> None:
|
||||
"""Register a behavior-changing middleware callback.
|
||||
|
||||
Middleware is separate from observer hooks: request middleware may
|
||||
rewrite the effective payload, and execution middleware may wrap the
|
||||
real callback. Unknown kinds are stored for forward compatibility but
|
||||
warned so plugin authors can catch typos.
|
||||
"""
|
||||
if kind not in VALID_MIDDLEWARE:
|
||||
logger.warning(
|
||||
"Plugin '%s' registered unknown middleware '%s' "
|
||||
"(valid: %s)",
|
||||
self.manifest.name,
|
||||
kind,
|
||||
", ".join(sorted(VALID_MIDDLEWARE)),
|
||||
)
|
||||
self._manager._middleware.setdefault(kind, []).append(callback)
|
||||
logger.debug("Plugin %s registered middleware: %s", self.manifest.name, kind)
|
||||
|
||||
# -- skill registration -------------------------------------------------
|
||||
|
||||
def register_skill(
|
||||
|
|
@ -1010,6 +1032,7 @@ class PluginManager:
|
|||
def __init__(self) -> None:
|
||||
self._plugins: Dict[str, LoadedPlugin] = {}
|
||||
self._hooks: Dict[str, List[Callable]] = {}
|
||||
self._middleware: Dict[str, List[Callable]] = {}
|
||||
self._plugin_tool_names: Set[str] = set()
|
||||
self._plugin_platform_names: Set[str] = set()
|
||||
self._cli_commands: Dict[str, dict] = {}
|
||||
|
|
@ -1039,6 +1062,7 @@ class PluginManager:
|
|||
if force:
|
||||
self._plugins.clear()
|
||||
self._hooks.clear()
|
||||
self._middleware.clear()
|
||||
self._plugin_tool_names.clear()
|
||||
self._cli_commands.clear()
|
||||
self._plugin_commands.clear()
|
||||
|
|
@ -1449,15 +1473,28 @@ class PluginManager:
|
|||
for h in p.hooks_registered
|
||||
}
|
||||
)
|
||||
loaded.middleware_registered = list(
|
||||
{
|
||||
kind
|
||||
for kind, cbs in self._middleware.items()
|
||||
if cbs
|
||||
}
|
||||
- {
|
||||
kind
|
||||
for name, p in self._plugins.items()
|
||||
for kind in p.middleware_registered
|
||||
}
|
||||
)
|
||||
loaded.commands_registered = [
|
||||
c for c in self._plugin_commands
|
||||
if self._plugin_commands[c].get("plugin") == manifest.name
|
||||
]
|
||||
loaded.enabled = True
|
||||
logger.debug(
|
||||
" registered: %d tool(s), %d hook(s), %d slash command(s), %d CLI command(s)",
|
||||
" registered: %d tool(s), %d hook(s), %d middleware, %d slash command(s), %d CLI command(s)",
|
||||
len(loaded.tools_registered),
|
||||
len(loaded.hooks_registered),
|
||||
len(loaded.middleware_registered),
|
||||
len(loaded.commands_registered),
|
||||
sum(
|
||||
1 for c in self._cli_commands
|
||||
|
|
@ -1575,6 +1612,33 @@ class PluginManager:
|
|||
"""Return True when at least one callback is registered for a hook."""
|
||||
return bool(self._hooks.get(hook_name))
|
||||
|
||||
def has_middleware(self, kind: str) -> bool:
|
||||
"""Return True when at least one callback is registered for middleware."""
|
||||
return bool(self._middleware.get(kind))
|
||||
|
||||
def invoke_middleware(self, kind: str, **kwargs: Any) -> List[Any]:
|
||||
"""Call registered middleware callbacks for *kind*.
|
||||
|
||||
Each callback is isolated so one plugin cannot break the base runtime
|
||||
path. Middleware that wants to change behavior must return the shape
|
||||
documented by the caller-specific contract.
|
||||
"""
|
||||
callbacks = self._middleware.get(kind, [])
|
||||
results: List[Any] = []
|
||||
for cb in callbacks:
|
||||
try:
|
||||
ret = cb(**kwargs)
|
||||
if ret is not None:
|
||||
results.append(ret)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Middleware '%s' callback %s raised: %s",
|
||||
kind,
|
||||
getattr(cb, "__name__", repr(cb)),
|
||||
exc,
|
||||
)
|
||||
return results
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Introspection
|
||||
# -----------------------------------------------------------------------
|
||||
|
|
@ -1594,6 +1658,7 @@ class PluginManager:
|
|||
"enabled": loaded.enabled,
|
||||
"tools": len(loaded.tools_registered),
|
||||
"hooks": len(loaded.hooks_registered),
|
||||
"middleware": len(loaded.middleware_registered),
|
||||
"commands": len(loaded.commands_registered),
|
||||
"error": loaded.error,
|
||||
}
|
||||
|
|
@ -1655,6 +1720,23 @@ def invoke_hook(hook_name: str, **kwargs: Any) -> List[Any]:
|
|||
return get_plugin_manager().invoke_hook(hook_name, **kwargs)
|
||||
|
||||
|
||||
def invoke_middleware(kind: str, **kwargs: Any) -> List[Any]:
|
||||
"""Invoke registered middleware callbacks.
|
||||
|
||||
Returns a list of non-``None`` return values from middleware callbacks.
|
||||
"""
|
||||
return get_plugin_manager().invoke_middleware(kind, **kwargs)
|
||||
|
||||
|
||||
def has_middleware(kind: str) -> bool:
|
||||
"""Return True when middleware callbacks are registered for ``kind``."""
|
||||
manager = get_plugin_manager()
|
||||
method = getattr(manager, "has_middleware", None)
|
||||
if callable(method):
|
||||
return bool(method(kind))
|
||||
return bool(getattr(manager, "_middleware", {}).get(kind))
|
||||
|
||||
|
||||
def has_hook(hook_name: str) -> bool:
|
||||
"""Return True when a hook has registered callbacks."""
|
||||
return get_plugin_manager().has_hook(hook_name)
|
||||
|
|
@ -1683,6 +1765,7 @@ def get_pre_tool_call_block_message(
|
|||
tool_call_id: str = "",
|
||||
turn_id: str = "",
|
||||
api_request_id: str = "",
|
||||
middleware_trace: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> Optional[str]:
|
||||
"""Check ``pre_tool_call`` hooks for a blocking directive.
|
||||
|
||||
|
|
@ -1709,6 +1792,7 @@ def get_pre_tool_call_block_message(
|
|||
tool_call_id=tool_call_id,
|
||||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
middleware_trace=list(middleware_trace or []),
|
||||
)
|
||||
|
||||
for result in hook_results:
|
||||
|
|
|
|||
|
|
@ -823,6 +823,7 @@ def _emit_post_tool_call_hook(
|
|||
status: Optional[str] = None,
|
||||
error_type: Optional[str] = None,
|
||||
error_message: Optional[str] = None,
|
||||
middleware_trace: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Emit the ``post_tool_call`` observer hook.
|
||||
|
||||
|
|
@ -853,6 +854,7 @@ def _emit_post_tool_call_hook(
|
|||
status=status,
|
||||
error_type=error_type,
|
||||
error_message=error_message,
|
||||
middleware_trace=list(middleware_trace or []),
|
||||
)
|
||||
except Exception as _hook_err:
|
||||
logger.debug("post_tool_call hook error: %s", _hook_err)
|
||||
|
|
@ -869,6 +871,8 @@ def handle_function_call(
|
|||
user_task: Optional[str] = None,
|
||||
enabled_tools: Optional[List[str]] = None,
|
||||
skip_pre_tool_call_hook: bool = False,
|
||||
skip_tool_request_middleware: bool = False,
|
||||
tool_request_middleware_trace: Optional[List[Dict[str, Any]]] = None,
|
||||
enabled_toolsets: Optional[List[str]] = None,
|
||||
disabled_toolsets: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
|
|
@ -900,6 +904,7 @@ def handle_function_call(
|
|||
function_args = coerce_tool_args(function_name, function_args)
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
_tool_middleware_trace = list(tool_request_middleware_trace or [])
|
||||
|
||||
# ── Tool Search bridge dispatch ──────────────────────────────────
|
||||
# tool_search and tool_describe are pure catalog reads — handle them
|
||||
|
|
@ -970,10 +975,32 @@ def handle_function_call(
|
|||
user_task=user_task,
|
||||
enabled_tools=enabled_tools,
|
||||
skip_pre_tool_call_hook=skip_pre_tool_call_hook,
|
||||
skip_tool_request_middleware=skip_tool_request_middleware,
|
||||
tool_request_middleware_trace=list(_tool_middleware_trace),
|
||||
enabled_toolsets=enabled_toolsets,
|
||||
disabled_toolsets=disabled_toolsets,
|
||||
)
|
||||
|
||||
_tool_original_args = dict(function_args)
|
||||
if not skip_tool_request_middleware:
|
||||
try:
|
||||
from hermes_cli.middleware import apply_tool_request_middleware
|
||||
|
||||
_tool_request_mw = apply_tool_request_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
task_id=task_id or "",
|
||||
session_id=session_id or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=turn_id or "",
|
||||
api_request_id=api_request_id or "",
|
||||
)
|
||||
function_args = _tool_request_mw.payload
|
||||
_tool_original_args = _tool_request_mw.original_payload
|
||||
_tool_middleware_trace = _tool_request_mw.trace
|
||||
except Exception as _mw_err:
|
||||
logger.debug("tool_request middleware error: %s", _mw_err)
|
||||
|
||||
try:
|
||||
if function_name in _AGENT_LOOP_TOOLS:
|
||||
return json.dumps({"error": f"{function_name} must be handled by the agent loop"})
|
||||
|
|
@ -1000,6 +1027,7 @@ def handle_function_call(
|
|||
tool_call_id=tool_call_id or "",
|
||||
turn_id=turn_id or "",
|
||||
api_request_id=api_request_id or "",
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
except Exception as _hook_err:
|
||||
logger.debug("pre_tool_call hook error: %s", _hook_err)
|
||||
|
|
@ -1018,6 +1046,7 @@ def handle_function_call(
|
|||
status="blocked",
|
||||
error_type="plugin_block",
|
||||
error_message=block_message,
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
return result
|
||||
|
||||
|
|
@ -1082,7 +1111,19 @@ def handle_function_call(
|
|||
task_id=task_id,
|
||||
user_task=user_task,
|
||||
)
|
||||
result = _dispatch(function_args)
|
||||
from hermes_cli.middleware import run_tool_execution_middleware
|
||||
|
||||
result = run_tool_execution_middleware(
|
||||
function_name,
|
||||
function_args,
|
||||
_dispatch,
|
||||
original_args=_tool_original_args,
|
||||
task_id=task_id or "",
|
||||
session_id=session_id or "",
|
||||
tool_call_id=tool_call_id or "",
|
||||
turn_id=turn_id or "",
|
||||
api_request_id=api_request_id or "",
|
||||
)
|
||||
finally:
|
||||
if _approval_tokens is not None and reset_current_observability_context is not None:
|
||||
try:
|
||||
|
|
@ -1101,6 +1142,7 @@ def handle_function_call(
|
|||
turn_id=turn_id,
|
||||
api_request_id=api_request_id,
|
||||
duration_ms=duration_ms,
|
||||
middleware_trace=list(_tool_middleware_trace),
|
||||
)
|
||||
|
||||
# Generic tool-result canonicalization seam: plugins receive the
|
||||
|
|
|
|||
|
|
@ -165,6 +165,28 @@ When `HERMES_NEMO_RELAY_PLUGINS_TOML` is set and initializes successfully, NeMo
|
|||
Relay owns exporter lifecycle through that config. The direct
|
||||
`HERMES_NEMO_RELAY_ATOF_*` fallback setup is skipped.
|
||||
|
||||
To enable NeMo Relay managed execution intercepts for provider and tool calls,
|
||||
include an adaptive component in the same `plugins.toml`:
|
||||
|
||||
```toml
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
```
|
||||
|
||||
When the adaptive component is enabled and the installed NeMo Relay runtime
|
||||
exposes `llm.execute(...)` / `tools.execute(...)`, Hermes routes LLM and tool
|
||||
execution through those middleware boundaries. The observer hooks still emit
|
||||
session, turn, approval, and subagent marks; the plugin skips its manual
|
||||
`llm.call` and `tools.call` spans for executions that are already managed by
|
||||
NeMo Relay.
|
||||
|
||||
For the full generic Hermes middleware contract, see
|
||||
[`docs/middleware/README.md`](../../../docs/middleware/README.md).
|
||||
|
||||
## Canonical Local Examples
|
||||
|
||||
The examples below use the official `nemo-relay==0.3` distribution and a local
|
||||
|
|
@ -366,3 +388,166 @@ subagent IDs, role/status fields when present, and derived
|
|||
`parent_trajectory_id` / `child_trajectory_id` values. This keeps the ATOF
|
||||
stream lossless for later ATIF conversion that can compact subagents into
|
||||
separate trajectories.
|
||||
|
||||
## Adaptive Middleware Example
|
||||
|
||||
The `observability/nemo_relay` plugin uses Hermes execution middleware to hand
|
||||
LLM and tool calls to NeMo Relay managed execution when an adaptive component is
|
||||
enabled.
|
||||
|
||||
Minimal `plugins.toml`:
|
||||
|
||||
```toml
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
```
|
||||
|
||||
Enable it for Hermes:
|
||||
|
||||
```bash
|
||||
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/plugins.toml
|
||||
```
|
||||
|
||||
When the adaptive component is enabled and the installed NeMo Relay runtime
|
||||
exposes `llm.execute(...)` and `tools.execute(...)`, Hermes routes execution
|
||||
through these boundaries:
|
||||
|
||||
```text
|
||||
Hermes provider call
|
||||
-> llm_execution middleware
|
||||
-> nemo_relay.llm.execute(...)
|
||||
-> Hermes provider adapter next_call(...)
|
||||
|
||||
Hermes tool call
|
||||
-> tool_execution middleware
|
||||
-> nemo_relay.tools.execute(...)
|
||||
-> Hermes tool dispatcher next_call(...)
|
||||
```
|
||||
|
||||
The plugin still emits observer marks for sessions, turns, approvals, and
|
||||
subagents. When adaptive managed execution is active, it skips manual
|
||||
`llm.call` and `tools.call` observer spans to avoid duplicate LLM/tool events
|
||||
for the same execution.
|
||||
|
||||
### Local Adaptive E2E
|
||||
|
||||
This example enables both NeMo Relay observability export and adaptive execution
|
||||
middleware for a local Hermes run.
|
||||
|
||||
```bash
|
||||
pip install "nemo-relay==0.3"
|
||||
|
||||
export HERMES_HOME=/tmp/hermes-middleware-test/hermes-home
|
||||
mkdir -p "$HERMES_HOME" /tmp/hermes-middleware-test/nemo-relay
|
||||
|
||||
cat > "$HERMES_HOME/config.yaml" <<'YAML'
|
||||
model:
|
||||
provider: custom
|
||||
default: qwen3.6:35b
|
||||
base_url: http://127.0.0.1:11434/v1
|
||||
api_key: ollama
|
||||
plugins:
|
||||
enabled:
|
||||
- observability/nemo_relay
|
||||
YAML
|
||||
|
||||
cat > /tmp/hermes-middleware-test/nemo-relay/plugins.toml <<'TOML'
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "observability"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
version = 1
|
||||
|
||||
[components.config.atof]
|
||||
enabled = true
|
||||
output_directory = "/tmp/hermes-middleware-test/atof"
|
||||
filename = "middleware-events.jsonl"
|
||||
mode = "overwrite"
|
||||
|
||||
[components.config.atif]
|
||||
enabled = true
|
||||
output_directory = "/tmp/hermes-middleware-test/atif"
|
||||
filename_template = "middleware-trajectory-{session_id}.json"
|
||||
agent_name = "Hermes Middleware E2E"
|
||||
agent_version = "local"
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
TOML
|
||||
|
||||
export HERMES_NEMO_RELAY_PLUGINS_TOML=/tmp/hermes-middleware-test/nemo-relay/plugins.toml
|
||||
|
||||
hermes chat \
|
||||
--query 'Use the terminal tool exactly once to run printf middleware_execution_ok. Then reply with exactly the command output.' \
|
||||
--provider custom \
|
||||
--model qwen3.6:35b \
|
||||
--toolsets terminal \
|
||||
--max-turns 4 \
|
||||
--quiet \
|
||||
--accept-hooks
|
||||
```
|
||||
|
||||
Expected CLI output:
|
||||
|
||||
```text
|
||||
session_id: middleware-demo-session
|
||||
middleware_execution_ok
|
||||
```
|
||||
|
||||
Expected ATOF shape:
|
||||
|
||||
```jsonl
|
||||
{"kind":"scope","category":"llm","name":"custom","scope_category":"start","metadata":{"session_id":"middleware-demo-session"},"data":{"mode":"route"}}
|
||||
{"kind":"scope","category":"tool","name":"terminal","scope_category":"start","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal"},"data":{"mode":"route"}}
|
||||
{"kind":"scope","category":"tool","name":"terminal","scope_category":"end","metadata":{"session_id":"middleware-demo-session","tool_call_id":"call_terminal","status":"ok"},"data":"{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"}
|
||||
```
|
||||
|
||||
Expected ATIF shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"schema_version": "ATIF-v1.7",
|
||||
"session_id": "middleware-demo-session",
|
||||
"agent": {
|
||||
"name": "Hermes Middleware E2E",
|
||||
"version": "local",
|
||||
"model_name": "qwen3.6:35b"
|
||||
},
|
||||
"steps": [
|
||||
{
|
||||
"source": "agent",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function_name": "terminal",
|
||||
"arguments": {"command": "printf middleware_execution_ok"}
|
||||
}
|
||||
],
|
||||
"observation": {
|
||||
"results": [
|
||||
{
|
||||
"source_call_id": "call_terminal",
|
||||
"content": "{\"output\":\"middleware_execution_ok\",\"exit_code\":0,\"error\":null}"
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
"source": "agent",
|
||||
"message": "middleware_execution_ok"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -42,6 +42,9 @@ class _SubagentParent:
|
|||
@dataclass
|
||||
class _Settings:
|
||||
plugins_toml_path: str = ""
|
||||
plugins_config: dict[str, Any] | None = None
|
||||
adaptive_enabled: bool = False
|
||||
adaptive_mode: str = "observe"
|
||||
atof_enabled: bool = False
|
||||
atof_output_directory: str = ""
|
||||
atof_filename: str = "hermes-atof.jsonl"
|
||||
|
|
@ -67,17 +70,15 @@ class _Runtime:
|
|||
self._configure_atof()
|
||||
|
||||
def _configure_plugins_toml(self) -> bool:
|
||||
if not self.settings.plugins_toml_path:
|
||||
if not self.settings.plugins_config:
|
||||
return False
|
||||
plugin_mod = getattr(self.nemo_relay, "plugin", None)
|
||||
initialize = getattr(plugin_mod, "initialize", None)
|
||||
if not callable(initialize):
|
||||
return False
|
||||
config_path = Path(self.settings.plugins_toml_path)
|
||||
try:
|
||||
config = tomllib.loads(config_path.read_text(encoding="utf-8"))
|
||||
self._ensure_plugin_config_output_dirs(config)
|
||||
result = initialize(config)
|
||||
self._ensure_plugin_config_output_dirs(self.settings.plugins_config)
|
||||
result = initialize(self.settings.plugins_config)
|
||||
if inspect.isawaitable(result):
|
||||
asyncio.run(result)
|
||||
return True
|
||||
|
|
@ -221,6 +222,100 @@ class _Runtime:
|
|||
self.subagent_parents.pop(child_session_id, None)
|
||||
self.mark("hermes.subagent.stop", kwargs)
|
||||
|
||||
def managed_llm_enabled(self) -> bool:
|
||||
return (
|
||||
self.settings.adaptive_enabled
|
||||
and callable(getattr(getattr(self.nemo_relay, "llm", None), "execute", None))
|
||||
and callable(getattr(self.nemo_relay, "LLMRequest", None))
|
||||
)
|
||||
|
||||
def managed_tool_enabled(self) -> bool:
|
||||
return (
|
||||
self.settings.adaptive_enabled
|
||||
and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None))
|
||||
)
|
||||
|
||||
def execute_llm(self, kwargs: dict[str, Any]) -> Any:
|
||||
state = self.ensure_session(kwargs)
|
||||
request_body = _jsonable(kwargs.get("request") or {})
|
||||
request = self.nemo_relay.LLMRequest({}, request_body)
|
||||
next_call = kwargs.get("next_call")
|
||||
if not callable(next_call):
|
||||
return request_body
|
||||
|
||||
raw_response: dict[str, Any] = {"set": False, "value": None}
|
||||
|
||||
def _impl(next_request: Any) -> Any:
|
||||
next_body = getattr(next_request, "content", next_request)
|
||||
raw = next_call(next_body if isinstance(next_body, dict) else request_body)
|
||||
raw_response["set"] = True
|
||||
raw_response["value"] = raw
|
||||
return _llm_response_payload(raw)
|
||||
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.llm.execute(
|
||||
str(kwargs.get("provider") or "llm"),
|
||||
request,
|
||||
_impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"api_call_count": kwargs.get("api_call_count"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
model_name=str(kwargs.get("model") or ""),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
managed_result = _resolve_awaitable(_managed_execute())
|
||||
return raw_response["value"] if raw_response["set"] else managed_result
|
||||
|
||||
def execute_tool(self, kwargs: dict[str, Any]) -> Any:
|
||||
state = self.ensure_session(kwargs)
|
||||
tool_name = str(kwargs.get("tool_name") or "tool")
|
||||
args = _jsonable(kwargs.get("args") or {})
|
||||
next_call = kwargs.get("next_call")
|
||||
if not callable(next_call):
|
||||
return args
|
||||
|
||||
raw_response: dict[str, Any] = {"set": False, "value": None}
|
||||
|
||||
def _impl(next_args: Any) -> Any:
|
||||
effective_args = next_args if isinstance(next_args, dict) else args
|
||||
raw = next_call(effective_args)
|
||||
raw_response["set"] = True
|
||||
raw_response["value"] = raw
|
||||
return _jsonable(raw)
|
||||
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.tools.execute(
|
||||
tool_name,
|
||||
args,
|
||||
_impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
managed_result = _resolve_awaitable(_managed_execute())
|
||||
return raw_response["value"] if raw_response["set"] else managed_result
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
ctx.register_hook("on_session_start", on_session_start)
|
||||
|
|
@ -238,6 +333,8 @@ def register(ctx) -> None:
|
|||
ctx.register_hook("post_approval_response", on_post_approval_response)
|
||||
ctx.register_hook("subagent_start", on_subagent_start)
|
||||
ctx.register_hook("subagent_stop", on_subagent_stop)
|
||||
ctx.register_middleware("llm_execution", on_llm_execution_middleware)
|
||||
ctx.register_middleware("tool_execution", on_tool_execution_middleware)
|
||||
|
||||
|
||||
def on_session_start(**kwargs: Any) -> None:
|
||||
|
|
@ -280,6 +377,8 @@ def on_pre_api_request(**kwargs: Any) -> None:
|
|||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return
|
||||
if runtime.managed_llm_enabled():
|
||||
return
|
||||
|
||||
def _record() -> None:
|
||||
state = runtime.ensure_session(kwargs)
|
||||
|
|
@ -303,6 +402,8 @@ def on_post_api_request(**kwargs: Any) -> None:
|
|||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return
|
||||
if runtime.managed_llm_enabled():
|
||||
return
|
||||
|
||||
def _record() -> None:
|
||||
state = runtime.ensure_session(kwargs)
|
||||
|
|
@ -324,6 +425,8 @@ def on_api_request_error(**kwargs: Any) -> None:
|
|||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return
|
||||
if runtime.managed_llm_enabled():
|
||||
return
|
||||
|
||||
def _record() -> None:
|
||||
state = runtime.ensure_session(kwargs)
|
||||
|
|
@ -345,6 +448,8 @@ def on_pre_tool_call(**kwargs: Any) -> None:
|
|||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return
|
||||
if runtime.managed_tool_enabled():
|
||||
return
|
||||
|
||||
def _record() -> None:
|
||||
state = runtime.ensure_session(kwargs)
|
||||
|
|
@ -365,6 +470,8 @@ def on_post_tool_call(**kwargs: Any) -> None:
|
|||
runtime = _get_runtime()
|
||||
if runtime is None:
|
||||
return
|
||||
if runtime.managed_tool_enabled():
|
||||
return
|
||||
|
||||
def _record() -> None:
|
||||
state = runtime.ensure_session(kwargs)
|
||||
|
|
@ -406,6 +513,28 @@ def on_subagent_stop(**kwargs: Any) -> None:
|
|||
_safe(lambda: runtime.mark_subagent_stop(kwargs))
|
||||
|
||||
|
||||
def on_llm_execution_middleware(**kwargs: Any) -> Any:
|
||||
runtime = _get_runtime()
|
||||
next_call = kwargs.get("next_call")
|
||||
request = kwargs.get("request") or {}
|
||||
if runtime is not None and runtime.managed_llm_enabled():
|
||||
return runtime.execute_llm(kwargs)
|
||||
if callable(next_call):
|
||||
return next_call(request)
|
||||
return request
|
||||
|
||||
|
||||
def on_tool_execution_middleware(**kwargs: Any) -> Any:
|
||||
runtime = _get_runtime()
|
||||
next_call = kwargs.get("next_call")
|
||||
args = kwargs.get("args") or {}
|
||||
if runtime is not None and runtime.managed_tool_enabled():
|
||||
return runtime.execute_tool(kwargs)
|
||||
if callable(next_call):
|
||||
return next_call(args)
|
||||
return args
|
||||
|
||||
|
||||
def _get_runtime() -> Optional[_Runtime]:
|
||||
global _RUNTIME
|
||||
with _LOCK:
|
||||
|
|
@ -429,8 +558,14 @@ def _get_runtime() -> Optional[_Runtime]:
|
|||
|
||||
|
||||
def _load_settings() -> _Settings:
|
||||
plugins_toml_path = _env("HERMES_NEMO_RELAY_PLUGINS_TOML")
|
||||
plugins_config = _load_plugins_config(plugins_toml_path)
|
||||
adaptive_config = _enabled_component_config(plugins_config, "adaptive")
|
||||
return _Settings(
|
||||
plugins_toml_path=_env("HERMES_NEMO_RELAY_PLUGINS_TOML"),
|
||||
plugins_toml_path=plugins_toml_path,
|
||||
plugins_config=plugins_config,
|
||||
adaptive_enabled=adaptive_config is not None,
|
||||
adaptive_mode=_adaptive_mode(adaptive_config),
|
||||
atof_enabled=_env_bool("HERMES_NEMO_RELAY_ATOF_ENABLED"),
|
||||
atof_output_directory=_env("HERMES_NEMO_RELAY_ATOF_OUTPUT_DIRECTORY"),
|
||||
atof_filename=_env("HERMES_NEMO_RELAY_ATOF_FILENAME") or "hermes-atof.jsonl",
|
||||
|
|
@ -445,6 +580,44 @@ def _load_settings() -> _Settings:
|
|||
)
|
||||
|
||||
|
||||
def _load_plugins_config(path: str) -> dict[str, Any] | None:
|
||||
if not path:
|
||||
return None
|
||||
try:
|
||||
return tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
logger.debug("NeMo Relay plugins.toml load failed: %s", exc, exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def _enabled_component_config(
|
||||
plugins_config: dict[str, Any] | None,
|
||||
kind: str,
|
||||
) -> dict[str, Any] | None:
|
||||
if not isinstance(plugins_config, dict):
|
||||
return None
|
||||
components = plugins_config.get("components")
|
||||
if not isinstance(components, list):
|
||||
return None
|
||||
for component in components:
|
||||
if not isinstance(component, dict):
|
||||
continue
|
||||
if component.get("kind") != kind or not component.get("enabled", True):
|
||||
continue
|
||||
config = component.get("config")
|
||||
return config if isinstance(config, dict) else {}
|
||||
return None
|
||||
|
||||
|
||||
def _adaptive_mode(config: dict[str, Any] | None) -> str:
|
||||
if not isinstance(config, dict):
|
||||
return "observe"
|
||||
mode = config.get("mode")
|
||||
if isinstance(mode, str) and mode.strip():
|
||||
return mode.strip()
|
||||
return "observe"
|
||||
|
||||
|
||||
def _env(name: str) -> str:
|
||||
return os.environ.get(name, "").strip()
|
||||
|
||||
|
|
@ -549,12 +722,78 @@ def _jsonable(value: Any) -> Any:
|
|||
return _jsonable(value.model_dump(mode="json"))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if hasattr(value, "__dict__"):
|
||||
return _jsonable(vars(value))
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
return json.loads(json.dumps(value, default=str))
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
|
||||
def _value(obj: Any, key: str, default: Any = None) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
|
||||
def _llm_response_payload(response: Any) -> Any:
|
||||
"""Return the LLM response shape NeMo Relay's ATIF conversion expects."""
|
||||
payload = _jsonable(response)
|
||||
if isinstance(payload, dict) and "assistant_message" in payload:
|
||||
return payload
|
||||
|
||||
choices = _value(response, "choices")
|
||||
if choices is None and isinstance(payload, dict):
|
||||
choices = payload.get("choices")
|
||||
first_choice = choices[0] if isinstance(choices, list) and choices else None
|
||||
message = _value(first_choice, "message")
|
||||
finish_reason = _value(first_choice, "finish_reason")
|
||||
|
||||
assistant_message: dict[str, Any] = {"role": "assistant", "content": ""}
|
||||
if message is not None:
|
||||
assistant_message["role"] = _value(message, "role", "assistant") or "assistant"
|
||||
content = _value(message, "content")
|
||||
if content is not None:
|
||||
assistant_message["content"] = _jsonable(content)
|
||||
tool_calls = _tool_calls_payload(_value(message, "tool_calls"))
|
||||
if tool_calls:
|
||||
assistant_message["tool_calls"] = tool_calls
|
||||
reasoning = _value(message, "reasoning_content")
|
||||
if reasoning is not None:
|
||||
assistant_message["reasoning_content"] = _jsonable(reasoning)
|
||||
elif isinstance(payload, dict):
|
||||
assistant_message["content"] = payload.get("content") or payload.get("output_text") or ""
|
||||
|
||||
return {
|
||||
"model": _value(response, "model", payload.get("model") if isinstance(payload, dict) else None),
|
||||
"assistant_message": assistant_message,
|
||||
"finish_reason": finish_reason,
|
||||
"usage": _jsonable(_value(response, "usage", payload.get("usage") if isinstance(payload, dict) else None)),
|
||||
}
|
||||
|
||||
|
||||
def _tool_calls_payload(tool_calls: Any) -> list[dict[str, Any]]:
|
||||
if not isinstance(tool_calls, list):
|
||||
return []
|
||||
normalized: list[dict[str, Any]] = []
|
||||
for call in tool_calls:
|
||||
function = _value(call, "function")
|
||||
normalized.append(
|
||||
{
|
||||
"id": _value(call, "id"),
|
||||
"type": _value(call, "type", "function") or "function",
|
||||
"function": {
|
||||
"name": _value(function, "name"),
|
||||
"arguments": _value(function, "arguments"),
|
||||
},
|
||||
}
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def _safe(fn) -> None:
|
||||
try:
|
||||
fn()
|
||||
|
|
@ -562,6 +801,35 @@ def _safe(fn) -> None:
|
|||
logger.debug("NeMo Relay hook handling failed: %s", exc, exc_info=True)
|
||||
|
||||
|
||||
def _resolve_awaitable(value: Any) -> Any:
|
||||
if not inspect.isawaitable(value):
|
||||
return value
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(value)
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
error: dict[str, BaseException] = {}
|
||||
|
||||
def _runner() -> None:
|
||||
try:
|
||||
result["value"] = asyncio.run(value)
|
||||
except BaseException as exc: # pragma: no cover - re-raised below
|
||||
error["exc"] = exc
|
||||
|
||||
thread = threading.Thread(
|
||||
target=_runner,
|
||||
name="hermes-nemo-relay-awaitable",
|
||||
daemon=True,
|
||||
)
|
||||
thread.start()
|
||||
thread.join()
|
||||
if "exc" in error:
|
||||
raise error["exc"]
|
||||
return result.get("value")
|
||||
|
||||
|
||||
def reset_for_tests() -> None:
|
||||
global _RUNTIME
|
||||
with _LOCK:
|
||||
|
|
|
|||
16
run_agent.py
16
run_agent.py
|
|
@ -4955,10 +4955,22 @@ class AIAgent:
|
|||
|
||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
||||
tool_call_id: Optional[str] = None, messages: list = None,
|
||||
pre_tool_block_checked: bool = False) -> str:
|
||||
pre_tool_block_checked: bool = False,
|
||||
skip_tool_request_middleware: bool = False,
|
||||
tool_request_middleware_trace: Optional[list[dict[str, Any]]] = None) -> str:
|
||||
"""Forwarder — see ``agent.agent_runtime_helpers.invoke_tool``."""
|
||||
from agent.agent_runtime_helpers import invoke_tool
|
||||
return invoke_tool(self, function_name, function_args, effective_task_id, tool_call_id, messages, pre_tool_block_checked)
|
||||
return invoke_tool(
|
||||
self,
|
||||
function_name,
|
||||
function_args,
|
||||
effective_task_id,
|
||||
tool_call_id,
|
||||
messages,
|
||||
pre_tool_block_checked,
|
||||
skip_tool_request_middleware,
|
||||
tool_request_middleware_trace,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_verbose(label: str, text: str, indent: str = " ") -> str:
|
||||
|
|
|
|||
|
|
@ -18,8 +18,15 @@ from hermes_cli.plugins import (
|
|||
get_plugin_command_handler,
|
||||
get_plugin_commands,
|
||||
get_pre_tool_call_block_message,
|
||||
has_middleware,
|
||||
resolve_plugin_command_result,
|
||||
)
|
||||
from hermes_cli.middleware import (
|
||||
VALID_MIDDLEWARE,
|
||||
apply_llm_request_middleware,
|
||||
apply_tool_request_middleware,
|
||||
run_tool_execution_middleware,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -96,6 +103,223 @@ class TestPluginDiscovery:
|
|||
assert "hello_plugin" in mgr._plugins
|
||||
assert mgr._plugins["hello_plugin"].enabled
|
||||
|
||||
def test_plugin_can_register_and_invoke_middleware(self, tmp_path, monkeypatch):
|
||||
plugins_dir = tmp_path / "hermes_test" / "plugins"
|
||||
_make_plugin_dir(
|
||||
plugins_dir,
|
||||
"mw_plugin",
|
||||
register_body=(
|
||||
"ctx.register_middleware('llm_request', "
|
||||
"lambda **kw: {'request': {**kw['request'], 'mw': True}})\n"
|
||||
" ctx.register_middleware('tool_request', "
|
||||
"lambda **kw: {'args': {**kw['args'], 'mw': True}})"
|
||||
),
|
||||
)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path / "hermes_test"))
|
||||
|
||||
mgr = PluginManager()
|
||||
mgr.discover_and_load()
|
||||
|
||||
assert "llm_request" in VALID_MIDDLEWARE
|
||||
assert "tool_request" in VALID_MIDDLEWARE
|
||||
assert set(mgr._plugins["mw_plugin"].middleware_registered) == {"llm_request", "tool_request"}
|
||||
assert mgr.invoke_middleware("llm_request", request={"messages": []}) == [
|
||||
{"request": {"messages": [], "mw": True}}
|
||||
]
|
||||
assert mgr.invoke_middleware("tool_request", args={"path": "README.md"}) == [
|
||||
{"args": {"path": "README.md", "mw": True}}
|
||||
]
|
||||
assert mgr.has_middleware("llm_request") is True
|
||||
|
||||
def test_execution_middleware_does_not_retry_downstream_failure(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise RuntimeError("tool failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="tool failed"):
|
||||
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
|
||||
|
||||
assert calls == [{"command": "false"}]
|
||||
|
||||
def test_middleware_helpers_skip_no_listener_work(self, monkeypatch):
|
||||
manager = types.SimpleNamespace(_middleware={})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
request = {"messages": []}
|
||||
args = {"path": "README.md"}
|
||||
|
||||
llm_result = apply_llm_request_middleware(request)
|
||||
tool_result = apply_tool_request_middleware("read_file", args)
|
||||
|
||||
assert llm_result.payload is request
|
||||
assert llm_result.original_payload is request
|
||||
assert llm_result.changed is False
|
||||
assert llm_result.trace == []
|
||||
assert tool_result.payload is args
|
||||
assert tool_result.original_payload is args
|
||||
assert tool_result.changed is False
|
||||
assert tool_result.trace == []
|
||||
assert run_tool_execution_middleware("terminal", args, lambda payload: payload) is args
|
||||
assert has_middleware("tool_request") is False
|
||||
|
||||
def test_request_middleware_changed_tracks_trace_not_deep_equality(self, monkeypatch):
|
||||
def same_payload_middleware(**kwargs):
|
||||
return {"args": kwargs["args"], "source": "same-payload"}
|
||||
|
||||
manager = types.SimpleNamespace(
|
||||
_middleware={"tool_request": [same_payload_middleware]},
|
||||
invoke_middleware=lambda kind, **kwargs: [same_payload_middleware(**kwargs)],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
args = {"path": "README.md"}
|
||||
result = apply_tool_request_middleware("read_file", args)
|
||||
|
||||
assert result.payload == args
|
||||
assert result.original_payload == args
|
||||
assert result.changed is True
|
||||
assert result.trace == [{"source": "same-payload"}]
|
||||
|
||||
def test_execution_middleware_post_next_call_error_does_not_retry(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
result = kwargs["next_call"](kwargs["args"])
|
||||
raise RuntimeError(f"post-processing failed after {result}")
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
return "terminal-result"
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == "terminal-result"
|
||||
assert calls == [{"command": "printf ok"}]
|
||||
|
||||
def test_execution_middleware_pre_next_call_error_fails_open_to_remaining_chain(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def failing_middleware(**kwargs):
|
||||
calls.append("failing")
|
||||
raise RuntimeError("middleware setup failed")
|
||||
|
||||
def downstream_middleware(**kwargs):
|
||||
calls.append("downstream")
|
||||
return kwargs["next_call"]({**kwargs["args"], "rewritten": True})
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [failing_middleware, downstream_middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(("terminal", args))
|
||||
return args
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == {"command": "printf ok", "rewritten": True}
|
||||
assert calls == ["failing", "downstream", ("terminal", {"command": "printf ok", "rewritten": True})]
|
||||
|
||||
def test_execution_middleware_translated_downstream_failure_is_not_masked(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
try:
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"translated downstream failure: {exc}") from exc
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise RuntimeError("terminal failed")
|
||||
|
||||
with pytest.raises(RuntimeError, match="translated downstream failure: terminal failed"):
|
||||
run_tool_execution_middleware("terminal", {"command": "false"}, terminal)
|
||||
|
||||
assert calls == [{"command": "false"}]
|
||||
|
||||
def test_execution_middleware_downstream_base_exception_is_not_wrapped(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
try:
|
||||
return kwargs["next_call"](kwargs["args"])
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"middleware should not catch base exception: {exc}") from exc
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
with pytest.raises(KeyboardInterrupt):
|
||||
run_tool_execution_middleware("terminal", {"command": "interrupt"}, terminal)
|
||||
|
||||
assert calls == [{"command": "interrupt"}]
|
||||
|
||||
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
first = kwargs["next_call"](kwargs["args"])
|
||||
# Deliberate misuse: a second next_call() must not re-run the
|
||||
# downstream tool. The chain surfaces it as an error and preserves
|
||||
# the first (successful) downstream result.
|
||||
kwargs["next_call"](kwargs["args"])
|
||||
return first
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
return "terminal-result"
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == "terminal-result"
|
||||
assert calls == [{"command": "printf ok"}]
|
||||
|
||||
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
|
||||
import threading
|
||||
|
||||
recorded = {}
|
||||
|
||||
def middleware(**kwargs):
|
||||
recorded["args"] = kwargs["args"]
|
||||
return None
|
||||
|
||||
manager = types.SimpleNamespace(
|
||||
_middleware={"tool_request": [middleware]},
|
||||
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
|
||||
args = {"command": "noop", "lock": threading.Lock()}
|
||||
result = apply_tool_request_middleware("terminal", args)
|
||||
|
||||
# Middleware ran (payload was copied via the shallow fallback) and the
|
||||
# non-deepcopyable member is shared by reference rather than aborting.
|
||||
assert recorded["args"]["command"] == "noop"
|
||||
assert result.payload["command"] == "noop"
|
||||
assert result.payload["lock"] is args["lock"]
|
||||
|
||||
def test_discover_project_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ./.hermes/plugins/ are discovered."""
|
||||
project_dir = tmp_path / "project"
|
||||
|
|
|
|||
|
|
@ -27,8 +27,16 @@ class _FakeNemoRelay:
|
|||
pop=self._scope_pop,
|
||||
event=self._scope_event,
|
||||
)
|
||||
self.llm = SimpleNamespace(call=self._llm_call, call_end=self._llm_call_end)
|
||||
self.tools = SimpleNamespace(call=self._tool_call, call_end=self._tool_call_end)
|
||||
self.llm = SimpleNamespace(
|
||||
call=self._llm_call,
|
||||
call_end=self._llm_call_end,
|
||||
execute=self._llm_execute,
|
||||
)
|
||||
self.tools = SimpleNamespace(
|
||||
call=self._tool_call,
|
||||
call_end=self._tool_call_end,
|
||||
execute=self._tool_execute,
|
||||
)
|
||||
self.plugin = SimpleNamespace(initialize=self._plugin_initialize)
|
||||
self.LLMRequest = _FakeLLMRequest
|
||||
self.AtofExporterConfig = _FakeAtofExporterConfig
|
||||
|
|
@ -55,6 +63,12 @@ class _FakeNemoRelay:
|
|||
def _llm_call_end(self, handle, response, **kwargs):
|
||||
self.events.append(("llm.call_end", handle, response, kwargs))
|
||||
|
||||
def _llm_execute(self, name, request, func, **kwargs):
|
||||
self.events.append(("llm.execute.start", name, request.content, kwargs))
|
||||
result = func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
self.events.append(("llm.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _tool_call(self, name, args, **kwargs):
|
||||
handle = ("tool", name)
|
||||
self.events.append(("tool.call", name, args, kwargs))
|
||||
|
|
@ -63,6 +77,12 @@ class _FakeNemoRelay:
|
|||
def _tool_call_end(self, handle, result, **kwargs):
|
||||
self.events.append(("tool.call_end", handle, result, kwargs))
|
||||
|
||||
def _tool_execute(self, name, args, func, **kwargs):
|
||||
self.events.append(("tool.execute.start", name, args, kwargs))
|
||||
result = func({"intercepted": True, **args})
|
||||
self.events.append(("tool.execute.end", name, result, kwargs))
|
||||
return result
|
||||
|
||||
def _make_atof_exporter(self, config):
|
||||
return _FakeAtofExporter(self.events, config)
|
||||
|
||||
|
|
@ -425,6 +445,221 @@ output_directory = "{atif_dir}"
|
|||
assert atif_dir.is_dir()
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_request = {}
|
||||
raw_choice = SimpleNamespace(
|
||||
message=SimpleNamespace(
|
||||
role="assistant",
|
||||
content=None,
|
||||
tool_calls=[
|
||||
SimpleNamespace(
|
||||
id="tool-1",
|
||||
type="function",
|
||||
function=SimpleNamespace(name="terminal", arguments='{"command":"pwd"}'),
|
||||
)
|
||||
],
|
||||
reasoning_content="need a tool",
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
|
||||
def next_call(request):
|
||||
seen_request.update(request)
|
||||
return SimpleNamespace(
|
||||
id="resp-1",
|
||||
model="demo-model",
|
||||
choices=[raw_choice],
|
||||
usage=SimpleNamespace(prompt_tokens=3, completion_tokens=5, total_tokens=8),
|
||||
)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
api_call_count=1,
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response.model == "demo-model"
|
||||
assert response.choices == [raw_choice]
|
||||
assert seen_request["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
|
||||
assert execute_end[2] == {
|
||||
"model": "demo-model",
|
||||
"assistant_message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tool-1",
|
||||
"type": "function",
|
||||
"function": {"name": "terminal", "arguments": '{"command":"pwd"}'},
|
||||
}
|
||||
],
|
||||
"reasoning_content": "need a tool",
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
"usage": {"prompt_tokens": 3, "completion_tokens": 5, "total_tokens": 8},
|
||||
}
|
||||
|
||||
|
||||
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": []},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"messages": []}}
|
||||
assert not any(event[0] == "llm.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_middleware_preserves_raw_response(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
seen_args = {}
|
||||
|
||||
def next_call(args):
|
||||
seen_args.update(args)
|
||||
return {"raw": True, "args": args}
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
task_id="t1",
|
||||
turn_id="turn-1",
|
||||
api_request_id="api-1",
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
|
||||
assert seen_args["intercepted"] is True
|
||||
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
|
||||
assert execute_start[3]["data"]["mode"] == "route"
|
||||
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
|
||||
|
||||
|
||||
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
||||
response = plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
assert response == {"raw": {"command": "pwd"}}
|
||||
assert not any(event[0] == "tool.execute.start" for event in fake.events)
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_execution_skips_duplicate_observer_spans(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config]
|
||||
mode = "route"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
base = {
|
||||
"session_id": "s1",
|
||||
"task_id": "t1",
|
||||
"turn_id": "turn-1",
|
||||
"api_request_id": "api-1",
|
||||
}
|
||||
plugin.on_pre_api_request(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"body": {"messages": [{"role": "user", "content": "hi"}]}},
|
||||
)
|
||||
plugin.on_post_api_request(**base, response={"ok": True})
|
||||
plugin.on_pre_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", args={"command": "pwd"})
|
||||
plugin.on_post_tool_call(**base, tool_name="terminal", tool_call_id="tool-1", result={"ok": True})
|
||||
|
||||
plugin.on_llm_execution_middleware(
|
||||
**base,
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
plugin.on_tool_execution_middleware(
|
||||
**base,
|
||||
tool_name="terminal",
|
||||
tool_call_id="tool-1",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
event_names = [event[0] for event in fake.events]
|
||||
assert "llm.call" not in event_names
|
||||
assert "llm.call_end" not in event_names
|
||||
assert "tool.call" not in event_names
|
||||
assert "tool.call_end" not in event_names
|
||||
assert "llm.execute.start" in event_names
|
||||
assert "tool.execute.start" in event_names
|
||||
|
||||
|
||||
def test_nemo_relay_plugin_noops_without_dependency(monkeypatch):
|
||||
monkeypatch.delitem(sys.modules, "nemo_relay", raising=False)
|
||||
sys.modules.pop("plugins.observability.nemo_relay", None)
|
||||
|
|
|
|||
|
|
@ -2466,8 +2466,10 @@ class TestConcurrentToolExecution:
|
|||
api_request_id="",
|
||||
enabled_tools=list(agent.valid_tool_names),
|
||||
skip_pre_tool_call_hook=True,
|
||||
skip_tool_request_middleware=True,
|
||||
enabled_toolsets=agent.enabled_toolsets,
|
||||
disabled_toolsets=agent.disabled_toolsets,
|
||||
tool_request_middleware_trace=[],
|
||||
)
|
||||
assert result == "result"
|
||||
|
||||
|
|
@ -2647,6 +2649,89 @@ class TestConcurrentToolExecution:
|
|||
assert post_call[1]["result"] == '{"ok":true}'
|
||||
assert post_call[1]["status"] == "ok"
|
||||
|
||||
def test_sequential_agent_level_tool_execution_middleware_wraps_inline_dispatch(self, agent, monkeypatch):
|
||||
"""Sequential built-in tool paths should expose the adaptive execution boundary."""
|
||||
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
hook_calls = []
|
||||
seen = {}
|
||||
|
||||
def request_middleware(**kwargs):
|
||||
return {
|
||||
"args": {**kwargs["args"], "request_rewritten": True},
|
||||
"source": "request-test",
|
||||
}
|
||||
|
||||
def execution_middleware(**kwargs):
|
||||
seen["middleware_args"] = kwargs["args"]
|
||||
return kwargs["next_call"]({**kwargs["args"], "merge": True})
|
||||
|
||||
manager = SimpleNamespace(_middleware={
|
||||
"tool_request": [request_middleware],
|
||||
"tool_execution": [execution_middleware],
|
||||
})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_middleware",
|
||||
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}') as mock_todo:
|
||||
agent._execute_tool_calls_sequential(mock_msg, messages, "task-1")
|
||||
|
||||
assert seen["middleware_args"] == {"todos": [], "request_rewritten": True}
|
||||
mock_todo.assert_called_once_with(todos=[], merge=True, store=agent._todo_store)
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert post_call[1]["tool_name"] == "todo"
|
||||
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True, "merge": True}
|
||||
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
|
||||
|
||||
def test_concurrent_agent_level_tool_preserves_request_middleware_trace(self, agent, monkeypatch):
|
||||
tool_call = _mock_tool_call(name="todo", arguments='{"todos":[]}', call_id="todo-1")
|
||||
mock_msg = _mock_assistant_msg(content="", tool_calls=[tool_call])
|
||||
messages = []
|
||||
hook_calls = []
|
||||
|
||||
def request_middleware(**kwargs):
|
||||
return {
|
||||
"args": {**kwargs["args"], "request_rewritten": True},
|
||||
"source": "request-test",
|
||||
}
|
||||
|
||||
manager = SimpleNamespace(_middleware={"tool_request": [request_middleware], "tool_execution": []})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_middleware",
|
||||
lambda kind, **kwargs: [request_middleware(**kwargs)] if kind == "tool_request" else [],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.get_pre_tool_call_block_message",
|
||||
lambda *args, **kwargs: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
|
||||
with patch("tools.todo_tool.todo_tool", return_value='{"ok":true}'):
|
||||
agent._execute_tool_calls_concurrent(mock_msg, messages, "task-1")
|
||||
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert post_call[1]["tool_name"] == "todo"
|
||||
assert post_call[1]["args"] == {"todos": [], "request_rewritten": True}
|
||||
assert post_call[1]["middleware_trace"] == [{"source": "request-test"}]
|
||||
|
||||
def test_agent_runtime_post_hook_ownership_predicate_covers_agent_tools(self, agent):
|
||||
"""Sequential and concurrent agent-level paths share post-hook ownership."""
|
||||
from agent.agent_runtime_helpers import agent_runtime_owns_post_tool_hook
|
||||
|
|
|
|||
|
|
@ -64,6 +64,7 @@ class TestHandleFunctionCall:
|
|||
tool_call_id="call-1",
|
||||
turn_id="",
|
||||
api_request_id="",
|
||||
middleware_trace=[],
|
||||
),
|
||||
call(
|
||||
"post_tool_call",
|
||||
|
|
@ -79,6 +80,7 @@ class TestHandleFunctionCall:
|
|||
status="ok",
|
||||
error_type=None,
|
||||
error_message=None,
|
||||
middleware_trace=[],
|
||||
),
|
||||
call(
|
||||
"transform_tool_result",
|
||||
|
|
@ -145,6 +147,60 @@ class TestHandleFunctionCall:
|
|||
assert "post_tool_call" not in fired
|
||||
assert "transform_tool_result" not in fired
|
||||
|
||||
def test_tool_request_and_execution_middleware_wrap_registry_dispatch(self, monkeypatch):
|
||||
seen = {}
|
||||
|
||||
def fake_invoke_middleware(kind, **kwargs):
|
||||
if kind == "tool_request":
|
||||
return [{
|
||||
"args": {**kwargs["args"], "rewritten": True},
|
||||
"source": "test-middleware",
|
||||
"reason": "rewrite",
|
||||
}]
|
||||
return []
|
||||
|
||||
def execution_middleware(**kwargs):
|
||||
seen["execution_args"] = kwargs["args"]
|
||||
return kwargs["next_call"]({**kwargs["args"], "wrapped": True})
|
||||
|
||||
def fake_dispatch(tool_name, args, **kwargs):
|
||||
seen["dispatch"] = (tool_name, args, kwargs)
|
||||
return json.dumps({"ok": True, "args": args})
|
||||
|
||||
manager = type(
|
||||
"Manager",
|
||||
(),
|
||||
{"_middleware": {"tool_request": [fake_invoke_middleware], "tool_execution": [execution_middleware]}},
|
||||
)()
|
||||
monkeypatch.setattr("hermes_cli.plugins.invoke_middleware", fake_invoke_middleware)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
hook_calls = []
|
||||
monkeypatch.setattr(
|
||||
"hermes_cli.plugins.invoke_hook",
|
||||
lambda hook_name, **kwargs: hook_calls.append((hook_name, kwargs)) or [],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.has_hook", lambda name: True)
|
||||
monkeypatch.setattr("model_tools.registry.dispatch", fake_dispatch)
|
||||
|
||||
result = json.loads(
|
||||
handle_function_call(
|
||||
"web_search",
|
||||
{"q": "test"},
|
||||
task_id="task-1",
|
||||
tool_call_id="tool-1",
|
||||
session_id="session-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert seen["execution_args"] == {"q": "test", "rewritten": True}
|
||||
assert seen["dispatch"][1] == {"q": "test", "rewritten": True, "wrapped": True}
|
||||
assert result["args"] == {"q": "test", "rewritten": True, "wrapped": True}
|
||||
expected_trace = [{"source": "test-middleware", "reason": "rewrite"}]
|
||||
pre_call = next(call for call in hook_calls if call[0] == "pre_tool_call")
|
||||
post_call = next(call for call in hook_calls if call[0] == "post_tool_call")
|
||||
assert pre_call[1]["middleware_trace"] == expected_trace
|
||||
assert post_call[1]["middleware_trace"] == expected_trace
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Agent loop tools
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue