From f0c5d812b0dc5b1df8d1dc24f395c69ef1cb4338 Mon Sep 17 00:00:00 2001 From: r266-tech Date: Wed, 24 Jun 2026 18:40:23 +0530 Subject: [PATCH] fix(gateway): offload handoff watcher SessionDB polling off the event loop The Discord gateway heartbeat stalled ('Shard ID None heartbeat blocked for more than N seconds') because _handoff_watcher polled the synchronous, blocking SQLite-backed SessionDB directly on the asyncio event loop every 2s. Each list_pending/claim/complete/fail call performed blocking disk I/O on the loop thread, starving the Discord heartbeat coroutine. Wrap every blocking SessionDB call inside the watcher loop in asyncio.to_thread(...) so the SQLite work runs on a worker thread and the event loop (and heartbeat) stays responsive. These four call sites are the only synchronous self._session_db.* calls inside the watcher loop body. Adds tests/gateway/test_handoff_watcher_async_db.py asserting the watcher offloads its SessionDB calls via asyncio.to_thread (mutation-survivable: reverting any to_thread wrap fails the corresponding assertion). Fixes #40695 Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> --- gateway/run.py | 8 +- .../gateway/test_handoff_watcher_async_db.py | 164 ++++++++++++++++++ 2 files changed, 168 insertions(+), 4 deletions(-) create mode 100644 tests/gateway/test_handoff_watcher_async_db.py diff --git a/gateway/run.py b/gateway/run.py index 0691827bf45..fdd30423e47 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5974,23 +5974,23 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if self._session_db is None: await asyncio.sleep(interval) continue - pending = self._session_db.list_pending_handoffs() + pending = await asyncio.to_thread(self._session_db.list_pending_handoffs) for row in pending: session_id = row.get("id") if not session_id: continue - if not self._session_db.claim_handoff(session_id): + if not await asyncio.to_thread(self._session_db.claim_handoff, session_id): # Another tick or another gateway already claimed it. continue try: await self._process_handoff(row) - self._session_db.complete_handoff(session_id) + await asyncio.to_thread(self._session_db.complete_handoff, session_id) except Exception as exc: logger.warning( "Handoff for session %s failed: %s", session_id, exc, exc_info=True, ) - self._session_db.fail_handoff(session_id, str(exc)) + await asyncio.to_thread(self._session_db.fail_handoff, session_id, str(exc)) except asyncio.CancelledError: raise except Exception as exc: diff --git a/tests/gateway/test_handoff_watcher_async_db.py b/tests/gateway/test_handoff_watcher_async_db.py new file mode 100644 index 00000000000..c10093d07a5 --- /dev/null +++ b/tests/gateway/test_handoff_watcher_async_db.py @@ -0,0 +1,164 @@ +"""Regression test for #40695 (salvage of keystone PR #40782). + +The Discord gateway heartbeat was stalling because the handoff watcher +(``GatewayRunner._handoff_watcher``) polled the synchronous, blocking +SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s +('Shard ID None heartbeat blocked for more than N seconds'). + +The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside +the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a +worker thread and never blocks the event loop / Discord heartbeat. + +These tests assert that behaviour contract. They are mutation-survivable: +reverting any ``asyncio.to_thread(self._session_db.)`` wrap back to a +direct synchronous call on the loop makes the relevant assertion fail. +""" + +import asyncio +import types + +import pytest + +import gateway.run as run + + +class _RecordingSessionDB: + """SessionDB stand-in that records the thread each method runs on. + + If the watcher calls these methods directly on the event loop (the bug), + they run on the loop thread. If they are wrapped in ``asyncio.to_thread`` + (the fix), they run on a *different* worker thread. + """ + + def __init__(self, loop_thread_ident): + self._loop_thread_ident = loop_thread_ident + self.threads = {} + self.calls = [] + + def _record(self, name): + import threading + + self.threads.setdefault(name, []).append(threading.get_ident()) + self.calls.append(name) + + def ran_off_loop(self, name): + """True iff every call to ``name`` ran on a non-loop thread.""" + idents = self.threads.get(name, []) + return bool(idents) and all(i != self._loop_thread_ident for i in idents) + + def list_pending_handoffs(self): + self._record("list_pending_handoffs") + return [{"id": "sess-1"}] + + def claim_handoff(self, session_id): + self._record("claim_handoff") + return True + + def complete_handoff(self, session_id): + self._record("complete_handoff") + + def fail_handoff(self, session_id, error): + self._record("fail_handoff") + + +def _make_fake_runner(session_db, *, fail_process=False): + """Build a minimal object that exposes exactly what the loop body touches.""" + fake = types.SimpleNamespace() + fake._session_db = session_db + # _running yields True for the first loop check, then False so the loop + # exits after a single tick. + states = iter([True, False]) + + class _Running: + def __bool__(_self): + try: + return next(states) + except StopIteration: + return False + + fake._running = _Running() + + async def _process_handoff(row): + if fail_process: + raise RuntimeError("boom") + + fake._process_handoff = _process_handoff + return fake + + +async def _run_one_tick(fake, monkeypatch): + """Run the watcher for a single tick with sleeps neutralised.""" + + async def _no_sleep(_seconds): + return None + + monkeypatch.setattr(run.asyncio, "sleep", _no_sleep) + # Bind the real (patched) method onto our minimal stand-in. + coro = run.GatewayRunner._handoff_watcher(fake, interval=0.0) + await asyncio.wait_for(coro, timeout=5) + + +@pytest.mark.asyncio +async def test_watcher_offloads_db_calls_to_threads(monkeypatch): + """The success path must run list_pending/claim/complete off the loop.""" + import threading + + loop_ident = threading.get_ident() + db = _RecordingSessionDB(loop_ident) + fake = _make_fake_runner(db, fail_process=False) + + await _run_one_tick(fake, monkeypatch) + + # Sanity: the watcher actually exercised the calls this tick. + assert "list_pending_handoffs" in db.calls + assert "claim_handoff" in db.calls + assert "complete_handoff" in db.calls + + # Contract: each blocking SessionDB call ran on a worker thread, NOT the + # asyncio event-loop thread. Reverting a to_thread wrap makes the + # corresponding call run on the loop thread and this fails. + assert db.ran_off_loop("list_pending_handoffs") + assert db.ran_off_loop("claim_handoff") + assert db.ran_off_loop("complete_handoff") + + +@pytest.mark.asyncio +async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch): + """The error path must run fail_handoff off the loop too.""" + import threading + + loop_ident = threading.get_ident() + db = _RecordingSessionDB(loop_ident) + fake = _make_fake_runner(db, fail_process=True) + + await _run_one_tick(fake, monkeypatch) + + assert "fail_handoff" in db.calls + assert db.ran_off_loop("fail_handoff") + + +@pytest.mark.asyncio +async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch): + """Explicitly assert the offload goes through asyncio.to_thread. + + Patches ``run.asyncio.to_thread`` and records which SessionDB callables + were handed to it. Mutation-survivable: dropping any wrap removes its + callable from the recorded set. + """ + db = _RecordingSessionDB(loop_thread_ident=-1) + fake = _make_fake_runner(db, fail_process=False) + + wrapped = [] + real_to_thread = run.asyncio.to_thread + + async def _spy_to_thread(func, *args, **kwargs): + wrapped.append(getattr(func, "__name__", repr(func))) + return await real_to_thread(func, *args, **kwargs) + + monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread) + + await _run_one_tick(fake, monkeypatch) + + assert "list_pending_handoffs" in wrapped + assert "claim_handoff" in wrapped + assert "complete_handoff" in wrapped