fix(nemo-relay): align adaptive config with tool_parallelism mode

Signed-off-by: mnajafian-nv <mnajafian@nvidia.com>
This commit is contained in:
mnajafian-nv 2026-06-08 11:48:19 -07:00
parent a38003be3d
commit 021d1034d0
No known key found for this signature in database
GPG key ID: C0C3EEEE9FB11E38
3 changed files with 110 additions and 25 deletions

View file

@ -457,8 +457,8 @@ version = 1
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
[components.config.tool_parallelism]
mode = "observe_only"
""",
encoding="utf-8",
)
@ -506,7 +506,7 @@ mode = "route"
assert response.choices == [raw_choice]
assert seen_request["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
assert execute_start[3]["data"]["mode"] == "observe_only"
execute_end = next(event for event in fake.events if event[0] == "llm.execute.end")
assert execute_end[2] == {
"model": "demo-model",
@ -527,6 +527,84 @@ mode = "route"
}
def _adaptive_llm_execute_mode(tmp_path, monkeypatch, plugins_toml_text: str) -> str:
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
plugins_toml = tmp_path / "plugins.toml"
plugins_toml.write_text(plugins_toml_text, encoding="utf-8")
monkeypatch.setenv("HERMES_NEMO_RELAY_PLUGINS_TOML", str(plugins_toml))
plugin.on_llm_execution_middleware(
session_id="s1",
provider="anthropic",
model="demo-model",
request={"messages": [{"role": "user", "content": "hi"}]},
next_call=lambda request: {"raw": request},
)
execute_start = next(event for event in fake.events if event[0] == "llm.execute.start")
return execute_start[3]["data"]["mode"]
def test_nemo_relay_adaptive_llm_execution_middleware_defaults_to_observe_only_when_mode_is_unset(
tmp_path, monkeypatch
):
mode = _adaptive_llm_execute_mode(
tmp_path,
monkeypatch,
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
version = 1
""",
)
assert mode == "observe_only"
def test_nemo_relay_adaptive_llm_execution_middleware_accepts_legacy_top_level_mode(tmp_path, monkeypatch):
mode = _adaptive_llm_execute_mode(
tmp_path,
monkeypatch,
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
""",
)
assert mode == "route"
def test_nemo_relay_adaptive_llm_execution_middleware_prefers_tool_parallelism_mode(tmp_path, monkeypatch):
mode = _adaptive_llm_execute_mode(
tmp_path,
monkeypatch,
"""
version = 1
[[components]]
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
[components.config.tool_parallelism]
mode = "schedule"
""",
)
assert mode == "schedule"
def test_nemo_relay_llm_execution_middleware_calls_through_without_adaptive(monkeypatch):
fake = _FakeNemoRelay()
plugin = _fresh_plugin(monkeypatch, fake)
@ -555,8 +633,8 @@ version = 1
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
[components.config.tool_parallelism]
mode = "observe_only"
""",
encoding="utf-8",
)
@ -582,7 +660,7 @@ mode = "route"
assert response == {"raw": True, "args": {"command": "pwd", "intercepted": True}}
assert seen_args["intercepted"] is True
execute_start = next(event for event in fake.events if event[0] == "tool.execute.start")
assert execute_start[3]["data"]["mode"] == "route"
assert execute_start[3]["data"]["mode"] == "observe_only"
assert execute_start[3]["data"]["tool_call_id"] == "tool-1"
@ -613,8 +691,8 @@ version = 1
kind = "adaptive"
enabled = true
[components.config]
mode = "route"
[components.config.tool_parallelism]
mode = "observe_only"
""",
encoding="utf-8",
)