mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
refactor(discord): slim down the race-polish fix (#12644)
PR #12558 was heavy for what the fix actually is — essay-length comments, a dedicated helper method where a setdefault would do, and a source-inspection test with no real behavior coverage. The genuine code change is ~5 lines of new logic (1 field, 2 async with, an on_ready wait block). Trimmed: - Replaced the 12-line _voice_lock_for helper with a setdefault one-liner at each call site (join_voice_channel, leave_voice_channel). - Collapsed the 12-line comment on on_message's _ready_event wait to 3 lines. Dropped the warning log on timeout — pass-on-timeout is fine; if on_ready hangs that long, the bot is already broken and the log wouldn't help. - Dropped the source-inspection test (greps the module source for expected substrings). It was low-value scaffolding; the voice-serialization test covers actual behavior. Net: -73 lines vs PR #12558. Same two guarantees preserved, same test passes (verified by stashing the fix and confirming failure).
This commit is contained in:
parent
5a23f3291a
commit
7e3b356574
2 changed files with 50 additions and 123 deletions
|
|
@ -637,29 +637,14 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
@self._client.event
|
||||
async def on_message(message: DiscordMessage):
|
||||
# Wait for on_ready to finish resolving username-based
|
||||
# allowlist entries. Without this block, messages
|
||||
# arriving between Discord's READY event and the end
|
||||
# of _resolve_allowed_usernames compare author IDs
|
||||
# (numeric) against a set that may still contain raw
|
||||
# usernames (strings) from DISCORD_ALLOWED_USERS —
|
||||
# legitimate users get silently rejected for the first
|
||||
# few seconds after every reconnect. The wait is a
|
||||
# near-instant no-op in steady state (_ready_event is
|
||||
# already set); only the startup / reconnect window
|
||||
# ever blocks.
|
||||
# Block until _resolve_allowed_usernames has swapped
|
||||
# any raw usernames in DISCORD_ALLOWED_USERS for numeric
|
||||
# IDs (otherwise on_message's author.id lookup can miss).
|
||||
if not adapter_self._ready_event.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
adapter_self._ready_event.wait(),
|
||||
timeout=30.0,
|
||||
)
|
||||
await asyncio.wait_for(adapter_self._ready_event.wait(), timeout=30.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"[%s] on_message timed out waiting for _ready_event; "
|
||||
"allowlist check may use pre-resolved entries",
|
||||
adapter_self.name,
|
||||
)
|
||||
pass
|
||||
|
||||
# Dedup: Discord RESUME replays events after reconnects (#4777)
|
||||
if adapter_self._dedup.is_duplicate(str(message.id)):
|
||||
|
|
@ -1256,28 +1241,13 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
# Voice channel methods (join / leave / play)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _voice_lock_for(self, guild_id: int) -> "asyncio.Lock":
|
||||
"""Return the per-guild lock, creating it on first use.
|
||||
|
||||
Voice join/leave/move must be serialized per guild — without
|
||||
this, two concurrent /voice channel invocations both see
|
||||
_voice_clients.get(guild_id) return None, both call
|
||||
channel.connect(), and discord.py raises ClientException
|
||||
('Already connected') on the loser.
|
||||
"""
|
||||
lock = self._voice_locks.get(guild_id)
|
||||
if lock is None:
|
||||
lock = asyncio.Lock()
|
||||
self._voice_locks[guild_id] = lock
|
||||
return lock
|
||||
|
||||
async def join_voice_channel(self, channel) -> bool:
|
||||
"""Join a Discord voice channel. Returns True on success."""
|
||||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return False
|
||||
guild_id = channel.guild.id
|
||||
|
||||
async with self._voice_lock_for(guild_id):
|
||||
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
|
||||
# Already connected in this guild?
|
||||
existing = self._voice_clients.get(guild_id)
|
||||
if existing and existing.is_connected():
|
||||
|
|
@ -1307,7 +1277,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
async def leave_voice_channel(self, guild_id: int) -> None:
|
||||
"""Disconnect from the voice channel in a guild."""
|
||||
async with self._voice_lock_for(guild_id):
|
||||
async with self._voice_locks.setdefault(guild_id, asyncio.Lock()):
|
||||
# Stop voice receiver first
|
||||
receiver = self._voice_receivers.pop(guild_id, None)
|
||||
if receiver:
|
||||
|
|
|
|||
|
|
@ -1,18 +1,8 @@
|
|||
"""Regression tests for the Discord adapter race-polish fix.
|
||||
|
||||
Two races are addressed:
|
||||
1. on_message allowlist check racing on_ready's _resolve_allowed_usernames
|
||||
resolution window. Username-based entries in DISCORD_ALLOWED_USERS
|
||||
appear in the set as raw strings for several seconds after
|
||||
connect/reconnect; author.id is always numeric, so legitimate users
|
||||
are silently rejected until resolution finishes.
|
||||
2. join_voice_channel check-and-connect: concurrent /voice channel
|
||||
invocations both see _voice_clients.get(guild_id) is None, both call
|
||||
channel.connect(), second raises ClientException ('Already connected').
|
||||
"""
|
||||
"""Discord adapter race polish: concurrent join_voice_channel must not
|
||||
double-invoke channel.connect() on the same guild."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -20,7 +10,6 @@ from gateway.config import Platform, PlatformConfig
|
|||
|
||||
|
||||
def _make_adapter():
|
||||
"""Bare DiscordAdapter for testing — object.__new__ pattern per AGENTS.md."""
|
||||
from gateway.platforms.discord import DiscordAdapter
|
||||
|
||||
adapter = object.__new__(DiscordAdapter)
|
||||
|
|
@ -40,83 +29,51 @@ def _make_adapter():
|
|||
return adapter
|
||||
|
||||
|
||||
class TestJoinVoiceSerialization:
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_joins_do_not_double_connect(self):
|
||||
"""Two concurrent join_voice_channel calls on the same guild
|
||||
must serialize through the per-guild lock — only ONE
|
||||
channel.connect() actually fires; the second sees the
|
||||
_voice_clients entry the first just installed."""
|
||||
adapter = _make_adapter()
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_joins_do_not_double_connect():
|
||||
"""Two concurrent join_voice_channel calls on the same guild must
|
||||
serialize through the per-guild lock — only ONE channel.connect()
|
||||
actually fires; the second sees the _voice_clients entry the first
|
||||
just installed."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
connect_count = [0]
|
||||
connect_event = asyncio.Event()
|
||||
connect_count = [0]
|
||||
release = asyncio.Event()
|
||||
|
||||
class FakeVC:
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
class FakeVC:
|
||||
def __init__(self, channel):
|
||||
self.channel = channel
|
||||
|
||||
def is_connected(self):
|
||||
return True
|
||||
def is_connected(self):
|
||||
return True
|
||||
|
||||
async def move_to(self, _channel):
|
||||
return None
|
||||
async def move_to(self, _channel):
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
return None
|
||||
async def slow_connect(self):
|
||||
connect_count[0] += 1
|
||||
await release.wait()
|
||||
return FakeVC(self)
|
||||
|
||||
async def slow_connect(self):
|
||||
connect_count[0] += 1
|
||||
# Widen the race window
|
||||
await connect_event.wait()
|
||||
return FakeVC(self)
|
||||
channel = MagicMock()
|
||||
channel.id = 111
|
||||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
|
||||
channel = MagicMock()
|
||||
channel.id = 111
|
||||
channel.guild.id = 42
|
||||
channel.connect = lambda: slow_connect(channel)
|
||||
from gateway.platforms import discord as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver",
|
||||
MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future",
|
||||
lambda _c: asyncio.create_task(asyncio.sleep(0))):
|
||||
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
await asyncio.sleep(0.05)
|
||||
release.set()
|
||||
r1, r2 = await asyncio.gather(t1, t2)
|
||||
|
||||
# Swap out VoiceReceiver so it doesn't try to set up real audio
|
||||
from gateway.platforms import discord as discord_mod
|
||||
with patch.object(discord_mod, "VoiceReceiver", MagicMock(return_value=MagicMock(start=lambda: None))):
|
||||
with patch.object(discord_mod.asyncio, "ensure_future", lambda _c: asyncio.create_task(asyncio.sleep(0))):
|
||||
# Fire two joins concurrently
|
||||
t1 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
t2 = asyncio.create_task(adapter.join_voice_channel(channel))
|
||||
# Let them run until they're blocked on our event
|
||||
await asyncio.sleep(0.05)
|
||||
# Release connect so both can finish
|
||||
connect_event.set()
|
||||
r1, r2 = await asyncio.gather(t1, t2)
|
||||
|
||||
assert connect_count[0] == 1, (
|
||||
f"Expected exactly 1 channel.connect() call, got {connect_count[0]} — "
|
||||
"per-guild voice lock is not serializing join_voice_channel"
|
||||
)
|
||||
assert r1 is True and r2 is True
|
||||
assert 42 in adapter._voice_clients
|
||||
|
||||
|
||||
class TestOnMessageWaitsForReadyEvent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_message_blocks_until_ready_event_set(self):
|
||||
"""A message arriving before on_ready finishes
|
||||
_resolve_allowed_usernames must wait, not proceed with a
|
||||
half-resolved allowlist."""
|
||||
# This is an integration-style check — we pull out the
|
||||
# on_message handler by asserting the source contains the
|
||||
# expected wait pattern. A full end-to-end test would require
|
||||
# setting up the discord.py client machinery, which is not
|
||||
# practical here.
|
||||
import inspect
|
||||
from gateway.platforms import discord as discord_mod
|
||||
|
||||
src = inspect.getsource(discord_mod.DiscordAdapter.connect)
|
||||
assert "_ready_event.is_set()" in src, (
|
||||
"on_message must gate on _ready_event so username-based "
|
||||
"allowlist entries are resolved before the allowlist check"
|
||||
)
|
||||
assert "await asyncio.wait_for(" in src and "_ready_event.wait()" in src, (
|
||||
"Expected asyncio.wait_for(_ready_event.wait(), timeout=...) "
|
||||
"pattern in on_message"
|
||||
)
|
||||
assert connect_count[0] == 1, (
|
||||
f"expected 1 channel.connect() call, got {connect_count[0]} — "
|
||||
"per-guild lock is not serializing join_voice_channel"
|
||||
)
|
||||
assert r1 is True and r2 is True
|
||||
assert 42 in adapter._voice_clients
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue