From 85852b71d86d13b9ca7c22cf653ff6730b325e3c Mon Sep 17 00:00:00 2001 From: kshitij <82637225+kshitijk4poor@users.noreply.github.com> Date: Tue, 9 Jun 2026 02:31:10 -0700 Subject: [PATCH] fix(nemo-relay): preserve downstream errors in adaptive execution (#42691) Based on #42658 by @mnajafian-nv. Preserves the real downstream provider/tool exception when NeMo Relay's managed adaptive execution wraps a failing callback as an internal runtime error. Without this, the original exception (and its retry-classification signal, e.g. status_code) is lost behind Relay's wrapper. Salvage changes on top of the original PR: - Tolerant Relay-wrapper match: _is_relay_wrapped_callback_error now uses str.startswith on the "internal error: : " prefix instead of exact equality, so a future Relay version appending a traceback/suffix doesn't silently defeat the unwrap. On a total format change it returns False and falls back to the pre-fix behavior (surfacing Relay's error) rather than masking it. - Deduplicated the LLM and tool execute paths into a shared _run_managed_with_downstream_preservation helper, removing ~20 lines of copy-pasted nonlocal/try-except scaffolding that could drift out of sync. - Added a real-middleware regression guard (test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape) that drives hermes_cli.middleware._run_execution_chain and asserts the plugin's _original_downstream_error unwraps the actual private _DownstreamExecutionError wrapper. The original synthetic tests modeled the wrapper with a local class, so a rename or shape change in core middleware would not have been caught; this test fails loudly if that contract drifts. Co-authored-by: mnajafian-nv --- plugins/observability/nemo_relay/__init__.py | 173 ++++++--- tests/plugins/test_nemo_relay_plugin.py | 363 +++++++++++++++++++ 2 files changed, 479 insertions(+), 57 deletions(-) diff --git a/plugins/observability/nemo_relay/__init__.py b/plugins/observability/nemo_relay/__init__.py index 894fa9a23e..4716b73ec0 100644 --- a/plugins/observability/nemo_relay/__init__.py +++ b/plugins/observability/nemo_relay/__init__.py @@ -9,6 +9,7 @@ import logging import os import threading import tomllib +from collections.abc import Callable from dataclasses import dataclass, field from pathlib import Path from typing import Any, Optional @@ -284,6 +285,43 @@ class _Runtime: and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None)) ) + def _run_managed_with_downstream_preservation( + self, + next_call: Callable[[Any], Any], + normalize_payload: Callable[[Any], Any], + shape_response: Callable[[Any], Any], + make_managed_execute: Callable[[Callable[[Any], Any]], Any], + ) -> Any: + # NeMo Relay's native managed execution may wrap a failing callback as an + # internal runtime error, hiding the real downstream provider/tool + # exception. Capture the original here and re-raise it after managed + # execution so Hermes retry classification still sees it. The LLM and tool + # paths share this scaffolding; they differ only in payload normalization, + # response shaping, and the Relay call itself. + raw_response: dict[str, Any] = {"set": False, "value": None} + callback_error: Exception | None = None + downstream_error: BaseException | None = None + + def _impl(next_payload: Any) -> Any: + nonlocal callback_error, downstream_error + try: + raw = next_call(normalize_payload(next_payload)) + except Exception as exc: + callback_error = exc + downstream_error = _original_downstream_error(exc) + raise + raw_response["set"] = True + raw_response["value"] = raw + return shape_response(raw) + + try: + managed_result = _resolve_awaitable(make_managed_execute(_impl)) + except Exception as exc: + if downstream_error is not None and _is_relay_wrapped_callback_error(exc, callback_error): + raise downstream_error + raise + return raw_response["value"] if raw_response["set"] else managed_result + def execute_llm(self, kwargs: dict[str, Any]) -> Any: state = self.ensure_session(kwargs) request_body = _jsonable(kwargs.get("request") or {}) @@ -292,38 +330,37 @@ class _Runtime: if not callable(next_call): return request_body - raw_response: dict[str, Any] = {"set": False, "value": None} - - def _impl(next_request: Any) -> Any: + def _normalize(next_request: Any) -> Any: next_body = getattr(next_request, "content", next_request) - raw = next_call(next_body if isinstance(next_body, dict) else request_body) - raw_response["set"] = True - raw_response["value"] = raw - return _llm_response_payload(raw) + return next_body if isinstance(next_body, dict) else request_body - async def _managed_execute() -> Any: - result = self.nemo_relay.llm.execute( - str(kwargs.get("provider") or "llm"), - request, - _impl, - handle=state.handle, - data=_jsonable( - { - "turn_id": kwargs.get("turn_id"), - "api_request_id": kwargs.get("api_request_id"), - "api_call_count": kwargs.get("api_call_count"), - "mode": self.settings.adaptive_mode, - } - ), - metadata=_metadata(kwargs), - model_name=str(kwargs.get("model") or ""), - ) - if inspect.isawaitable(result): - return await result - return result + def _make_managed(impl: Callable[[Any], Any]) -> Any: + async def _managed_execute() -> Any: + result = self.nemo_relay.llm.execute( + str(kwargs.get("provider") or "llm"), + request, + impl, + handle=state.handle, + data=_jsonable( + { + "turn_id": kwargs.get("turn_id"), + "api_request_id": kwargs.get("api_request_id"), + "api_call_count": kwargs.get("api_call_count"), + "mode": self.settings.adaptive_mode, + } + ), + metadata=_metadata(kwargs), + model_name=str(kwargs.get("model") or ""), + ) + if inspect.isawaitable(result): + return await result + return result - managed_result = _resolve_awaitable(_managed_execute()) - return raw_response["value"] if raw_response["set"] else managed_result + return _managed_execute() + + return self._run_managed_with_downstream_preservation( + next_call, _normalize, _llm_response_payload, _make_managed + ) def execute_tool(self, kwargs: dict[str, Any]) -> Any: state = self.ensure_session(kwargs) @@ -333,37 +370,35 @@ class _Runtime: if not callable(next_call): return args - raw_response: dict[str, Any] = {"set": False, "value": None} + def _normalize(next_args: Any) -> Any: + return next_args if isinstance(next_args, dict) else args - def _impl(next_args: Any) -> Any: - effective_args = next_args if isinstance(next_args, dict) else args - raw = next_call(effective_args) - raw_response["set"] = True - raw_response["value"] = raw - return _jsonable(raw) + def _make_managed(impl: Callable[[Any], Any]) -> Any: + async def _managed_execute() -> Any: + result = self.nemo_relay.tools.execute( + tool_name, + args, + impl, + handle=state.handle, + data=_jsonable( + { + "turn_id": kwargs.get("turn_id"), + "api_request_id": kwargs.get("api_request_id"), + "tool_call_id": kwargs.get("tool_call_id"), + "mode": self.settings.adaptive_mode, + } + ), + metadata=_metadata(kwargs), + ) + if inspect.isawaitable(result): + return await result + return result - async def _managed_execute() -> Any: - result = self.nemo_relay.tools.execute( - tool_name, - args, - _impl, - handle=state.handle, - data=_jsonable( - { - "turn_id": kwargs.get("turn_id"), - "api_request_id": kwargs.get("api_request_id"), - "tool_call_id": kwargs.get("tool_call_id"), - "mode": self.settings.adaptive_mode, - } - ), - metadata=_metadata(kwargs), - ) - if inspect.isawaitable(result): - return await result - return result + return _managed_execute() - managed_result = _resolve_awaitable(_managed_execute()) - return raw_response["value"] if raw_response["set"] else managed_result + return self._run_managed_with_downstream_preservation( + next_call, _normalize, _jsonable, _make_managed + ) def register(ctx) -> None: @@ -806,6 +841,30 @@ def _value(obj: Any, key: str, default: Any = None) -> Any: return getattr(obj, key, default) +def _original_downstream_error(exc: Exception) -> BaseException: + # Hermes wraps downstream execution failures in a local/private exception + # class, so detect the wrapper by shape instead of importing it here. + original = getattr(exc, "original", None) + if exc.__class__.__name__ == "_DownstreamExecutionError" and isinstance(original, BaseException): + return original + return exc + + +def _is_relay_wrapped_callback_error(exc: Exception, callback_error: Exception | None) -> bool: + # NeMo Relay re-wraps a failing callback as ``RuntimeError("internal error: + # : ")``. Match by prefix rather than exact equality so a + # trailing traceback/suffix in a future Relay version doesn't silently defeat + # the unwrap; the class-name + message prefix still discriminates the real + # downstream failure from unrelated Relay-internal errors. If Relay drops the + # leading ``internal error:`` shape entirely, this returns False and Hermes + # falls back to surfacing Relay's error (the pre-fix behavior) rather than + # masking it. + if callback_error is None or not isinstance(exc, RuntimeError): + return False + expected = f"internal error: {callback_error.__class__.__name__}: {callback_error}" + return str(exc).startswith(expected) + + def _llm_response_payload(response: Any) -> Any: """Return the LLM response shape NeMo Relay's ATIF conversion expects.""" payload = _jsonable(response) diff --git a/tests/plugins/test_nemo_relay_plugin.py b/tests/plugins/test_nemo_relay_plugin.py index 948e80f1e0..1e72520d29 100644 --- a/tests/plugins/test_nemo_relay_plugin.py +++ b/tests/plugins/test_nemo_relay_plugin.py @@ -12,6 +12,7 @@ import warnings from pathlib import Path from types import SimpleNamespace +import pytest import yaml from hermes_cli.plugins import PluginManager @@ -153,6 +154,33 @@ def _fresh_plugin(monkeypatch, fake): return plugin +def _wrapped_downstream_error(original): + class _DownstreamExecutionError(Exception): + def __init__(self, original): + super().__init__(str(original)) + self.original = original + + return _DownstreamExecutionError(original) + + +def _enable_adaptive_plugin(tmp_path, monkeypatch) -> None: + plugins_toml = tmp_path / "plugins.toml" + plugins_toml.write_text( + """ +version = 1 + +[[components]] +kind = "adaptive" +enabled = true + +[components.config.tool_parallelism] +mode = "observe_only" +""", + encoding="utf-8", + ) + monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml)) + + def test_manifest_fields(): data = yaml.safe_load((PLUGIN_DIR / "plugin.yaml").read_text()) assert data["name"] == "nemo_relay" @@ -783,6 +811,220 @@ mode = "observe_only" } +def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + def native_like_execute(name, request, func, **kwargs): + fake.events.append(("llm.execute.start", name, request.content, kwargs)) + try: + return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content})) + except Exception as exc: + raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None + + fake.llm.execute = native_like_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + class ProviderAuthError(Exception): + status_code = 403 + + provider_error = ProviderAuthError("provider auth failed") + + def next_call(request): + raise _wrapped_downstream_error(provider_error) + + with pytest.raises(ProviderAuthError) as caught: + plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=next_call, + ) + + assert caught.value is provider_error + assert caught.value.status_code == 403 + + +def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error_with_relay_suffix( + tmp_path, monkeypatch +): + # Guards the startswith (vs exact ==) match in _is_relay_wrapped_callback_error: + # Relay re-wraps the callback failure with its canonical prefix but APPENDS a + # trailing suffix. Exact equality would miss this and surface Relay's wrapper; + # prefix matching must still recover the original downstream error. + fake = _FakeNemoRelay() + + def native_like_execute(name, request, func, **kwargs): + try: + return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content})) + except Exception as exc: + raise RuntimeError(f"internal error: {type(exc).__name__}: {exc} (retried 3x)") from None + + fake.llm.execute = native_like_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + class ProviderAuthError(Exception): + status_code = 403 + + provider_error = ProviderAuthError("provider auth failed") + + def next_call(request): + raise _wrapped_downstream_error(provider_error) + + with pytest.raises(ProviderAuthError) as caught: + plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=next_call, + ) + + assert caught.value is provider_error + assert caught.value.status_code == 403 + + +def test_nemo_relay_adaptive_llm_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + relay_error = RuntimeError("internal error: relay setup failed") + + def internal_error_execute(name, request, func, **kwargs): + raise relay_error + + fake.llm.execute = internal_error_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + with pytest.raises(RuntimeError) as caught: + plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=lambda request: {"raw": request}, + ) + + assert caught.value is relay_error + + +def test_nemo_relay_adaptive_llm_execution_keeps_wrapped_relay_error_after_downstream_failure( + tmp_path, monkeypatch +): + fake = _FakeNemoRelay() + relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream") + + def translated_execute(name, request, func, **kwargs): + try: + return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content})) + except Exception: + raise relay_error + + fake.llm.execute = translated_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + def next_call(request): + raise _wrapped_downstream_error(RuntimeError("provider failed")) + + with pytest.raises(RuntimeError) as caught: + plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=next_call, + ) + + assert caught.value is relay_error + + +def test_nemo_relay_adaptive_llm_execution_keeps_relay_translated_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + class RelayPolicyError(Exception): + pass + + relay_error = RelayPolicyError("relay policy blocked") + + def translated_execute(name, request, func, **kwargs): + try: + return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content})) + except Exception: + raise relay_error + + fake.llm.execute = translated_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + provider_error = RuntimeError("provider failed") + + def next_call(request): + raise _wrapped_downstream_error(provider_error) + + with pytest.raises(RelayPolicyError) as caught: + plugin.on_llm_execution_middleware( + session_id="s1", + provider="anthropic", + model="demo-model", + request={"messages": [{"role": "user", "content": "hi"}]}, + next_call=next_call, + ) + + assert caught.value is relay_error + + +def test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape(monkeypatch): + # Regression guard against core/plugin drift. The synthetic tests above model + # the downstream-error wrapper with a local class, so they keep passing even + # if core middleware renames its private ``_DownstreamExecutionError`` or drops + # ``.original`` -- the exact shape the plugin matches by name at + # ``_original_downstream_error``. Capture the wrapper the REAL + # ``hermes_cli.middleware._run_execution_chain`` hands to a middleware + # callback's ``next_call`` and assert the plugin's detector unwraps it to the + # original exception. If core middleware changes the wrapper shape, this fails + # here instead of silently defeating the unwrap in production. + from hermes_cli import middleware + + from plugins.observability.nemo_relay import _original_downstream_error + + class ProviderError(Exception): + status_code = 403 + + provider_error = ProviderError("provider auth failed") + captured: dict[str, Exception] = {} + + def terminal_call(payload): + raise provider_error + + def capturing_callback(**kwargs): + next_call = kwargs["next_call"] + try: + return next_call(kwargs.get("request")) + except Exception as exc: + captured["wrapper"] = exc + # Surface the original so the chain unwinds without re-wrapping noise. + raise _original_downstream_error(exc) from None + + with pytest.raises(ProviderError) as caught: + middleware._run_execution_chain( + "llm", + [capturing_callback], + terminal_call, + request={"messages": []}, + ) + + wrapper = captured["wrapper"] + # The wrapper the plugin sees must match what _original_downstream_error keys on. + assert wrapper.__class__.__name__ == "_DownstreamExecutionError" + assert isinstance(getattr(wrapper, "original", None), BaseException) + assert _original_downstream_error(wrapper) is provider_error + assert caught.value is provider_error + assert caught.value.status_code == 403 + + def _adaptive_llm_execute_mode(tmp_path, monkeypatch, plugins_toml_text: str) -> str: fake = _FakeNemoRelay() plugin = _fresh_plugin(monkeypatch, fake) @@ -920,6 +1162,127 @@ mode = "observe_only" assert execute_start[3]["data"]["tool_call_id"] == "tool-1" +def test_nemo_relay_adaptive_tool_execution_preserves_downstream_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + def native_like_execute(name, args, func, **kwargs): + fake.events.append(("tool.execute.start", name, args, kwargs)) + try: + return func({"intercepted": True, **args}) + except Exception as exc: + raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None + + fake.tools.execute = native_like_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + class ToolAuthError(Exception): + status_code = 403 + + tool_error = ToolAuthError("tool auth failed") + + def next_call(args): + raise _wrapped_downstream_error(tool_error) + + with pytest.raises(ToolAuthError) as caught: + plugin.on_tool_execution_middleware( + session_id="s1", + tool_name="terminal", + args={"command": "pwd"}, + next_call=next_call, + ) + + assert caught.value is tool_error + assert caught.value.status_code == 403 + + +def test_nemo_relay_adaptive_tool_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + relay_error = RuntimeError("internal error: relay setup failed") + + def internal_error_execute(name, args, func, **kwargs): + raise relay_error + + fake.tools.execute = internal_error_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + with pytest.raises(RuntimeError) as caught: + plugin.on_tool_execution_middleware( + session_id="s1", + tool_name="terminal", + args={"command": "pwd"}, + next_call=lambda args: {"raw": args}, + ) + + assert caught.value is relay_error + + +def test_nemo_relay_adaptive_tool_execution_keeps_wrapped_relay_error_after_downstream_failure( + tmp_path, monkeypatch +): + fake = _FakeNemoRelay() + relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream") + + def translated_execute(name, args, func, **kwargs): + try: + return func({"intercepted": True, **args}) + except Exception: + raise relay_error + + fake.tools.execute = translated_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + def next_call(args): + raise _wrapped_downstream_error(RuntimeError("tool failed")) + + with pytest.raises(RuntimeError) as caught: + plugin.on_tool_execution_middleware( + session_id="s1", + tool_name="terminal", + args={"command": "pwd"}, + next_call=next_call, + ) + + assert caught.value is relay_error + + +def test_nemo_relay_adaptive_tool_execution_keeps_relay_translated_error(tmp_path, monkeypatch): + fake = _FakeNemoRelay() + + class RelayPolicyError(Exception): + pass + + relay_error = RelayPolicyError("relay policy blocked") + + def translated_execute(name, args, func, **kwargs): + try: + return func({"intercepted": True, **args}) + except Exception: + raise relay_error + + fake.tools.execute = translated_execute + plugin = _fresh_plugin(monkeypatch, fake) + _enable_adaptive_plugin(tmp_path, monkeypatch) + + tool_error = RuntimeError("tool failed") + + def next_call(args): + raise _wrapped_downstream_error(tool_error) + + with pytest.raises(RelayPolicyError) as caught: + plugin.on_tool_execution_middleware( + session_id="s1", + tool_name="terminal", + args={"command": "pwd"}, + next_call=next_call, + ) + + assert caught.value is relay_error + + def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch): fake = _FakeNemoRelay() plugin = _fresh_plugin(monkeypatch, fake)