fix(discord): close two low-severity adapter races (#12558)

Two small races in gateway/platforms/discord.py, bundled together
since they're adjacent in the adapter and both narrow in impact.

1. on_message vs _resolve_allowed_usernames (startup window)
   DISCORD_ALLOWED_USERS accepts both numeric IDs and raw usernames.
   At connect-time, _resolve_allowed_usernames walks the bot's guilds
   (fetch_members can take multiple seconds) to swap usernames for IDs.
   on_message can fire during that window; _is_allowed_user compares
   the numeric author.id against a set that may still contain raw
   usernames — legitimate users get silently rejected for a few
   seconds after every reconnect.

   Fix: on_message awaits _ready_event (with a 30s timeout) when it
   isn't already set.  on_ready sets the event after the resolve
   completes.  In steady state this is a no-op (event already set);
   only the startup / reconnect window ever blocks.

2. join_voice_channel check-and-connect
   The existing-connection check at _voice_clients.get() and the
   channel.connect() call straddled an await boundary with no lock.
   Two concurrent /voice channel invocations could both see None and
   both call connect(); discord.py raises ClientException
   ("Already connected") on the loser.  Same race class for leave
   running concurrently with _voice_timeout_handler.

   Fix: per-guild asyncio.Lock (_voice_locks dict with lazy alloc via
   _voice_lock_for).  join_voice_channel and leave_voice_channel both
   run their body under the lock.  Sequential within a guild, still
   fully concurrent across guilds.

Both: LOW severity.  The first only affects username-based allowlists
on fast-follow-up messages at startup; the second is a narrow
exception on simultaneous voice commands.  Bundled so the adapter
gets a single coherent polish pass.

Tests (tests/gateway/test_discord_race_polish.py): 2 regression cases.
- test_concurrent_joins_do_not_double_connect: two concurrent
  join_voice_channel calls on the same guild result in exactly one
  channel.connect() invocation.
- test_on_message_blocks_until_ready_event_set: asserts the expected
  wait pattern is present in on_message (source inspection, since
  full discord.py client setup isn't practical here).

Regression-guard validated: against unpatched gateway/platforms/discord.py
both tests fail.  With the fix they pass.  Full Discord suite (118
tests) green.
This commit is contained in:
Teknium 2026-04-19 05:45:59 -07:00 committed by GitHub
parent c567adb58a
commit a521005fe5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 201 additions and 37 deletions

View file

@ -498,6 +498,7 @@ class DiscordAdapter(BasePlatformAdapter):
self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering
# Voice channel state (per-guild)
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
self._voice_locks: Dict[int, asyncio.Lock] = {} # guild_id -> serialize join/leave
# Text batching: merge rapid successive messages (Telegram-style)
self._text_batch_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", "0.6"))
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
@ -636,6 +637,30 @@ class DiscordAdapter(BasePlatformAdapter):
@self._client.event
async def on_message(message: DiscordMessage):
# Wait for on_ready to finish resolving username-based
# allowlist entries. Without this block, messages
# arriving between Discord's READY event and the end
# of _resolve_allowed_usernames compare author IDs
# (numeric) against a set that may still contain raw
# usernames (strings) from DISCORD_ALLOWED_USERS —
# legitimate users get silently rejected for the first
# few seconds after every reconnect. The wait is a
# near-instant no-op in steady state (_ready_event is
# already set); only the startup / reconnect window
# ever blocks.
if not adapter_self._ready_event.is_set():
try:
await asyncio.wait_for(
adapter_self._ready_event.wait(),
timeout=30.0,
)
except asyncio.TimeoutError:
logger.warning(
"[%s] on_message timed out waiting for _ready_event; "
"allowlist check may use pre-resolved entries",
adapter_self.name,
)
# Dedup: Discord RESUME replays events after reconnects (#4777)
if adapter_self._dedup.is_duplicate(str(message.id)):
return
@ -1231,57 +1256,74 @@ class DiscordAdapter(BasePlatformAdapter):
# Voice channel methods (join / leave / play)
# ------------------------------------------------------------------
def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock":
"""Return the per-guild lock, creating it on first use.
Voice join/leave/move must be serialized per guild without
this, two concurrent /voice channel invocations both see
_voice_clients.get(guild_id) return None, both call
channel.connect(), and discord.py raises ClientException
('Already connected') on the loser.
"""
lock = self._voice_locks.get(guild_id)
if lock is None:
lock = asyncio.Lock()
self._voice_locks[guild_id] = lock
return lock
async def join_voice_channel(self, channel) -> bool:
"""Join a Discord voice channel. Returns True on success."""
if not self._client or not DISCORD_AVAILABLE:
return False
guild_id = channel.guild.id
# Already connected in this guild?
existing = self._voice_clients.get(guild_id)
if existing and existing.is_connected():
if existing.channel.id == channel.id:
async with self._voice_lock_for(guild_id):
# Already connected in this guild?
existing = self._voice_clients.get(guild_id)
if existing and existing.is_connected():
if existing.channel.id == channel.id:
self._reset_voice_timeout(guild_id)
return True
await existing.move_to(channel)
self._reset_voice_timeout(guild_id)
return True
await existing.move_to(channel)
vc = await channel.connect()
self._voice_clients[guild_id] = vc
self._reset_voice_timeout(guild_id)
# Start voice receiver (Phase 2: listen to users)
try:
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
receiver.start()
self._voice_receivers[guild_id] = receiver
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
self._voice_listen_loop(guild_id)
)
except Exception as e:
logger.warning("Voice receiver failed to start: %s", e)
return True
vc = await channel.connect()
self._voice_clients[guild_id] = vc
self._reset_voice_timeout(guild_id)
# Start voice receiver (Phase 2: listen to users)
try:
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
receiver.start()
self._voice_receivers[guild_id] = receiver
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
self._voice_listen_loop(guild_id)
)
except Exception as e:
logger.warning("Voice receiver failed to start: %s", e)
return True
async def leave_voice_channel(self, guild_id: int) -> None:
"""Disconnect from the voice channel in a guild."""
# Stop voice receiver first
receiver = self._voice_receivers.pop(guild_id, None)
if receiver:
receiver.stop()
listen_task = self._voice_listen_tasks.pop(guild_id, None)
if listen_task:
listen_task.cancel()
async with self._voice_lock_for(guild_id):
# Stop voice receiver first
receiver = self._voice_receivers.pop(guild_id, None)
if receiver:
receiver.stop()
listen_task = self._voice_listen_tasks.pop(guild_id, None)
if listen_task:
listen_task.cancel()
vc = self._voice_clients.pop(guild_id, None)
if vc and vc.is_connected():
await vc.disconnect()
task = self._voice_timeout_tasks.pop(guild_id, None)
if task:
task.cancel()
self._voice_text_channels.pop(guild_id, None)
self._voice_sources.pop(guild_id, None)
vc = self._voice_clients.pop(guild_id, None)
if vc and vc.is_connected():
await vc.disconnect()
task = self._voice_timeout_tasks.pop(guild_id, None)
if task:
task.cancel()
self._voice_text_channels.pop(guild_id, None)
self._voice_sources.pop(guild_id, None)
# Maximum seconds to wait for voice playback before giving up
PLAYBACK_TIMEOUT = 120

View file

@ -0,0 +1,122 @@
"""Regression tests for the Discord adapter race-polish fix.
Two races are addressed:
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
appear in the set as raw strings for several seconds after
connect/reconnect; author.id is always numeric, so legitimate users
are silently rejected until resolution finishes.
2. join_voice_channel check-and-connect: concurrent /voice channel
invocations both see _voice_clients.get(guild_id) is None, both call
channel.connect(), second raises ClientException ('Already connected').
"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
def _make_adapter():
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
from gateway.platforms.discord import DiscordAdapter
adapter = object.__new__(DiscordAdapter)
adapter._platform = Platform.DISCORD
adapter.config = PlatformConfig(enabled=True, token="t")
adapter._ready_event = asyncio.Event()
adapter._allowed_user_ids = set()
adapter._allowed_role_ids = set()
adapter._voice_clients = {}
adapter._voice_locks = {}
adapter._voice_receivers = {}
adapter._voice_listen_tasks = {}
adapter._voice_timeout_tasks = {}
adapter._voice_text_channels = {}
adapter._voice_sources = {}
adapter._client = MagicMock()
return adapter
class TestJoinVoiceSerialization:
@pytest.mark.asyncio
async def test_concurrent_joins_do_not_double_connect(self):
"""Two concurrent join_voice_channel calls on the same guild
must serialize through the per-guild lock only ONE
channel.connect() actually fires; the second sees the
_voice_clients entry the first just installed."""
adapter = _make_adapter()
connect_count = [0]
connect_event = asyncio.Event()
class FakeVC:
def __init__(self, channel):
self.channel = channel
def is_connected(self):
return True
async def move_to(self, _channel):
return None
async def disconnect(self):
return None
async def slow_connect(self):
connect_count[0] += 1
# Widen the race window
await connect_event.wait()
return FakeVC(self)
channel = MagicMock()
channel.id = 111
channel.guild.id = 42
channel.connect = lambda: slow_connect(channel)
# Swap out VoiceReceiver so it doesn't try to set up real audio
from gateway.platforms import discord as discord_mod
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
# Fire two joins concurrently
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
# Let them run until they're blocked on our event
await asyncio.sleep(0.05)
# Release connect so both can finish
connect_event.set()
r1, r2 = await asyncio.gather(t1, t2)
assert connect_count[0] == 1, (
f"Expected exactly 1 channel.connect() call, got {connect_count[0]}"
"per-guild voice lock is not serializing join_voice_channel"
)
assert r1 is True and r2 is True
assert 42 in adapter._voice_clients
class TestOnMessageWaitsForReadyEvent:
@pytest.mark.asyncio
async def test_on_message_blocks_until_ready_event_set(self):
"""A message arriving before on_ready finishes
_resolve_allowed_usernames must wait, not proceed with a
half-resolved allowlist."""
# This is an integration-style check — we pull out the
# on_message handler by asserting the source contains the
# expected wait pattern. A full end-to-end test would require
# setting up the discord.py client machinery, which is not
# practical here.
import inspect
from gateway.platforms import discord as discord_mod
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
assert "_ready_event.is_set()" in src, (
"on_message must gate on _ready_event so username-based "
"allowlist entries are resolved before the allowlist check"
)
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
"pattern in on_message"
)