mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): namespace voice mode state by platform to prevent cross-platform collision (#12542)
This commit is contained in:
parent
519faa6e76
commit
40164ba12b
2 changed files with 281 additions and 18 deletions
|
|
@ -786,6 +786,10 @@ class GatewayRunner:
|
|||
|
||||
_VOICE_MODE_PATH = _hermes_home / "gateway_voice_mode.json"
|
||||
|
||||
def _voice_key(self, platform: Platform, chat_id: str) -> str:
|
||||
"""Return a platform-namespaced key for voice mode state."""
|
||||
return f"{platform.value}:{chat_id}"
|
||||
|
||||
def _load_voice_modes(self) -> Dict[str, str]:
|
||||
try:
|
||||
data = json.loads(self._VOICE_MODE_PATH.read_text())
|
||||
|
|
@ -796,11 +800,21 @@ class GatewayRunner:
|
|||
return {}
|
||||
|
||||
valid_modes = {"off", "voice_only", "all"}
|
||||
return {
|
||||
str(chat_id): mode
|
||||
for chat_id, mode in data.items()
|
||||
if mode in valid_modes
|
||||
}
|
||||
result = {}
|
||||
for chat_id, mode in data.items():
|
||||
if mode not in valid_modes:
|
||||
continue
|
||||
key = str(chat_id)
|
||||
# Skip legacy unprefixed keys (warn and skip)
|
||||
if ":" not in key:
|
||||
logger.warning(
|
||||
"Skipping legacy unprefixed voice mode key %r during migration. "
|
||||
"Re-enable voice mode on that chat to rebuild the prefixed key.",
|
||||
key,
|
||||
)
|
||||
continue
|
||||
result[key] = mode
|
||||
return result
|
||||
|
||||
def _save_voice_modes(self) -> None:
|
||||
try:
|
||||
|
|
@ -826,9 +840,14 @@ class GatewayRunner:
|
|||
disabled_chats = getattr(adapter, "_auto_tts_disabled_chats", None)
|
||||
if not isinstance(disabled_chats, set):
|
||||
return
|
||||
platform = getattr(adapter, "platform", None)
|
||||
if not isinstance(platform, Platform):
|
||||
return
|
||||
disabled_chats.clear()
|
||||
prefix = f"{platform.value}:"
|
||||
disabled_chats.update(
|
||||
chat_id for chat_id, mode in self._voice_mode.items() if mode == "off"
|
||||
key[len(prefix):] for key, mode in self._voice_mode.items()
|
||||
if mode == "off" and key.startswith(prefix)
|
||||
)
|
||||
|
||||
async def _safe_adapter_disconnect(self, adapter, platform) -> None:
|
||||
|
|
@ -5781,11 +5800,13 @@ class GatewayRunner:
|
|||
"""Handle /voice [on|off|tts|channel|leave|status] command."""
|
||||
args = event.get_command_args().strip().lower()
|
||||
chat_id = event.source.chat_id
|
||||
platform = event.source.platform
|
||||
voice_key = self._voice_key(platform, chat_id)
|
||||
|
||||
adapter = self.adapters.get(event.source.platform)
|
||||
adapter = self.adapters.get(platform)
|
||||
|
||||
if args in ("on", "enable"):
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._voice_mode[voice_key] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
|
|
@ -5795,13 +5816,13 @@ class GatewayRunner:
|
|||
"Use /voice tts to get voice replies for all messages."
|
||||
)
|
||||
elif args in ("off", "disable"):
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._voice_mode[voice_key] = "off"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
return "Voice mode disabled. Text-only replies."
|
||||
elif args == "tts":
|
||||
self._voice_mode[chat_id] = "all"
|
||||
self._voice_mode[voice_key] = "all"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
|
|
@ -5814,7 +5835,7 @@ class GatewayRunner:
|
|||
elif args == "leave":
|
||||
return await self._handle_voice_channel_leave(event)
|
||||
elif args == "status":
|
||||
mode = self._voice_mode.get(chat_id, "off")
|
||||
mode = self._voice_mode.get(voice_key, "off")
|
||||
labels = {
|
||||
"off": "Off (text only)",
|
||||
"voice_only": "On (voice reply to voice messages)",
|
||||
|
|
@ -5838,15 +5859,15 @@ class GatewayRunner:
|
|||
return f"Voice mode: {labels.get(mode, mode)}"
|
||||
else:
|
||||
# Toggle: off → on, on/all → off
|
||||
current = self._voice_mode.get(chat_id, "off")
|
||||
current = self._voice_mode.get(voice_key, "off")
|
||||
if current == "off":
|
||||
self._voice_mode[chat_id] = "voice_only"
|
||||
self._voice_mode[voice_key] = "voice_only"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=False)
|
||||
return "Voice mode enabled."
|
||||
else:
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._voice_mode[voice_key] = "off"
|
||||
self._save_voice_modes()
|
||||
if adapter:
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
|
|
@ -5892,7 +5913,7 @@ class GatewayRunner:
|
|||
adapter._voice_text_channels[guild_id] = int(event.source.chat_id)
|
||||
if hasattr(adapter, "_voice_sources"):
|
||||
adapter._voice_sources[guild_id] = event.source.to_dict()
|
||||
self._voice_mode[event.source.chat_id] = "all"
|
||||
self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "all"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=False)
|
||||
return (
|
||||
|
|
@ -5919,7 +5940,7 @@ class GatewayRunner:
|
|||
except Exception as e:
|
||||
logger.warning("Error leaving voice channel: %s", e)
|
||||
# Always clean up state even if leave raised an exception
|
||||
self._voice_mode[event.source.chat_id] = "off"
|
||||
self._voice_mode[self._voice_key(Platform.DISCORD, event.source.chat_id)] = "off"
|
||||
self._save_voice_modes()
|
||||
self._set_adapter_auto_tts_disabled(adapter, event.source.chat_id, disabled=True)
|
||||
if hasattr(adapter, "_voice_input_callback"):
|
||||
|
|
@ -5931,7 +5952,7 @@ class GatewayRunner:
|
|||
|
||||
Cleans up runner-side voice_mode state that the adapter cannot reach.
|
||||
"""
|
||||
self._voice_mode[chat_id] = "off"
|
||||
self._voice_mode[self._voice_key(Platform.DISCORD, chat_id)] = "off"
|
||||
self._save_voice_modes()
|
||||
adapter = self.adapters.get(Platform.DISCORD)
|
||||
self._set_adapter_auto_tts_disabled(adapter, chat_id, disabled=True)
|
||||
|
|
@ -6017,7 +6038,7 @@ class GatewayRunner:
|
|||
return False
|
||||
|
||||
chat_id = event.source.chat_id
|
||||
voice_mode = self._voice_mode.get(chat_id, "off")
|
||||
voice_mode = self._voice_mode.get(self._voice_key(event.source.platform, chat_id), "off")
|
||||
is_voice_input = (event.message_type == MessageType.VOICE)
|
||||
|
||||
should = (
|
||||
|
|
|
|||
242
tests/gateway/test_voice_mode_platform_isolation.py
Normal file
242
tests/gateway/test_voice_mode_platform_isolation.py
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
"""Tests for voice mode platform isolation (bug #12542).
|
||||
|
||||
Voice mode state stored as {chat_id: mode} without a platform namespace
|
||||
caused collisions: Telegram chat '123' and Slack chat '123' shared the
|
||||
same key. The fix prefixes keys with platform value: 'telegram:123' vs
|
||||
'slack:123'.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform
|
||||
from gateway.run import GatewayRunner
|
||||
|
||||
|
||||
class TestVoiceKeyHelper:
|
||||
"""Test the _voice_key helper method."""
|
||||
|
||||
def test_voice_key_format(self):
|
||||
"""_voice_key returns 'platform:chat_id' format."""
|
||||
runner = _make_runner()
|
||||
assert runner._voice_key(Platform.TELEGRAM, "123") == "telegram:123"
|
||||
assert runner._voice_key(Platform.SLACK, "456") == "slack:456"
|
||||
assert runner._voice_key(Platform.DISCORD, "789") == "discord:789"
|
||||
|
||||
def test_voice_key_different_platforms_same_chat_id(self):
|
||||
"""Same chat_id on different platforms yields different keys."""
|
||||
runner = _make_runner()
|
||||
key_telegram = runner._voice_key(Platform.TELEGRAM, "123")
|
||||
key_slack = runner._voice_key(Platform.SLACK, "123")
|
||||
key_discord = runner._voice_key(Platform.DISCORD, "123")
|
||||
assert key_telegram != key_slack
|
||||
assert key_slack != key_discord
|
||||
assert key_telegram == "telegram:123"
|
||||
assert key_slack == "slack:123"
|
||||
assert key_discord == "discord:123"
|
||||
|
||||
|
||||
class TestVoiceModePlatformIsolation:
|
||||
"""Test that voice mode state is isolated by platform."""
|
||||
|
||||
def test_telegram_and_slack_voice_mode_independent(self):
|
||||
"""Setting voice mode for Telegram chat '123' does not affect Slack chat '123'."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Enable voice mode for Telegram chat '123'
|
||||
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "all"
|
||||
# Enable voice mode for Slack chat '123' to a different mode
|
||||
runner._voice_mode[runner._voice_key(Platform.SLACK, "123")] = "voice_only"
|
||||
|
||||
# Verify they are independent
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "all"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
|
||||
|
||||
# Disabling Telegram should not affect Slack
|
||||
runner._voice_mode[runner._voice_key(Platform.TELEGRAM, "123")] = "off"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.TELEGRAM, "123")) == "off"
|
||||
assert runner._voice_mode.get(runner._voice_key(Platform.SLACK, "123")) == "voice_only"
|
||||
|
||||
def test_legacy_key_collision_bug(self):
|
||||
"""Demonstrates the pre-fix bug: same key without platform prefix collides.
|
||||
|
||||
This test documents the original bug behavior. After the fix, keys are
|
||||
properly namespaced, so this scenario cannot occur in the fixed code.
|
||||
The test shows that if two platforms shared the same raw chat_id as key,
|
||||
they would overwrite each other.
|
||||
"""
|
||||
runner = _make_runner()
|
||||
|
||||
# Simulate legacy behavior where keys were just chat_id (no platform prefix)
|
||||
# In the fixed code this cannot happen because _voice_key is always used,
|
||||
# but this test shows WHY the fix was needed.
|
||||
legacy_key = "123" # No platform prefix
|
||||
|
||||
runner._voice_mode[legacy_key] = "all"
|
||||
# If Slack also used "123" as key, it would overwrite
|
||||
runner._voice_mode[legacy_key] = "voice_only"
|
||||
|
||||
# Both platforms would see the same value (last write wins)
|
||||
assert runner._voice_mode[legacy_key] == "voice_only"
|
||||
|
||||
# The fix prevents this by using platform-prefixed keys
|
||||
|
||||
|
||||
class TestLegacyKeyMigration:
|
||||
"""Test migration of legacy unprefixed keys in _load_voice_modes."""
|
||||
|
||||
def test_load_voice_modes_skips_legacy_keys(self):
|
||||
"""_load_voice_modes skips keys without ':' prefix and logs a warning."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Simulate legacy persisted data with unprefixed keys
|
||||
legacy_data = {
|
||||
"123": "all",
|
||||
"456": "voice_only",
|
||||
# Also includes a properly prefixed key (from after the fix)
|
||||
"telegram:789": "off",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(legacy_data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
with patch("gateway.run.logger") as mock_logger:
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
# Legacy keys without ':' should be skipped
|
||||
assert "123" not in result
|
||||
assert "456" not in result
|
||||
# Prefixed key should be preserved
|
||||
assert result.get("telegram:789") == "off"
|
||||
# Warning should be logged for each legacy key
|
||||
assert mock_logger.warning.called
|
||||
warning_calls = [str(call) for call in mock_logger.warning.call_args_list]
|
||||
assert any("Skipping legacy unprefixed voice mode key" in str(c) for c in warning_calls)
|
||||
|
||||
def test_load_voice_modes_preserves_prefixed_keys(self):
|
||||
"""_load_voice_modes correctly loads platform-prefixed keys."""
|
||||
runner = _make_runner()
|
||||
|
||||
persisted_data = {
|
||||
"telegram:123": "all",
|
||||
"slack:456": "voice_only",
|
||||
"discord:789": "off",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(persisted_data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
assert result.get("telegram:123") == "all"
|
||||
assert result.get("slack:456") == "voice_only"
|
||||
assert result.get("discord:789") == "off"
|
||||
|
||||
def test_load_voice_modes_invalid_modes_filtered(self):
|
||||
"""_load_voice_modes filters out invalid mode values."""
|
||||
runner = _make_runner()
|
||||
|
||||
data = {
|
||||
"telegram:123": "all",
|
||||
"telegram:456": "invalid_mode",
|
||||
"telegram:789": "voice_only",
|
||||
}
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
voice_path = Path(tmpdir) / "gateway_voice_mode.json"
|
||||
voice_path.write_text(json.dumps(data))
|
||||
|
||||
with patch.object(runner, "_VOICE_MODE_PATH", voice_path):
|
||||
result = runner._load_voice_modes()
|
||||
|
||||
assert result.get("telegram:123") == "all"
|
||||
assert "telegram:456" not in result
|
||||
assert result.get("telegram:789") == "voice_only"
|
||||
|
||||
|
||||
class TestSyncVoiceModeStateToAdapter:
|
||||
"""Test _sync_voice_mode_state_to_adapter filters by platform."""
|
||||
|
||||
def test_sync_only_includes_platform_chats(self):
|
||||
"""Only chats matching the adapter's platform are synced."""
|
||||
runner = _make_runner()
|
||||
|
||||
# Set up voice mode state with multiple platforms
|
||||
runner._voice_mode = {
|
||||
"telegram:123": "off", # Should sync
|
||||
"telegram:456": "all", # Should NOT sync (mode is not "off")
|
||||
"slack:123": "off", # Should NOT sync (different platform)
|
||||
"discord:789": "off", # Should NOT sync (different platform)
|
||||
}
|
||||
|
||||
# Create a mock Telegram adapter
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = Platform.TELEGRAM
|
||||
mock_adapter._auto_tts_disabled_chats = set()
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# Only telegram:123 should be in disabled_chats (mode="off" for telegram)
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_sync_clears_existing_state(self):
|
||||
"""_sync_voice_mode_state_to_adapter clears existing disabled_chats first."""
|
||||
runner = _make_runner()
|
||||
|
||||
runner._voice_mode = {
|
||||
"telegram:123": "off",
|
||||
}
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = Platform.TELEGRAM
|
||||
mock_adapter._auto_tts_disabled_chats = {"old_chat_id", "another_old"}
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# Old entries should be cleared
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"123"}
|
||||
|
||||
def test_sync_returns_early_without_platform(self):
|
||||
"""_sync_voice_mode_state_to_adapter returns early if adapter has no platform."""
|
||||
runner = _make_runner()
|
||||
runner._voice_mode = {"telegram:123": "off"}
|
||||
|
||||
mock_adapter = MagicMock()
|
||||
mock_adapter.platform = None
|
||||
mock_adapter._auto_tts_disabled_chats = {"old"}
|
||||
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
# disabled_chats should not be modified
|
||||
assert mock_adapter._auto_tts_disabled_chats == {"old"}
|
||||
|
||||
def test_sync_returns_early_without_auto_tts_disabled_chats(self):
|
||||
"""_sync_voice_mode_state_to_adapter returns early if adapter lacks _auto_tts_disabled_chats."""
|
||||
runner = _make_runner()
|
||||
runner._voice_mode = {"telegram:123": "off"}
|
||||
|
||||
mock_adapter = MagicMock(spec=[]) # No _auto_tts_disabled_chats attribute
|
||||
|
||||
# Should not raise
|
||||
runner._sync_voice_mode_state_to_adapter(mock_adapter)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_runner() -> GatewayRunner:
|
||||
"""Create a minimal GatewayRunner for testing."""
|
||||
with patch("gateway.run.GatewayRunner._load_voice_modes", return_value={}):
|
||||
runner = GatewayRunner.__new__(GatewayRunner)
|
||||
runner._voice_mode = {}
|
||||
runner.adapters = {}
|
||||
return runner
|
||||
Loading…
Add table
Add a link
Reference in a new issue