fix(gateway): offload handoff watcher SessionDB polling off the event loop

The Discord gateway heartbeat stalled ('Shard ID None heartbeat blocked
for more than N seconds') because _handoff_watcher polled the synchronous,
blocking SQLite-backed SessionDB directly on the asyncio event loop every
2s. Each list_pending/claim/complete/fail call performed blocking disk I/O
on the loop thread, starving the Discord heartbeat coroutine.

Wrap every blocking SessionDB call inside the watcher loop in
asyncio.to_thread(...) so the SQLite work runs on a worker thread and the
event loop (and heartbeat) stays responsive. These four call sites are the
only synchronous self._session_db.* calls inside the watcher loop body.

Adds tests/gateway/test_handoff_watcher_async_db.py asserting the watcher
offloads its SessionDB calls via asyncio.to_thread (mutation-survivable:
reverting any to_thread wrap fails the corresponding assertion).

Fixes #40695

Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com>
This commit is contained in:
r266-tech 2026-06-24 18:40:23 +05:30
parent a4a74ca9e9
commit f0c5d812b0
2 changed files with 168 additions and 4 deletions

View file

@ -5974,23 +5974,23 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if self._session_db is None:
await asyncio.sleep(interval)
continue
pending = self._session_db.list_pending_handoffs()
pending = await asyncio.to_thread(self._session_db.list_pending_handoffs)
for row in pending:
session_id = row.get("id")
if not session_id:
continue
if not self._session_db.claim_handoff(session_id):
if not await asyncio.to_thread(self._session_db.claim_handoff, session_id):
# Another tick or another gateway already claimed it.
continue
try:
await self._process_handoff(row)
self._session_db.complete_handoff(session_id)
await asyncio.to_thread(self._session_db.complete_handoff, session_id)
except Exception as exc:
logger.warning(
"Handoff for session %s failed: %s",
session_id, exc, exc_info=True,
)
self._session_db.fail_handoff(session_id, str(exc))
await asyncio.to_thread(self._session_db.fail_handoff, session_id, str(exc))
except asyncio.CancelledError:
raise
except Exception as exc:

View file

@ -0,0 +1,164 @@
"""Regression test for #40695 (salvage of keystone PR #40782).
The Discord gateway heartbeat was stalling because the handoff watcher
(``GatewayRunner._handoff_watcher``) polled the synchronous, blocking
SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s
('Shard ID None heartbeat blocked for more than N seconds').
The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside
the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a
worker thread and never blocks the event loop / Discord heartbeat.
These tests assert that behaviour contract. They are mutation-survivable:
reverting any ``asyncio.to_thread(self._session_db.<call>)`` wrap back to a
direct synchronous call on the loop makes the relevant assertion fail.
"""
import asyncio
import types
import pytest
import gateway.run as run
class _RecordingSessionDB:
"""SessionDB stand-in that records the thread each method runs on.
If the watcher calls these methods directly on the event loop (the bug),
they run on the loop thread. If they are wrapped in ``asyncio.to_thread``
(the fix), they run on a *different* worker thread.
"""
def __init__(self, loop_thread_ident):
self._loop_thread_ident = loop_thread_ident
self.threads = {}
self.calls = []
def _record(self, name):
import threading
self.threads.setdefault(name, []).append(threading.get_ident())
self.calls.append(name)
def ran_off_loop(self, name):
"""True iff every call to ``name`` ran on a non-loop thread."""
idents = self.threads.get(name, [])
return bool(idents) and all(i != self._loop_thread_ident for i in idents)
def list_pending_handoffs(self):
self._record("list_pending_handoffs")
return [{"id": "sess-1"}]
def claim_handoff(self, session_id):
self._record("claim_handoff")
return True
def complete_handoff(self, session_id):
self._record("complete_handoff")
def fail_handoff(self, session_id, error):
self._record("fail_handoff")
def _make_fake_runner(session_db, *, fail_process=False):
"""Build a minimal object that exposes exactly what the loop body touches."""
fake = types.SimpleNamespace()
fake._session_db = session_db
# _running yields True for the first loop check, then False so the loop
# exits after a single tick.
states = iter([True, False])
class _Running:
def __bool__(_self):
try:
return next(states)
except StopIteration:
return False
fake._running = _Running()
async def _process_handoff(row):
if fail_process:
raise RuntimeError("boom")
fake._process_handoff = _process_handoff
return fake
async def _run_one_tick(fake, monkeypatch):
"""Run the watcher for a single tick with sleeps neutralised."""
async def _no_sleep(_seconds):
return None
monkeypatch.setattr(run.asyncio, "sleep", _no_sleep)
# Bind the real (patched) method onto our minimal stand-in.
coro = run.GatewayRunner._handoff_watcher(fake, interval=0.0)
await asyncio.wait_for(coro, timeout=5)
@pytest.mark.asyncio
async def test_watcher_offloads_db_calls_to_threads(monkeypatch):
"""The success path must run list_pending/claim/complete off the loop."""
import threading
loop_ident = threading.get_ident()
db = _RecordingSessionDB(loop_ident)
fake = _make_fake_runner(db, fail_process=False)
await _run_one_tick(fake, monkeypatch)
# Sanity: the watcher actually exercised the calls this tick.
assert "list_pending_handoffs" in db.calls
assert "claim_handoff" in db.calls
assert "complete_handoff" in db.calls
# Contract: each blocking SessionDB call ran on a worker thread, NOT the
# asyncio event-loop thread. Reverting a to_thread wrap makes the
# corresponding call run on the loop thread and this fails.
assert db.ran_off_loop("list_pending_handoffs")
assert db.ran_off_loop("claim_handoff")
assert db.ran_off_loop("complete_handoff")
@pytest.mark.asyncio
async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch):
"""The error path must run fail_handoff off the loop too."""
import threading
loop_ident = threading.get_ident()
db = _RecordingSessionDB(loop_ident)
fake = _make_fake_runner(db, fail_process=True)
await _run_one_tick(fake, monkeypatch)
assert "fail_handoff" in db.calls
assert db.ran_off_loop("fail_handoff")
@pytest.mark.asyncio
async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch):
"""Explicitly assert the offload goes through asyncio.to_thread.
Patches ``run.asyncio.to_thread`` and records which SessionDB callables
were handed to it. Mutation-survivable: dropping any wrap removes its
callable from the recorded set.
"""
db = _RecordingSessionDB(loop_thread_ident=-1)
fake = _make_fake_runner(db, fail_process=False)
wrapped = []
real_to_thread = run.asyncio.to_thread
async def _spy_to_thread(func, *args, **kwargs):
wrapped.append(getattr(func, "__name__", repr(func)))
return await real_to_thread(func, *args, **kwargs)
monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread)
await _run_one_tick(fake, monkeypatch)
assert "list_pending_handoffs" in wrapped
assert "claim_handoff" in wrapped
assert "complete_handoff" in wrapped