diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index b1585637f..1ec831b66 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -3265,7 +3265,20 @@ class DiscordAdapter(BasePlatformAdapter): "[Discord] Flushing text batch %s (%d chars)", key, len(event.text or ""), ) - await self.handle_message(event) + # Shield the downstream dispatch so that a subsequent chunk + # arriving while handle_message is mid-flight cannot cancel + # the running agent turn. _enqueue_text_event always cancels + # the prior flush task when a new chunk lands; without this + # shield, CancelledError would propagate from our task down + # into handle_message → the agent's streaming request, + # aborting the response the user was waiting on. The new + # chunk is handled by the fresh flush task regardless. + await asyncio.shield(self.handle_message(event)) + except asyncio.CancelledError: + # Only reached if cancel landed before the pop — the shielded + # handle_message is unaffected either way. Let the task exit + # cleanly so the finally block cleans up. + pass finally: if self._pending_text_batch_tasks.get(key) is current_task: self._pending_text_batch_tasks.pop(key, None) diff --git a/tests/gateway/test_text_batching.py b/tests/gateway/test_text_batching.py index 56bc602ef..1ad89ffd0 100644 --- a/tests/gateway/test_text_batching.py +++ b/tests/gateway/test_text_batching.py @@ -148,6 +148,70 @@ class TestDiscordTextBatching: await asyncio.sleep(0.25) adapter.handle_message.assert_called_once() + @pytest.mark.asyncio + async def test_shield_protects_handle_message_from_cancel(self): + """Regression guard: a follow-up chunk arriving while + handle_message is mid-flight must NOT cancel the running + dispatch. _enqueue_text_event fires prior_task.cancel() on + every new chunk; without asyncio.shield around handle_message + the cancel propagates into the agent's streaming request and + aborts the response. + """ + adapter = _make_discord_adapter() + + handle_started = asyncio.Event() + release_handle = asyncio.Event() + first_handle_cancelled = asyncio.Event() + first_handle_completed = asyncio.Event() + call_count = [0] + + async def slow_handle(event): + call_count[0] += 1 + # Only the first call (batch 1) is the one we're protecting. + if call_count[0] == 1: + handle_started.set() + try: + await release_handle.wait() + first_handle_completed.set() + except asyncio.CancelledError: + first_handle_cancelled.set() + raise + # Second call (batch 2) returns immediately — not the subject + # of this test. + + adapter.handle_message = slow_handle + + # Prime batch 1 and wait for it to land inside handle_message. + adapter._enqueue_text_event(_make_event("batch 1", Platform.DISCORD)) + await asyncio.wait_for(handle_started.wait(), timeout=1.0) + + # A new chunk arrives — _enqueue_text_event fires + # prior_task.cancel() on batch 1's flush task, which is + # currently awaiting inside handle_message. + adapter._enqueue_text_event(_make_event("batch 2 follow-up", Platform.DISCORD)) + + # Let the cancel propagate. + await asyncio.sleep(0.05) + + # CRITICAL ASSERTION: batch 1's handle_message must NOT have + # been cancelled. Without asyncio.shield this assertion fails + # because CancelledError propagates from the flush task's + # `await self.handle_message(event)` into slow_handle. + assert not first_handle_cancelled.is_set(), ( + "handle_message for batch 1 was cancelled by a follow-up " + "chunk — asyncio.shield is missing or broken" + ) + + # Release batch 1's handle_message and let it complete. + release_handle.set() + await asyncio.wait_for(first_handle_completed.wait(), timeout=1.0) + assert first_handle_completed.is_set() + + # Cleanup + for task in list(adapter._pending_text_batch_tasks.values()): + task.cancel() + await asyncio.sleep(0.01) + # ===================================================================== # Matrix text batching