From 19f9be1dffaf803bbb5bcb0d86afc20475f037e3 Mon Sep 17 00:00:00 2001 From: Heltman <44333070+Heltman@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:20:05 +0800 Subject: [PATCH] fix(tools): serialize concurrent hermes_tools RPC calls from execute_code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The sandbox-side `_call()` in both the UDS and file-based transports was not thread-safe, so scripts that call tools from multiple threads (e.g. `ThreadPoolExecutor` over `terminal()`) inside a single `execute_code` run could silently receive each other's responses. Root cause: * UDS transport — a single module-level `_sock` was shared across all threads; the newline-framed protocol has no request-id; and the server-side RPC loop handles one connection serially. With concurrent callers, each thread would `sendall()` then race to `recv()` the next newline-terminated response from the shared buffer, so responses got delivered to the wrong caller. * File transport — `_seq += 1` is a non-atomic read-modify-write, so two threads could allocate the same sequence number and clobber each other's request/response files. Fix: guard `_call()` with a `threading.Lock` in the UDS case (covering send+recv), and guard `_seq` allocation with a lock in the file case. No protocol change. Regression tests cover both the generated-source level (lock is present and used) and an end-to-end concurrency test: running a sandboxed ThreadPoolExecutor of 10 `terminal()` calls against a slow mock dispatcher, asserting every caller sees its own tagged response. The test fails without the fix (10/10 mismatched, matching real-world repro) and passes with it. --- tests/tools/test_code_execution.py | 78 +++++++++++++++++++++++++++++- tools/code_execution_tool.py | 42 ++++++++++------ 2 files changed, 103 insertions(+), 17 deletions(-) diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 6f6260ffe2..a580604658 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -114,14 +114,30 @@ class TestHermesToolsGeneration(unittest.TestCase): self.assertIn("def json_parse(", src) self.assertIn("def shell_quote(", src) self.assertIn("def retry(", src) - self.assertIn("import json, os, socket, shlex, time", src) + self.assertIn("import json, os, socket, shlex, threading, time", src) def test_file_transport_uses_tempfile_fallback_for_rpc_dir(self): src = generate_hermes_tools_module(["terminal"], transport="file") - self.assertIn("import json, os, shlex, tempfile, time", src) + self.assertIn("import json, os, shlex, tempfile, threading, time", src) self.assertIn("os.path.join(tempfile.gettempdir(), \"hermes_rpc\")", src) self.assertNotIn('os.environ.get("HERMES_RPC_DIR", "/tmp/hermes_rpc")', src) + def test_uds_transport_serializes_concurrent_calls(self): + """Regression: UDS _call() must hold a lock across send+recv so that + concurrent tool calls from multiple threads don't interleave on the + shared socket and receive each other's responses.""" + src = generate_hermes_tools_module(["terminal"], transport="uds") + self.assertIn("_call_lock = threading.Lock()", src) + self.assertIn("with _call_lock:", src) + + def test_file_transport_serializes_seq_allocation(self): + """Regression: file transport _call() must allocate `_seq` under a + lock, otherwise concurrent threads can pick the same seq and clobber + each other's request files.""" + src = generate_hermes_tools_module(["terminal"], transport="file") + self.assertIn("_seq_lock = threading.Lock()", src) + self.assertIn("with _seq_lock:", src) + class TestExecuteCodeRemoteTempDir(unittest.TestCase): def test_execute_remote_uses_backend_temp_dir_for_sandbox(self): @@ -226,6 +242,64 @@ print(f"file lines: {r2['total_lines']}") result = self._run("raise ValueError('test error')") self.assertEqual(result["status"], "error") + def test_concurrent_tool_calls_match_responses(self): + """Regression for the UDS RPC race: multiple threads inside the + sandbox calling terminal() concurrently must each receive their own + response, not another thread's. + + Before the fix, `_sock` and the recv-loop were shared without a + lock, so responses (written FIFO by the single-threaded server) + got delivered to whichever client thread happened to win the + recv() race. That surfaced as each thread seeing another thread's + output. + + The mock dispatcher sleeps briefly to guarantee the requests + overlap on the socket. + """ + code = ''' +import threading +from concurrent.futures import ThreadPoolExecutor +from hermes_tools import terminal + +N = 10 + +def call(i): + r = terminal(f"echo TAG-{i}") + return i, r.get("output", "") + +with ThreadPoolExecutor(max_workers=N) as ex: + results = list(ex.map(call, range(N))) + +mismatches = [(i, out) for i, out in results if f"TAG-{i}" not in out] +if mismatches: + print(f"MISMATCH {len(mismatches)}/{N}: {mismatches[:3]}") +else: + print(f"OK {N}/{N}") +''' + + def slow_mock(function_name, function_args, task_id=None, user_task=None): + import time as _t + if function_name == "terminal": + _t.sleep(0.05) # ensure requests overlap on the socket + cmd = function_args.get("command", "") + # Echo semantics: strip leading "echo " and return the rest + out = cmd[5:] if cmd.startswith("echo ") else f"mock: {cmd}" + return json.dumps({"output": out, "exit_code": 0}) + return _mock_handle_function_call( + function_name, function_args, task_id=task_id, user_task=user_task + ) + + with patch("model_tools.handle_function_call", side_effect=slow_mock): + raw = execute_code( + code=code, + task_id="test-concurrent", + enabled_tools=list(SANDBOX_ALLOWED_TOOLS), + ) + result = json.loads(raw) + self.assertEqual(result["status"], "success", msg=result) + self.assertIn("OK 10/10", result["output"], + msg=f"Concurrent tool calls mismatched: {result['output']!r}") + def test_excluded_tool_returns_error(self): """Script calling a tool not in the allow-list gets an error from RPC.""" code = """ diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index c91907c4d1..ffcf726fcd 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -224,9 +224,14 @@ def retry(fn, max_attempts=3, delay=2): _UDS_TRANSPORT_HEADER = '''\ """Auto-generated Hermes tools RPC stubs.""" -import json, os, socket, shlex, time +import json, os, socket, shlex, threading, time _sock = None +# The RPC server handles a single client connection serially and has no +# request-id in the protocol, so concurrent _call() invocations from multiple +# threads (e.g. ThreadPoolExecutor) would race on the shared socket and get +# each other's responses. Serialize the entire send+recv round-trip. +_call_lock = threading.Lock() ''' + _COMMON_HELPERS + '''\ def _connect(): @@ -239,17 +244,18 @@ def _connect(): def _call(tool_name, args): """Send a tool call to the parent process and return the parsed result.""" - conn = _connect() request = json.dumps({"tool": tool_name, "args": args}) + "\\n" - conn.sendall(request.encode()) - buf = b"" - while True: - chunk = conn.recv(65536) - if not chunk: - raise RuntimeError("Agent process disconnected") - buf += chunk - if buf.endswith(b"\\n"): - break + with _call_lock: + conn = _connect() + conn.sendall(request.encode()) + buf = b"" + while True: + chunk = conn.recv(65536) + if not chunk: + raise RuntimeError("Agent process disconnected") + buf += chunk + if buf.endswith(b"\\n"): + break raw = buf.decode().strip() result = json.loads(raw) if isinstance(result, str): @@ -265,24 +271,30 @@ def _call(tool_name, args): _FILE_TRANSPORT_HEADER = '''\ """Auto-generated Hermes tools RPC stubs (file-based transport).""" -import json, os, shlex, tempfile, time +import json, os, shlex, tempfile, threading, time _RPC_DIR = os.environ.get("HERMES_RPC_DIR") or os.path.join(tempfile.gettempdir(), "hermes_rpc") _seq = 0 +# `_seq += 1` is not atomic (read-modify-write), so concurrent _call() +# invocations from multiple threads could allocate the same sequence number +# and clobber each other's request files. Guard seq allocation with a lock. +_seq_lock = threading.Lock() ''' + _COMMON_HELPERS + '''\ def _call(tool_name, args): """Send a tool call request via file-based RPC and wait for response.""" global _seq - _seq += 1 - seq_str = f"{_seq:06d}" + with _seq_lock: + _seq += 1 + seq = _seq + seq_str = f"{seq:06d}" req_file = os.path.join(_RPC_DIR, f"req_{seq_str}") res_file = os.path.join(_RPC_DIR, f"res_{seq_str}") # Write request atomically (write to .tmp, then rename) tmp = req_file + ".tmp" with open(tmp, "w") as f: - json.dump({"tool": tool_name, "args": args, "seq": _seq}, f) + json.dump({"tool": tool_name, "args": args, "seq": seq}, f) os.rename(tmp, req_file) # Wait for response with adaptive polling