diff --git a/gateway/run.py b/gateway/run.py index 9991ecc6e..d8722edcd 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -786,6 +786,10 @@ class GatewayRunner: _VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json" + def _voice_key(self, platform: Platform, chat_id: str) -> str: + """Return a platform-namespaced key for voice mode state.""" + return f"{platform.value}:{chat_id}" + def _load_voice_modes(self) -> Dict[str, str]: try: data = json.loads(self._VOICE_MODE_PATH.read_text()) @@ -796,11 +800,21 @@ class GatewayRunner: return {} valid_modes = {"off", "voice_only", "all"} - return { - str(chat_id): mode - for chat_id, mode in data.items() - if mode in valid_modes - } + result = {} + for chat_id, mode in data.items(): + if mode not in valid_modes: + continue + key = str(chat_id) + # Skip legacy unprefixed keys (warn and skip) + if ":" not in key: + logger.warning( + "Skipping legacy unprefixed voice mode key %r during migration. " + "Re-enable voice mode on that chat to rebuild the prefixed key.", + key, + ) + continue + result[key] = mode + return result def _save_voice_modes(self) -> None: try: @@ -826,9 +840,14 @@ class GatewayRunner: disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None) if not isinstance(disabled_chats, set): return + platform = getattr(adapter, "platform", None) + if not isinstance(platform, Platform): + return disabled_chats.clear() + prefix = f"{platform.value}:" disabled_chats.update( - chat_id for chat_id, mode in self._voice_mode.items() if mode == "off" + key[len(prefix):] for key, mode in self._voice_mode.items() + if mode == "off" and key.startswith(prefix) ) async def _safe_adapter_disconnect(self, adapter, platform) -> None: @@ -5830,11 +5849,13 @@ class GatewayRunner: """Handle /voice [on|off|tts|channel|leave|status] command.""" args = event.get_command_args().strip().lower() chat_id = event.source.chat_id + platform = event.source.platform + voice_key = self._voice_key(platform, chat_id) - adapter = self.adapters.get(event.source.platform) + adapter = self.adapters.get(platform) if args in ("on", "enable"): - self._voice_mode[chat_id] = "voice_only" + self._voice_mode[voice_key] = "voice_only" self._save_voice_modes() if adapter: self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) @@ -5844,13 +5865,13 @@ class GatewayRunner: "Use /voice tts to get voice replies for all messages." ) elif args in ("off", "disable"): - self._voice_mode[chat_id] = "off" + self._voice_mode[voice_key] = "off" self._save_voice_modes() if adapter: self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) return "Voice mode disabled. Text-only replies." elif args == "tts": - self._voice_mode[chat_id] = "all" + self._voice_mode[voice_key] = "all" self._save_voice_modes() if adapter: self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) @@ -5863,7 +5884,7 @@ class GatewayRunner: elif args == "leave": return await self._handle_voice_channel_leave(event) elif args == "status": - mode = self._voice_mode.get(chat_id, "off") + mode = self._voice_mode.get(voice_key, "off") labels = { "off": "Off (text only)", "voice_only": "On (voice reply to voice messages)", @@ -5887,15 +5908,15 @@ class GatewayRunner: return f"Voice mode: {labels.get(mode, mode)}" else: # Toggle: off → on, on/all → off - current = self._voice_mode.get(chat_id, "off") + current = self._voice_mode.get(voice_key, "off") if current == "off": - self._voice_mode[chat_id] = "voice_only" + self._voice_mode[voice_key] = "voice_only" self._save_voice_modes() if adapter: self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False) return "Voice mode enabled." else: - self._voice_mode[chat_id] = "off" + self._voice_mode[voice_key] = "off" self._save_voice_modes() if adapter: self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) @@ -5941,7 +5962,7 @@ class GatewayRunner: adapter._voice_text_channels[guild_id] = int(event.source.chat_id) if hasattr(adapter, "_voice_sources"): adapter._voice_sources[guild_id] = event.source.to_dict() - self._voice_mode[event.source.chat_id] = "all" + self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "all" self._save_voice_modes() self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False) return ( @@ -5968,7 +5989,7 @@ class GatewayRunner: except Exception as e: logger.warning("Error leaving voice channel: %s", e) # Always clean up state even if leave raised an exception - self._voice_mode[event.source.chat_id] = "off" + self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "off" self._save_voice_modes() self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True) if hasattr(adapter, "_voice_input_callback"): @@ -5980,7 +6001,7 @@ class GatewayRunner: Cleans up runner-side voice_mode state that the adapter cannot reach. """ - self._voice_mode[chat_id] = "off" + self._voice_mode[self._voice_key(Platform.DISCORD, chat_id)] = "off" self._save_voice_modes() adapter = self.adapters.get(Platform.DISCORD) self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True) @@ -6066,7 +6087,7 @@ class GatewayRunner: return False chat_id = event.source.chat_id - voice_mode = self._voice_mode.get(chat_id, "off") + voice_mode = self._voice_mode.get(self._voice_key(event.source.platform, chat_id), "off") is_voice_input = (event.message_type == MessageType.VOICE) should = ( diff --git a/tests/gateway/test_voice_mode_platform_isolation.py b/tests/gateway/test_voice_mode_platform_isolation.py new file mode 100644 index 000000000..5678c876e --- /dev/null +++ b/tests/gateway/test_voice_mode_platform_isolation.py @@ -0,0 +1,242 @@ +"""Tests for voice mode platform isolation (bug #12542). + +Voice mode state stored as {chat_id: mode} without a platform namespace +caused collisions: Telegram chat '123' and Slack chat '123' shared the +same key. The fix prefixes keys with platform value: 'telegram:123' vs +'slack:123'. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from gateway.config import Platform +from gateway.run import GatewayRunner + + +class TestVoiceKeyHelper: + """Test the _voice_key helper method.""" + + def test_voice_key_format(self): + """_voice_key returns 'platform:chat_id' format.""" + runner = _make_runner() + assert runner._voice_key(Platform.TELEGRAM, "123") == "telegram:123" + assert runner._voice_key(Platform.SLACK, "456") == "slack:456" + assert runner._voice_key(Platform.DISCORD, "789") == "discord:789" + + def test_voice_key_different_platforms_same_chat_id(self): + """Same chat_id on different platforms yields different keys.""" + runner = _make_runner() + key_telegram = runner._voice_key(Platform.TELEGRAM, "123") + key_slack = runner._voice_key(Platform.SLACK, "123") + key_discord = runner._voice_key(Platform.DISCORD, "123") + assert key_telegram != key_slack + assert key_slack != key_discord + assert key_telegram == "telegram:123" + assert key_slack == "slack:123" + assert key_discord == "discord:123" + + +class TestVoiceModePlatformIsolation: + """Test that voice mode state is isolated by platform.""" + + def test_telegram_and_slack_voice_mode_independent(self): + """Setting voice mode for Telegram chat '123' does not affect Slack chat '123'.""" + runner = _make_runner() + + # Enable voice mode for Telegram chat '123' + runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "all" + # Enable voice mode for Slack chat '123' to a different mode + runner._voice_mode[runner._voice_key(Platform.SLACK, "123")] = "voice_only" + + # Verify they are independent + assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "all" + assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only" + + # Disabling Telegram should not affect Slack + runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "off" + assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "off" + assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only" + + def test_legacy_key_collision_bug(self): + """Demonstrates the pre-fix bug: same key without platform prefix collides. + + This test documents the original bug behavior. After the fix, keys are + properly namespaced, so this scenario cannot occur in the fixed code. + The test shows that if two platforms shared the same raw chat_id as key, + they would overwrite each other. + """ + runner = _make_runner() + + # Simulate legacy behavior where keys were just chat_id (no platform prefix) + # In the fixed code this cannot happen because _voice_key is always used, + # but this test shows WHY the fix was needed. + legacy_key = "123" # No platform prefix + + runner._voice_mode[legacy_key] = "all" + # If Slack also used "123" as key, it would overwrite + runner._voice_mode[legacy_key] = "voice_only" + + # Both platforms would see the same value (last write wins) + assert runner._voice_mode[legacy_key] == "voice_only" + + # The fix prevents this by using platform-prefixed keys + + +class TestLegacyKeyMigration: + """Test migration of legacy unprefixed keys in _load_voice_modes.""" + + def test_load_voice_modes_skips_legacy_keys(self): + """_load_voice_modes skips keys without ':' prefix and logs a warning.""" + runner = _make_runner() + + # Simulate legacy persisted data with unprefixed keys + legacy_data = { + "123": "all", + "456": "voice_only", + # Also includes a properly prefixed key (from after the fix) + "telegram:789": "off", + } + + with tempfile.TemporaryDirectory() as tmpdir: + voice_path = Path(tmpdir) / "gateway_voice_mode.json" + voice_path.write_text(json.dumps(legacy_data)) + + with patch.object(runner, "_VOICE_MODE_PATH", voice_path): + with patch("gateway.run.logger") as mock_logger: + result = runner._load_voice_modes() + + # Legacy keys without ':' should be skipped + assert "123" not in result + assert "456" not in result + # Prefixed key should be preserved + assert result.get("telegram:789") == "off" + # Warning should be logged for each legacy key + assert mock_logger.warning.called + warning_calls = [str(call) for call in mock_logger.warning.call_args_list] + assert any("Skipping legacy unprefixed voice mode key" in str(c) for c in warning_calls) + + def test_load_voice_modes_preserves_prefixed_keys(self): + """_load_voice_modes correctly loads platform-prefixed keys.""" + runner = _make_runner() + + persisted_data = { + "telegram:123": "all", + "slack:456": "voice_only", + "discord:789": "off", + } + + with tempfile.TemporaryDirectory() as tmpdir: + voice_path = Path(tmpdir) / "gateway_voice_mode.json" + voice_path.write_text(json.dumps(persisted_data)) + + with patch.object(runner, "_VOICE_MODE_PATH", voice_path): + result = runner._load_voice_modes() + + assert result.get("telegram:123") == "all" + assert result.get("slack:456") == "voice_only" + assert result.get("discord:789") == "off" + + def test_load_voice_modes_invalid_modes_filtered(self): + """_load_voice_modes filters out invalid mode values.""" + runner = _make_runner() + + data = { + "telegram:123": "all", + "telegram:456": "invalid_mode", + "telegram:789": "voice_only", + } + + with tempfile.TemporaryDirectory() as tmpdir: + voice_path = Path(tmpdir) / "gateway_voice_mode.json" + voice_path.write_text(json.dumps(data)) + + with patch.object(runner, "_VOICE_MODE_PATH", voice_path): + result = runner._load_voice_modes() + + assert result.get("telegram:123") == "all" + assert "telegram:456" not in result + assert result.get("telegram:789") == "voice_only" + + +class TestSyncVoiceModeStateToAdapter: + """Test _sync_voice_mode_state_to_adapter filters by platform.""" + + def test_sync_only_includes_platform_chats(self): + """Only chats matching the adapter's platform are synced.""" + runner = _make_runner() + + # Set up voice mode state with multiple platforms + runner._voice_mode = { + "telegram:123": "off", # Should sync + "telegram:456": "all", # Should NOT sync (mode is not "off") + "slack:123": "off", # Should NOT sync (different platform) + "discord:789": "off", # Should NOT sync (different platform) + } + + # Create a mock Telegram adapter + mock_adapter = MagicMock() + mock_adapter.platform = Platform.TELEGRAM + mock_adapter._auto_tts_disabled_chats = set() + + runner._sync_voice_mode_state_to_adapter(mock_adapter) + + # Only telegram:123 should be in disabled_chats (mode="off" for telegram) + assert mock_adapter._auto_tts_disabled_chats == {"123"} + + def test_sync_clears_existing_state(self): + """_sync_voice_mode_state_to_adapter clears existing disabled_chats first.""" + runner = _make_runner() + + runner._voice_mode = { + "telegram:123": "off", + } + + mock_adapter = MagicMock() + mock_adapter.platform = Platform.TELEGRAM + mock_adapter._auto_tts_disabled_chats = {"old_chat_id", "another_old"} + + runner._sync_voice_mode_state_to_adapter(mock_adapter) + + # Old entries should be cleared + assert mock_adapter._auto_tts_disabled_chats == {"123"} + + def test_sync_returns_early_without_platform(self): + """_sync_voice_mode_state_to_adapter returns early if adapter has no platform.""" + runner = _make_runner() + runner._voice_mode = {"telegram:123": "off"} + + mock_adapter = MagicMock() + mock_adapter.platform = None + mock_adapter._auto_tts_disabled_chats = {"old"} + + runner._sync_voice_mode_state_to_adapter(mock_adapter) + + # disabled_chats should not be modified + assert mock_adapter._auto_tts_disabled_chats == {"old"} + + def test_sync_returns_early_without_auto_tts_disabled_chats(self): + """_sync_voice_mode_state_to_adapter returns early if adapter lacks _auto_tts_disabled_chats.""" + runner = _make_runner() + runner._voice_mode = {"telegram:123": "off"} + + mock_adapter = MagicMock(spec=[]) # No _auto_tts_disabled_chats attribute + + # Should not raise + runner._sync_voice_mode_state_to_adapter(mock_adapter) + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _make_runner() -> GatewayRunner: + """Create a minimal GatewayRunner for testing.""" + with patch("gateway.run.GatewayRunner._load_voice_modes", return_value={}): + runner = GatewayRunner.__new__(GatewayRunner) + runner._voice_mode = {} + runner.adapters = {} + return runner