fix(nemo-relay): preserve downstream errors in adaptive execution (#42691)

Based on #42658 by @mnajafian-nv.

Preserves the real downstream provider/tool exception when NeMo Relay's
managed adaptive execution wraps a failing callback as an internal runtime
error. Without this, the original exception (and its retry-classification
signal, e.g. status_code) is lost behind Relay's wrapper.

Salvage changes on top of the original PR:

- Tolerant Relay-wrapper match: _is_relay_wrapped_callback_error now uses
  str.startswith on the "internal error: <cls>: <msg>" prefix instead of
  exact equality, so a future Relay version appending a traceback/suffix
  doesn't silently defeat the unwrap. On a total format change it returns
  False and falls back to the pre-fix behavior (surfacing Relay's error)
  rather than masking it.
- Deduplicated the LLM and tool execute paths into a shared
  _run_managed_with_downstream_preservation helper, removing ~20 lines of
  copy-pasted nonlocal/try-except scaffolding that could drift out of sync.
- Added a real-middleware regression guard
  (test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape)
  that drives hermes_cli.middleware._run_execution_chain and asserts the
  plugin's _original_downstream_error unwraps the actual private
  _DownstreamExecutionError wrapper. The original synthetic tests modeled the
  wrapper with a local class, so a rename or shape change in core middleware
  would not have been caught; this test fails loudly if that contract drifts.

Co-authored-by: mnajafian-nv <mnajafian@nvidia.com>
This commit is contained in:
kshitij 2026-06-09 02:31:10 -07:00 committed by GitHub
parent 8d99b5bc4f
commit 85852b71d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 479 additions and 57 deletions

View file

@ -9,6 +9,7 @@ import logging
import os
import threading
import tomllib
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional
@ -284,6 +285,43 @@ class _Runtime:
and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None))
)
def _run_managed_with_downstream_preservation(
self,
next_call: Callable[[Any], Any],
normalize_payload: Callable[[Any], Any],
shape_response: Callable[[Any], Any],
make_managed_execute: Callable[[Callable[[Any], Any]], Any],
) -> Any:
# NeMo Relay's native managed execution may wrap a failing callback as an
# internal runtime error, hiding the real downstream provider/tool
# exception. Capture the original here and re-raise it after managed
# execution so Hermes retry classification still sees it. The LLM and tool
# paths share this scaffolding; they differ only in payload normalization,
# response shaping, and the Relay call itself.
raw_response: dict[str, Any] = {"set": False, "value": None}
callback_error: Exception | None = None
downstream_error: BaseException | None = None
def _impl(next_payload: Any) -> Any:
nonlocal callback_error, downstream_error
try:
raw = next_call(normalize_payload(next_payload))
except Exception as exc:
callback_error = exc
downstream_error = _original_downstream_error(exc)
raise
raw_response["set"] = True
raw_response["value"] = raw
return shape_response(raw)
try:
managed_result = _resolve_awaitable(make_managed_execute(_impl))
except Exception as exc:
if downstream_error is not None and _is_relay_wrapped_callback_error(exc, callback_error):
raise downstream_error
raise
return raw_response["value"] if raw_response["set"] else managed_result
def execute_llm(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
request_body = _jsonable(kwargs.get("request") or {})
@ -292,38 +330,37 @@ class _Runtime:
if not callable(next_call):
return request_body
raw_response: dict[str, Any] = {"set": False, "value": None}
def _impl(next_request: Any) -> Any:
def _normalize(next_request: Any) -> Any:
next_body = getattr(next_request, "content", next_request)
raw = next_call(next_body if isinstance(next_body, dict) else request_body)
raw_response["set"] = True
raw_response["value"] = raw
return _llm_response_payload(raw)
return next_body if isinstance(next_body, dict) else request_body
async def _managed_execute() -> Any:
result = self.nemo_relay.llm.execute(
str(kwargs.get("provider") or "llm"),
request,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"api_call_count": kwargs.get("api_call_count"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
model_name=str(kwargs.get("model") or ""),
)
if inspect.isawaitable(result):
return await result
return result
def _make_managed(impl: Callable[[Any], Any]) -> Any:
async def _managed_execute() -> Any:
result = self.nemo_relay.llm.execute(
str(kwargs.get("provider") or "llm"),
request,
impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"api_call_count": kwargs.get("api_call_count"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
model_name=str(kwargs.get("model") or ""),
)
if inspect.isawaitable(result):
return await result
return result
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
return _managed_execute()
return self._run_managed_with_downstream_preservation(
next_call, _normalize, _llm_response_payload, _make_managed
)
def execute_tool(self, kwargs: dict[str, Any]) -> Any:
state = self.ensure_session(kwargs)
@ -333,37 +370,35 @@ class _Runtime:
if not callable(next_call):
return args
raw_response: dict[str, Any] = {"set": False, "value": None}
def _normalize(next_args: Any) -> Any:
return next_args if isinstance(next_args, dict) else args
def _impl(next_args: Any) -> Any:
effective_args = next_args if isinstance(next_args, dict) else args
raw = next_call(effective_args)
raw_response["set"] = True
raw_response["value"] = raw
return _jsonable(raw)
def _make_managed(impl: Callable[[Any], Any]) -> Any:
async def _managed_execute() -> Any:
result = self.nemo_relay.tools.execute(
tool_name,
args,
impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"tool_call_id": kwargs.get("tool_call_id"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
)
if inspect.isawaitable(result):
return await result
return result
async def _managed_execute() -> Any:
result = self.nemo_relay.tools.execute(
tool_name,
args,
_impl,
handle=state.handle,
data=_jsonable(
{
"turn_id": kwargs.get("turn_id"),
"api_request_id": kwargs.get("api_request_id"),
"tool_call_id": kwargs.get("tool_call_id"),
"mode": self.settings.adaptive_mode,
}
),
metadata=_metadata(kwargs),
)
if inspect.isawaitable(result):
return await result
return result
return _managed_execute()
managed_result = _resolve_awaitable(_managed_execute())
return raw_response["value"] if raw_response["set"] else managed_result
return self._run_managed_with_downstream_preservation(
next_call, _normalize, _jsonable, _make_managed
)
def register(ctx) -> None:
@ -806,6 +841,30 @@ def _value(obj: Any, key: str, default: Any = None) -> Any:
return getattr(obj, key, default)
def _original_downstream_error(exc: Exception) -> BaseException:
# Hermes wraps downstream execution failures in a local/private exception
# class, so detect the wrapper by shape instead of importing it here.
original = getattr(exc, "original", None)
if exc.__class__.__name__ == "_DownstreamExecutionError" and isinstance(original, BaseException):
return original
return exc
def _is_relay_wrapped_callback_error(exc: Exception, callback_error: Exception | None) -> bool:
# NeMo Relay re-wraps a failing callback as ``RuntimeError("internal error:
# <ClassName>: <message>")``. Match by prefix rather than exact equality so a
# trailing traceback/suffix in a future Relay version doesn't silently defeat
# the unwrap; the class-name + message prefix still discriminates the real
# downstream failure from unrelated Relay-internal errors. If Relay drops the
# leading ``internal error:`` shape entirely, this returns False and Hermes
# falls back to surfacing Relay's error (the pre-fix behavior) rather than
# masking it.
if callback_error is None or not isinstance(exc, RuntimeError):
return False
expected = f"internal error: {callback_error.__class__.__name__}: {callback_error}"
return str(exc).startswith(expected)
def _llm_response_payload(response: Any) -> Any:
"""Return the LLM response shape NeMo Relay's ATIF conversion expects."""
payload = _jsonable(response)

View file

@ -12,6 +12,7 @@ import warnings
from pathlib import Path
from types import SimpleNamespace
import pytest
import yaml
from hermes_cli.plugins import PluginManager
@ -153,6 +154,33 @@ def _fresh_plugin(monkeypatch, fake):
return plugin
def _wrapped_downstream_error(original):
class _DownstreamExecutionError(Exception):
def __init__(self, original):
super().__init__(str(original))
self.original = original
return _DownstreamExecutionError(original)
def _enable_adaptive_plugin(tmp_path, monkeypatch) -> None:
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config.tool_parallelism]
mode = "observe_only"
""",
encoding="utf-8",
)
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
def test_manifest_fields():
data = yaml.safe_load((PLUGIN_DIR / "plugin.yaml").read_text())
assert data["name"] == "nemo_relay"
@ -783,6 +811,220 @@ mode = "observe_only"
}
def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
def native_like_execute(name, request, func, **kwargs):
fake.events.append(("llm.execute.start", name, request.content, kwargs))
try:
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
except Exception as exc:
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None
fake.llm.execute = native_like_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
class ProviderAuthError(Exception):
status_code = 403
provider_error = ProviderAuthError("provider auth failed")
def next_call(request):
raise _wrapped_downstream_error(provider_error)
with pytest.raises(ProviderAuthError) as caught:
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert caught.value is provider_error
assert caught.value.status_code == 403
def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error_with_relay_suffix(
tmp_path, monkeypatch
):
# Guards the startswith (vs exact ==) match in _is_relay_wrapped_callback_error:
# Relay re-wraps the callback failure with its canonical prefix but APPENDS a
# trailing suffix. Exact equality would miss this and surface Relay's wrapper;
# prefix matching must still recover the original downstream error.
fake = _FakeNemoRelay()
def native_like_execute(name, request, func, **kwargs):
try:
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
except Exception as exc:
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc} (retried 3x)") from None
fake.llm.execute = native_like_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
class ProviderAuthError(Exception):
status_code = 403
provider_error = ProviderAuthError("provider auth failed")
def next_call(request):
raise _wrapped_downstream_error(provider_error)
with pytest.raises(ProviderAuthError) as caught:
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert caught.value is provider_error
assert caught.value.status_code == 403
def test_nemo_relay_adaptive_llm_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
relay_error = RuntimeError("internal error: relay setup failed")
def internal_error_execute(name, request, func, **kwargs):
raise relay_error
fake.llm.execute = internal_error_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
with pytest.raises(RuntimeError) as caught:
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=lambda request: {"raw": request},
)
assert caught.value is relay_error
def test_nemo_relay_adaptive_llm_execution_keeps_wrapped_relay_error_after_downstream_failure(
tmp_path, monkeypatch
):
fake = _FakeNemoRelay()
relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream")
def translated_execute(name, request, func, **kwargs):
try:
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
except Exception:
raise relay_error
fake.llm.execute = translated_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
def next_call(request):
raise _wrapped_downstream_error(RuntimeError("provider failed"))
with pytest.raises(RuntimeError) as caught:
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert caught.value is relay_error
def test_nemo_relay_adaptive_llm_execution_keeps_relay_translated_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
class RelayPolicyError(Exception):
pass
relay_error = RelayPolicyError("relay policy blocked")
def translated_execute(name, request, func, **kwargs):
try:
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
except Exception:
raise relay_error
fake.llm.execute = translated_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
provider_error = RuntimeError("provider failed")
def next_call(request):
raise _wrapped_downstream_error(provider_error)
with pytest.raises(RelayPolicyError) as caught:
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=next_call,
)
assert caught.value is relay_error
def test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape(monkeypatch):
# Regression guard against core/plugin drift. The synthetic tests above model
# the downstream-error wrapper with a local class, so they keep passing even
# if core middleware renames its private ``_DownstreamExecutionError`` or drops
# ``.original`` -- the exact shape the plugin matches by name at
# ``_original_downstream_error``. Capture the wrapper the REAL
# ``hermes_cli.middleware._run_execution_chain`` hands to a middleware
# callback's ``next_call`` and assert the plugin's detector unwraps it to the
# original exception. If core middleware changes the wrapper shape, this fails
# here instead of silently defeating the unwrap in production.
from hermes_cli import middleware
from plugins.observability.nemo_relay import _original_downstream_error
class ProviderError(Exception):
status_code = 403
provider_error = ProviderError("provider auth failed")
captured: dict[str, Exception] = {}
def terminal_call(payload):
raise provider_error
def capturing_callback(**kwargs):
next_call = kwargs["next_call"]
try:
return next_call(kwargs.get("request"))
except Exception as exc:
captured["wrapper"] = exc
# Surface the original so the chain unwinds without re-wrapping noise.
raise _original_downstream_error(exc) from None
with pytest.raises(ProviderError) as caught:
middleware._run_execution_chain(
"llm",
[capturing_callback],
terminal_call,
request={"messages": []},
)
wrapper = captured["wrapper"]
# The wrapper the plugin sees must match what _original_downstream_error keys on.
assert wrapper.__class__.__name__ == "_DownstreamExecutionError"
assert isinstance(getattr(wrapper, "original", None), BaseException)
assert _original_downstream_error(wrapper) is provider_error
assert caught.value is provider_error
assert caught.value.status_code == 403
def _adaptive_llm_execute_mode(tmp_path, monkeypatch, plugins_toml_text: str) -> str:
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
@ -920,6 +1162,127 @@ mode = "observe_only"
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
def test_nemo_relay_adaptive_tool_execution_preserves_downstream_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
def native_like_execute(name, args, func, **kwargs):
fake.events.append(("tool.execute.start", name, args, kwargs))
try:
return func({"intercepted": True, **args})
except Exception as exc:
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None
fake.tools.execute = native_like_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
class ToolAuthError(Exception):
status_code = 403
tool_error = ToolAuthError("tool auth failed")
def next_call(args):
raise _wrapped_downstream_error(tool_error)
with pytest.raises(ToolAuthError) as caught:
plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=next_call,
)
assert caught.value is tool_error
assert caught.value.status_code == 403
def test_nemo_relay_adaptive_tool_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
relay_error = RuntimeError("internal error: relay setup failed")
def internal_error_execute(name, args, func, **kwargs):
raise relay_error
fake.tools.execute = internal_error_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
with pytest.raises(RuntimeError) as caught:
plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=lambda args: {"raw": args},
)
assert caught.value is relay_error
def test_nemo_relay_adaptive_tool_execution_keeps_wrapped_relay_error_after_downstream_failure(
tmp_path, monkeypatch
):
fake = _FakeNemoRelay()
relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream")
def translated_execute(name, args, func, **kwargs):
try:
return func({"intercepted": True, **args})
except Exception:
raise relay_error
fake.tools.execute = translated_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
def next_call(args):
raise _wrapped_downstream_error(RuntimeError("tool failed"))
with pytest.raises(RuntimeError) as caught:
plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=next_call,
)
assert caught.value is relay_error
def test_nemo_relay_adaptive_tool_execution_keeps_relay_translated_error(tmp_path, monkeypatch):
fake = _FakeNemoRelay()
class RelayPolicyError(Exception):
pass
relay_error = RelayPolicyError("relay policy blocked")
def translated_execute(name, args, func, **kwargs):
try:
return func({"intercepted": True, **args})
except Exception:
raise relay_error
fake.tools.execute = translated_execute
plugin = _fresh_plugin(monkeypatch, fake)
_enable_adaptive_plugin(tmp_path, monkeypatch)
tool_error = RuntimeError("tool failed")
def next_call(args):
raise _wrapped_downstream_error(tool_error)
with pytest.raises(RelayPolicyError) as caught:
plugin.on_tool_execution_middleware(
session_id="s1",
tool_name="terminal",
args={"command": "pwd"},
next_call=next_call,
)
assert caught.value is relay_error
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)