fix(gateway): namespace voice mode state by platform to prevent cross-platform collision (#12542)

This commit is contained in:
Tranquil-Flow 2026-04-20 01:42:53 +00:00 committed by Teknium
parent 519faa6e76
commit 40164ba12b
No known key found for this signature in database
2 changed files with 281 additions and 18 deletions

View file

@ -786,6 +786,10 @@ class GatewayRunner:
_VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json"
def _voice_key(self, platform: Platform, chat_id: str) -> str:
"""Return a platform-namespaced key for voice mode state."""
return f"{platform.value}:{chat_id}"
def _load_voice_modes(self) -> Dict[str, str]:
try:
data = json.loads(self._VOICE_MODE_PATH.read_text())
@ -796,11 +800,21 @@ class GatewayRunner:
return {}
valid_modes = {"off", "voice_only", "all"}
return {
str(chat_id): mode
for chat_id, mode in data.items()
if mode in valid_modes
}
result = {}
for chat_id, mode in data.items():
if mode not in valid_modes:
continue
key = str(chat_id)
# Skip legacy unprefixed keys (warn and skip)
if ":" not in key:
logger.warning(
"Skipping legacy unprefixed voice mode key %r during migration. "
"Re-enable voice mode on that chat to rebuild the prefixed key.",
key,
)
continue
result[key] = mode
return result
def _save_voice_modes(self) -> None:
try:
@ -826,9 +840,14 @@ class GatewayRunner:
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
if not isinstance(disabled_chats, set):
return
platform = getattr(adapter, "platform", None)
if not isinstance(platform, Platform):
return
disabled_chats.clear()
prefix = f"{platform.value}:"
disabled_chats.update(
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
key[len(prefix):] for key, mode in self._voice_mode.items()
if mode == "off" and key.startswith(prefix)
)
async def _safe_adapter_disconnect(self, adapter, platform) -> None:
@ -5781,11 +5800,13 @@ class GatewayRunner:
"""Handle /voice [on|off|tts|channel|leave|status] command."""
args = event.get_command_args().strip().lower()
chat_id = event.source.chat_id
platform = event.source.platform
voice_key = self._voice_key(platform, chat_id)
adapter = self.adapters.get(event.source.platform)
adapter = self.adapters.get(platform)
if args in ("on", "enable"):
self._voice_mode[chat_id] = "voice_only"
self._voice_mode[voice_key] = "voice_only"
self._save_voice_modes()
if adapter:
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
@ -5795,13 +5816,13 @@ class GatewayRunner:
"Use /voice tts to get voice replies for all messages."
)
elif args in ("off", "disable"):
self._voice_mode[chat_id] = "off"
self._voice_mode[voice_key] = "off"
self._save_voice_modes()
if adapter:
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
return "Voice mode disabled. Text-only replies."
elif args == "tts":
self._voice_mode[chat_id] = "all"
self._voice_mode[voice_key] = "all"
self._save_voice_modes()
if adapter:
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
@ -5814,7 +5835,7 @@ class GatewayRunner:
elif args == "leave":
return await self._handle_voice_channel_leave(event)
elif args == "status":
mode = self._voice_mode.get(chat_id, "off")
mode = self._voice_mode.get(voice_key, "off")
labels = {
"off": "Off (text only)",
"voice_only": "On (voice reply to voice messages)",
@ -5838,15 +5859,15 @@ class GatewayRunner:
return f"Voice mode: {labels.get(mode, mode)}"
else:
# Toggle: off → on, on/all → off
current = self._voice_mode.get(chat_id, "off")
current = self._voice_mode.get(voice_key, "off")
if current == "off":
self._voice_mode[chat_id] = "voice_only"
self._voice_mode[voice_key] = "voice_only"
self._save_voice_modes()
if adapter:
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
return "Voice mode enabled."
else:
self._voice_mode[chat_id] = "off"
self._voice_mode[voice_key] = "off"
self._save_voice_modes()
if adapter:
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
@ -5892,7 +5913,7 @@ class GatewayRunner:
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
if hasattr(adapter, "_voice_sources"):
adapter._voice_sources[guild_id] = event.source.to_dict()
self._voice_mode[event.source.chat_id] = "all"
self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "all"
self._save_voice_modes()
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
return (
@ -5919,7 +5940,7 @@ class GatewayRunner:
except Exception as e:
logger.warning("Error leaving voice channel: %s", e)
# Always clean up state even if leave raised an exception
self._voice_mode[event.source.chat_id] = "off"
self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "off"
self._save_voice_modes()
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
if hasattr(adapter, "_voice_input_callback"):
@ -5931,7 +5952,7 @@ class GatewayRunner:
Cleans up runner-side voice_mode state that the adapter cannot reach.
"""
self._voice_mode[chat_id] = "off"
self._voice_mode[self._voice_key(Platform.DISCORD, chat_id)] = "off"
self._save_voice_modes()
adapter = self.adapters.get(Platform.DISCORD)
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
@ -6017,7 +6038,7 @@ class GatewayRunner:
return False
chat_id = event.source.chat_id
voice_mode = self._voice_mode.get(chat_id, "off")
voice_mode = self._voice_mode.get(self._voice_key(event.source.platform, chat_id), "off")
is_voice_input = (event.message_type == MessageType.VOICE)
should = (

View file

@ -0,0 +1,242 @@
"""Tests for voice mode platform isolation (bug #12542).
Voice mode state stored as {chat_id: mode} without a platform namespace
caused collisions: Telegram chat '123' and Slack chat '123' shared the
same key. The fix prefixes keys with platform value: 'telegram:123' vs
'slack:123'.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from gateway.config import Platform
from gateway.run import GatewayRunner
class TestVoiceKeyHelper:
"""Test the _voice_key helper method."""
def test_voice_key_format(self):
"""_voice_key returns 'platform:chat_id' format."""
runner = _make_runner()
assert runner._voice_key(Platform.TELEGRAM, "123") == "telegram:123"
assert runner._voice_key(Platform.SLACK, "456") == "slack:456"
assert runner._voice_key(Platform.DISCORD, "789") == "discord:789"
def test_voice_key_different_platforms_same_chat_id(self):
"""Same chat_id on different platforms yields different keys."""
runner = _make_runner()
key_telegram = runner._voice_key(Platform.TELEGRAM, "123")
key_slack = runner._voice_key(Platform.SLACK, "123")
key_discord = runner._voice_key(Platform.DISCORD, "123")
assert key_telegram != key_slack
assert key_slack != key_discord
assert key_telegram == "telegram:123"
assert key_slack == "slack:123"
assert key_discord == "discord:123"
class TestVoiceModePlatformIsolation:
"""Test that voice mode state is isolated by platform."""
def test_telegram_and_slack_voice_mode_independent(self):
"""Setting voice mode for Telegram chat '123' does not affect Slack chat '123'."""
runner = _make_runner()
# Enable voice mode for Telegram chat '123'
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "all"
# Enable voice mode for Slack chat '123' to a different mode
runner._voice_mode[runner._voice_key(Platform.SLACK, "123")] = "voice_only"
# Verify they are independent
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "all"
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
# Disabling Telegram should not affect Slack
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "off"
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "off"
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
def test_legacy_key_collision_bug(self):
"""Demonstrates the pre-fix bug: same key without platform prefix collides.
This test documents the original bug behavior. After the fix, keys are
properly namespaced, so this scenario cannot occur in the fixed code.
The test shows that if two platforms shared the same raw chat_id as key,
they would overwrite each other.
"""
runner = _make_runner()
# Simulate legacy behavior where keys were just chat_id (no platform prefix)
# In the fixed code this cannot happen because _voice_key is always used,
# but this test shows WHY the fix was needed.
legacy_key = "123" # No platform prefix
runner._voice_mode[legacy_key] = "all"
# If Slack also used "123" as key, it would overwrite
runner._voice_mode[legacy_key] = "voice_only"
# Both platforms would see the same value (last write wins)
assert runner._voice_mode[legacy_key] == "voice_only"
# The fix prevents this by using platform-prefixed keys
class TestLegacyKeyMigration:
"""Test migration of legacy unprefixed keys in _load_voice_modes."""
def test_load_voice_modes_skips_legacy_keys(self):
"""_load_voice_modes skips keys without ':' prefix and logs a warning."""
runner = _make_runner()
# Simulate legacy persisted data with unprefixed keys
legacy_data = {
"123": "all",
"456": "voice_only",
# Also includes a properly prefixed key (from after the fix)
"telegram:789": "off",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(legacy_data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
with patch("gateway.run.logger") as mock_logger:
result = runner._load_voice_modes()
# Legacy keys without ':' should be skipped
assert "123" not in result
assert "456" not in result
# Prefixed key should be preserved
assert result.get("telegram:789") == "off"
# Warning should be logged for each legacy key
assert mock_logger.warning.called
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
assert any("Skipping legacy unprefixed voice mode key" in str(c) for c in warning_calls)
def test_load_voice_modes_preserves_prefixed_keys(self):
"""_load_voice_modes correctly loads platform-prefixed keys."""
runner = _make_runner()
persisted_data = {
"telegram:123": "all",
"slack:456": "voice_only",
"discord:789": "off",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(persisted_data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
result = runner._load_voice_modes()
assert result.get("telegram:123") == "all"
assert result.get("slack:456") == "voice_only"
assert result.get("discord:789") == "off"
def test_load_voice_modes_invalid_modes_filtered(self):
"""_load_voice_modes filters out invalid mode values."""
runner = _make_runner()
data = {
"telegram:123": "all",
"telegram:456": "invalid_mode",
"telegram:789": "voice_only",
}
with tempfile.TemporaryDirectory() as tmpdir:
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
voice_path.write_text(json.dumps(data))
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
result = runner._load_voice_modes()
assert result.get("telegram:123") == "all"
assert "telegram:456" not in result
assert result.get("telegram:789") == "voice_only"
class TestSyncVoiceModeStateToAdapter:
"""Test _sync_voice_mode_state_to_adapter filters by platform."""
def test_sync_only_includes_platform_chats(self):
"""Only chats matching the adapter's platform are synced."""
runner = _make_runner()
# Set up voice mode state with multiple platforms
runner._voice_mode = {
"telegram:123": "off", # Should sync
"telegram:456": "all", # Should NOT sync (mode is not "off")
"slack:123": "off", # Should NOT sync (different platform)
"discord:789": "off", # Should NOT sync (different platform)
}
# Create a mock Telegram adapter
mock_adapter = MagicMock()
mock_adapter.platform = Platform.TELEGRAM
mock_adapter._auto_tts_disabled_chats = set()
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# Only telegram:123 should be in disabled_chats (mode="off" for telegram)
assert mock_adapter._auto_tts_disabled_chats == {"123"}
def test_sync_clears_existing_state(self):
"""_sync_voice_mode_state_to_adapter clears existing disabled_chats first."""
runner = _make_runner()
runner._voice_mode = {
"telegram:123": "off",
}
mock_adapter = MagicMock()
mock_adapter.platform = Platform.TELEGRAM
mock_adapter._auto_tts_disabled_chats = {"old_chat_id", "another_old"}
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# Old entries should be cleared
assert mock_adapter._auto_tts_disabled_chats == {"123"}
def test_sync_returns_early_without_platform(self):
"""_sync_voice_mode_state_to_adapter returns early if adapter has no platform."""
runner = _make_runner()
runner._voice_mode = {"telegram:123": "off"}
mock_adapter = MagicMock()
mock_adapter.platform = None
mock_adapter._auto_tts_disabled_chats = {"old"}
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# disabled_chats should not be modified
assert mock_adapter._auto_tts_disabled_chats == {"old"}
def test_sync_returns_early_without_auto_tts_disabled_chats(self):
"""_sync_voice_mode_state_to_adapter returns early if adapter lacks _auto_tts_disabled_chats."""
runner = _make_runner()
runner._voice_mode = {"telegram:123": "off"}
mock_adapter = MagicMock(spec=[]) # No _auto_tts_disabled_chats attribute
# Should not raise
runner._sync_voice_mode_state_to_adapter(mock_adapter)
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _make_runner() -> GatewayRunner:
"""Create a minimal GatewayRunner for testing."""
with patch("gateway.run.GatewayRunner._load_voice_modes", return_value={}):
runner = GatewayRunner.__new__(GatewayRunner)
runner._voice_mode = {}
runner.adapters = {}
return runner