mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(discord): close two low-severity adapter races (#12558)
Two small races in gateway/platforms/discord.py, bundled together
since they're adjacent in the adapter and both narrow in impact.
1. on_message vs _resolve_allowed_usernames (startup window)
DISCORD_ALLOWED_USERS accepts both numeric IDs and raw usernames.
At connect-time, _resolve_allowed_usernames walks the bot's guilds
(fetch_members can take multiple seconds) to swap usernames for IDs.
on_message can fire during that window; _is_allowed_user compares
the numeric author.id against a set that may still contain raw
usernames — legitimate users get silently rejected for a few
seconds after every reconnect.
Fix: on_message awaits _ready_event (with a 30s timeout) when it
isn't already set. on_ready sets the event after the resolve
completes. In steady state this is a no-op (event already set);
only the startup / reconnect window ever blocks.
2. join_voice_channel check-and-connect
The existing-connection check at _voice_clients.get() and the
channel.connect() call straddled an await boundary with no lock.
Two concurrent /voice channel invocations could both see None and
both call connect(); discord.py raises ClientException
("Already connected") on the loser. Same race class for leave
running concurrently with _voice_timeout_handler.
Fix: per-guild asyncio.Lock (_voice_locks dict with lazy alloc via
_voice_lock_for). join_voice_channel and leave_voice_channel both
run their body under the lock. Sequential within a guild, still
fully concurrent across guilds.
Both: LOW severity. The first only affects username-based allowlists
on fast-follow-up messages at startup; the second is a narrow
exception on simultaneous voice commands. Bundled so the adapter
gets a single coherent polish pass.
Tests (tests/gateway/test_discord_race_polish.py): 2 regression cases.
- test_concurrent_joins_do_not_double_connect: two concurrent
join_voice_channel calls on the same guild result in exactly one
channel.connect() invocation.
- test_on_message_blocks_until_ready_event_set: asserts the expected
wait pattern is present in on_message (source inspection, since
full discord.py client setup isn't practical here).
Regression-guard validated: against unpatched gateway/platforms/discord.py
both tests fail. With the fix they pass. Full Discord suite (118
tests) green.
This commit is contained in:
parent
c567adb58a
commit
a521005fe5
2 changed files with 201 additions and 37 deletions
|
|
@ -498,6 +498,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self._allowed_role_ids: set = set() # For DISCORD_ALLOWED_ROLES filtering
|
||||
# Voice channel state (per-guild)
|
||||
self._voice_clients: Dict[int, Any] = {} # guild_id -> VoiceClient
|
||||
self._voice_locks: Dict[int, asyncio.Lock] = {} # guild_id -> serialize join/leave
|
||||
# Text batching: merge rapid successive messages (Telegram-style)
|
||||
self._text_batch_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_DELAY_SECONDS", "0.6"))
|
||||
self._text_batch_split_delay_seconds = float(os.getenv("HERMES_DISCORD_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0"))
|
||||
|
|
@ -636,6 +637,30 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Wait for on_ready to finish resolving username-based
|
||||
# allowlist entries. Without this block, messages
|
||||
# arriving between Discord's READY event and the end
|
||||
# of _resolve_allowed_usernames compare author IDs
|
||||
# (numeric) against a set that may still contain raw
|
||||
# usernames (strings) from DISCORD_ALLOWED_USERS —
|
||||
# legitimate users get silently rejected for the first
|
||||
# few seconds after every reconnect. The wait is a
|
||||
# near-instant no-op in steady state (_ready_event is
|
||||
# already set); only the startup / reconnect window
|
||||
# ever blocks.
|
||||
if not adapter_self._ready_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
adapter_self._ready_event.wait(),
|
||||
timeout=30.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[%s] on_message timed out waiting for _ready_event; "
|
||||
"allowlist check may use pre-resolved entries",
|
||||
adapter_self.name,
|
||||
)
|
||||
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||||
return
|
||||
|
|
@ -1231,57 +1256,74 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Voice channel methods (join / leave / play)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock":
|
||||
"""Return the per-guild lock, creating it on first use.
|
||||
|
||||
Voice join/leave/move must be serialized per guild — without
|
||||
this, two concurrent /voice channel invocations both see
|
||||
_voice_clients.get(guild_id) return None, both call
|
||||
channel.connect(), and discord.py raises ClientException
|
||||
('Already connected') on the loser.
|
||||
"""
|
||||
lock = self._voice_locks.get(guild_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._voice_locks[guild_id] = lock
|
||||
return lock
|
||||
|
||||
async def join_voice_channel(self, channel) -> bool:
|
||||
"""Join a Discord voice channel. Returns True on success."""
|
||||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return False
|
||||
guild_id = channel.guild.id
|
||||
|
||||
# Already connected in this guild?
|
||||
existing = self._voice_clients.get(guild_id)
|
||||
if existing and existing.is_connected():
|
||||
if existing.channel.id == channel.id:
|
||||
async with self._voice_lock_for(guild_id):
|
||||
# Already connected in this guild?
|
||||
existing = self._voice_clients.get(guild_id)
|
||||
if existing and existing.is_connected():
|
||||
if existing.channel.id == channel.id:
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
await existing.move_to(channel)
|
||||
self._reset_voice_timeout(guild_id)
|
||||
return True
|
||||
await existing.move_to(channel)
|
||||
|
||||
vc = await channel.connect()
|
||||
self._voice_clients[guild_id] = vc
|
||||
self._reset_voice_timeout(guild_id)
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_listen_loop(guild_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice receiver failed to start: %s", e)
|
||||
|
||||
return True
|
||||
|
||||
vc = await channel.connect()
|
||||
self._voice_clients[guild_id] = vc
|
||||
self._reset_voice_timeout(guild_id)
|
||||
|
||||
# Start voice receiver (Phase 2: listen to users)
|
||||
try:
|
||||
receiver = VoiceReceiver(vc, allowed_user_ids=self._allowed_user_ids)
|
||||
receiver.start()
|
||||
self._voice_receivers[guild_id] = receiver
|
||||
self._voice_listen_tasks[guild_id] = asyncio.ensure_future(
|
||||
self._voice_listen_loop(guild_id)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Voice receiver failed to start: %s", e)
|
||||
|
||||
return True
|
||||
|
||||
async def leave_voice_channel(self, guild_id: int) -> None:
|
||||
"""Disconnect from the voice channel in a guild."""
|
||||
# Stop voice receiver first
|
||||
receiver = self._voice_receivers.pop(guild_id, None)
|
||||
if receiver:
|
||||
receiver.stop()
|
||||
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||||
if listen_task:
|
||||
listen_task.cancel()
|
||||
async with self._voice_lock_for(guild_id):
|
||||
# Stop voice receiver first
|
||||
receiver = self._voice_receivers.pop(guild_id, None)
|
||||
if receiver:
|
||||
receiver.stop()
|
||||
listen_task = self._voice_listen_tasks.pop(guild_id, None)
|
||||
if listen_task:
|
||||
listen_task.cancel()
|
||||
|
||||
vc = self._voice_clients.pop(guild_id, None)
|
||||
if vc and vc.is_connected():
|
||||
await vc.disconnect()
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
self._voice_sources.pop(guild_id, None)
|
||||
vc = self._voice_clients.pop(guild_id, None)
|
||||
if vc and vc.is_connected():
|
||||
await vc.disconnect()
|
||||
task = self._voice_timeout_tasks.pop(guild_id, None)
|
||||
if task:
|
||||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
self._voice_sources.pop(guild_id, None)
|
||||
|
||||
# Maximum seconds to wait for voice playback before giving up
|
||||
PLAYBACK_TIMEOUT = 120
|
||||
|
|
|
|||
122
tests/gateway/test_discord_race_polish.py
Normal file
122
tests/gateway/test_discord_race_polish.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Regression tests for the Discord adapter race-polish fix.
|
||||
|
||||
Two races are addressed:
|
||||
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
|
||||
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
|
||||
appear in the set as raw strings for several seconds after
|
||||
connect/reconnect; author.id is always numeric, so legitimate users
|
||||
are silently rejected until resolution finishes.
|
||||
2. join_voice_channel check-and-connect: concurrent /voice channel
|
||||
invocations both see _voice_clients.get(guild_id) is None, both call
|
||||
channel.connect(), second raises ClientException ('Already connected').
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._platform = Platform.DISCORD
|
||||
adapter.config = PlatformConfig(enabled=True, token="t")
|
||||
adapter._ready_event = asyncio.Event()
|
||||
adapter._allowed_user_ids = set()
|
||||
adapter._allowed_role_ids = set()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_locks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._client = MagicMock()
|
||||
return adapter
|
||||
|
||||
|
||||
class TestJoinVoiceSerialization:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_joins_do_not_double_connect(self):
|
||||
"""Two concurrent join_voice_channel calls on the same guild
|
||||
must serialize through the per-guild lock — only ONE
|
||||
channel.connect() actually fires; the second sees the
|
||||
_voice_clients entry the first just installed."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
connect_count = [0]
|
||||
connect_event = asyncio.Event()
|
||||
|
||||
class FakeVC:
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def is_connected(self):
|
||||
return True
|
||||
|
||||
async def move_to(self, _channel):
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
|
||||
async def slow_connect(self):
|
||||
connect_count[0] += 1
|
||||
# Widen the race window
|
||||
await connect_event.wait()
|
||||
return FakeVC(self)
|
||||
|
||||
channel = MagicMock()
|
||||
channel.id = 111
|
||||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
|
||||
# Swap out VoiceReceiver so it doesn't try to set up real audio
|
||||
from gateway.platforms import discord as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
|
||||
# Fire two joins concurrently
|
||||
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
# Let them run until they're blocked on our event
|
||||
await asyncio.sleep(0.05)
|
||||
# Release connect so both can finish
|
||||
connect_event.set()
|
||||
r1, r2 = await asyncio.gather(t1, t2)
|
||||
|
||||
assert connect_count[0] == 1, (
|
||||
f"Expected exactly 1 channel.connect() call, got {connect_count[0]} — "
|
||||
"per-guild voice lock is not serializing join_voice_channel"
|
||||
)
|
||||
assert r1 is True and r2 is True
|
||||
assert 42 in adapter._voice_clients
|
||||
|
||||
|
||||
class TestOnMessageWaitsForReadyEvent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_blocks_until_ready_event_set(self):
|
||||
"""A message arriving before on_ready finishes
|
||||
_resolve_allowed_usernames must wait, not proceed with a
|
||||
half-resolved allowlist."""
|
||||
# This is an integration-style check — we pull out the
|
||||
# on_message handler by asserting the source contains the
|
||||
# expected wait pattern. A full end-to-end test would require
|
||||
# setting up the discord.py client machinery, which is not
|
||||
# practical here.
|
||||
import inspect
|
||||
from gateway.platforms import discord as discord_mod
|
||||
|
||||
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
|
||||
assert "_ready_event.is_set()" in src, (
|
||||
"on_message must gate on _ready_event so username-based "
|
||||
"allowlist entries are resolved before the allowlist check"
|
||||
)
|
||||
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
|
||||
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
|
||||
"pattern in on_message"
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue