mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-14 09:11:54 +00:00
fix(nemo-relay): preserve downstream errors in adaptive execution (#42691)
Based on #42658 by @mnajafian-nv. Preserves the real downstream provider/tool exception when NeMo Relay's managed adaptive execution wraps a failing callback as an internal runtime error. Without this, the original exception (and its retry-classification signal, e.g. status_code) is lost behind Relay's wrapper. Salvage changes on top of the original PR: - Tolerant Relay-wrapper match: _is_relay_wrapped_callback_error now uses str.startswith on the "internal error: <cls>: <msg>" prefix instead of exact equality, so a future Relay version appending a traceback/suffix doesn't silently defeat the unwrap. On a total format change it returns False and falls back to the pre-fix behavior (surfacing Relay's error) rather than masking it. - Deduplicated the LLM and tool execute paths into a shared _run_managed_with_downstream_preservation helper, removing ~20 lines of copy-pasted nonlocal/try-except scaffolding that could drift out of sync. - Added a real-middleware regression guard (test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape) that drives hermes_cli.middleware._run_execution_chain and asserts the plugin's _original_downstream_error unwraps the actual private _DownstreamExecutionError wrapper. The original synthetic tests modeled the wrapper with a local class, so a rename or shape change in core middleware would not have been caught; this test fails loudly if that contract drifts. Co-authored-by: mnajafian-nv <mnajafian@nvidia.com>
This commit is contained in:
parent
8d99b5bc4f
commit
85852b71d8
2 changed files with 479 additions and 57 deletions
|
|
@ -9,6 +9,7 @@ import logging
|
|||
import os
|
||||
import threading
|
||||
import tomllib
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
|
@ -284,6 +285,43 @@ class _Runtime:
|
|||
and callable(getattr(getattr(self.nemo_relay, "tools", None), "execute", None))
|
||||
)
|
||||
|
||||
def _run_managed_with_downstream_preservation(
|
||||
self,
|
||||
next_call: Callable[[Any], Any],
|
||||
normalize_payload: Callable[[Any], Any],
|
||||
shape_response: Callable[[Any], Any],
|
||||
make_managed_execute: Callable[[Callable[[Any], Any]], Any],
|
||||
) -> Any:
|
||||
# NeMo Relay's native managed execution may wrap a failing callback as an
|
||||
# internal runtime error, hiding the real downstream provider/tool
|
||||
# exception. Capture the original here and re-raise it after managed
|
||||
# execution so Hermes retry classification still sees it. The LLM and tool
|
||||
# paths share this scaffolding; they differ only in payload normalization,
|
||||
# response shaping, and the Relay call itself.
|
||||
raw_response: dict[str, Any] = {"set": False, "value": None}
|
||||
callback_error: Exception | None = None
|
||||
downstream_error: BaseException | None = None
|
||||
|
||||
def _impl(next_payload: Any) -> Any:
|
||||
nonlocal callback_error, downstream_error
|
||||
try:
|
||||
raw = next_call(normalize_payload(next_payload))
|
||||
except Exception as exc:
|
||||
callback_error = exc
|
||||
downstream_error = _original_downstream_error(exc)
|
||||
raise
|
||||
raw_response["set"] = True
|
||||
raw_response["value"] = raw
|
||||
return shape_response(raw)
|
||||
|
||||
try:
|
||||
managed_result = _resolve_awaitable(make_managed_execute(_impl))
|
||||
except Exception as exc:
|
||||
if downstream_error is not None and _is_relay_wrapped_callback_error(exc, callback_error):
|
||||
raise downstream_error
|
||||
raise
|
||||
return raw_response["value"] if raw_response["set"] else managed_result
|
||||
|
||||
def execute_llm(self, kwargs: dict[str, Any]) -> Any:
|
||||
state = self.ensure_session(kwargs)
|
||||
request_body = _jsonable(kwargs.get("request") or {})
|
||||
|
|
@ -292,38 +330,37 @@ class _Runtime:
|
|||
if not callable(next_call):
|
||||
return request_body
|
||||
|
||||
raw_response: dict[str, Any] = {"set": False, "value": None}
|
||||
|
||||
def _impl(next_request: Any) -> Any:
|
||||
def _normalize(next_request: Any) -> Any:
|
||||
next_body = getattr(next_request, "content", next_request)
|
||||
raw = next_call(next_body if isinstance(next_body, dict) else request_body)
|
||||
raw_response["set"] = True
|
||||
raw_response["value"] = raw
|
||||
return _llm_response_payload(raw)
|
||||
return next_body if isinstance(next_body, dict) else request_body
|
||||
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.llm.execute(
|
||||
str(kwargs.get("provider") or "llm"),
|
||||
request,
|
||||
_impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"api_call_count": kwargs.get("api_call_count"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
model_name=str(kwargs.get("model") or ""),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
def _make_managed(impl: Callable[[Any], Any]) -> Any:
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.llm.execute(
|
||||
str(kwargs.get("provider") or "llm"),
|
||||
request,
|
||||
impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"api_call_count": kwargs.get("api_call_count"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
model_name=str(kwargs.get("model") or ""),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
managed_result = _resolve_awaitable(_managed_execute())
|
||||
return raw_response["value"] if raw_response["set"] else managed_result
|
||||
return _managed_execute()
|
||||
|
||||
return self._run_managed_with_downstream_preservation(
|
||||
next_call, _normalize, _llm_response_payload, _make_managed
|
||||
)
|
||||
|
||||
def execute_tool(self, kwargs: dict[str, Any]) -> Any:
|
||||
state = self.ensure_session(kwargs)
|
||||
|
|
@ -333,37 +370,35 @@ class _Runtime:
|
|||
if not callable(next_call):
|
||||
return args
|
||||
|
||||
raw_response: dict[str, Any] = {"set": False, "value": None}
|
||||
def _normalize(next_args: Any) -> Any:
|
||||
return next_args if isinstance(next_args, dict) else args
|
||||
|
||||
def _impl(next_args: Any) -> Any:
|
||||
effective_args = next_args if isinstance(next_args, dict) else args
|
||||
raw = next_call(effective_args)
|
||||
raw_response["set"] = True
|
||||
raw_response["value"] = raw
|
||||
return _jsonable(raw)
|
||||
def _make_managed(impl: Callable[[Any], Any]) -> Any:
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.tools.execute(
|
||||
tool_name,
|
||||
args,
|
||||
impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
|
||||
async def _managed_execute() -> Any:
|
||||
result = self.nemo_relay.tools.execute(
|
||||
tool_name,
|
||||
args,
|
||||
_impl,
|
||||
handle=state.handle,
|
||||
data=_jsonable(
|
||||
{
|
||||
"turn_id": kwargs.get("turn_id"),
|
||||
"api_request_id": kwargs.get("api_request_id"),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"mode": self.settings.adaptive_mode,
|
||||
}
|
||||
),
|
||||
metadata=_metadata(kwargs),
|
||||
)
|
||||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
return result
|
||||
return _managed_execute()
|
||||
|
||||
managed_result = _resolve_awaitable(_managed_execute())
|
||||
return raw_response["value"] if raw_response["set"] else managed_result
|
||||
return self._run_managed_with_downstream_preservation(
|
||||
next_call, _normalize, _jsonable, _make_managed
|
||||
)
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
|
|
@ -806,6 +841,30 @@ def _value(obj: Any, key: str, default: Any = None) -> Any:
|
|||
return getattr(obj, key, default)
|
||||
|
||||
|
||||
def _original_downstream_error(exc: Exception) -> BaseException:
|
||||
# Hermes wraps downstream execution failures in a local/private exception
|
||||
# class, so detect the wrapper by shape instead of importing it here.
|
||||
original = getattr(exc, "original", None)
|
||||
if exc.__class__.__name__ == "_DownstreamExecutionError" and isinstance(original, BaseException):
|
||||
return original
|
||||
return exc
|
||||
|
||||
|
||||
def _is_relay_wrapped_callback_error(exc: Exception, callback_error: Exception | None) -> bool:
|
||||
# NeMo Relay re-wraps a failing callback as ``RuntimeError("internal error:
|
||||
# <ClassName>: <message>")``. Match by prefix rather than exact equality so a
|
||||
# trailing traceback/suffix in a future Relay version doesn't silently defeat
|
||||
# the unwrap; the class-name + message prefix still discriminates the real
|
||||
# downstream failure from unrelated Relay-internal errors. If Relay drops the
|
||||
# leading ``internal error:`` shape entirely, this returns False and Hermes
|
||||
# falls back to surfacing Relay's error (the pre-fix behavior) rather than
|
||||
# masking it.
|
||||
if callback_error is None or not isinstance(exc, RuntimeError):
|
||||
return False
|
||||
expected = f"internal error: {callback_error.__class__.__name__}: {callback_error}"
|
||||
return str(exc).startswith(expected)
|
||||
|
||||
|
||||
def _llm_response_payload(response: Any) -> Any:
|
||||
"""Return the LLM response shape NeMo Relay's ATIF conversion expects."""
|
||||
payload = _jsonable(response)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import warnings
|
|||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from hermes_cli.plugins import PluginManager
|
||||
|
|
@ -153,6 +154,33 @@ def _fresh_plugin(monkeypatch, fake):
|
|||
return plugin
|
||||
|
||||
|
||||
def _wrapped_downstream_error(original):
|
||||
class _DownstreamExecutionError(Exception):
|
||||
def __init__(self, original):
|
||||
super().__init__(str(original))
|
||||
self.original = original
|
||||
|
||||
return _DownstreamExecutionError(original)
|
||||
|
||||
|
||||
def _enable_adaptive_plugin(tmp_path, monkeypatch) -> None:
|
||||
plugins_toml = tmp_path / "plugins.toml"
|
||||
plugins_toml.write_text(
|
||||
"""
|
||||
version = 1
|
||||
|
||||
[[components]]
|
||||
kind = "adaptive"
|
||||
enabled = true
|
||||
|
||||
[components.config.tool_parallelism]
|
||||
mode = "observe_only"
|
||||
""",
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
|
||||
|
||||
|
||||
def test_manifest_fields():
|
||||
data = yaml.safe_load((PLUGIN_DIR / "plugin.yaml").read_text())
|
||||
assert data["name"] == "nemo_relay"
|
||||
|
|
@ -783,6 +811,220 @@ mode = "observe_only"
|
|||
}
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
def native_like_execute(name, request, func, **kwargs):
|
||||
fake.events.append(("llm.execute.start", name, request.content, kwargs))
|
||||
try:
|
||||
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None
|
||||
|
||||
fake.llm.execute = native_like_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
class ProviderAuthError(Exception):
|
||||
status_code = 403
|
||||
|
||||
provider_error = ProviderAuthError("provider auth failed")
|
||||
|
||||
def next_call(request):
|
||||
raise _wrapped_downstream_error(provider_error)
|
||||
|
||||
with pytest.raises(ProviderAuthError) as caught:
|
||||
plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is provider_error
|
||||
assert caught.value.status_code == 403
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_preserves_downstream_error_with_relay_suffix(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
# Guards the startswith (vs exact ==) match in _is_relay_wrapped_callback_error:
|
||||
# Relay re-wraps the callback failure with its canonical prefix but APPENDS a
|
||||
# trailing suffix. Exact equality would miss this and surface Relay's wrapper;
|
||||
# prefix matching must still recover the original downstream error.
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
def native_like_execute(name, request, func, **kwargs):
|
||||
try:
|
||||
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc} (retried 3x)") from None
|
||||
|
||||
fake.llm.execute = native_like_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
class ProviderAuthError(Exception):
|
||||
status_code = 403
|
||||
|
||||
provider_error = ProviderAuthError("provider auth failed")
|
||||
|
||||
def next_call(request):
|
||||
raise _wrapped_downstream_error(provider_error)
|
||||
|
||||
with pytest.raises(ProviderAuthError) as caught:
|
||||
plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is provider_error
|
||||
assert caught.value.status_code == 403
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
relay_error = RuntimeError("internal error: relay setup failed")
|
||||
|
||||
def internal_error_execute(name, request, func, **kwargs):
|
||||
raise relay_error
|
||||
|
||||
fake.llm.execute = internal_error_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
with pytest.raises(RuntimeError) as caught:
|
||||
plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=lambda request: {"raw": request},
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_keeps_wrapped_relay_error_after_downstream_failure(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
fake = _FakeNemoRelay()
|
||||
relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream")
|
||||
|
||||
def translated_execute(name, request, func, **kwargs):
|
||||
try:
|
||||
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
except Exception:
|
||||
raise relay_error
|
||||
|
||||
fake.llm.execute = translated_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
def next_call(request):
|
||||
raise _wrapped_downstream_error(RuntimeError("provider failed"))
|
||||
|
||||
with pytest.raises(RuntimeError) as caught:
|
||||
plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_llm_execution_keeps_relay_translated_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
class RelayPolicyError(Exception):
|
||||
pass
|
||||
|
||||
relay_error = RelayPolicyError("relay policy blocked")
|
||||
|
||||
def translated_execute(name, request, func, **kwargs):
|
||||
try:
|
||||
return func(_FakeLLMRequest(request.headers, {"intercepted": True, **request.content}))
|
||||
except Exception:
|
||||
raise relay_error
|
||||
|
||||
fake.llm.execute = translated_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
provider_error = RuntimeError("provider failed")
|
||||
|
||||
def next_call(request):
|
||||
raise _wrapped_downstream_error(provider_error)
|
||||
|
||||
with pytest.raises(RelayPolicyError) as caught:
|
||||
plugin.on_llm_execution_middleware(
|
||||
session_id="s1",
|
||||
provider="anthropic",
|
||||
model="demo-model",
|
||||
request={"messages": [{"role": "user", "content": "hi"}]},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_downstream_unwrap_matches_real_middleware_wrapper_shape(monkeypatch):
|
||||
# Regression guard against core/plugin drift. The synthetic tests above model
|
||||
# the downstream-error wrapper with a local class, so they keep passing even
|
||||
# if core middleware renames its private ``_DownstreamExecutionError`` or drops
|
||||
# ``.original`` -- the exact shape the plugin matches by name at
|
||||
# ``_original_downstream_error``. Capture the wrapper the REAL
|
||||
# ``hermes_cli.middleware._run_execution_chain`` hands to a middleware
|
||||
# callback's ``next_call`` and assert the plugin's detector unwraps it to the
|
||||
# original exception. If core middleware changes the wrapper shape, this fails
|
||||
# here instead of silently defeating the unwrap in production.
|
||||
from hermes_cli import middleware
|
||||
|
||||
from plugins.observability.nemo_relay import _original_downstream_error
|
||||
|
||||
class ProviderError(Exception):
|
||||
status_code = 403
|
||||
|
||||
provider_error = ProviderError("provider auth failed")
|
||||
captured: dict[str, Exception] = {}
|
||||
|
||||
def terminal_call(payload):
|
||||
raise provider_error
|
||||
|
||||
def capturing_callback(**kwargs):
|
||||
next_call = kwargs["next_call"]
|
||||
try:
|
||||
return next_call(kwargs.get("request"))
|
||||
except Exception as exc:
|
||||
captured["wrapper"] = exc
|
||||
# Surface the original so the chain unwinds without re-wrapping noise.
|
||||
raise _original_downstream_error(exc) from None
|
||||
|
||||
with pytest.raises(ProviderError) as caught:
|
||||
middleware._run_execution_chain(
|
||||
"llm",
|
||||
[capturing_callback],
|
||||
terminal_call,
|
||||
request={"messages": []},
|
||||
)
|
||||
|
||||
wrapper = captured["wrapper"]
|
||||
# The wrapper the plugin sees must match what _original_downstream_error keys on.
|
||||
assert wrapper.__class__.__name__ == "_DownstreamExecutionError"
|
||||
assert isinstance(getattr(wrapper, "original", None), BaseException)
|
||||
assert _original_downstream_error(wrapper) is provider_error
|
||||
assert caught.value is provider_error
|
||||
assert caught.value.status_code == 403
|
||||
|
||||
|
||||
def _adaptive_llm_execute_mode(tmp_path, monkeypatch, plugins_toml_text: str) -> str:
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
|
@ -920,6 +1162,127 @@ mode = "observe_only"
|
|||
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_preserves_downstream_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
def native_like_execute(name, args, func, **kwargs):
|
||||
fake.events.append(("tool.execute.start", name, args, kwargs))
|
||||
try:
|
||||
return func({"intercepted": True, **args})
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"internal error: {type(exc).__name__}: {exc}") from None
|
||||
|
||||
fake.tools.execute = native_like_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
class ToolAuthError(Exception):
|
||||
status_code = 403
|
||||
|
||||
tool_error = ToolAuthError("tool auth failed")
|
||||
|
||||
def next_call(args):
|
||||
raise _wrapped_downstream_error(tool_error)
|
||||
|
||||
with pytest.raises(ToolAuthError) as caught:
|
||||
plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is tool_error
|
||||
assert caught.value.status_code == 403
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_keeps_unrelated_internal_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
relay_error = RuntimeError("internal error: relay setup failed")
|
||||
|
||||
def internal_error_execute(name, args, func, **kwargs):
|
||||
raise relay_error
|
||||
|
||||
fake.tools.execute = internal_error_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
with pytest.raises(RuntimeError) as caught:
|
||||
plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=lambda args: {"raw": args},
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_keeps_wrapped_relay_error_after_downstream_failure(
|
||||
tmp_path, monkeypatch
|
||||
):
|
||||
fake = _FakeNemoRelay()
|
||||
relay_error = RuntimeError("internal error: RuntimeError: relay policy blocked after downstream")
|
||||
|
||||
def translated_execute(name, args, func, **kwargs):
|
||||
try:
|
||||
return func({"intercepted": True, **args})
|
||||
except Exception:
|
||||
raise relay_error
|
||||
|
||||
fake.tools.execute = translated_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
def next_call(args):
|
||||
raise _wrapped_downstream_error(RuntimeError("tool failed"))
|
||||
|
||||
with pytest.raises(RuntimeError) as caught:
|
||||
plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_adaptive_tool_execution_keeps_relay_translated_error(tmp_path, monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
|
||||
class RelayPolicyError(Exception):
|
||||
pass
|
||||
|
||||
relay_error = RelayPolicyError("relay policy blocked")
|
||||
|
||||
def translated_execute(name, args, func, **kwargs):
|
||||
try:
|
||||
return func({"intercepted": True, **args})
|
||||
except Exception:
|
||||
raise relay_error
|
||||
|
||||
fake.tools.execute = translated_execute
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
_enable_adaptive_plugin(tmp_path, monkeypatch)
|
||||
|
||||
tool_error = RuntimeError("tool failed")
|
||||
|
||||
def next_call(args):
|
||||
raise _wrapped_downstream_error(tool_error)
|
||||
|
||||
with pytest.raises(RelayPolicyError) as caught:
|
||||
plugin.on_tool_execution_middleware(
|
||||
session_id="s1",
|
||||
tool_name="terminal",
|
||||
args={"command": "pwd"},
|
||||
next_call=next_call,
|
||||
)
|
||||
|
||||
assert caught.value is relay_error
|
||||
|
||||
|
||||
def test_nemo_relay_tool_execution_middleware_calls_through_without_adaptive(monkeypatch):
|
||||
fake = _FakeNemoRelay()
|
||||
plugin = _fresh_plugin(monkeypatch, fake)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue