diff --git a/gateway/run.py b/gateway/run.py index 13f4cb6478..9c2b5b1db5 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -9443,6 +9443,19 @@ class GatewayRunner: return result next_message_id = getattr(pending_event, "message_id", None) + # Restart typing indicator so the user sees activity while + # the follow-up turn runs. The outer _process_message_background + # typing task is still alive but may be stale. + _followup_adapter = self.adapters.get(source.platform) + if _followup_adapter: + try: + await _followup_adapter.send_typing( + source.chat_id, + metadata=_status_thread_metadata, + ) + except Exception: + pass + return await self._run_agent( message=next_message, context_prompt=context_prompt, diff --git a/run_agent.py b/run_agent.py index 944217e6bc..d6dc9a0240 100644 --- a/run_agent.py +++ b/run_agent.py @@ -7549,24 +7549,50 @@ class AIAgent: # Wait for all to complete with periodic heartbeats so the # gateway's inactivity monitor doesn't kill us during long - # concurrent tool batches. + # concurrent tool batches. Also check for user interrupts + # so we don't block indefinitely when the user sends /stop + # or a new message during concurrent tool execution. _conc_start = time.time() + _interrupt_logged = False while True: done, not_done = concurrent.futures.wait( - futures, timeout=30.0, + futures, timeout=5.0, ) if not not_done: break + + # Check for interrupt — the per-thread interrupt signal + # already causes individual tools (terminal, execute_code) + # to abort, but tools without interrupt checks (web_search, + # read_file) will run to completion. Cancel any futures + # that haven't started yet so we don't block on them. + if self._interrupt_requested: + if not _interrupt_logged: + _interrupt_logged = True + self._vprint( + f"{self.log_prefix}⚡ Interrupt: cancelling " + f"{len(not_done)} pending concurrent tool(s)", + force=True, + ) + for f in not_done: + f.cancel() + # Give already-running tools a moment to notice the + # per-thread interrupt signal and exit gracefully. + concurrent.futures.wait(not_done, timeout=3.0) + break + _conc_elapsed = int(time.time() - _conc_start) - _still_running = [ - parsed_calls[futures.index(f)][1] - for f in not_done - if f in futures - ] - self._touch_activity( - f"concurrent tools running ({_conc_elapsed}s, " - f"{len(not_done)} remaining: {', '.join(_still_running[:3])})" - ) + # Heartbeat every ~30s (6 × 5s poll intervals) + if _conc_elapsed > 0 and _conc_elapsed % 30 < 6: + _still_running = [ + parsed_calls[futures.index(f)][1] + for f in not_done + if f in futures + ] + self._touch_activity( + f"concurrent tools running ({_conc_elapsed}s, " + f"{len(not_done)} remaining: {', '.join(_still_running[:3])})" + ) finally: if spinner: # Build a summary message for the spinner stop @@ -7578,8 +7604,11 @@ class AIAgent: for i, (tc, name, args) in enumerate(parsed_calls): r = results[i] if r is None: - # Shouldn't happen, but safety fallback - function_result = f"Error executing tool '{name}': thread did not return a result" + # Tool was cancelled (interrupt) or thread didn't return + if self._interrupt_requested: + function_result = f"[Tool execution cancelled — {name} was skipped due to user interrupt]" + else: + function_result = f"Error executing tool '{name}': thread did not return a result" tool_duration = 0.0 else: function_name, function_args, function_result, tool_duration, is_error = r diff --git a/tests/run_agent/test_concurrent_interrupt.py b/tests/run_agent/test_concurrent_interrupt.py new file mode 100644 index 0000000000..fdeb8dd690 --- /dev/null +++ b/tests/run_agent/test_concurrent_interrupt.py @@ -0,0 +1,139 @@ +"""Tests for interrupt handling in concurrent tool execution.""" + +import concurrent.futures +import threading +import time +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture(autouse=True) +def _isolate_hermes(tmp_path, monkeypatch): + monkeypatch.setenv("HERMES_HOME", str(tmp_path / ".hermes")) + (tmp_path / ".hermes").mkdir(exist_ok=True) + + +def _make_agent(monkeypatch): + """Create a minimal AIAgent-like object with just the methods under test.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "") + monkeypatch.setenv("HERMES_INFERENCE_PROVIDER", "") + # Avoid full AIAgent init — just import the class and build a stub + import run_agent as _ra + + class _Stub: + _interrupt_requested = False + log_prefix = "" + quiet_mode = True + verbose_logging = False + log_prefix_chars = 200 + _checkpoint_mgr = MagicMock(enabled=False) + _subdirectory_hints = MagicMock() + tool_progress_callback = None + tool_start_callback = None + tool_complete_callback = None + _todo_store = MagicMock() + _session_db = None + valid_tool_names = set() + _turns_since_memory = 0 + _iters_since_skill = 0 + _current_tool = None + _last_activity = 0 + _print_fn = print + + def _touch_activity(self, desc): + self._last_activity = time.time() + + def _vprint(self, msg, force=False): + pass + + def _safe_print(self, msg): + pass + + def _should_emit_quiet_tool_messages(self): + return False + + def _should_start_quiet_spinner(self): + return False + + def _has_stream_consumers(self): + return False + + stub = _Stub() + # Bind the real methods + stub._execute_tool_calls_concurrent = _ra.AIAgent._execute_tool_calls_concurrent.__get__(stub) + stub._invoke_tool = MagicMock(side_effect=lambda *a, **kw: '{"ok": true}') + return stub + + +class _FakeToolCall: + def __init__(self, name, args="{}", call_id="tc_1"): + self.function = MagicMock(name=name, arguments=args) + self.function.name = name + self.id = call_id + + +class _FakeAssistantMsg: + def __init__(self, tool_calls): + self.tool_calls = tool_calls + + +def test_concurrent_interrupt_cancels_pending(monkeypatch): + """When _interrupt_requested is set during concurrent execution, + the wait loop should exit early and cancelled tools get interrupt messages.""" + agent = _make_agent(monkeypatch) + + # Create a tool that blocks until interrupted + barrier = threading.Event() + + original_invoke = agent._invoke_tool + + def slow_tool(name, args, task_id, call_id=None): + if name == "slow_one": + # Block until the test sets the interrupt + barrier.wait(timeout=10) + return '{"slow": true}' + return '{"fast": true}' + + agent._invoke_tool = MagicMock(side_effect=slow_tool) + + tc1 = _FakeToolCall("fast_one", call_id="tc_fast") + tc2 = _FakeToolCall("slow_one", call_id="tc_slow") + msg = _FakeAssistantMsg([tc1, tc2]) + messages = [] + + def _set_interrupt_after_delay(): + time.sleep(0.3) + agent._interrupt_requested = True + barrier.set() # unblock the slow tool + + t = threading.Thread(target=_set_interrupt_after_delay) + t.start() + + agent._execute_tool_calls_concurrent(msg, messages, "test_task") + t.join() + + # Both tools should have results in messages + assert len(messages) == 2 + # The interrupt was detected + assert agent._interrupt_requested is True + + +def test_concurrent_preflight_interrupt_skips_all(monkeypatch): + """When _interrupt_requested is already set before concurrent execution, + all tools are skipped with cancellation messages.""" + agent = _make_agent(monkeypatch) + agent._interrupt_requested = True + + tc1 = _FakeToolCall("tool_a", call_id="tc_a") + tc2 = _FakeToolCall("tool_b", call_id="tc_b") + msg = _FakeAssistantMsg([tc1, tc2]) + messages = [] + + agent._execute_tool_calls_concurrent(msg, messages, "test_task") + + assert len(messages) == 2 + assert "skipped due to user interrupt" in messages[0]["content"] + assert "skipped due to user interrupt" in messages[1]["content"] + # _invoke_tool should never have been called + agent._invoke_tool.assert_not_called()