diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 65f7226e1..645a642ba 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -2033,12 +2033,26 @@ class BasePlatformAdapter(ABC): Used during gateway shutdown/replacement so active sessions from the old process do not keep running after adapters are being torn down. """ - tasks = [task for task in self._background_tasks if not task.done()] - for task in tasks: - self._expected_cancelled_tasks.add(task) - task.cancel() - if tasks: + # Loop until no new tasks appear. Without this, a message + # arriving during the `await asyncio.gather` below would spawn + # a fresh _process_message_background task (added to + # self._background_tasks at line ~1668 via handle_message), + # and the _background_tasks.clear() at the end of this method + # would drop the reference — the task runs untracked against a + # disconnecting adapter, logs send-failures, and may linger + # until it completes on its own. Retrying the drain until the + # task set stabilizes closes the window. + MAX_DRAIN_ROUNDS = 5 + for _ in range(MAX_DRAIN_ROUNDS): + tasks = [task for task in self._background_tasks if not task.done()] + if not tasks: + break + for task in tasks: + self._expected_cancelled_tasks.add(task) + task.cancel() await asyncio.gather(*tasks, return_exceptions=True) + # Loop: late-arrival tasks spawned during the gather above + # will be in self._background_tasks now. Re-check. self._background_tasks.clear() self._expected_cancelled_tasks.clear() self._pending_messages.clear() diff --git a/tests/gateway/test_cancel_background_drain.py b/tests/gateway/test_cancel_background_drain.py new file mode 100644 index 000000000..c95fdc062 --- /dev/null +++ b/tests/gateway/test_cancel_background_drain.py @@ -0,0 +1,148 @@ +"""Regression test: cancel_background_tasks must drain late-arrival tasks. + +During gateway shutdown, a message arriving while +cancel_background_tasks is mid-await can spawn a fresh +_process_message_background task via handle_message, which is added +to self._background_tasks. Without the re-drain loop, the subsequent +_background_tasks.clear() drops the reference; the task runs +untracked against a disconnecting adapter. +""" + +import asyncio +from unittest.mock import AsyncMock + +import pytest + +from gateway.config import Platform, PlatformConfig +from gateway.platforms.base import BasePlatformAdapter, MessageEvent, MessageType +from gateway.session import SessionSource, build_session_key + + +class _StubAdapter(BasePlatformAdapter): + async def connect(self): + pass + + async def disconnect(self): + pass + + async def send(self, chat_id, text, **kwargs): + return None + + async def get_chat_info(self, chat_id): + return {} + + +def _make_adapter(): + adapter = _StubAdapter(PlatformConfig(enabled=True, token="t"), Platform.TELEGRAM) + adapter._send_with_retry = AsyncMock(return_value=None) + return adapter + + +def _event(text, cid="42"): + return MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=SessionSource(platform=Platform.TELEGRAM, chat_id=cid, chat_type="dm"), + ) + + +@pytest.mark.asyncio +async def test_cancel_background_tasks_drains_late_arrivals(): + """A message that arrives during the gather window must be picked + up by the re-drain loop, not leaked as an untracked task.""" + adapter = _make_adapter() + sk = build_session_key( + SessionSource(platform=Platform.TELEGRAM, chat_id="42", chat_type="dm") + ) + + m1_started = asyncio.Event() + m1_cleanup_running = asyncio.Event() + m2_started = asyncio.Event() + m2_cancelled = asyncio.Event() + + async def handler(event): + if event.text == "M1": + m1_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + m1_cleanup_running.set() + # Widen the gather window with a shielded cleanup + # delay so M2 can get injected during it. + await asyncio.shield(asyncio.sleep(0.2)) + raise + else: # M2 — the late arrival + m2_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + m2_cancelled.set() + raise + + adapter._message_handler = handler + + # Spawn M1. + await adapter.handle_message(_event("M1")) + await asyncio.wait_for(m1_started.wait(), timeout=1.0) + + # Kick off shutdown. This will cancel M1 and await its cleanup. + cancel_task = asyncio.create_task(adapter.cancel_background_tasks()) + + # Wait until M1's cleanup is running (inside the shielded sleep). + # This is the race window: cancel_task is awaiting gather, M1 is + # shielded in cleanup, the _active_sessions entry has been cleared + # by M1's own finally. + await asyncio.wait_for(m1_cleanup_running.wait(), timeout=1.0) + + # Clear the active-session entry (M1's finally hasn't fully run yet, + # but in production the platform dispatcher would deliver a new + # message that takes the no-active-session spawn path). For this + # repro, make it deterministic. + adapter._active_sessions.pop(sk, None) + + # Inject late arrival — spawns a fresh _process_message_background + # task and adds it to _background_tasks while cancel_task is still + # in gather. + await adapter.handle_message(_event("M2")) + await asyncio.wait_for(m2_started.wait(), timeout=1.0) + + # Let cancel_task finish. Round 1's gather completes when M1's + # shielded cleanup finishes. Round 2 should pick up M2. + await asyncio.wait_for(cancel_task, timeout=5.0) + + # Assert M2 was drained, not leaked. + assert m2_cancelled.is_set(), ( + "Late-arrival M2 was NOT cancelled by cancel_background_tasks — " + "the re-drain loop is missing and the task leaked" + ) + assert adapter._background_tasks == set() + + +@pytest.mark.asyncio +async def test_cancel_background_tasks_handles_no_tasks(): + """Regression guard: no tasks, no hang, no error.""" + adapter = _make_adapter() + await adapter.cancel_background_tasks() + assert adapter._background_tasks == set() + + +@pytest.mark.asyncio +async def test_cancel_background_tasks_bounded_rounds(): + """Regression guard: the drain loop is bounded — it does not spin + forever even if late-arrival tasks keep getting spawned.""" + adapter = _make_adapter() + + # Single well-behaved task that cancels cleanly — baseline check + # that the loop terminates in one round. + async def quick(): + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + raise + + task = asyncio.create_task(quick()) + adapter._background_tasks.add(task) + + await adapter.cancel_background_tasks() + assert task.done() + assert adapter._background_tasks == set()