From 7e3b3565740b4bc2fe62d267f5bb3e894b68bdc5 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sun, 19 Apr 2026 11:08:10 -0700 Subject: [PATCH] refactor(discord): slim down the race-polish fix (#12644) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PR #12558 was heavy for what the fix actually is — essay-length comments, a dedicated helper method where a setdefault would do, and a source-inspection test with no real behavior coverage. The genuine code change is ~5 lines of new logic (1 field, 2 async with, an on_ready wait block). Trimmed: - Replaced the 12-line _voice_lock_for helper with a setdefault one-liner at each call site (join_voice_channel, leave_voice_channel). - Collapsed the 12-line comment on on_message's _ready_event wait to 3 lines. Dropped the warning log on timeout — pass-on-timeout is fine; if on_ready hangs that long, the bot is already broken and the log wouldn't help. - Dropped the source-inspection test (greps the module source for expected substrings). It was low-value scaffolding; the voice-serialization test covers actual behavior. Net: -73 lines vs PR #12558. Same two guarantees preserved, same test passes (verified by stashing the fix and confirming failure). --- gateway/platforms/discord.py | 44 ++------ tests/gateway/test_discord_race_polish.py | 129 ++++++++-------------- 2 files changed, 50 insertions(+), 123 deletions(-) diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index fce7ece41..28286d48c 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -637,29 +637,14 @@ class DiscordAdapter(BasePlatformAdapter): @self._client.event async def on_message(message: DiscordMessage): - # Wait for on_ready to finish resolving username-based - # allowlist entries. Without this block, messages - # arriving between Discord's READY event and the end - # of _resolve_allowed_usernames compare author IDs - # (numeric) against a set that may still contain raw - # usernames (strings) from DISCORD_ALLOWED_USERS — - # legitimate users get silently rejected for the first - # few seconds after every reconnect. The wait is a - # near-instant no-op in steady state (_ready_event is - # already set); only the startup / reconnect window - # ever blocks. + # Block until _resolve_allowed_usernames has swapped + # any raw usernames in DISCORD_ALLOWED_USERS for numeric + # IDs (otherwise on_message's author.id lookup can miss). if not adapter_self._ready_event.is_set(): try: - await asyncio.wait_for( - adapter_self._ready_event.wait(), - timeout=30.0, - ) + await asyncio.wait_for(adapter_self._ready_event.wait(), timeout=30.0) except asyncio.TimeoutError: - logger.warning( - "[%s] on_message timed out waiting for _ready_event; " - "allowlist check may use pre-resolved entries", - adapter_self.name, - ) + pass # Dedup: Discord RESUME replays events after reconnects (#4777) if adapter_self._dedup.is_duplicate(str(message.id)): @@ -1256,28 +1241,13 @@ class DiscordAdapter(BasePlatformAdapter): # Voice channel methods (join / leave / play) # ------------------------------------------------------------------ - def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock": - """Return the per-guild lock, creating it on first use. - - Voice join/leave/move must be serialized per guild — without - this, two concurrent /voice channel invocations both see - _voice_clients.get(guild_id) return None, both call - channel.connect(), and discord.py raises ClientException - ('Already connected') on the loser. - """ - lock = self._voice_locks.get(guild_id) - if lock is None: - lock = asyncio.Lock() - self._voice_locks[guild_id] = lock - return lock - async def join_voice_channel(self, channel) -> bool: """Join a Discord voice channel. Returns True on success.""" if not self._client or not DISCORD_AVAILABLE: return False guild_id = channel.guild.id - async with self._voice_lock_for(guild_id): + async with self._voice_locks.setdefault(guild_id, asyncio.Lock()): # Already connected in this guild? existing = self._voice_clients.get(guild_id) if existing and existing.is_connected(): @@ -1307,7 +1277,7 @@ class DiscordAdapter(BasePlatformAdapter): async def leave_voice_channel(self, guild_id: int) -> None: """Disconnect from the voice channel in a guild.""" - async with self._voice_lock_for(guild_id): + async with self._voice_locks.setdefault(guild_id, asyncio.Lock()): # Stop voice receiver first receiver = self._voice_receivers.pop(guild_id, None) if receiver: diff --git a/tests/gateway/test_discord_race_polish.py b/tests/gateway/test_discord_race_polish.py index a0f900aea..02c927e37 100644 --- a/tests/gateway/test_discord_race_polish.py +++ b/tests/gateway/test_discord_race_polish.py @@ -1,18 +1,8 @@ -"""Regression tests for the Discord adapter race-polish fix. - -Two races are addressed: -1. on_message allowlist check racing on_ready's _resolve_allowed_usernames - resolution window. Username-based entries in DISCORD_ALLOWED_USERS - appear in the set as raw strings for several seconds after - connect/reconnect; author.id is always numeric, so legitimate users - are silently rejected until resolution finishes. -2. join_voice_channel check-and-connect: concurrent /voice channel - invocations both see _voice_clients.get(guild_id) is None, both call - channel.connect(), second raises ClientException ('Already connected'). -""" +"""Discord adapter race polish: concurrent join_voice_channel must not +double-invoke channel.connect() on the same guild.""" import asyncio -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest @@ -20,7 +10,6 @@ from gateway.config import Platform, PlatformConfig def _make_adapter(): - """Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md.""" from gateway.platforms.discord import DiscordAdapter adapter = object.__new__(DiscordAdapter) @@ -40,83 +29,51 @@ def _make_adapter(): return adapter -class TestJoinVoiceSerialization: - @pytest.mark.asyncio - async def test_concurrent_joins_do_not_double_connect(self): - """Two concurrent join_voice_channel calls on the same guild - must serialize through the per-guild lock — only ONE - channel.connect() actually fires; the second sees the - _voice_clients entry the first just installed.""" - adapter = _make_adapter() +@pytest.mark.asyncio +async def test_concurrent_joins_do_not_double_connect(): + """Two concurrent join_voice_channel calls on the same guild must + serialize through the per-guild lock — only ONE channel.connect() + actually fires; the second sees the _voice_clients entry the first + just installed.""" + adapter = _make_adapter() - connect_count = [0] - connect_event = asyncio.Event() + connect_count = [0] + release = asyncio.Event() - class FakeVC: - def __init__(self, channel): - self.channel = channel + class FakeVC: + def __init__(self, channel): + self.channel = channel - def is_connected(self): - return True + def is_connected(self): + return True - async def move_to(self, _channel): - return None + async def move_to(self, _channel): + return None - async def disconnect(self): - return None + async def slow_connect(self): + connect_count[0] += 1 + await release.wait() + return FakeVC(self) - async def slow_connect(self): - connect_count[0] += 1 - # Widen the race window - await connect_event.wait() - return FakeVC(self) + channel = MagicMock() + channel.id = 111 + channel.guild.id = 42 + channel.connect = lambda: slow_connect(channel) - channel = MagicMock() - channel.id = 111 - channel.guild.id = 42 - channel.connect = lambda: slow_connect(channel) + from gateway.platforms import discord as discord_mod + with patch.object(discord_mod, "VoiceReceiver", + MagicMock(return_value=MagicMock(start=lambda: None))): + with patch.object(discord_mod.asyncio, "ensure_future", + lambda _c: asyncio.create_task(asyncio.sleep(0))): + t1 = asyncio.create_task(adapter.join_voice_channel(channel)) + t2 = asyncio.create_task(adapter.join_voice_channel(channel)) + await asyncio.sleep(0.05) + release.set() + r1, r2 = await asyncio.gather(t1, t2) - # Swap out VoiceReceiver so it doesn't try to set up real audio - from gateway.platforms import discord as discord_mod - with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))): - with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))): - # Fire two joins concurrently - t1 = asyncio.create_task(adapter.join_voice_channel(channel)) - t2 = asyncio.create_task(adapter.join_voice_channel(channel)) - # Let them run until they're blocked on our event - await asyncio.sleep(0.05) - # Release connect so both can finish - connect_event.set() - r1, r2 = await asyncio.gather(t1, t2) - - assert connect_count[0] == 1, ( - f"Expected exactly 1 channel.connect() call, got {connect_count[0]} — " - "per-guild voice lock is not serializing join_voice_channel" - ) - assert r1 is True and r2 is True - assert 42 in adapter._voice_clients - - -class TestOnMessageWaitsForReadyEvent: - @pytest.mark.asyncio - async def test_on_message_blocks_until_ready_event_set(self): - """A message arriving before on_ready finishes - _resolve_allowed_usernames must wait, not proceed with a - half-resolved allowlist.""" - # This is an integration-style check — we pull out the - # on_message handler by asserting the source contains the - # expected wait pattern. A full end-to-end test would require - # setting up the discord.py client machinery, which is not - # practical here. - import inspect - from gateway.platforms import discord as discord_mod - - src = inspect.getsource(discord_mod.DiscordAdapter.connect) - assert "_ready_event.is_set()" in src, ( - "on_message must gate on _ready_event so username-based " - "allowlist entries are resolved before the allowlist check" - ) - assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, ( - "Expected asyncio.wait_for(_ready_event.wait(), timeout=...) " - "pattern in on_message" - ) + assert connect_count[0] == 1, ( + f"expected 1 channel.connect() call, got {connect_count[0]} — " + "per-guild lock is not serializing join_voice_channel" + ) + assert r1 is True and r2 is True + assert 42 in adapter._voice_clients