diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index f7483397f95..e11a6093319 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -2697,6 +2697,8 @@ class DiscordAdapter(BasePlatformAdapter): await asyncio.sleep(8) except asyncio.CancelledError: pass + finally: + self._typing_tasks.pop(chat_id, None) self._typing_tasks[chat_id] = asyncio.create_task(_typing_loop()) diff --git a/tests/gateway/test_discord_send.py b/tests/gateway/test_discord_send.py index 89be6885a9c..03f442a3b88 100644 --- a/tests/gateway/test_discord_send.py +++ b/tests/gateway/test_discord_send.py @@ -1,3 +1,4 @@ +import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock import sys @@ -386,3 +387,61 @@ async def test_forum_post_file_creation_failure(): assert result.success is False assert "missing perms" in (result.error or "") + + +# --------------------------------------------------------------------------- +# Typing indicator task lifecycle +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_typing_task_removed_after_api_error(): + """When typing API call fails, stale task must be removed so typing can restart.""" + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + adapter._client = MagicMock() + adapter._client.http = MagicMock() + adapter._client.http.request = AsyncMock(side_effect=Exception("rate limited")) + adapter._typing_tasks = {} + + await adapter.send_typing("12345") + await asyncio.sleep(0.1) + + assert "12345" not in adapter._typing_tasks, \ + "Stale task should be removed after API error" + + +@pytest.mark.asyncio +async def test_typing_restartable_after_error(): + """After a typing error, send_typing should start a new task (not blocked by stale entry).""" + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + adapter._client = MagicMock() + adapter._client.http = MagicMock() + adapter._typing_tasks = {} + + # First call fails + adapter._client.http.request = AsyncMock(side_effect=Exception("503")) + await adapter.send_typing("12345") + await asyncio.sleep(0.1) + + # Second call should work + adapter._client.http.request = AsyncMock() + await adapter.send_typing("12345") + + assert "12345" in adapter._typing_tasks, \ + "Should restart typing after previous failure" + + +@pytest.mark.asyncio +async def test_typing_stop_cleans_up(): + """stop_typing should remove the task from _typing_tasks.""" + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + adapter._client = MagicMock() + adapter._client.http = MagicMock() + adapter._client.http.request = AsyncMock() + adapter._typing_tasks = {} + + await adapter.send_typing("12345") + assert "12345" in adapter._typing_tasks + + await adapter.stop_typing("12345") + assert "12345" not in adapter._typing_tasks