diff --git a/gateway/platforms/wecom.py b/gateway/platforms/wecom.py index 5aad1e09cc5..1569d5faf52 100644 --- a/gateway/platforms/wecom.py +++ b/gateway/platforms/wecom.py @@ -616,6 +616,18 @@ class WeComAdapter(BasePlatformAdapter): else: delay = self._text_batch_delay_seconds await asyncio.sleep(delay) + # Guard against the cancel-delivery race: when the sleep timer + # fires just before cancel() is called, CPython sets + # Task._must_cancel but cannot cancel the already-done sleep + # future, so CancelledError is delivered at the *next* await + # (handle_message) rather than here. By that point this task + # has already popped the merged event, so the superseding task + # sees an empty batch and silently drops the message. + # This check is synchronous — no await between the sleep and + # the pop — so no other coroutine can modify the task registry + # in between. + if self._pending_text_batch_tasks.get(key) is not current_task: + return event = self._pending_text_batches.pop(key, None) if not event: return diff --git a/tests/gateway/test_wecom.py b/tests/gateway/test_wecom.py index 7bf56f9d319..02d04daf64e 100644 --- a/tests/gateway/test_wecom.py +++ b/tests/gateway/test_wecom.py @@ -1,5 +1,6 @@ """Tests for the WeCom platform adapter.""" +import asyncio import base64 import os from pathlib import Path @@ -831,3 +832,91 @@ class TestWeComZombieSessionFix: cmd = adapter._send_request.await_args.args[0] assert cmd == APP_CMD_SEND + + +class TestTextBatchFlushRace: + """Regression tests for the cancel-delivery race in _flush_text_batch. + + When asyncio.sleep() fires and Task.cancel() is called before the task + runs, CPython sets _must_cancel but cannot cancel the already-done sleep + future. CancelledError is then delivered at the *next* await + (handle_message), after the task has already popped the event — the + superseding task sees an empty batch and silently drops the message. + The fix adds a synchronous task-registry check between the sleep and + the pop so a superseded task returns before touching the event. + """ + + @pytest.mark.asyncio + async def test_superseded_task_does_not_pop_or_process_event(self): + """A flush task that has been superseded must leave the event in the + batch dict for the new task to handle.""" + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._text_batch_delay_seconds = 0 + + key = "test-session" + event = MessageEvent(text="hello", message_type=MessageType.TEXT) + adapter._pending_text_batches[key] = event + + handle_calls = [] + + async def fake_handle(evt): + handle_calls.append(evt) + + adapter.handle_message = fake_handle + + # Create T1 and register it. + t1 = asyncio.create_task(adapter._flush_text_batch(key)) + adapter._pending_text_batch_tasks[key] = t1 + + # Simulate T2 superseding T1 before T1 wakes from sleep. + t2 = asyncio.create_task(asyncio.sleep(9999)) + adapter._pending_text_batch_tasks[key] = t2 + + # Yield long enough for T1's sleep(0) to complete and T1 to run. + await asyncio.sleep(0.05) + + t2.cancel() + try: + await t2 + except asyncio.CancelledError: + pass + + # T1 must have returned without processing or removing the event. + assert handle_calls == [], "superseded task must not call handle_message" + assert adapter._pending_text_batches.get(key) is event, ( + "superseded task must not pop the event" + ) + + @pytest.mark.asyncio + async def test_active_task_processes_event_normally(self): + """When the task is not superseded it must still process the event.""" + from gateway.platforms.base import MessageEvent, MessageType + from gateway.platforms.wecom import WeComAdapter + + adapter = WeComAdapter(PlatformConfig(enabled=True)) + adapter._text_batch_delay_seconds = 0 + + key = "test-session" + event = MessageEvent(text="world", message_type=MessageType.TEXT) + adapter._pending_text_batches[key] = event + + handle_calls = [] + + async def fake_handle(evt): + handle_calls.append(evt) + + adapter.handle_message = fake_handle + + t1 = asyncio.create_task(adapter._flush_text_batch(key)) + adapter._pending_text_batch_tasks[key] = t1 + + # No superseding task — T1 should process normally. + await asyncio.sleep(0.05) + + assert handle_calls == [event], "active task must call handle_message" + assert adapter._pending_text_batches.get(key) is None, ( + "active task must pop the event after processing" + )