From 8d5d36d793583fbdc679d880674268178ed9b81b Mon Sep 17 00:00:00 2001 From: aimable100 <129232709+aimable100@users.noreply.github.com> Date: Sat, 13 Jun 2026 21:27:59 -0700 Subject: [PATCH] fix(dispatch): forward session_id into registry.dispatch (#28479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Both the regular and execute_code dispatch paths forward task_id into registry.dispatch via middleware _dispatch lambdas but silently dropped session_id. Dispatch-layer hooks (e.g. set_enforcement_fn) that correlate calls with the active session received "" for every invocation. Pass session_id=session_id at both _dispatch call sites inside handle_function_call, matching the existing task_id pattern. Hooks already received session_id; this closes the registry.dispatch gap. Rebased onto current main where dispatch is wrapped by run_tool_execution_middleware — the old direct-dispatch sites from #28479 no longer exist. test(dispatch): add tests for session_id forwarding (NousResearch#28479) Covers standard and execute_code paths through the middleware wrapper. Verifies task_id forwarding is not broken by the change. --- model_tools.py | 2 + tests/test_dispatch_session_id.py | 75 +++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) create mode 100644 tests/test_dispatch_session_id.py diff --git a/model_tools.py b/model_tools.py index 22719a5daef..0618138aa9a 100644 --- a/model_tools.py +++ b/model_tools.py @@ -1115,6 +1115,7 @@ def handle_function_call( return registry.dispatch( function_name, next_args, task_id=task_id, + session_id=session_id, enabled_tools=sandbox_enabled, ) else: @@ -1122,6 +1123,7 @@ def handle_function_call( return registry.dispatch( function_name, next_args, task_id=task_id, + session_id=session_id, user_task=user_task, ) from hermes_cli.middleware import run_tool_execution_middleware diff --git a/tests/test_dispatch_session_id.py b/tests/test_dispatch_session_id.py new file mode 100644 index 00000000000..69a6e592f0a --- /dev/null +++ b/tests/test_dispatch_session_id.py @@ -0,0 +1,75 @@ +"""Tests that handle_function_call forwards session_id into registry.dispatch.""" + +import json +from unittest.mock import MagicMock, patch + + +def _make_registry(captured: dict): + """Return a mock registry whose dispatch records the kwargs it receives.""" + registry = MagicMock() + + def _dispatch(name, args, **kwargs): + captured.update(kwargs) + return json.dumps({"result": "ok"}) + + registry.dispatch.side_effect = _dispatch + return registry + + +class TestSessionIdForwarding: + + def test_standard_path_forwards_session_id(self): + """registry.dispatch receives session_id on the normal tool path.""" + captured = {} + with patch("model_tools.registry", _make_registry(captured)): + from model_tools import handle_function_call + handle_function_call( + "web_search", + {"query": "test"}, + task_id="t1", + session_id="sess-abc", + skip_pre_tool_call_hook=True, + ) + assert captured.get("session_id") == "sess-abc" + + def test_execute_code_path_forwards_session_id(self): + """registry.dispatch receives session_id on the execute_code path.""" + captured = {} + with patch("model_tools.registry", _make_registry(captured)): + from model_tools import handle_function_call + handle_function_call( + "execute_code", + {"code": "print(1)"}, + task_id="t1", + session_id="sess-xyz", + skip_pre_tool_call_hook=True, + ) + assert captured.get("session_id") == "sess-xyz" + + def test_session_id_default_is_none(self): + """When session_id is omitted, dispatch receives None.""" + captured = {} + with patch("model_tools.registry", _make_registry(captured)): + from model_tools import handle_function_call + handle_function_call( + "web_search", + {"query": "test"}, + task_id="t1", + skip_pre_tool_call_hook=True, + ) + assert "session_id" in captured + assert captured["session_id"] is None + + def test_task_id_still_forwarded(self): + """Existing task_id forwarding is not broken by this change.""" + captured = {} + with patch("model_tools.registry", _make_registry(captured)): + from model_tools import handle_function_call + handle_function_call( + "web_search", + {"query": "test"}, + task_id="task-999", + session_id="sess-1", + skip_pre_tool_call_hook=True, + ) + assert captured.get("task_id") == "task-999"