From a418ddbd8b9e7d3d158ac2dcb7ca281d1c9f602f Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Wed, 15 Apr 2026 13:29:05 -0700 Subject: [PATCH] fix: add activity heartbeats to prevent false gateway inactivity timeouts (#10501) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Multiple gaps in activity tracking could cause the gateway's inactivity timeout to fire while the agent is actively working: 1. Streaming wait loop had no periodic heartbeat — the outer thread only touched activity when the stale-stream detector fired (180-300s), and for local providers (Ollama) the stale timeout was infinity, meaning zero heartbeats. Now touches activity every 30s. 2. Concurrent tool execution never set the activity callback on worker threads (threading.local invisible across threads) and never set _current_tool. Workers now set the callback, and the concurrent wait uses a polling loop with 30s heartbeats. 3. Modal backend's execute() override had its own polling loop without any activity callback. Now matches _wait_for_process cadence (10s). --- run_agent.py | 54 ++++++++++++++++++- tests/tools/test_managed_modal_environment.py | 2 +- tools/environments/modal_utils.py | 20 +++++++ 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/run_agent.py b/run_agent.py index 1676d2f5a..b0bfa53da 100644 --- a/run_agent.py +++ b/run_agent.py @@ -5522,9 +5522,27 @@ class AIAgent: t = threading.Thread(target=_call, daemon=True) t.start() + _last_heartbeat = time.time() + _HEARTBEAT_INTERVAL = 30.0 # seconds between gateway activity touches while t.is_alive(): t.join(timeout=0.3) + # Periodic heartbeat: touch the agent's activity tracker so the + # gateway's inactivity monitor knows we're alive while waiting + # for stream chunks. Without this, long thinking pauses (e.g. + # reasoning models) or slow prefill on local providers (Ollama) + # trigger false inactivity timeouts. The _call thread touches + # activity on each chunk, but the gap between API call start + # and first chunk can exceed the gateway timeout — especially + # when the stale-stream timeout is disabled (local providers). + _hb_now = time.time() + if _hb_now - _last_heartbeat >= _HEARTBEAT_INTERVAL: + _last_heartbeat = _hb_now + _waiting_secs = int(_hb_now - last_chunk_time["t"]) + self._touch_activity( + f"waiting for stream response ({_waiting_secs}s, no chunks yet)" + ) + # Detect stale streams: connections kept alive by SSE pings # but delivering no real chunks. Kill the client so the # inner retry loop can start a fresh connection. @@ -7141,8 +7159,22 @@ class AIAgent: # Each slot holds (function_name, function_args, function_result, duration, error_flag) results = [None] * num_tools + # Touch activity before launching workers so the gateway knows + # we're executing tools (not stuck). + self._current_tool = tool_names_str + self._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}") + def _run_tool(index, tool_call, function_name, function_args): """Worker function executed in a thread.""" + # Set the activity callback on THIS worker thread so + # _wait_for_process (terminal commands) can fire heartbeats. + # The callback is thread-local; the main thread's callback + # is invisible to worker threads. + try: + from tools.environments.base import set_activity_callback + set_activity_callback(self._touch_activity) + except Exception: + pass start = time.time() try: result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id) @@ -7172,8 +7204,26 @@ class AIAgent: f = executor.submit(_run_tool, i, tc, name, args) futures.append(f) - # Wait for all to complete (exceptions are captured inside _run_tool) - concurrent.futures.wait(futures) + # Wait for all to complete with periodic heartbeats so the + # gateway's inactivity monitor doesn't kill us during long + # concurrent tool batches. + _conc_start = time.time() + while True: + done, not_done = concurrent.futures.wait( + futures, timeout=30.0, + ) + if not not_done: + break + _conc_elapsed = int(time.time() - _conc_start) + _still_running = [ + parsed_calls[futures.index(f)][1] + for f in not_done + if f in futures + ] + self._touch_activity( + f"concurrent tools running ({_conc_elapsed}s, " + f"{len(not_done)} remaining: {', '.join(_still_running[:3])})" + ) finally: if spinner: # Build a summary message for the spinner stop diff --git a/tests/tools/test_managed_modal_environment.py b/tests/tools/test_managed_modal_environment.py index 1d7241e0b..d36418336 100644 --- a/tests/tools/test_managed_modal_environment.py +++ b/tests/tools/test_managed_modal_environment.py @@ -296,7 +296,7 @@ def test_managed_modal_execute_times_out_and_cancels(monkeypatch): modal_common = sys.modules["tools.environments.modal_utils"] calls = [] - monotonic_values = iter([0.0, 12.5]) + monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5]) def fake_request(method, url, headers=None, json=None, timeout=None): calls.append((method, url, json, timeout)) diff --git a/tools/environments/modal_utils.py b/tools/environments/modal_utils.py index 0db819471..161aad261 100644 --- a/tools/environments/modal_utils.py +++ b/tools/environments/modal_utils.py @@ -105,6 +105,10 @@ class BaseModalExecutionEnvironment(BaseEnvironment): if self._client_timeout_grace_seconds is not None: deadline = time.monotonic() + prepared.timeout + self._client_timeout_grace_seconds + _last_activity_touch = time.monotonic() + _modal_exec_start = time.monotonic() + _ACTIVITY_INTERVAL = 10.0 # match _wait_for_process cadence + while True: if is_interrupted(): try: @@ -128,6 +132,22 @@ class BaseModalExecutionEnvironment(BaseEnvironment): pass return self._timeout_result_for_modal(prepared.timeout) + # Periodic activity touch so the gateway knows we're alive + _now = time.monotonic() + if _now - _last_activity_touch >= _ACTIVITY_INTERVAL: + _last_activity_touch = _now + try: + from tools.environments.base import _get_activity_callback + _cb = _get_activity_callback() + except Exception: + _cb = None + if _cb: + try: + _elapsed = int(_now - _modal_exec_start) + _cb(f"modal command running ({_elapsed}s elapsed)") + except Exception: + pass + time.sleep(self._poll_interval_seconds) def _before_execute(self) -> None: