fix: scope tool interrupt signal per-thread to prevent cross-session leaks (#7930)

The interrupt mechanism in tools/interrupt.py used a process-global
threading.Event. In the gateway, multiple agents run concurrently in
the same process via run_in_executor. When any agent was interrupted
(user sends a follow-up message), the global flag killed ALL agents'
running tools — terminal commands, browser ops, web requests — across
all sessions.

Changes:
- tools/interrupt.py: Replace single threading.Event with a set of
  interrupted thread IDs. set_interrupt() targets a specific thread;
  is_interrupted() checks the current thread. Includes a backward-
  compatible _ThreadAwareEventProxy for legacy _interrupt_event usage.
- run_agent.py: Store execution thread ID at start of run_conversation().
  interrupt() and clear_interrupt() pass it to set_interrupt() so only
  this agent's thread is affected.
- tools/code_execution_tool.py: Use is_interrupted() instead of
  directly checking _interrupt_event.is_set().
- tools/process_registry.py: Same — use is_interrupted().
- tests: Update interrupt tests for per-thread semantics. Add new
  TestPerThreadInterruptIsolation with two tests verifying cross-thread
  isolation.
This commit is contained in:
Teknium 2026-04-11 14:02:58 -07:00 committed by GitHub
parent 75380de430
commit dfc820345d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 183 additions and 78 deletions

View file

@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def tearDown(self):
set_interrupt(False)
def _make_bare_agent(self):
"""Create a bare AIAgent via __new__ with all interrupt-related attrs."""
from run_agent import AIAgent
agent = AIAgent.__new__(AIAgent)
agent._interrupt_requested = False
agent._interrupt_message = None
agent._execution_thread_id = None # defaults to current thread in set_interrupt
agent._active_children = []
agent._active_children_lock = threading.Lock()
agent.quiet_mode = True
return agent
def test_parent_interrupt_sets_child_flag(self):
"""When parent.interrupt() is called, child._interrupt_requested should be set."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
parent._active_children.append(child)
@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase):
assert child._interrupt_message == "new user message"
assert is_interrupted() is True
def test_child_clear_interrupt_at_start_clears_global(self):
"""child.clear_interrupt() at start of run_conversation clears the GLOBAL event.
This is the intended behavior at startup, but verify it doesn't
accidentally clear an interrupt intended for a running child.
def test_child_clear_interrupt_at_start_clears_thread(self):
"""child.clear_interrupt() at start of run_conversation clears the
per-thread interrupt flag for the current thread.
"""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child = self._make_bare_agent()
child._interrupt_requested = True
child._interrupt_message = "msg"
child.quiet_mode = True
child._active_children = []
child._active_children_lock = threading.Lock()
# Global is set
# Interrupt for current thread is set
set_interrupt(True)
assert is_interrupted() is True
# child.clear_interrupt() clears both
# child.clear_interrupt() clears both instance flag and thread flag
child.clear_interrupt()
assert child._interrupt_requested is False
assert is_interrupted() is False
def test_interrupt_during_child_api_call_detected(self):
"""Interrupt set during _interruptible_api_call is detected within 0.5s."""
from run_agent import AIAgent
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
child = self._make_bare_agent()
child.api_mode = "chat_completions"
child.log_prefix = ""
child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"}
@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase):
def test_concurrent_interrupt_propagation(self):
"""Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts."""
from run_agent import AIAgent
parent = AIAgent.__new__(AIAgent)
parent._interrupt_requested = False
parent._interrupt_message = None
parent._active_children = []
parent._active_children_lock = threading.Lock()
parent.quiet_mode = True
child = AIAgent.__new__(AIAgent)
child._interrupt_requested = False
child._interrupt_message = None
child._active_children = []
child._active_children_lock = threading.Lock()
child.quiet_mode = True
parent = self._make_bare_agent()
child = self._make_bare_agent()
# Register child (simulating what _run_single_child does)
parent._active_children.append(child)
@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase):
set_interrupt(False)
class TestPerThreadInterruptIsolation(unittest.TestCase):
"""Verify that interrupting one agent does NOT affect another agent's thread.
This is the core fix for the gateway cross-session interrupt leak:
multiple agents run in separate threads within the same process, and
interrupting agent A must not kill agent B's running tools.
"""
def setUp(self):
set_interrupt(False)
def tearDown(self):
set_interrupt(False)
def test_interrupt_only_affects_target_thread(self):
"""set_interrupt(True, tid) only makes is_interrupted() True on that thread."""
results = {}
barrier = threading.Barrier(2)
def thread_a():
"""Agent A's execution thread — will be interrupted."""
tid = threading.current_thread().ident
results["a_tid"] = tid
barrier.wait(timeout=5) # sync with thread B
time.sleep(0.2) # let the interrupt arrive
results["a_interrupted"] = is_interrupted()
def thread_b():
"""Agent B's execution thread — should NOT be affected."""
tid = threading.current_thread().ident
results["b_tid"] = tid
barrier.wait(timeout=5) # sync with thread A
time.sleep(0.2)
results["b_interrupted"] = is_interrupted()
ta = threading.Thread(target=thread_a)
tb = threading.Thread(target=thread_b)
ta.start()
tb.start()
# Wait for both threads to register their TIDs
time.sleep(0.05)
while "a_tid" not in results or "b_tid" not in results:
time.sleep(0.01)
# Interrupt ONLY thread A (simulates gateway interrupting agent A)
set_interrupt(True, results["a_tid"])
ta.join(timeout=3)
tb.join(timeout=3)
assert results["a_interrupted"] is True, "Thread A should see the interrupt"
assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt"
def test_clear_interrupt_only_clears_target_thread(self):
"""Clearing one thread's interrupt doesn't clear another's."""
tid_a = 99990001
tid_b = 99990002
set_interrupt(True, tid_a)
set_interrupt(True, tid_b)
# Clear only A
set_interrupt(False, tid_a)
# Simulate checking from thread B's perspective
from tools.interrupt import _interrupted_threads, _lock
with _lock:
assert tid_a not in _interrupted_threads
assert tid_b in _interrupted_threads
# Cleanup
set_interrupt(False, tid_b)
if __name__ == "__main__":
unittest.main()