mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
fix(middleware): single-use next_call guard + deepcopy-safe request copies
Address the two non-blocking follow-ups from review: - next_call is now single-use per middleware frame. A second invocation raises instead of silently re-running the downstream provider/tool, so the terminal call cannot execute twice via the chain. The error surfaces through the existing handler, which preserves the first downstream result. - Request-middleware payload copies go through _safe_copy(), which falls back to a shallow dict copy when deepcopy() fails on a non-deepcopyable member (clients, callbacks, file handles) instead of aborting the pass. Adds regression coverage for both: double next_call() keeps the terminal single-run, and a non-deepcopyable (threading.Lock) request payload still runs middleware via the shallow fallback.
This commit is contained in:
parent
5abe45674d
commit
c4c5548eb4
2 changed files with 83 additions and 6 deletions
|
|
@ -55,6 +55,25 @@ def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
|
|||
return kwargs
|
||||
|
||||
|
||||
def _safe_copy(payload: Any) -> Any:
|
||||
"""Deep-copy a request payload, tolerating non-deepcopyable members.
|
||||
|
||||
Request payloads are normally plain JSON-shaped dicts, but an LLM request
|
||||
can occasionally carry non-deepcopyable objects (clients, callbacks, file
|
||||
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
|
||||
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
|
||||
still runs and the original nested objects are shared by reference rather
|
||||
than corrupting the live payload.
|
||||
"""
|
||||
try:
|
||||
return deepcopy(payload)
|
||||
except Exception as exc: # pragma: no cover - exercised via fallback test
|
||||
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
|
||||
if isinstance(payload, dict):
|
||||
return dict(payload)
|
||||
return payload
|
||||
|
||||
|
||||
def apply_llm_request_middleware(
|
||||
request: Dict[str, Any],
|
||||
**context: Any,
|
||||
|
|
@ -72,8 +91,8 @@ def apply_llm_request_middleware(
|
|||
trace=[],
|
||||
)
|
||||
|
||||
original_request = deepcopy(request)
|
||||
current_request = deepcopy(original_request)
|
||||
original_request = _safe_copy(request)
|
||||
current_request = _safe_copy(original_request)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
|
|
@ -87,7 +106,7 @@ def apply_llm_request_middleware(
|
|||
next_request = result.get("request")
|
||||
if not isinstance(next_request, dict):
|
||||
continue
|
||||
current_request = deepcopy(next_request)
|
||||
current_request = _safe_copy(next_request)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
|
|
@ -116,8 +135,8 @@ def apply_tool_request_middleware(
|
|||
trace=[],
|
||||
)
|
||||
|
||||
original_args = deepcopy(args)
|
||||
current_args = deepcopy(original_args)
|
||||
original_args = _safe_copy(args)
|
||||
current_args = _safe_copy(original_args)
|
||||
trace: List[Dict[str, Any]] = []
|
||||
|
||||
for result in _invoke_middleware(
|
||||
|
|
@ -132,7 +151,7 @@ def apply_tool_request_middleware(
|
|||
next_args = result.get("args")
|
||||
if not isinstance(next_args, dict):
|
||||
continue
|
||||
current_args = deepcopy(next_args)
|
||||
current_args = _safe_copy(next_args)
|
||||
trace.append(_trace_entry(result))
|
||||
|
||||
return RequestMiddlewareResult(
|
||||
|
|
@ -242,6 +261,16 @@ def _run_execution_chain(
|
|||
|
||||
def next_call(next_payload: Any = None) -> Any:
|
||||
nonlocal next_called, next_succeeded, next_result
|
||||
# ``next_call`` is single-use per middleware frame. Calling it more
|
||||
# than once would re-run the downstream provider/tool, so a second
|
||||
# invocation is a contract violation rather than a retry. Surface it
|
||||
# instead of silently executing the terminal call twice.
|
||||
if next_called:
|
||||
raise RuntimeError(
|
||||
f"Middleware '{kind}' callback "
|
||||
f"{getattr(callback, '__name__', repr(callback))} called "
|
||||
"next_call() more than once; downstream execution is single-use"
|
||||
)
|
||||
next_called = True
|
||||
try:
|
||||
next_result = call_at(index + 1, payload if next_payload is None else next_payload)
|
||||
|
|
|
|||
|
|
@ -272,6 +272,54 @@ class TestPluginDiscovery:
|
|||
|
||||
assert calls == [{"command": "interrupt"}]
|
||||
|
||||
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
|
||||
calls = []
|
||||
|
||||
def middleware(**kwargs):
|
||||
first = kwargs["next_call"](kwargs["args"])
|
||||
# Deliberate misuse: a second next_call() must not re-run the
|
||||
# downstream tool. The chain surfaces it as an error and preserves
|
||||
# the first (successful) downstream result.
|
||||
kwargs["next_call"](kwargs["args"])
|
||||
return first
|
||||
|
||||
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
def terminal(args):
|
||||
calls.append(args)
|
||||
return "terminal-result"
|
||||
|
||||
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
|
||||
|
||||
assert result == "terminal-result"
|
||||
assert calls == [{"command": "printf ok"}]
|
||||
|
||||
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
|
||||
import threading
|
||||
|
||||
recorded = {}
|
||||
|
||||
def middleware(**kwargs):
|
||||
recorded["args"] = kwargs["args"]
|
||||
return None
|
||||
|
||||
manager = types.SimpleNamespace(
|
||||
_middleware={"tool_request": [middleware]},
|
||||
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
|
||||
)
|
||||
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
|
||||
|
||||
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
|
||||
args = {"command": "noop", "lock": threading.Lock()}
|
||||
result = apply_tool_request_middleware("terminal", args)
|
||||
|
||||
# Middleware ran (payload was copied via the shallow fallback) and the
|
||||
# non-deepcopyable member is shared by reference rather than aborting.
|
||||
assert recorded["args"]["command"] == "noop"
|
||||
assert result.payload["command"] == "noop"
|
||||
assert result.payload["lock"] is args["lock"]
|
||||
|
||||
def test_discover_project_plugins(self, tmp_path, monkeypatch):
|
||||
"""Plugins in ./.hermes/plugins/ are discovered."""
|
||||
project_dir = tmp_path / "project"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue