fix: add activity heartbeats to prevent false gateway inactivity timeouts (#10501)

Multiple gaps in activity tracking could cause the gateway's inactivity
timeout to fire while the agent is actively working:

1. Streaming wait loop had no periodic heartbeat — the outer thread only
   touched activity when the stale-stream detector fired (180-300s), and
   for local providers (Ollama) the stale timeout was infinity, meaning
   zero heartbeats. Now touches activity every 30s.

2. Concurrent tool execution never set the activity callback on worker
   threads (threading.local invisible across threads) and never set
   _current_tool. Workers now set the callback, and the concurrent wait
   uses a polling loop with 30s heartbeats.

3. Modal backend's execute() override had its own polling loop without
   any activity callback. Now matches _wait_for_process cadence (10s).
This commit is contained in:
Teknium 2026-04-15 13:29:05 -07:00 committed by GitHub
parent 0d25e1c146
commit a418ddbd8b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 73 additions and 3 deletions

View file

@ -5522,9 +5522,27 @@ class AIAgent:
t = threading.Thread(target=_call, daemon=True)
t.start()
_last_heartbeat = time.time()
_HEARTBEAT_INTERVAL = 30.0 # seconds between gateway activity touches
while t.is_alive():
t.join(timeout=0.3)
# Periodic heartbeat: touch the agent's activity tracker so the
# gateway's inactivity monitor knows we're alive while waiting
# for stream chunks. Without this, long thinking pauses (e.g.
# reasoning models) or slow prefill on local providers (Ollama)
# trigger false inactivity timeouts. The _call thread touches
# activity on each chunk, but the gap between API call start
# and first chunk can exceed the gateway timeout — especially
# when the stale-stream timeout is disabled (local providers).
_hb_now = time.time()
if _hb_now - _last_heartbeat >= _HEARTBEAT_INTERVAL:
_last_heartbeat = _hb_now
_waiting_secs = int(_hb_now - last_chunk_time["t"])
self._touch_activity(
f"waiting for stream response ({_waiting_secs}s, no chunks yet)"
)
# Detect stale streams: connections kept alive by SSE pings
# but delivering no real chunks. Kill the client so the
# inner retry loop can start a fresh connection.
@ -7141,8 +7159,22 @@ class AIAgent:
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
results = [None] * num_tools
# Touch activity before launching workers so the gateway knows
# we're executing tools (not stuck).
self._current_tool = tool_names_str
self._touch_activity(f"executing {num_tools} tools concurrently: {tool_names_str}")
def _run_tool(index, tool_call, function_name, function_args):
"""Worker function executed in a thread."""
# Set the activity callback on THIS worker thread so
# _wait_for_process (terminal commands) can fire heartbeats.
# The callback is thread-local; the main thread's callback
# is invisible to worker threads.
try:
from tools.environments.base import set_activity_callback
set_activity_callback(self._touch_activity)
except Exception:
pass
start = time.time()
try:
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id)
@ -7172,8 +7204,26 @@ class AIAgent:
f = executor.submit(_run_tool, i, tc, name, args)
futures.append(f)
# Wait for all to complete (exceptions are captured inside _run_tool)
concurrent.futures.wait(futures)
# Wait for all to complete with periodic heartbeats so the
# gateway's inactivity monitor doesn't kill us during long
# concurrent tool batches.
_conc_start = time.time()
while True:
done, not_done = concurrent.futures.wait(
futures, timeout=30.0,
)
if not not_done:
break
_conc_elapsed = int(time.time() - _conc_start)
_still_running = [
parsed_calls[futures.index(f)][1]
for f in not_done
if f in futures
]
self._touch_activity(
f"concurrent tools running ({_conc_elapsed}s, "
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
)
finally:
if spinner:
# Build a summary message for the spinner stop

View file

@ -296,7 +296,7 @@ def test_managed_modal_execute_times_out_and_cancels(monkeypatch):
modal_common = sys.modules["tools.environments.modal_utils"]
calls = []
monotonic_values = iter([0.0, 12.5])
monotonic_values = iter([0.0, 0.0, 0.0, 12.5, 12.5])
def fake_request(method, url, headers=None, json=None, timeout=None):
calls.append((method, url, json, timeout))

View file

@ -105,6 +105,10 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
if self._client_timeout_grace_seconds is not None:
deadline = time.monotonic() + prepared.timeout + self._client_timeout_grace_seconds
_last_activity_touch = time.monotonic()
_modal_exec_start = time.monotonic()
_ACTIVITY_INTERVAL = 10.0 # match _wait_for_process cadence
while True:
if is_interrupted():
try:
@ -128,6 +132,22 @@ class BaseModalExecutionEnvironment(BaseEnvironment):
pass
return self._timeout_result_for_modal(prepared.timeout)
# Periodic activity touch so the gateway knows we're alive
_now = time.monotonic()
if _now - _last_activity_touch >= _ACTIVITY_INTERVAL:
_last_activity_touch = _now
try:
from tools.environments.base import _get_activity_callback
_cb = _get_activity_callback()
except Exception:
_cb = None
if _cb:
try:
_elapsed = int(_now - _modal_exec_start)
_cb(f"modal command running ({_elapsed}s elapsed)")
except Exception:
pass
time.sleep(self._poll_interval_seconds)
def _before_execute(self) -> None: