diff --git a/run_agent.py b/run_agent.py index 21a896063..ba0a9f93d 100644 --- a/run_agent.py +++ b/run_agent.py @@ -739,6 +739,7 @@ class AIAgent: # Interrupt mechanism for breaking out of tool loops self._interrupt_requested = False self._interrupt_message = None # Optional message that triggered interrupt + self._execution_thread_id: int | None = None # Set at run_conversation() start self._client_lock = threading.RLock() # Subagent delegation state @@ -2832,8 +2833,10 @@ class AIAgent: """ self._interrupt_requested = True self._interrupt_message = message - # Signal all tools to abort any in-flight operations immediately - _set_interrupt(True) + # Signal all tools to abort any in-flight operations immediately. + # Scope the interrupt to this agent's execution thread so other + # agents running in the same process (gateway) are not affected. + _set_interrupt(True, self._execution_thread_id) # Propagate interrupt to any running child agents (subagent delegation) with self._active_children_lock: children_copy = list(self._active_children) @@ -2846,10 +2849,10 @@ class AIAgent: print("\n⚡ Interrupt requested" + (f": '{message[:40]}...'" if message and len(message) > 40 else f": '{message}'" if message else "")) def clear_interrupt(self) -> None: - """Clear any pending interrupt request and the global tool interrupt signal.""" + """Clear any pending interrupt request and the per-thread tool interrupt signal.""" self._interrupt_requested = False self._interrupt_message = None - _set_interrupt(False) + _set_interrupt(False, self._execution_thread_id) def _touch_activity(self, desc: str) -> None: """Update the last-activity timestamp and description (thread-safe).""" @@ -7799,6 +7802,11 @@ class AIAgent: compression_attempts = 0 _turn_exit_reason = "unknown" # Diagnostic: why the loop ended + # Record the execution thread so interrupt()/clear_interrupt() can + # scope the tool-level interrupt signal to THIS agent's thread only. + # Must be set before clear_interrupt() which uses it. + self._execution_thread_id = threading.current_thread().ident + # Clear any stale interrupt state at start self.clear_interrupt() diff --git a/tests/run_agent/test_interrupt_propagation.py b/tests/run_agent/test_interrupt_propagation.py index 7f8cb01c3..a746efdac 100644 --- a/tests/run_agent/test_interrupt_propagation.py +++ b/tests/run_agent/test_interrupt_propagation.py @@ -22,23 +22,22 @@ class TestInterruptPropagationToChild(unittest.TestCase): def tearDown(self): set_interrupt(False) + def _make_bare_agent(self): + """Create a bare AIAgent via __new__ with all interrupt-related attrs.""" + from run_agent import AIAgent + agent = AIAgent.__new__(AIAgent) + agent._interrupt_requested = False + agent._interrupt_message = None + agent._execution_thread_id = None # defaults to current thread in set_interrupt + agent._active_children = [] + agent._active_children_lock = threading.Lock() + agent.quiet_mode = True + return agent + def test_parent_interrupt_sets_child_flag(self): """When parent.interrupt() is called, child._interrupt_requested should be set.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() parent._active_children.append(child) @@ -49,40 +48,26 @@ class TestInterruptPropagationToChild(unittest.TestCase): assert child._interrupt_message == "new user message" assert is_interrupted() is True - def test_child_clear_interrupt_at_start_clears_global(self): - """child.clear_interrupt() at start of run_conversation clears the GLOBAL event. - - This is the intended behavior at startup, but verify it doesn't - accidentally clear an interrupt intended for a running child. + def test_child_clear_interrupt_at_start_clears_thread(self): + """child.clear_interrupt() at start of run_conversation clears the + per-thread interrupt flag for the current thread. """ - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) + child = self._make_bare_agent() child._interrupt_requested = True child._interrupt_message = "msg" - child.quiet_mode = True - child._active_children = [] - child._active_children_lock = threading.Lock() - # Global is set + # Interrupt for current thread is set set_interrupt(True) assert is_interrupted() is True - # child.clear_interrupt() clears both + # child.clear_interrupt() clears both instance flag and thread flag child.clear_interrupt() assert child._interrupt_requested is False assert is_interrupted() is False def test_interrupt_during_child_api_call_detected(self): """Interrupt set during _interruptible_api_call is detected within 0.5s.""" - from run_agent import AIAgent - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + child = self._make_bare_agent() child.api_mode = "chat_completions" child.log_prefix = "" child._client_kwargs = {"api_key": "test", "base_url": "http://localhost:1234"} @@ -117,21 +102,8 @@ class TestInterruptPropagationToChild(unittest.TestCase): def test_concurrent_interrupt_propagation(self): """Simulates exact CLI flow: parent runs delegate in thread, main thread interrupts.""" - from run_agent import AIAgent - - parent = AIAgent.__new__(AIAgent) - parent._interrupt_requested = False - parent._interrupt_message = None - parent._active_children = [] - parent._active_children_lock = threading.Lock() - parent.quiet_mode = True - - child = AIAgent.__new__(AIAgent) - child._interrupt_requested = False - child._interrupt_message = None - child._active_children = [] - child._active_children_lock = threading.Lock() - child.quiet_mode = True + parent = self._make_bare_agent() + child = self._make_bare_agent() # Register child (simulating what _run_single_child does) parent._active_children.append(child) @@ -157,5 +129,79 @@ class TestInterruptPropagationToChild(unittest.TestCase): set_interrupt(False) +class TestPerThreadInterruptIsolation(unittest.TestCase): + """Verify that interrupting one agent does NOT affect another agent's thread. + + This is the core fix for the gateway cross-session interrupt leak: + multiple agents run in separate threads within the same process, and + interrupting agent A must not kill agent B's running tools. + """ + + def setUp(self): + set_interrupt(False) + + def tearDown(self): + set_interrupt(False) + + def test_interrupt_only_affects_target_thread(self): + """set_interrupt(True, tid) only makes is_interrupted() True on that thread.""" + results = {} + barrier = threading.Barrier(2) + + def thread_a(): + """Agent A's execution thread — will be interrupted.""" + tid = threading.current_thread().ident + results["a_tid"] = tid + barrier.wait(timeout=5) # sync with thread B + time.sleep(0.2) # let the interrupt arrive + results["a_interrupted"] = is_interrupted() + + def thread_b(): + """Agent B's execution thread — should NOT be affected.""" + tid = threading.current_thread().ident + results["b_tid"] = tid + barrier.wait(timeout=5) # sync with thread A + time.sleep(0.2) + results["b_interrupted"] = is_interrupted() + + ta = threading.Thread(target=thread_a) + tb = threading.Thread(target=thread_b) + ta.start() + tb.start() + + # Wait for both threads to register their TIDs + time.sleep(0.05) + while "a_tid" not in results or "b_tid" not in results: + time.sleep(0.01) + + # Interrupt ONLY thread A (simulates gateway interrupting agent A) + set_interrupt(True, results["a_tid"]) + + ta.join(timeout=3) + tb.join(timeout=3) + + assert results["a_interrupted"] is True, "Thread A should see the interrupt" + assert results["b_interrupted"] is False, "Thread B must NOT see thread A's interrupt" + + def test_clear_interrupt_only_clears_target_thread(self): + """Clearing one thread's interrupt doesn't clear another's.""" + tid_a = 99990001 + tid_b = 99990002 + set_interrupt(True, tid_a) + set_interrupt(True, tid_b) + + # Clear only A + set_interrupt(False, tid_a) + + # Simulate checking from thread B's perspective + from tools.interrupt import _interrupted_threads, _lock + with _lock: + assert tid_a not in _interrupted_threads + assert tid_b in _interrupted_threads + + # Cleanup + set_interrupt(False, tid_b) + + if __name__ == "__main__": unittest.main() diff --git a/tests/tools/test_code_execution.py b/tests/tools/test_code_execution.py index 33653c360..e015e5d42 100644 --- a/tests/tools/test_code_execution.py +++ b/tests/tools/test_code_execution.py @@ -780,14 +780,18 @@ class TestLoadConfig(unittest.TestCase): @unittest.skipIf(sys.platform == "win32", "UDS not available on Windows") class TestInterruptHandling(unittest.TestCase): def test_interrupt_event_stops_execution(self): - """When _interrupt_event is set, execute_code should stop the script.""" + """When interrupt is set for the execution thread, execute_code should stop.""" code = "import time; time.sleep(60); print('should not reach')" + from tools.interrupt import set_interrupt + + # Capture the main thread ID so we can target the interrupt correctly. + # execute_code runs in the current thread; set_interrupt needs its ID. + main_tid = threading.current_thread().ident def set_interrupt_after_delay(): import time as _t _t.sleep(1) - from tools.terminal_tool import _interrupt_event - _interrupt_event.set() + set_interrupt(True, main_tid) t = threading.Thread(target=set_interrupt_after_delay, daemon=True) t.start() @@ -804,8 +808,7 @@ class TestInterruptHandling(unittest.TestCase): self.assertEqual(result["status"], "interrupted") self.assertIn("interrupted", result["output"]) finally: - from tools.terminal_tool import _interrupt_event - _interrupt_event.clear() + set_interrupt(False, main_tid) t.join(timeout=3) diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 7837d70d6..d6c561e2c 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -924,8 +924,8 @@ def execute_code( # --- Local execution path (UDS) --- below this line is unchanged --- - # Import interrupt event from terminal_tool (cooperative cancellation) - from tools.terminal_tool import _interrupt_event + # Import per-thread interrupt check (cooperative cancellation) + from tools.interrupt import is_interrupted as _is_interrupted # Resolve config _cfg = _load_config() @@ -1114,7 +1114,7 @@ def execute_code( status = "success" while proc.poll() is None: - if _interrupt_event.is_set(): + if _is_interrupted(): _kill_process_group(proc) status = "interrupted" break diff --git a/tools/interrupt.py b/tools/interrupt.py index e5c9b1e27..9bc8b83ae 100644 --- a/tools/interrupt.py +++ b/tools/interrupt.py @@ -1,8 +1,12 @@ -"""Shared interrupt signaling for all tools. +"""Per-thread interrupt signaling for all tools. -Provides a global threading.Event that any tool can check to determine -if the user has requested an interrupt. The agent's interrupt() method -sets this event, and tools poll it during long-running operations. +Provides thread-scoped interrupt tracking so that interrupting one agent +session does not kill tools running in other sessions. This is critical +in the gateway where multiple agents run concurrently in the same process. + +The agent stores its execution thread ID at the start of run_conversation() +and passes it to set_interrupt()/clear_interrupt(). Tools call +is_interrupted() which checks the CURRENT thread — no argument needed. Usage in tools: from tools.interrupt import is_interrupted @@ -12,17 +16,61 @@ Usage in tools: import threading -_interrupt_event = threading.Event() +# Set of thread idents that have been interrupted. +_interrupted_threads: set[int] = set() +_lock = threading.Lock() -def set_interrupt(active: bool) -> None: - """Called by the agent to signal or clear the interrupt.""" - if active: - _interrupt_event.set() - else: - _interrupt_event.clear() +def set_interrupt(active: bool, thread_id: int | None = None) -> None: + """Set or clear interrupt for a specific thread. + + Args: + active: True to signal interrupt, False to clear it. + thread_id: Target thread ident. When None, targets the + current thread (backward compat for CLI/tests). + """ + tid = thread_id if thread_id is not None else threading.current_thread().ident + with _lock: + if active: + _interrupted_threads.add(tid) + else: + _interrupted_threads.discard(tid) def is_interrupted() -> bool: - """Check if an interrupt has been requested. Safe to call from any thread.""" - return _interrupt_event.is_set() + """Check if an interrupt has been requested for the current thread. + + Safe to call from any thread — each thread only sees its own + interrupt state. + """ + tid = threading.current_thread().ident + with _lock: + return tid in _interrupted_threads + + +# --------------------------------------------------------------------------- +# Backward-compatible _interrupt_event proxy +# --------------------------------------------------------------------------- +# Some legacy call sites (code_execution_tool, process_registry, tests) +# import _interrupt_event directly and call .is_set() / .set() / .clear(). +# This shim maps those calls to the per-thread functions above so existing +# code keeps working while the underlying mechanism is thread-scoped. + +class _ThreadAwareEventProxy: + """Drop-in proxy that maps threading.Event methods to per-thread state.""" + + def is_set(self) -> bool: + return is_interrupted() + + def set(self) -> None: # noqa: A003 + set_interrupt(True) + + def clear(self) -> None: + set_interrupt(False) + + def wait(self, timeout: float | None = None) -> bool: + """Not truly supported — returns current state immediately.""" + return self.is_set() + + +_interrupt_event = _ThreadAwareEventProxy() diff --git a/tools/process_registry.py b/tools/process_registry.py index 1761221f0..044a4e776 100644 --- a/tools/process_registry.py +++ b/tools/process_registry.py @@ -686,7 +686,7 @@ class ProcessRegistry: and output snapshot. """ from tools.ansi_strip import strip_ansi - from tools.terminal_tool import _interrupt_event + from tools.interrupt import is_interrupted as _is_interrupted try: default_timeout = int(os.getenv("TERMINAL_TIMEOUT", "180")) @@ -723,7 +723,7 @@ class ProcessRegistry: result["timeout_note"] = timeout_note return result - if _interrupt_event.is_set(): + if _is_interrupted(): result = { "status": "interrupted", "output": strip_ansi(session.output_buffer[-1000:]),