diff --git a/gateway/run.py b/gateway/run.py index 9d3f1019e8b..aafb7ac6338 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3543,9 +3543,13 @@ class GatewayRunner: ) continue - # Include thread_id if present so the message lands in the - # correct forum topic / thread. - metadata = {"thread_id": thread_id} if thread_id else None + metadata = self._thread_metadata_for_target( + platform, + chat_id, + thread_id, + chat_type=getattr(source, "chat_type", None) if source is not None else None, + adapter=adapter, + ) result = await adapter.send(chat_id, msg, metadata=metadata) if result is not None and getattr(result, "success", True) is False: @@ -3591,7 +3595,12 @@ class GatewayRunner: continue try: - metadata = {"thread_id": home.thread_id} if home.thread_id else None + metadata = self._thread_metadata_for_target( + platform, + home.chat_id, + home.thread_id, + adapter=adapter, + ) if metadata: result = await adapter.send(str(home.chat_id), msg, metadata=metadata) else: @@ -10352,6 +10361,7 @@ class GatewayRunner: notify_data = { "platform": event.source.platform.value if event.source.platform else None, "chat_id": event.source.chat_id, + "chat_type": event.source.chat_type, } if event.source.thread_id: notify_data["thread_id"] = event.source.thread_id @@ -14197,13 +14207,34 @@ class GatewayRunner: reply_to_message_id: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Build the metadata dict platforms need for thread-aware replies.""" - thread_id = getattr(source, "thread_id", None) + return self._thread_metadata_for_target( + getattr(source, "platform", None), + getattr(source, "chat_id", None), + getattr(source, "thread_id", None), + chat_type=getattr(source, "chat_type", None), + reply_to_message_id=reply_to_message_id or getattr(source, "message_id", None), + ) + + def _thread_metadata_for_target( + self, + platform: Optional[Platform], + chat_id: Optional[str], + thread_id: Optional[str], + *, + chat_type: Optional[str] = None, + reply_to_message_id: Optional[str] = None, + adapter: Optional[Any] = None, + ) -> Optional[Dict[str, Any]]: + """Build thread metadata for synthetic sends that only have routing state.""" if thread_id is None: return None metadata: Dict[str, Any] = {"thread_id": thread_id} - if ( - getattr(source, "platform", None) == Platform.TELEGRAM - and getattr(source, "chat_type", None) == "dm" + if self._is_telegram_dm_topic_target( + platform, + chat_id, + thread_id, + chat_type=chat_type, + adapter=adapter, ): metadata["telegram_dm_topic_reply_fallback"] = True # Telegram DM topic lanes need direct_messages_topic_id in metadata @@ -14212,11 +14243,32 @@ class GatewayRunner: tid = str(thread_id) if tid and tid not in {"", "1"}: metadata["direct_messages_topic_id"] = tid - anchor = reply_to_message_id or getattr(source, "message_id", None) - if anchor is not None: - metadata["telegram_reply_to_message_id"] = str(anchor) + if reply_to_message_id is not None: + metadata["telegram_reply_to_message_id"] = str(reply_to_message_id) return metadata + @staticmethod + def _is_telegram_dm_topic_target( + platform: Optional[Platform], + chat_id: Optional[str], + thread_id: Optional[str], + *, + chat_type: Optional[str] = None, + adapter: Optional[Any] = None, + ) -> bool: + """Return True when a target is a Telegram private DM topic lane.""" + if platform != Platform.TELEGRAM or thread_id is None: + return False + if chat_type == "dm": + return True + get_dm_topic_info = getattr(adapter, "_get_dm_topic_info", None) + if callable(get_dm_topic_info) and chat_id: + try: + return bool(get_dm_topic_info(str(chat_id), str(thread_id))) + except Exception: + logger.debug("Failed to inspect Telegram DM topic metadata", exc_info=True) + return False + @staticmethod def _reply_anchor_for_event(event: MessageEvent) -> Optional[str]: """Return the platform-specific reply anchor for GatewayRunner sends.""" @@ -14425,6 +14477,7 @@ class GatewayRunner: pending = { "platform": event.source.platform.value, "chat_id": event.source.chat_id, + "chat_type": event.source.chat_type, "user_id": event.source.user_id, "session_key": session_key, "timestamp": datetime.now().isoformat(), @@ -14575,12 +14628,19 @@ class GatewayRunner: pending = json.loads(path.read_text()) platform_str = pending.get("platform") chat_id = pending.get("chat_id") + chat_type = pending.get("chat_type") session_key = pending.get("session_key") thread_id = pending.get("thread_id") - metadata = {"thread_id": thread_id} if thread_id else None if platform_str and chat_id: platform = Platform(platform_str) adapter = self.adapters.get(platform) + metadata = self._thread_metadata_for_target( + platform, + chat_id, + thread_id, + chat_type=chat_type, + adapter=adapter, + ) # Fallback session key if not stored (old pending files) if not session_key: session_key = f"{platform_str}:{chat_id}" @@ -14784,6 +14844,7 @@ class GatewayRunner: pending = json.loads(claimed_path.read_text()) platform_str = pending.get("platform") chat_id = pending.get("chat_id") + chat_type = pending.get("chat_type") thread_id = pending.get("thread_id") if not exit_code_path.exists(): @@ -14806,7 +14867,13 @@ class GatewayRunner: adapter = self.adapters.get(platform) if adapter and chat_id: - metadata = {"thread_id": thread_id} if thread_id else None + metadata = self._thread_metadata_for_target( + platform, + chat_id, + thread_id, + chat_type=chat_type, + adapter=adapter, + ) # Strip ANSI escape codes for clean display output = re.sub(r'\x1b\[[0-9;]*m', '', output).strip() if output: @@ -14848,6 +14915,7 @@ class GatewayRunner: data = json.loads(notify_path.read_text()) platform_str = data.get("platform") chat_id = data.get("chat_id") + chat_type = data.get("chat_type") thread_id = data.get("thread_id") if not platform_str or not chat_id: @@ -14870,7 +14938,13 @@ class GatewayRunner: ) return None - metadata = {"thread_id": thread_id} if thread_id else None + metadata = self._thread_metadata_for_target( + platform, + chat_id, + thread_id, + chat_type=chat_type, + adapter=adapter, + ) result = await adapter.send( str(chat_id), "♻ Gateway restarted successfully. Your session continues.", @@ -14934,7 +15008,12 @@ class GatewayRunner: continue try: - metadata = {"thread_id": home.thread_id} if home.thread_id else None + metadata = self._thread_metadata_for_target( + platform, + home.chat_id, + home.thread_id, + adapter=adapter, + ) if metadata: result = await adapter.send(str(home.chat_id), message, metadata=metadata) else: diff --git a/tests/gateway/test_restart_notification.py b/tests/gateway/test_restart_notification.py index e7a931f8f8a..6abfaac3577 100644 --- a/tests/gateway/test_restart_notification.py +++ b/tests/gateway/test_restart_notification.py @@ -59,6 +59,7 @@ async def test_restart_command_writes_notify_file(tmp_path, monkeypatch): data = json.loads(notify_path.read_text()) assert data["platform"] == "telegram" assert data["chat_id"] == "42" + assert data["chat_type"] == "dm" assert "thread_id" not in data # no thread → omitted @@ -112,8 +113,7 @@ async def test_restart_command_preserves_thread_id(tmp_path, monkeypatch): runner, _adapter = make_restart_runner() runner.request_restart = MagicMock(return_value=True) - source = make_restart_source(chat_id="99") - source.thread_id = "topic_7" + source = make_restart_source(chat_id="99", thread_id="777") event = MessageEvent( text="/restart", @@ -125,7 +125,8 @@ async def test_restart_command_preserves_thread_id(tmp_path, monkeypatch): await runner._handle_restart_command(event) data = json.loads((tmp_path / ".restart_notify.json").read_text()) - assert data["thread_id"] == "topic_7" + assert data["chat_type"] == "dm" + assert data["thread_id"] == "777" @pytest.mark.asyncio @@ -258,17 +259,22 @@ async def test_send_home_channel_startup_notification_preserves_thread_metadata( platform=Platform.TELEGRAM, chat_id="parent-42", name="Ops Topic", - thread_id="topic-7", + thread_id="777", ) + adapter._get_dm_topic_info = MagicMock(return_value={"name": "Ops Topic"}) adapter.send = AsyncMock(return_value=SendResult(success=True, message_id="home")) delivered = await runner._send_home_channel_startup_notifications() - assert delivered == {("telegram", "parent-42", "topic-7")} + assert delivered == {("telegram", "parent-42", "777")} adapter.send.assert_called_once_with( "parent-42", "♻️ Gateway online — Hermes is back and ready.", - metadata={"thread_id": "topic-7"}, + metadata={ + "thread_id": "777", + "telegram_dm_topic_reply_fallback": True, + "direct_messages_topic_id": "777", + }, ) @@ -373,7 +379,8 @@ async def test_send_restart_notification_with_thread(tmp_path, monkeypatch): notify_path.write_text(json.dumps({ "platform": "telegram", "chat_id": "99", - "thread_id": "topic_7", + "chat_type": "dm", + "thread_id": "777", })) runner, adapter = make_restart_runner() @@ -381,9 +388,13 @@ async def test_send_restart_notification_with_thread(tmp_path, monkeypatch): delivered_target = await runner._send_restart_notification() - assert delivered_target == ("telegram", "99", "topic_7") + assert delivered_target == ("telegram", "99", "777") call_args = adapter.send.call_args - assert call_args[1]["metadata"] == {"thread_id": "topic_7"} + assert call_args[1]["metadata"] == { + "thread_id": "777", + "telegram_dm_topic_reply_fallback": True, + "direct_messages_topic_id": "777", + } assert not notify_path.exists() diff --git a/tests/gateway/test_update_command.py b/tests/gateway/test_update_command.py index 154603898b3..e3f74694bdf 100644 --- a/tests/gateway/test_update_command.py +++ b/tests/gateway/test_update_command.py @@ -210,6 +210,7 @@ class TestHandleUpdateCommand: data = json.loads(pending_path.read_text()) assert data["platform"] == "telegram" assert data["chat_id"] == "99999" + assert data["chat_type"] == "dm" assert "timestamp" in data assert not (hermes_home / ".update_exit_code").exists() @@ -469,6 +470,7 @@ class TestSendUpdateNotification: pending = { "platform": "telegram", "chat_id": "67890", + "chat_type": "dm", "thread_id": "777", "user_id": "12345", } @@ -482,7 +484,11 @@ class TestSendUpdateNotification: with patch("gateway.run._hermes_home", hermes_home): await runner._send_update_notification() - assert mock_adapter.send.call_args.kwargs["metadata"] == {"thread_id": "777"} + assert mock_adapter.send.call_args.kwargs["metadata"] == { + "thread_id": "777", + "telegram_dm_topic_reply_fallback": True, + "direct_messages_topic_id": "777", + } @pytest.mark.asyncio async def test_strips_ansi_codes(self, tmp_path):