refactor(discord): slim down the race-polish fix (#12644)

PR #12558 was heavy for what the fix actually is — essay-length
comments, a dedicated helper method where a setdefault would do, and
a source-inspection test with no real behavior coverage.  The
genuine code change is ~5 lines of new logic (1 field, 2 async with,
an on_ready wait block).

Trimmed:
- Replaced the 12-line _voice_lock_for helper with a setdefault
  one-liner at each call site (join_voice_channel, leave_voice_channel).
- Collapsed the 12-line comment on on_message's _ready_event wait to
  3 lines.  Dropped the warning log on timeout — pass-on-timeout is
  fine; if on_ready hangs that long, the bot is already broken and
  the log wouldn't help.
- Dropped the source-inspection test (greps the module source for
  expected substrings).  It was low-value scaffolding; the
  voice-serialization test covers actual behavior.

Net: -73 lines vs PR #12558.  Same two guarantees preserved, same
test passes (verified by stashing the fix and confirming failure).
This commit is contained in:
Teknium 2026-04-19 11:08:10 -07:00 committed by GitHub
parent 5a23f3291a
commit 7e3b356574
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 50 additions and 123 deletions

View file

@ -637,29 +637,14 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event
async def on_message(message: DiscordMessage):
# Wait for on_ready to finish resolving username-based
# allowlist entries. Without this block, messages
# arriving between Discord's READY event and the end
# of _resolve_allowed_usernames compare author IDs
# (numeric) against a set that may still contain raw
# usernames (strings) from DISCORD_ALLOWED_USERS —
# legitimate users get silently rejected for the first
# few seconds after every reconnect. The wait is a
# near-instant no-op in steady state (_ready_event is
# already set); only the startup / reconnect window
# ever blocks.
# Block until _resolve_allowed_usernames has swapped
# any raw usernames in DISCORD_ALLOWED_USERS for numeric
# IDs (otherwise on_message's author.id lookup can miss).
if not adapter_self._ready_event.is_set():
try:
await asyncio.wait_for(
adapter_self._ready_event.wait(),
timeout=30.0,
)
await asyncio.wait_for(adapter_self._ready_event.wait(), timeout=30.0)
except asyncio.TimeoutError:
logger.warning(
"[%s] on_message timed out waiting for _ready_event; "
"allowlist check may use pre-resolved entries",
adapter_self.name,
)
pass
# Dedup: Discord RESUME replays events after reconnects (#4777)
if adapter_self._dedup.is_duplicate(str(message.id)):
@ -1256,28 +1241,13 @@ class DiscordAdapter(BasePlatformAdapter):
# Voice channel methods (join / leave / play)
# ------------------------------------------------------------------
def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock":
"""Return the per-guild lock, creating it on first use.
Voice join/leave/move must be serialized per guild without
this, two concurrent /voice channel invocations both see
_voice_clients.get(guild_id) return None, both call
channel.connect(), and discord.py raises ClientException
('Already connected') on the loser.
"""
lock = self._voice_locks.get(guild_id)
if lock is None:
lock = asyncio.Lock()
self._voice_locks[guild_id] = lock
return lock
async def join_voice_channel(self, channel) -> bool:
"""Join a Discord voice channel. Returns True on success."""
if not self._client or not DISCORD_AVAILABLE:
return False
guild_id = channel.guild.id
async with self._voice_lock_for(guild_id):
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
# Already connected in this guild?
existing = self._voice_clients.get(guild_id)
if existing and existing.is_connected():
@ -1307,7 +1277,7 @@ class DiscordAdapter(BasePlatformAdapter):
async def leave_voice_channel(self, guild_id: int) -> None:
"""Disconnect from the voice channel in a guild."""
async with self._voice_lock_for(guild_id):
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
# Stop voice receiver first
receiver = self._voice_receivers.pop(guild_id, None)
if receiver:

View file

@ -1,18 +1,8 @@
"""Regression tests for the Discord adapter race-polish fix.
Two races are addressed:
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
appear in the set as raw strings for several seconds after
connect/reconnect; author.id is always numeric, so legitimate users
are silently rejected until resolution finishes.
2. join_voice_channel check-and-connect: concurrent /voice channel
invocations both see _voice_clients.get(guild_id) is None, both call
channel.connect(), second raises ClientException ('Already connected').
"""
"""Discord adapter race polish: concurrent join_voice_channel must not
double-invoke channel.connect() on the same guild."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import MagicMock, patch
import pytest
@ -20,7 +10,6 @@ from gateway.config import Platform, PlatformConfig
def _make_adapter():
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
@ -40,83 +29,51 @@ def _make_adapter():
return adapter
class TestJoinVoiceSerialization:
@pytest.mark.asyncio
async def test_concurrent_joins_do_not_double_connect(self):
"""Two concurrent join_voice_channel calls on the same guild
must serialize through the per-guild lock only ONE
channel.connect() actually fires; the second sees the
_voice_clients entry the first just installed."""
adapter = _make_adapter()
@pytest.mark.asyncio
async def test_concurrent_joins_do_not_double_connect():
"""Two concurrent join_voice_channel calls on the same guild must
serialize through the per-guild lock only ONE channel.connect()
actually fires; the second sees the _voice_clients entry the first
just installed."""
adapter = _make_adapter()
connect_count = [0]
connect_event = asyncio.Event()
connect_count = [0]
release = asyncio.Event()
class FakeVC:
def __init__(self, channel):
self.channel = channel
class FakeVC:
def __init__(self, channel):
self.channel = channel
def is_connected(self):
return True
def is_connected(self):
return True
async def move_to(self, _channel):
return None
async def move_to(self, _channel):
return None
async def disconnect(self):
return None
async def slow_connect(self):
connect_count[0] += 1
await release.wait()
return FakeVC(self)
async def slow_connect(self):
connect_count[0] += 1
# Widen the race window
await connect_event.wait()
return FakeVC(self)
channel = MagicMock()
channel.id = 111
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
channel = MagicMock()
channel.id = 111
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
from gateway.platforms import discord as discord_mod
with patch.object(discord_mod, "VoiceReceiver",
MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future",
lambda _c: asyncio.create_task(asyncio.sleep(0))):
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
await asyncio.sleep(0.05)
release.set()
r1, r2 = await asyncio.gather(t1, t2)
# Swap out VoiceReceiver so it doesn't try to set up real audio
from gateway.platforms import discord as discord_mod
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
# Fire two joins concurrently
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
# Let them run until they're blocked on our event
await asyncio.sleep(0.05)
# Release connect so both can finish
connect_event.set()
r1, r2 = await asyncio.gather(t1, t2)
assert connect_count[0] == 1, (
f"Expected exactly 1 channel.connect() call, got {connect_count[0]}"
"per-guild voice lock is not serializing join_voice_channel"
)
assert r1 is True and r2 is True
assert 42 in adapter._voice_clients
class TestOnMessageWaitsForReadyEvent:
@pytest.mark.asyncio
async def test_on_message_blocks_until_ready_event_set(self):
"""A message arriving before on_ready finishes
_resolve_allowed_usernames must wait, not proceed with a
half-resolved allowlist."""
# This is an integration-style check — we pull out the
# on_message handler by asserting the source contains the
# expected wait pattern. A full end-to-end test would require
# setting up the discord.py client machinery, which is not
# practical here.
import inspect
from gateway.platforms import discord as discord_mod
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
assert "_ready_event.is_set()" in src, (
"on_message must gate on _ready_event so username-based "
"allowlist entries are resolved before the allowlist check"
)
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
"pattern in on_message"
)
assert connect_count[0] == 1, (
f"expected 1 channel.connect() call, got {connect_count[0]}"
"per-guild lock is not serializing join_voice_channel"
)
assert r1 is True and r2 is True
assert 42 in adapter._voice_clients