diff --git a/gateway/config.py b/gateway/config.py index 60387cc84c..935a50d74a 100644 --- a/gateway/config.py +++ b/gateway/config.py @@ -138,6 +138,12 @@ class PlatformConfig: api_key: Optional[str] = None # API key if different from token home_channel: Optional[HomeChannel] = None + # Reply threading mode (Telegram/Slack) + # - "off": Never thread replies to original message + # - "first": Only first chunk threads to user's message (default) + # - "all": All chunks in multi-part replies thread to user's message + reply_to_mode: str = "first" + # Platform-specific settings extra: Dict[str, Any] = field(default_factory=dict) @@ -145,6 +151,7 @@ class PlatformConfig: result = { "enabled": self.enabled, "extra": self.extra, + "reply_to_mode": self.reply_to_mode, } if self.token: result["token"] = self.token @@ -165,6 +172,7 @@ class PlatformConfig: token=data.get("token"), api_key=data.get("api_key"), home_channel=home_channel, + reply_to_mode=data.get("reply_to_mode", "first"), extra=data.get("extra", {}), ) @@ -586,6 +594,13 @@ def _apply_env_overrides(config: GatewayConfig) -> None: config.platforms[Platform.TELEGRAM].enabled = True config.platforms[Platform.TELEGRAM].token = telegram_token + # Reply threading mode for Telegram (off/first/all) + telegram_reply_mode = os.getenv("TELEGRAM_REPLY_TO_MODE", "").lower() + if telegram_reply_mode in ("off", "first", "all"): + if Platform.TELEGRAM not in config.platforms: + config.platforms[Platform.TELEGRAM] = PlatformConfig() + config.platforms[Platform.TELEGRAM].reply_to_mode = telegram_reply_mode + telegram_home = os.getenv("TELEGRAM_HOME_CHANNEL") if telegram_home and Platform.TELEGRAM in config.platforms: config.platforms[Platform.TELEGRAM].home_channel = HomeChannel( diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 6591bdcc21..b65f29ee51 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -115,6 +115,7 @@ class TelegramAdapter(BasePlatformAdapter): super().__init__(config, Platform.TELEGRAM) self._app: Optional[Application] = None self._bot: Optional[Bot] = None + self._reply_to_mode: str = getattr(config, 'reply_to_mode', 'first') or 'first' # Buffer rapid/album photo updates so Telegram image bursts are handled # as a single MessageEvent instead of self-interrupting multiple turns. self._media_batch_delay_seconds = float(os.getenv("HERMES_TELEGRAM_MEDIA_BATCH_DELAY_SECONDS", "0.8")) @@ -442,6 +443,26 @@ class TelegramAdapter(BasePlatformAdapter): self._token_lock_identity = None logger.info("[%s] Disconnected from Telegram", self.name) + def _should_thread_reply(self, reply_to: Optional[str], chunk_index: int) -> bool: + """Determine if this message chunk should thread to the original message. + + Args: + reply_to: The original message ID to reply to + chunk_index: Index of this chunk (0 = first chunk) + + Returns: + True if this chunk should be threaded to the original message + """ + if not reply_to: + return False + mode = self._reply_to_mode + if mode == "off": + return False + elif mode == "all": + return True + else: # "first" (default) + return chunk_index == 0 + async def send( self, chat_id: str, @@ -475,6 +496,9 @@ class TelegramAdapter(BasePlatformAdapter): _NetErr = OSError # type: ignore[misc,assignment] for i, chunk in enumerate(chunks): + should_thread = self._should_thread_reply(reply_to, i) + reply_to_id = int(reply_to) if should_thread else None + msg = None for _send_attempt in range(3): try: @@ -484,7 +508,7 @@ class TelegramAdapter(BasePlatformAdapter): chat_id=int(chat_id), text=chunk, parse_mode=ParseMode.MARKDOWN_V2, - reply_to_message_id=int(reply_to) if reply_to and i == 0 else None, + reply_to_message_id=reply_to_id, message_thread_id=int(thread_id) if thread_id else None, ) except Exception as md_error: @@ -496,7 +520,7 @@ class TelegramAdapter(BasePlatformAdapter): chat_id=int(chat_id), text=plain_chunk, parse_mode=None, - reply_to_message_id=int(reply_to) if reply_to and i == 0 else None, + reply_to_message_id=reply_to_id, message_thread_id=int(thread_id) if thread_id else None, ) else: diff --git a/tests/gateway/test_telegram_reply_mode.py b/tests/gateway/test_telegram_reply_mode.py new file mode 100644 index 0000000000..1ec16b5123 --- /dev/null +++ b/tests/gateway/test_telegram_reply_mode.py @@ -0,0 +1,242 @@ +"""Tests for Telegram reply_to_mode functionality. + +Covers the threading behavior control for multi-chunk replies: +- "off": Never thread replies to original message +- "first": Only first chunk threads (default) +- "all": All chunks thread to original message +""" +import os +import sys +from unittest.mock import MagicMock, AsyncMock, patch + +import pytest + +from gateway.config import PlatformConfig, GatewayConfig, Platform, _apply_env_overrides + + +def _ensure_telegram_mock(): + """Mock the telegram package if it's not installed.""" + if "telegram" in sys.modules and hasattr(sys.modules["telegram"], "__file__"): + return + mod = MagicMock() + mod.ext.ContextTypes.DEFAULT_TYPE = type(None) + mod.constants.ParseMode.MARKDOWN_V2 = "MarkdownV2" + mod.constants.ChatType.GROUP = "group" + mod.constants.ChatType.SUPERGROUP = "supergroup" + mod.constants.ChatType.CHANNEL = "channel" + mod.constants.ChatType.PRIVATE = "private" + for name in ("telegram", "telegram.ext", "telegram.constants"): + sys.modules.setdefault(name, mod) + + +_ensure_telegram_mock() + +from gateway.platforms.telegram import TelegramAdapter # noqa: E402 + + +@pytest.fixture() +def adapter_factory(): + """Factory to create TelegramAdapter with custom reply_to_mode.""" + def create(reply_to_mode: str = "first"): + config = PlatformConfig(enabled=True, token="test-token", reply_to_mode=reply_to_mode) + return TelegramAdapter(config) + return create + + +class TestReplyToModeConfig: + """Tests for reply_to_mode configuration loading.""" + + def test_default_mode_is_first(self, adapter_factory): + adapter = adapter_factory() + assert adapter._reply_to_mode == "first" + + def test_off_mode(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="off") + assert adapter._reply_to_mode == "off" + + def test_first_mode(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="first") + assert adapter._reply_to_mode == "first" + + def test_all_mode(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="all") + assert adapter._reply_to_mode == "all" + + def test_invalid_mode_stored_as_is(self, adapter_factory): + """Invalid modes are stored but _should_thread_reply handles them.""" + adapter = adapter_factory(reply_to_mode="invalid") + assert adapter._reply_to_mode == "invalid" + + def test_none_mode_defaults_to_first(self): + config = PlatformConfig(enabled=True, token="test-token") + adapter = TelegramAdapter(config) + assert adapter._reply_to_mode == "first" + + def test_empty_string_mode_defaults_to_first(self): + config = PlatformConfig(enabled=True, token="test-token", reply_to_mode="") + adapter = TelegramAdapter(config) + assert adapter._reply_to_mode == "first" + + +class TestShouldThreadReply: + """Tests for _should_thread_reply method.""" + + def test_no_reply_to_returns_false(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="first") + assert adapter._should_thread_reply(None, 0) is False + assert adapter._should_thread_reply("", 0) is False + + def test_off_mode_never_threads(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="off") + assert adapter._should_thread_reply("msg-123", 0) is False + assert adapter._should_thread_reply("msg-123", 1) is False + assert adapter._should_thread_reply("msg-123", 5) is False + + def test_first_mode_only_first_chunk(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="first") + assert adapter._should_thread_reply("msg-123", 0) is True + assert adapter._should_thread_reply("msg-123", 1) is False + assert adapter._should_thread_reply("msg-123", 2) is False + assert adapter._should_thread_reply("msg-123", 10) is False + + def test_all_mode_all_chunks(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="all") + assert adapter._should_thread_reply("msg-123", 0) is True + assert adapter._should_thread_reply("msg-123", 1) is True + assert adapter._should_thread_reply("msg-123", 2) is True + assert adapter._should_thread_reply("msg-123", 10) is True + + def test_invalid_mode_falls_back_to_first(self, adapter_factory): + """Invalid mode behaves like 'first' - only first chunk threads.""" + adapter = adapter_factory(reply_to_mode="invalid") + assert adapter._should_thread_reply("msg-123", 0) is True + assert adapter._should_thread_reply("msg-123", 1) is False + + +class TestSendWithReplyToMode: + """Tests for send() method respecting reply_to_mode.""" + + @pytest.mark.asyncio + async def test_off_mode_no_reply_threading(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="off") + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1)) + adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"] + + await adapter.send("12345", "test content", reply_to="999") + + for call in adapter._bot.send_message.call_args_list: + assert call.kwargs.get("reply_to_message_id") is None + + @pytest.mark.asyncio + async def test_first_mode_only_first_chunk_threads(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="first") + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1)) + adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"] + + await adapter.send("12345", "test content", reply_to="999") + + calls = adapter._bot.send_message.call_args_list + assert len(calls) == 3 + assert calls[0].kwargs.get("reply_to_message_id") == 999 + assert calls[1].kwargs.get("reply_to_message_id") is None + assert calls[2].kwargs.get("reply_to_message_id") is None + + @pytest.mark.asyncio + async def test_all_mode_all_chunks_thread(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="all") + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1)) + adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2", "chunk3"] + + await adapter.send("12345", "test content", reply_to="999") + + calls = adapter._bot.send_message.call_args_list + assert len(calls) == 3 + for call in calls: + assert call.kwargs.get("reply_to_message_id") == 999 + + @pytest.mark.asyncio + async def test_no_reply_to_param_no_threading(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="all") + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1)) + adapter.truncate_message = lambda content, max_len: ["chunk1", "chunk2"] + + await adapter.send("12345", "test content", reply_to=None) + + calls = adapter._bot.send_message.call_args_list + for call in calls: + assert call.kwargs.get("reply_to_message_id") is None + + @pytest.mark.asyncio + async def test_single_chunk_respects_mode(self, adapter_factory): + adapter = adapter_factory(reply_to_mode="first") + adapter._bot = MagicMock() + adapter._bot.send_message = AsyncMock(return_value=MagicMock(message_id=1)) + adapter.truncate_message = lambda content, max_len: ["single chunk"] + + await adapter.send("12345", "test", reply_to="999") + + calls = adapter._bot.send_message.call_args_list + assert len(calls) == 1 + assert calls[0].kwargs.get("reply_to_message_id") == 999 + + +class TestConfigSerialization: + """Tests for reply_to_mode serialization.""" + + def test_to_dict_includes_reply_to_mode(self): + config = PlatformConfig(enabled=True, token="test", reply_to_mode="all") + result = config.to_dict() + assert result["reply_to_mode"] == "all" + + def test_from_dict_loads_reply_to_mode(self): + data = {"enabled": True, "token": "test", "reply_to_mode": "off"} + config = PlatformConfig.from_dict(data) + assert config.reply_to_mode == "off" + + def test_from_dict_defaults_to_first(self): + data = {"enabled": True, "token": "test"} + config = PlatformConfig.from_dict(data) + assert config.reply_to_mode == "first" + + +class TestEnvVarOverride: + """Tests for TELEGRAM_REPLY_TO_MODE environment variable override.""" + + def _make_config(self): + config = GatewayConfig() + config.platforms[Platform.TELEGRAM] = PlatformConfig(enabled=True, token="test") + return config + + def test_env_var_sets_off_mode(self): + config = self._make_config() + with patch.dict(os.environ, {"TELEGRAM_REPLY_TO_MODE": "off"}, clear=False): + _apply_env_overrides(config) + assert config.platforms[Platform.TELEGRAM].reply_to_mode == "off" + + def test_env_var_sets_all_mode(self): + config = self._make_config() + with patch.dict(os.environ, {"TELEGRAM_REPLY_TO_MODE": "all"}, clear=False): + _apply_env_overrides(config) + assert config.platforms[Platform.TELEGRAM].reply_to_mode == "all" + + def test_env_var_case_insensitive(self): + config = self._make_config() + with patch.dict(os.environ, {"TELEGRAM_REPLY_TO_MODE": "ALL"}, clear=False): + _apply_env_overrides(config) + assert config.platforms[Platform.TELEGRAM].reply_to_mode == "all" + + def test_env_var_invalid_value_ignored(self): + config = self._make_config() + with patch.dict(os.environ, {"TELEGRAM_REPLY_TO_MODE": "banana"}, clear=False): + _apply_env_overrides(config) + assert config.platforms[Platform.TELEGRAM].reply_to_mode == "first" + + def test_env_var_empty_value_ignored(self): + config = self._make_config() + with patch.dict(os.environ, {"TELEGRAM_REPLY_TO_MODE": ""}, clear=False): + _apply_env_overrides(config) + assert config.platforms[Platform.TELEGRAM].reply_to_mode == "first"