diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 1ec831b66..fce7ece41 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -498,6 +498,7 @@ class DiscordAdapter(BasePlatformAdapter): self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering # Voice channel state (per-guild) self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient + self._voice_locks: Dict[int, asyncio.Lock] = {} # guild_id -> serialize join/leave # Text batching: merge rapid successive messages (Telegram-style) self._text_batch_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", "0.6")) self._text_batch_split_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) @@ -636,6 +637,30 @@ class DiscordAdapter(BasePlatformAdapter): @self._client.event async def on_message(message: DiscordMessage): + # Wait for on_ready to finish resolving username-based + # allowlist entries. Without this block, messages + # arriving between Discord's READY event and the end + # of _resolve_allowed_usernames compare author IDs + # (numeric) against a set that may still contain raw + # usernames (strings) from DISCORD_ALLOWED_USERS — + # legitimate users get silently rejected for the first + # few seconds after every reconnect. The wait is a + # near-instant no-op in steady state (_ready_event is + # already set); only the startup / reconnect window + # ever blocks. + if not adapter_self._ready_event.is_set(): + try: + await asyncio.wait_for( + adapter_self._ready_event.wait(), + timeout=30.0, + ) + except asyncio.TimeoutError: + logger.warning( + "[%s] on_message timed out waiting for _ready_event; " + "allowlist check may use pre-resolved entries", + adapter_self.name, + ) + # Dedup: Discord RESUME replays events after reconnects (#4777) if adapter_self._dedup.is_duplicate(str(message.id)): return @@ -1231,57 +1256,74 @@ class DiscordAdapter(BasePlatformAdapter): # Voice channel methods (join / leave / play) # ------------------------------------------------------------------ + def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock": + """Return the per-guild lock, creating it on first use. + + Voice join/leave/move must be serialized per guild — without + this, two concurrent /voice channel invocations both see + _voice_clients.get(guild_id) return None, both call + channel.connect(), and discord.py raises ClientException + ('Already connected') on the loser. + """ + lock = self._voice_locks.get(guild_id) + if lock is None: + lock = asyncio.Lock() + self._voice_locks[guild_id] = lock + return lock + async def join_voice_channel(self, channel) -> bool: """Join a Discord voice channel. Returns True on success.""" if not self._client or not DISCORD_AVAILABLE: return False guild_id = channel.guild.id - # Already connected in this guild? - existing = self._voice_clients.get(guild_id) - if existing and existing.is_connected(): - if existing.channel.id == channel.id: + async with self._voice_lock_for(guild_id): + # Already connected in this guild? + existing = self._voice_clients.get(guild_id) + if existing and existing.is_connected(): + if existing.channel.id == channel.id: + self._reset_voice_timeout(guild_id) + return True + await existing.move_to(channel) self._reset_voice_timeout(guild_id) return True - await existing.move_to(channel) + + vc = await channel.connect() + self._voice_clients[guild_id] = vc self._reset_voice_timeout(guild_id) + + # Start voice receiver (Phase 2: listen to users) + try: + receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids) + receiver.start() + self._voice_receivers[guild_id] = receiver + self._voice_listen_tasks[guild_id] = asyncio.ensure_future( + self._voice_listen_loop(guild_id) + ) + except Exception as e: + logger.warning("Voice receiver failed to start: %s", e) + return True - vc = await channel.connect() - self._voice_clients[guild_id] = vc - self._reset_voice_timeout(guild_id) - - # Start voice receiver (Phase 2: listen to users) - try: - receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids) - receiver.start() - self._voice_receivers[guild_id] = receiver - self._voice_listen_tasks[guild_id] = asyncio.ensure_future( - self._voice_listen_loop(guild_id) - ) - except Exception as e: - logger.warning("Voice receiver failed to start: %s", e) - - return True - async def leave_voice_channel(self, guild_id: int) -> None: """Disconnect from the voice channel in a guild.""" - # Stop voice receiver first - receiver = self._voice_receivers.pop(guild_id, None) - if receiver: - receiver.stop() - listen_task = self._voice_listen_tasks.pop(guild_id, None) - if listen_task: - listen_task.cancel() + async with self._voice_lock_for(guild_id): + # Stop voice receiver first + receiver = self._voice_receivers.pop(guild_id, None) + if receiver: + receiver.stop() + listen_task = self._voice_listen_tasks.pop(guild_id, None) + if listen_task: + listen_task.cancel() - vc = self._voice_clients.pop(guild_id, None) - if vc and vc.is_connected(): - await vc.disconnect() - task = self._voice_timeout_tasks.pop(guild_id, None) - if task: - task.cancel() - self._voice_text_channels.pop(guild_id, None) - self._voice_sources.pop(guild_id, None) + vc = self._voice_clients.pop(guild_id, None) + if vc and vc.is_connected(): + await vc.disconnect() + task = self._voice_timeout_tasks.pop(guild_id, None) + if task: + task.cancel() + self._voice_text_channels.pop(guild_id, None) + self._voice_sources.pop(guild_id, None) # Maximum seconds to wait for voice playback before giving up PLAYBACK_TIMEOUT = 120 diff --git a/tests/gateway/test_discord_race_polish.py b/tests/gateway/test_discord_race_polish.py new file mode 100644 index 000000000..a0f900aea --- /dev/null +++ b/tests/gateway/test_discord_race_polish.py @@ -0,0 +1,122 @@ +"""Regression tests for the Discord adapter race-polish fix. + +Two races are addressed: +1. on_message allowlist check racing on_ready's _resolve_allowed_usernames + resolution window. Username-based entries in DISCORD_ALLOWED_USERS + appear in the set as raw strings for several seconds after + connect/reconnect; author.id is always numeric, so legitimate users + are silently rejected until resolution finishes. +2. join_voice_channel check-and-connect: concurrent /voice channel + invocations both see _voice_clients.get(guild_id) is None, both call + channel.connect(), second raises ClientException ('Already connected'). +""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import Platform, PlatformConfig + + +def _make_adapter(): + """Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md.""" + from gateway.platforms.discord import DiscordAdapter + + adapter = object.__new__(DiscordAdapter) + adapter._platform = Platform.DISCORD + adapter.config = PlatformConfig(enabled=True, token="t") + adapter._ready_event = asyncio.Event() + adapter._allowed_user_ids = set() + adapter._allowed_role_ids = set() + adapter._voice_clients = {} + adapter._voice_locks = {} + adapter._voice_receivers = {} + adapter._voice_listen_tasks = {} + adapter._voice_timeout_tasks = {} + adapter._voice_text_channels = {} + adapter._voice_sources = {} + adapter._client = MagicMock() + return adapter + + +class TestJoinVoiceSerialization: + @pytest.mark.asyncio + async def test_concurrent_joins_do_not_double_connect(self): + """Two concurrent join_voice_channel calls on the same guild + must serialize through the per-guild lock — only ONE + channel.connect() actually fires; the second sees the + _voice_clients entry the first just installed.""" + adapter = _make_adapter() + + connect_count = [0] + connect_event = asyncio.Event() + + class FakeVC: + def __init__(self, channel): + self.channel = channel + + def is_connected(self): + return True + + async def move_to(self, _channel): + return None + + async def disconnect(self): + return None + + async def slow_connect(self): + connect_count[0] += 1 + # Widen the race window + await connect_event.wait() + return FakeVC(self) + + channel = MagicMock() + channel.id = 111 + channel.guild.id = 42 + channel.connect = lambda: slow_connect(channel) + + # Swap out VoiceReceiver so it doesn't try to set up real audio + from gateway.platforms import discord as discord_mod + with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))): + with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))): + # Fire two joins concurrently + t1 = asyncio.create_task(adapter.join_voice_channel(channel)) + t2 = asyncio.create_task(adapter.join_voice_channel(channel)) + # Let them run until they're blocked on our event + await asyncio.sleep(0.05) + # Release connect so both can finish + connect_event.set() + r1, r2 = await asyncio.gather(t1, t2) + + assert connect_count[0] == 1, ( + f"Expected exactly 1 channel.connect() call, got {connect_count[0]} — " + "per-guild voice lock is not serializing join_voice_channel" + ) + assert r1 is True and r2 is True + assert 42 in adapter._voice_clients + + +class TestOnMessageWaitsForReadyEvent: + @pytest.mark.asyncio + async def test_on_message_blocks_until_ready_event_set(self): + """A message arriving before on_ready finishes + _resolve_allowed_usernames must wait, not proceed with a + half-resolved allowlist.""" + # This is an integration-style check — we pull out the + # on_message handler by asserting the source contains the + # expected wait pattern. A full end-to-end test would require + # setting up the discord.py client machinery, which is not + # practical here. + import inspect + from gateway.platforms import discord as discord_mod + + src = inspect.getsource(discord_mod.DiscordAdapter.connect) + assert "_ready_event.is_set()" in src, ( + "on_message must gate on _ready_event so username-based " + "allowlist entries are resolved before the allowlist check" + ) + assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, ( + "Expected asyncio.wait_for(_ready_event.wait(), timeout=...) " + "pattern in on_message" + )