diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 43a9338d78..f92cdf8db0 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -442,6 +442,7 @@ class DiscordAdapter(BasePlatformAdapter): self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} self._voice_text_channels: Dict[int, int] = {} # guild_id -> text_channel_id + self._voice_sources: Dict[int, Dict[str, Any]] = {} # guild_id -> linked text channel source metadata self._voice_timeout_tasks: Dict[int, asyncio.Task] = {} # guild_id -> timeout task # Phase 2: voice listening self._voice_receivers: Dict[int, VoiceReceiver] = {} # guild_id -> VoiceReceiver @@ -1045,6 +1046,7 @@ class DiscordAdapter(BasePlatformAdapter): if task: task.cancel() self._voice_text_channels.pop(guild_id, None) + self._voice_sources.pop(guild_id, None) # Maximum seconds to wait for voice playback before giving up PLAYBACK_TIMEOUT = 120 @@ -2244,6 +2246,7 @@ class DiscordAdapter(BasePlatformAdapter): thread_id = str(message.channel.id) parent_channel_id = self._get_parent_channel_id(message.channel) + is_voice_linked_channel = False if not isinstance(message.channel, discord.DMChannel): channel_ids = {str(message.channel.id)} if parent_channel_id: @@ -2270,7 +2273,12 @@ class DiscordAdapter(BasePlatformAdapter): channel_ids.add(parent_channel_id) require_mention = os.getenv("DISCORD_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - is_free_channel = bool(channel_ids & free_channels) + # Voice-linked text channels act as free-response while voice is active. + # Only the exact bound channel gets the exemption, not sibling threads. + voice_linked_ids = {str(ch_id) for ch_id in self._voice_text_channels.values()} + current_channel_id = str(message.channel.id) + is_voice_linked_channel = current_channel_id in voice_linked_ids + is_free_channel = bool(channel_ids & free_channels) or is_voice_linked_channel # Skip the mention check if the message is in a thread where # the bot has previously participated (auto-created or replied in). @@ -2294,7 +2302,7 @@ class DiscordAdapter(BasePlatformAdapter): no_thread_channels = {ch.strip() for ch in no_thread_channels_raw.split(",") if ch.strip()} skip_thread = bool(channel_ids & no_thread_channels) auto_thread = os.getenv("DISCORD_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - if auto_thread and not skip_thread: + if auto_thread and not skip_thread and not is_voice_linked_channel: thread = await self._auto_create_thread(message) if thread: is_thread = True diff --git a/gateway/run.py b/gateway/run.py index b54149d04a..4c30db7db8 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4927,6 +4927,8 @@ class GatewayRunner: if success: adapter._voice_text_channels[guild_id] = int(event.source.chat_id) + if hasattr(adapter, "_voice_sources"): + adapter._voice_sources[guild_id] = event.source.to_dict() self._voice_mode[event.source.chat_id] = "all" self._save_voice_modes() self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False) @@ -4987,14 +4989,23 @@ class GatewayRunner: if not text_ch_id: return + # Build source — reuse the linked text channel's metadata when available + # so voice input shares the same session as the bound text conversation. + source_data = getattr(adapter, "_voice_sources", {}).get(guild_id) + if source_data: + source = SessionSource.from_dict(source_data) + source.user_id = str(user_id) + source.user_name = str(user_id) + else: + source = SessionSource( + platform=Platform.DISCORD, + chat_id=str(text_ch_id), + user_id=str(user_id), + user_name=str(user_id), + chat_type="channel", + ) + # Check authorization before processing voice input - source = SessionSource( - platform=Platform.DISCORD, - chat_id=str(text_ch_id), - user_id=str(user_id), - user_name=str(user_id), - chat_type="channel", - ) if not self._is_user_authorized(source): logger.debug("Unauthorized voice input from user %d, ignoring", user_id) return @@ -8891,16 +8902,19 @@ async def start_gateway(config: Optional[GatewayConfig] = None, replace: bool = runner.request_restart(detached=False, via_service=True) loop = asyncio.get_event_loop() - for sig in (signal.SIGINT, signal.SIGTERM): - try: - loop.add_signal_handler(sig, shutdown_signal_handler) - except NotImplementedError: - pass - if hasattr(signal, "SIGUSR1"): - try: - loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler) - except NotImplementedError: - pass + if threading.current_thread() is threading.main_thread(): + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, shutdown_signal_handler) + except NotImplementedError: + pass + if hasattr(signal, "SIGUSR1"): + try: + loop.add_signal_handler(signal.SIGUSR1, restart_signal_handler) + except NotImplementedError: + pass + else: + logger.info("Skipping signal handlers (not running in main thread).") # Start the gateway success = await runner.start() diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index 29f65efc67..c2ef286d8e 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -359,3 +359,44 @@ async def test_discord_thread_participation_tracked_on_dispatch(adapter, monkeyp await adapter._handle_message(message) assert "777" in adapter._threads + + +@pytest.mark.asyncio +async def test_discord_voice_linked_channel_skips_mention_requirement_and_auto_thread(adapter, monkeypatch): + """Active voice-linked text channels should behave like free-response channels.""" + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + monkeypatch.delenv("DISCORD_AUTO_THREAD", raising=False) + + adapter._voice_text_channels[111] = 789 + adapter._auto_create_thread = AsyncMock() + + message = make_message( + channel=FakeTextChannel(channel_id=789), + content="follow-up from voice text chat", + ) + + await adapter._handle_message(message) + + adapter._auto_create_thread.assert_not_awaited() + adapter.handle_message.assert_awaited_once() + event = adapter.handle_message.await_args.args[0] + assert event.text == "follow-up from voice text chat" + assert event.source.chat_type == "group" + + +@pytest.mark.asyncio +async def test_discord_voice_linked_parent_thread_still_requires_mention(adapter, monkeypatch): + """Threads under a voice-linked channel should still require @mention.""" + monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") + monkeypatch.delenv("DISCORD_FREE_RESPONSE_CHANNELS", raising=False) + + adapter._voice_text_channels[111] = 789 + message = make_message( + channel=FakeThread(channel_id=790, parent=FakeTextChannel(channel_id=789)), + content="thread reply without mention", + ) + + await adapter._handle_message(message) + + adapter.handle_message.assert_not_awaited() diff --git a/tests/gateway/test_voice_command.py b/tests/gateway/test_voice_command.py index 0638452f0b..f0c3171d6e 100644 --- a/tests/gateway/test_voice_command.py +++ b/tests/gateway/test_voice_command.py @@ -417,6 +417,7 @@ class TestDiscordPlayTtsSkip: adapter.config = config adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_timeout_tasks = {} adapter._voice_receivers = {} adapter._voice_listen_tasks = {} @@ -702,13 +703,18 @@ class TestVoiceChannelCommands: mock_adapter.join_voice_channel = AsyncMock(return_value=True) mock_adapter.get_user_voice_channel = AsyncMock(return_value=mock_channel) mock_adapter._voice_text_channels = {} + mock_adapter._voice_sources = {} mock_adapter._voice_input_callback = None event = self._make_discord_event() + event.source.chat_type = "group" + event.source.chat_name = "Hermes Server / #general" runner.adapters[event.source.platform] = mock_adapter result = await runner._handle_voice_channel_join(event) assert "joined" in result.lower() assert "General" in result assert runner._voice_mode["123"] == "all" + assert mock_adapter._voice_sources[111]["chat_id"] == "123" + assert mock_adapter._voice_sources[111]["chat_type"] == "group" @pytest.mark.asyncio async def test_join_failure(self, runner): @@ -815,6 +821,7 @@ class TestVoiceChannelCommands: from gateway.config import Platform mock_adapter = AsyncMock() mock_adapter._voice_text_channels = {111: 123} + mock_adapter._voice_sources = {} mock_channel = AsyncMock() mock_adapter._client = MagicMock() mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) @@ -828,12 +835,45 @@ class TestVoiceChannelCommands: assert event.source.chat_id == "123" assert event.source.chat_type == "channel" + @pytest.mark.asyncio + async def test_input_reuses_bound_source_metadata(self, runner): + """Voice input should share the linked text channel session metadata.""" + from gateway.config import Platform + + bound_source = SessionSource( + chat_id="123", + chat_name="Hermes Server / #general", + chat_type="group", + user_id="user1", + user_name="user1", + platform=Platform.DISCORD, + ) + + mock_adapter = AsyncMock() + mock_adapter._voice_text_channels = {111: 123} + mock_adapter._voice_sources = {111: bound_source.to_dict()} + mock_channel = AsyncMock() + mock_adapter._client = MagicMock() + mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) + mock_adapter.handle_message = AsyncMock() + runner.adapters[Platform.DISCORD] = mock_adapter + + await runner._handle_voice_channel_input(111, 42, "Hello from VC") + + mock_adapter.handle_message.assert_called_once() + event = mock_adapter.handle_message.call_args[0][0] + assert event.source.chat_id == "123" + assert event.source.chat_type == "group" + assert event.source.chat_name == "Hermes Server / #general" + assert event.source.user_id == "42" + @pytest.mark.asyncio async def test_input_posts_transcript_in_text_channel(self, runner): """Voice input sends transcript message to text channel.""" from gateway.config import Platform mock_adapter = AsyncMock() mock_adapter._voice_text_channels = {111: 123} + mock_adapter._voice_sources = {} mock_channel = AsyncMock() mock_adapter._client = MagicMock() mock_adapter._client.get_channel = MagicMock(return_value=mock_channel) @@ -892,6 +932,7 @@ class TestDiscordVoiceChannelMethods: adapter._client = MagicMock() adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_timeout_tasks = {} adapter._voice_receivers = {} adapter._voice_listen_tasks = {} @@ -926,6 +967,7 @@ class TestDiscordVoiceChannelMethods: mock_vc.disconnect = AsyncMock() adapter._voice_clients[111] = mock_vc adapter._voice_text_channels[111] = 123 + adapter._voice_sources[111] = {"chat_id": "123", "chat_type": "group"} mock_receiver = MagicMock() adapter._voice_receivers[111] = mock_receiver @@ -944,6 +986,7 @@ class TestDiscordVoiceChannelMethods: mock_timeout.cancel.assert_called_once() assert 111 not in adapter._voice_clients assert 111 not in adapter._voice_text_channels + assert 111 not in adapter._voice_sources assert 111 not in adapter._voice_receivers @pytest.mark.asyncio @@ -1670,6 +1713,7 @@ class TestVoiceTimeoutCleansRunnerState: adapter.config = config adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_timeout_tasks = {} adapter._voice_receivers = {} adapter._voice_listen_tasks = {} @@ -1759,6 +1803,7 @@ class TestPlaybackTimeout: adapter.config = config adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_timeout_tasks = {} adapter._voice_receivers = {} adapter._voice_listen_tasks = {} @@ -1939,6 +1984,7 @@ class TestVoiceChannelAwareness: adapter = object.__new__(DiscordAdapter) adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_receivers = {} adapter._client = MagicMock() adapter._client.user = SimpleNamespace(id=99999, name="HermesBot") @@ -2408,6 +2454,7 @@ class TestVoiceTTSPlayback: adapter.config = config adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_receivers = {} return adapter @@ -2587,6 +2634,7 @@ class TestUDPKeepalive: adapter.config = config adapter._voice_clients = {} adapter._voice_text_channels = {} + adapter._voice_sources = {} adapter._voice_receivers = {} adapter._voice_listen_tasks = {}