diff --git a/hermes_cli/middleware.py b/hermes_cli/middleware.py index 277368dffb3..8795952a2b7 100644 --- a/hermes_cli/middleware.py +++ b/hermes_cli/middleware.py @@ -55,6 +55,25 @@ def middleware_payload(**kwargs: Any) -> Dict[str, Any]: return kwargs +def _safe_copy(payload: Any) -> Any: + """Deep-copy a request payload, tolerating non-deepcopyable members. + + Request payloads are normally plain JSON-shaped dicts, but an LLM request + can occasionally carry non-deepcopyable objects (clients, callbacks, file + handles). A hard ``deepcopy`` failure there would otherwise abort the whole + request-middleware pass. Fall back to a shallow ``dict`` copy so middleware + still runs and the original nested objects are shared by reference rather + than corrupting the live payload. + """ + try: + return deepcopy(payload) + except Exception as exc: # pragma: no cover - exercised via fallback test + logger.debug("deepcopy failed for request payload (%s); using shallow copy", exc) + if isinstance(payload, dict): + return dict(payload) + return payload + + def apply_llm_request_middleware( request: Dict[str, Any], **context: Any, @@ -72,8 +91,8 @@ def apply_llm_request_middleware( trace=[], ) - original_request = deepcopy(request) - current_request = deepcopy(original_request) + original_request = _safe_copy(request) + current_request = _safe_copy(original_request) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( @@ -87,7 +106,7 @@ def apply_llm_request_middleware( next_request = result.get("request") if not isinstance(next_request, dict): continue - current_request = deepcopy(next_request) + current_request = _safe_copy(next_request) trace.append(_trace_entry(result)) return RequestMiddlewareResult( @@ -116,8 +135,8 @@ def apply_tool_request_middleware( trace=[], ) - original_args = deepcopy(args) - current_args = deepcopy(original_args) + original_args = _safe_copy(args) + current_args = _safe_copy(original_args) trace: List[Dict[str, Any]] = [] for result in _invoke_middleware( @@ -132,7 +151,7 @@ def apply_tool_request_middleware( next_args = result.get("args") if not isinstance(next_args, dict): continue - current_args = deepcopy(next_args) + current_args = _safe_copy(next_args) trace.append(_trace_entry(result)) return RequestMiddlewareResult( @@ -242,6 +261,16 @@ def _run_execution_chain( def next_call(next_payload: Any = None) -> Any: nonlocal next_called, next_succeeded, next_result + # ``next_call`` is single-use per middleware frame. Calling it more + # than once would re-run the downstream provider/tool, so a second + # invocation is a contract violation rather than a retry. Surface it + # instead of silently executing the terminal call twice. + if next_called: + raise RuntimeError( + f"Middleware '{kind}' callback " + f"{getattr(callback, '__name__', repr(callback))} called " + "next_call() more than once; downstream execution is single-use" + ) next_called = True try: next_result = call_at(index + 1, payload if next_payload is None else next_payload) diff --git a/tests/hermes_cli/test_plugins.py b/tests/hermes_cli/test_plugins.py index ddd1dab56e4..bb889450d00 100644 --- a/tests/hermes_cli/test_plugins.py +++ b/tests/hermes_cli/test_plugins.py @@ -272,6 +272,54 @@ class TestPluginDiscovery: assert calls == [{"command": "interrupt"}] + def test_execution_middleware_double_next_call_does_not_run_terminal_twice(self, monkeypatch): + calls = [] + + def middleware(**kwargs): + first = kwargs["next_call"](kwargs["args"]) + # Deliberate misuse: a second next_call() must not re-run the + # downstream tool. The chain surfaces it as an error and preserves + # the first (successful) downstream result. + kwargs["next_call"](kwargs["args"]) + return first + + manager = types.SimpleNamespace(_middleware={"tool_execution": [middleware]}) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + def terminal(args): + calls.append(args) + return "terminal-result" + + result = run_tool_execution_middleware("terminal", {"command": "printf ok"}, terminal) + + assert result == "terminal-result" + assert calls == [{"command": "printf ok"}] + + def test_request_middleware_tolerates_non_deepcopyable_payload(self, monkeypatch): + import threading + + recorded = {} + + def middleware(**kwargs): + recorded["args"] = kwargs["args"] + return None + + manager = types.SimpleNamespace( + _middleware={"tool_request": [middleware]}, + invoke_middleware=lambda kind, **kwargs: [middleware(**kwargs)], + ) + monkeypatch.setattr("hermes_cli.plugins.get_plugin_manager", lambda: manager) + + # threading.Lock is not deepcopyable; a hard deepcopy would raise. + args = {"command": "noop", "lock": threading.Lock()} + result = apply_tool_request_middleware("terminal", args) + + # Middleware ran (payload was copied via the shallow fallback) and the + # non-deepcopyable member is shared by reference rather than aborting. + assert recorded["args"]["command"] == "noop" + assert result.payload["command"] == "noop" + assert result.payload["lock"] is args["lock"] + def test_discover_project_plugins(self, tmp_path, monkeypatch): """Plugins in ./.hermes/plugins/ are discovered.""" project_dir = tmp_path / "project"