mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(discord): voice session continuity and signal handler thread safety
- Store source metadata on /voice channel join so voice input shares the same session as the linked text channel conversation - Treat voice-linked text channels as free-response (skip @mention and auto-thread) while voice is active - Scope the voice-linked exemption to the exact bound channel, not sibling threads - Guard signal handler registration in start_gateway() for non-main threads (prevents RuntimeError when gateway runs in a daemon thread) - Clean up _voice_sources on leave_voice_channel Salvaged from PR #3475 by twilwa (Modal runtime portions excluded).
This commit is contained in:
parent
381810ad50
commit
3a64348772
4 changed files with 130 additions and 19 deletions
|
|
@ -442,6 +442,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
self._pending_text_batches: Dict[str, MessageEvent] = {}
|
||||
self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id
|
||||
self._voice_sources: Dict[int, Dict[str, Any]] = {} # guild_id -> linked text channel source metadata
|
||||
self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task
|
||||
# Phase 2: voice listening
|
||||
self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver
|
||||
|
|
@ -1045,6 +1046,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if task:
|
||||
task.cancel()
|
||||
self._voice_text_channels.pop(guild_id, None)
|
||||
self._voice_sources.pop(guild_id, None)
|
||||
|
||||
# Maximum seconds to wait for voice playback before giving up
|
||||
PLAYBACK_TIMEOUT = 120
|
||||
|
|
@ -2244,6 +2246,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
thread_id = str(message.channel.id)
|
||||
parent_channel_id = self._get_parent_channel_id(message.channel)
|
||||
|
||||
is_voice_linked_channel = False
|
||||
if not isinstance(message.channel, discord.DMChannel):
|
||||
channel_ids = {str(message.channel.id)}
|
||||
if parent_channel_id:
|
||||
|
|
@ -2270,7 +2273,12 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
channel_ids.add(parent_channel_id)
|
||||
|
||||
require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no")
|
||||
is_free_channel = bool(channel_ids & free_channels)
|
||||
# Voice-linked text channels act as free-response while voice is active.
|
||||
# Only the exact bound channel gets the exemption, not sibling threads.
|
||||
voice_linked_ids = {str(ch_id) for ch_id in self._voice_text_channels.values()}
|
||||
current_channel_id = str(message.channel.id)
|
||||
is_voice_linked_channel = current_channel_id in voice_linked_ids
|
||||
is_free_channel = bool(channel_ids & free_channels) or is_voice_linked_channel
|
||||
|
||||
# Skip the mention check if the message is in a thread where
|
||||
# the bot has previously participated (auto-created or replied in).
|
||||
|
|
@ -2294,7 +2302,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()}
|
||||
skip_thread = bool(channel_ids & no_thread_channels)
|
||||
auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes")
|
||||
if auto_thread and not skip_thread:
|
||||
if auto_thread and not skip_thread and not is_voice_linked_channel:
|
||||
thread = await self._auto_create_thread(message)
|
||||
if thread:
|
||||
is_thread = True
|
||||
|
|
|
|||
|
|
@ -4927,6 +4927,8 @@ class GatewayRunner:
|
|||
|
||||
if success:
|
||||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
if hasattr(adapter, "_voice_sources"):
|
||||
adapter._voice_sources[guild_id] = event.source.to_dict()
|
||||
self._voice_mode[event.source.chat_id] = "all"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
|
||||
|
|
@ -4987,14 +4989,23 @@ class GatewayRunner:
|
|||
if not text_ch_id:
|
||||
return
|
||||
|
||||
# Build source — reuse the linked text channel's metadata when available
|
||||
# so voice input shares the same session as the bound text conversation.
|
||||
source_data = getattr(adapter, "_voice_sources", {}).get(guild_id)
|
||||
if source_data:
|
||||
source = SessionSource.from_dict(source_data)
|
||||
source.user_id = str(user_id)
|
||||
source.user_name = str(user_id)
|
||||
else:
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=str(text_ch_id),
|
||||
user_id=str(user_id),
|
||||
user_name=str(user_id),
|
||||
chat_type="channel",
|
||||
)
|
||||
|
||||
# Check authorization before processing voice input
|
||||
source = SessionSource(
|
||||
platform=Platform.DISCORD,
|
||||
chat_id=str(text_ch_id),
|
||||
user_id=str(user_id),
|
||||
user_name=str(user_id),
|
||||
chat_type="channel",
|
||||
)
|
||||
if not self._is_user_authorized(source):
|
||||
logger.debug("Unauthorized voice input from user %d, ignoring", user_id)
|
||||
return
|
||||
|
|
@ -8891,16 +8902,19 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool =
|
|||
runner.request_restart(detached=False, via_service=True)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, shutdown_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
if hasattr(signal, "SIGUSR1"):
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
try:
|
||||
loop.add_signal_handler(sig, shutdown_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
if hasattr(signal, "SIGUSR1"):
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
else:
|
||||
logger.info("Skipping signal handlers (not running in main thread).")
|
||||
|
||||
# Start the gateway
|
||||
success = await runner.start()
|
||||
|
|
|
|||
|
|
@ -359,3 +359,44 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp
|
|||
await adapter._handle_message(message)
|
||||
|
||||
assert "777" in adapter._threads
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_linked_channel_skips_mention_requirement_and_auto_thread(adapter, monkeypatch):
|
||||
"""Active voice-linked text channels should behave like free-response channels."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False)
|
||||
|
||||
adapter._voice_text_channels[111] = 789
|
||||
adapter._auto_create_thread = AsyncMock()
|
||||
|
||||
message = make_message(
|
||||
channel=FakeTextChannel(channel_id=789),
|
||||
content="follow-up from voice text chat",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter._auto_create_thread.assert_not_awaited()
|
||||
adapter.handle_message.assert_awaited_once()
|
||||
event = adapter.handle_message.await_args.args[0]
|
||||
assert event.text == "follow-up from voice text chat"
|
||||
assert event.source.chat_type == "group"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_voice_linked_parent_thread_still_requires_mention(adapter, monkeypatch):
|
||||
"""Threads under a voice-linked channel should still require @mention."""
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False)
|
||||
|
||||
adapter._voice_text_channels[111] = 789
|
||||
message = make_message(
|
||||
channel=FakeThread(channel_id=790, parent=FakeTextChannel(channel_id=789)),
|
||||
content="thread reply without mention",
|
||||
)
|
||||
|
||||
await adapter._handle_message(message)
|
||||
|
||||
adapter.handle_message.assert_not_awaited()
|
||||
|
|
|
|||
|
|
@ -417,6 +417,7 @@ class TestDiscordPlayTtsSkip:
|
|||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
|
@ -702,13 +703,18 @@ class TestVoiceChannelCommands:
|
|||
mock_adapter.join_voice_channel = AsyncMock(return_value=True)
|
||||
mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel)
|
||||
mock_adapter._voice_text_channels = {}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_adapter._voice_input_callback = None
|
||||
event = self._make_discord_event()
|
||||
event.source.chat_type = "group"
|
||||
event.source.chat_name = "Hermes Server / #general"
|
||||
runner.adapters[event.source.platform] = mock_adapter
|
||||
result = await runner._handle_voice_channel_join(event)
|
||||
assert "joined" in result.lower()
|
||||
assert "General" in result
|
||||
assert runner._voice_mode["123"] == "all"
|
||||
assert mock_adapter._voice_sources[111]["chat_id"] == "123"
|
||||
assert mock_adapter._voice_sources[111]["chat_type"] == "group"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_failure(self, runner):
|
||||
|
|
@ -815,6 +821,7 @@ class TestVoiceChannelCommands:
|
|||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
|
@ -828,12 +835,45 @@ class TestVoiceChannelCommands:
|
|||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "channel"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_reuses_bound_source_metadata(self, runner):
|
||||
"""Voice input should share the linked text channel session metadata."""
|
||||
from gateway.config import Platform
|
||||
|
||||
bound_source = SessionSource(
|
||||
chat_id="123",
|
||||
chat_name="Hermes Server / #general",
|
||||
chat_type="group",
|
||||
user_id="user1",
|
||||
user_name="user1",
|
||||
platform=Platform.DISCORD,
|
||||
)
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {111: bound_source.to_dict()}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
mock_adapter.handle_message = AsyncMock()
|
||||
runner.adapters[Platform.DISCORD] = mock_adapter
|
||||
|
||||
await runner._handle_voice_channel_input(111, 42, "Hello from VC")
|
||||
|
||||
mock_adapter.handle_message.assert_called_once()
|
||||
event = mock_adapter.handle_message.call_args[0][0]
|
||||
assert event.source.chat_id == "123"
|
||||
assert event.source.chat_type == "group"
|
||||
assert event.source.chat_name == "Hermes Server / #general"
|
||||
assert event.source.user_id == "42"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_input_posts_transcript_in_text_channel(self, runner):
|
||||
"""Voice input sends transcript message to text channel."""
|
||||
from gateway.config import Platform
|
||||
mock_adapter = AsyncMock()
|
||||
mock_adapter._voice_text_channels = {111: 123}
|
||||
mock_adapter._voice_sources = {}
|
||||
mock_channel = AsyncMock()
|
||||
mock_adapter._client = MagicMock()
|
||||
mock_adapter._client.get_channel = MagicMock(return_value=mock_channel)
|
||||
|
|
@ -892,6 +932,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
adapter._client = MagicMock()
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
|
@ -926,6 +967,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
mock_vc.disconnect = AsyncMock()
|
||||
adapter._voice_clients[111] = mock_vc
|
||||
adapter._voice_text_channels[111] = 123
|
||||
adapter._voice_sources[111] = {"chat_id": "123", "chat_type": "group"}
|
||||
|
||||
mock_receiver = MagicMock()
|
||||
adapter._voice_receivers[111] = mock_receiver
|
||||
|
|
@ -944,6 +986,7 @@ class TestDiscordVoiceChannelMethods:
|
|||
mock_timeout.cancel.assert_called_once()
|
||||
assert 111 not in adapter._voice_clients
|
||||
assert 111 not in adapter._voice_text_channels
|
||||
assert 111 not in adapter._voice_sources
|
||||
assert 111 not in adapter._voice_receivers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1670,6 +1713,7 @@ class TestVoiceTimeoutCleansRunnerState:
|
|||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
|
@ -1759,6 +1803,7 @@ class TestPlaybackTimeout:
|
|||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_timeout_tasks = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
|
@ -1939,6 +1984,7 @@ class TestVoiceChannelAwareness:
|
|||
adapter = object.__new__(DiscordAdapter)
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.user = SimpleNamespace(id=99999, name="HermesBot")
|
||||
|
|
@ -2408,6 +2454,7 @@ class TestVoiceTTSPlayback:
|
|||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
return adapter
|
||||
|
||||
|
|
@ -2587,6 +2634,7 @@ class TestUDPKeepalive:
|
|||
adapter.config = config
|
||||
adapter._voice_clients = {}
|
||||
adapter._voice_text_channels = {}
|
||||
adapter._voice_sources = {}
|
||||
adapter._voice_receivers = {}
|
||||
adapter._voice_listen_tasks = {}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue