fix(middleware): single-use next_call guard + deepcopy-safe request copies

Address the two non-blocking follow-ups from review:

- next_call is now single-use per middleware frame. A second invocation
  raises instead of silently re-running the downstream provider/tool, so
  the terminal call cannot execute twice via the chain. The error surfaces
  through the existing handler, which preserves the first downstream result.
- Request-middleware payload copies go through _safe_copy(), which falls
  back to a shallow dict copy when deepcopy() fails on a non-deepcopyable
  member (clients, callbacks, file handles) instead of aborting the pass.

Adds regression coverage for both: double next_call() keeps the terminal
single-run, and a non-deepcopyable (threading.Lock) request payload still
runs middleware via the shallow fallback.
This commit is contained in:
kshitijk4poor 2026-06-06 23:07:25 +05:30
parent 5abe45674d
commit c4c5548eb4
2 changed files with 83 additions and 6 deletions

View file

@ -55,6 +55,25 @@ def middleware_payload(**kwargs: Any) -> Dict[str, Any]:
return kwargs
def _safe_copy(payload: Any) -> Any:
"""Deep-copy a request payload, tolerating non-deepcopyable members.
Request payloads are normally plain JSON-shaped dicts, but an LLM request
can occasionally carry non-deepcopyable objects (clients, callbacks, file
handles). A hard ``deepcopy`` failure there would otherwise abort the whole
request-middleware pass. Fall back to a shallow ``dict`` copy so middleware
still runs and the original nested objects are shared by reference rather
than corrupting the live payload.
"""
try:
return deepcopy(payload)
except Exception as exc: # pragma: no cover - exercised via fallback test
logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc)
if isinstance(payload, dict):
return dict(payload)
return payload
def apply_llm_request_middleware(
request: Dict[str, Any],
**context: Any,
@ -72,8 +91,8 @@ def apply_llm_request_middleware(
trace=[],
)
original_request = deepcopy(request)
current_request = deepcopy(original_request)
original_request = _safe_copy(request)
current_request = _safe_copy(original_request)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
@ -87,7 +106,7 @@ def apply_llm_request_middleware(
next_request = result.get("request")
if not isinstance(next_request, dict):
continue
current_request = deepcopy(next_request)
current_request = _safe_copy(next_request)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
@ -116,8 +135,8 @@ def apply_tool_request_middleware(
trace=[],
)
original_args = deepcopy(args)
current_args = deepcopy(original_args)
original_args = _safe_copy(args)
current_args = _safe_copy(original_args)
trace: List[Dict[str, Any]] = []
for result in _invoke_middleware(
@ -132,7 +151,7 @@ def apply_tool_request_middleware(
next_args = result.get("args")
if not isinstance(next_args, dict):
continue
current_args = deepcopy(next_args)
current_args = _safe_copy(next_args)
trace.append(_trace_entry(result))
return RequestMiddlewareResult(
@ -242,6 +261,16 @@ def _run_execution_chain(
def next_call(next_payload: Any = None) -> Any:
nonlocal next_called, next_succeeded, next_result
# ``next_call`` is single-use per middleware frame. Calling it more
# than once would re-run the downstream provider/tool, so a second
# invocation is a contract violation rather than a retry. Surface it
# instead of silently executing the terminal call twice.
if next_called:
raise RuntimeError(
f"Middleware '{kind}' callback "
f"{getattr(callback, '__name__', repr(callback))} called "
"next_call() more than once; downstream execution is single-use"
)
next_called = True
try:
next_result = call_at(index + 1, payload if next_payload is None else next_payload)

View file

@ -272,6 +272,54 @@ class TestPluginDiscovery:
assert calls == [{"command": "interrupt"}]
def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch):
calls = []
def middleware(**kwargs):
first = kwargs["next_call"](kwargs["args"])
# Deliberate misuse: a second next_call() must not re-run the
# downstream tool. The chain surfaces it as an error and preserves
# the first (successful) downstream result.
kwargs["next_call"](kwargs["args"])
return first
manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]})
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
def terminal(args):
calls.append(args)
return "terminal-result"
result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal)
assert result == "terminal-result"
assert calls == [{"command": "printf ok"}]
def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch):
import threading
recorded = {}
def middleware(**kwargs):
recorded["args"] = kwargs["args"]
return None
manager = types.SimpleNamespace(
_middleware={"tool_request": [middleware]},
invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)],
)
monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager)
# threading.Lock is not deepcopyable; a hard deepcopy would raise.
args = {"command": "noop", "lock": threading.Lock()}
result = apply_tool_request_middleware("terminal", args)
# Middleware ran (payload was copied via the shallow fallback) and the
# non-deepcopyable member is shared by reference rather than aborting.
assert recorded["args"]["command"] == "noop"
assert result.payload["command"] == "noop"
assert result.payload["lock"] is args["lock"]
def test_discover_project_plugins(self, tmp_path, monkeypatch):
"""Plugins in ./.hermes/plugins/ are discovered."""
project_dir = tmp_path / "project"