fix(tools): serialize concurrent hermes_tools RPC calls from execute_code

The sandbox-side `_call()` in both the UDS and file-based transports was
not thread-safe, so scripts that call tools from multiple threads (e.g.
`ThreadPoolExecutor` over `terminal()`) inside a single `execute_code`
run could silently receive each other's responses.

Root cause:

* UDS transport — a single module-level `_sock` was shared across all
  threads; the newline-framed protocol has no request-id; and the
  server-side RPC loop handles one connection serially. With concurrent
  callers, each thread would `sendall()` then race to `recv()` the next
  newline-terminated response from the shared buffer, so responses got
  delivered to the wrong caller.

* File transport — `_seq += 1` is a non-atomic read-modify-write, so
  two threads could allocate the same sequence number and clobber each
  other's request/response files.

Fix: guard `_call()` with a `threading.Lock` in the UDS case (covering
send+recv), and guard `_seq` allocation with a lock in the file case.
No protocol change.

Regression tests cover both the generated-source level (lock is present
and used) and an end-to-end concurrency test: running a sandboxed
ThreadPoolExecutor of 10 `terminal()` calls against a slow mock
dispatcher, asserting every caller sees its own tagged response. The
test fails without the fix (10/10 mismatched, matching real-world
repro) and passes with it.
This commit is contained in:
Heltman 2026-04-30 13:20:05 +08:00 committed by Teknium
parent 3858f9419e
commit 19f9be1dff
2 changed files with 103 additions and 17 deletions

View file

@ -224,9 +224,14 @@ def retry(fn, max_attempts=3, delay=2):
_UDS_TRANSPORT_HEADER = '''\
"""Auto-generated Hermes tools RPC stubs."""
import json, os, socket, shlex, time
import json, os, socket, shlex, threading, time
_sock = None
# The RPC server handles a single client connection serially and has no
# request-id in the protocol, so concurrent _call() invocations from multiple
# threads (e.g. ThreadPoolExecutor) would race on the shared socket and get
# each other's responses. Serialize the entire send+recv round-trip.
_call_lock = threading.Lock()
''' + _COMMON_HELPERS + '''\
def _connect():
@ -239,17 +244,18 @@ def _connect():
def _call(tool_name, args):
"""Send a tool call to the parent process and return the parsed result."""
conn = _connect()
request = json.dumps({"tool": tool_name, "args": args}) + "\\n"
conn.sendall(request.encode())
buf = b""
while True:
chunk = conn.recv(65536)
if not chunk:
raise RuntimeError("Agent process disconnected")
buf += chunk
if buf.endswith(b"\\n"):
break
with _call_lock:
conn = _connect()
conn.sendall(request.encode())
buf = b""
while True:
chunk = conn.recv(65536)
if not chunk:
raise RuntimeError("Agent process disconnected")
buf += chunk
if buf.endswith(b"\\n"):
break
raw = buf.decode().strip()
result = json.loads(raw)
if isinstance(result, str):
@ -265,24 +271,30 @@ def _call(tool_name, args):
_FILE_TRANSPORT_HEADER = '''\
"""Auto-generated Hermes tools RPC stubs (file-based transport)."""
import json, os, shlex, tempfile, time
import json, os, shlex, tempfile, threading, time
_RPC_DIR = os.environ.get("HERMES_RPC_DIR") or os.path.join(tempfile.gettempdir(), "hermes_rpc")
_seq = 0
# `_seq += 1` is not atomic (read-modify-write), so concurrent _call()
# invocations from multiple threads could allocate the same sequence number
# and clobber each other's request files. Guard seq allocation with a lock.
_seq_lock = threading.Lock()
''' + _COMMON_HELPERS + '''\
def _call(tool_name, args):
"""Send a tool call request via file-based RPC and wait for response."""
global _seq
_seq += 1
seq_str = f"{_seq:06d}"
with _seq_lock:
_seq += 1
seq = _seq
seq_str = f"{seq:06d}"
req_file = os.path.join(_RPC_DIR, f"req_{seq_str}")
res_file = os.path.join(_RPC_DIR, f"res_{seq_str}")
# Write request atomically (write to .tmp, then rename)
tmp = req_file + ".tmp"
with open(tmp, "w") as f:
json.dump({"tool": tool_name, "args": args, "seq": _seq}, f)
json.dump({"tool": tool_name, "args": args, "seq": seq}, f)
os.rename(tmp, req_file)
# Wait for response with adaptive polling