diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 300fc49c04f..624c83cbc3e 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -5029,12 +5029,36 @@ class TelegramAdapter(BasePlatformAdapter): def _text_batch_key(self, event: MessageEvent) -> str: """Session-scoped key for text message batching.""" from gateway.session import build_session_key + source = self._normalize_text_batch_source(event) return build_session_key( - event.source, + source, group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False), ) + def _normalize_text_batch_source(self, event: MessageEvent): + """Apply runner-side Telegram DM topic recovery before batching.""" + source = getattr(event, "source", None) + if source is None: + return source + runner = getattr(getattr(self, "_message_handler", None), "__self__", None) + recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None) + if not callable(recover_fn): + return source + try: + recovered = recover_fn(source) + except Exception: + logger.debug("telegram text batch recovery failed", exc_info=True) + return source + if recovered is None or str(recovered) == str(source.thread_id or ""): + return source + normalized = dataclasses.replace(source, thread_id=str(recovered)) + try: + event.source = normalized + except Exception: + pass + return normalized + def _enqueue_text_event(self, event: MessageEvent) -> None: """Buffer a text event and reset the flush timer. diff --git a/tests/gateway/test_telegram_text_batching.py b/tests/gateway/test_telegram_text_batching.py index 14c3f0dd67e..e68a679a3eb 100644 --- a/tests/gateway/test_telegram_text_batching.py +++ b/tests/gateway/test_telegram_text_batching.py @@ -6,12 +6,14 @@ from the same session and aggregate them before dispatching. """ import asyncio +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest from gateway.config import Platform, PlatformConfig from gateway.platforms.base import MessageEvent, MessageType, SessionSource +from gateway.session import build_session_key def _make_adapter(): @@ -119,3 +121,62 @@ class TestTextBatching: assert len(adapter._pending_text_batches) == 0 assert len(adapter._pending_text_batch_tasks) == 0 + + @pytest.mark.asyncio + async def test_dm_topic_batching_recovers_thread_before_keying(self): + """DM-topic text batches should use the recovered topic lane.""" + adapter = _make_adapter() + + class _Runner: + def _recover_telegram_topic_thread_id(self, source): + return "222" if str(source.thread_id or "") == "1" else None + + async def _handle_message(self, _event): + return None + + runner = _Runner() + adapter._message_handler = runner._handle_message + event = MessageEvent( + text="hello from DM topic", + message_type=MessageType.TEXT, + source=SessionSource( + platform=Platform.TELEGRAM, + chat_id="12345", + chat_type="dm", + user_id="user-1", + thread_id="1", + ), + ) + + adapter._enqueue_text_event(event) + + recovered_key = build_session_key( + SimpleNamespace( + platform=Platform.TELEGRAM, + chat_id="12345", + chat_type="dm", + thread_id="222", + ), + group_sessions_per_user=True, + thread_sessions_per_user=False, + ) + stale_key = build_session_key( + SimpleNamespace( + platform=Platform.TELEGRAM, + chat_id="12345", + chat_type="dm", + thread_id="1", + ), + group_sessions_per_user=True, + thread_sessions_per_user=False, + ) + + assert recovered_key in adapter._pending_text_batches + assert stale_key not in adapter._pending_text_batches + assert event.source.thread_id == "222" + + await asyncio.sleep(0.2) + + adapter.handle_message.assert_called_once() + dispatched = adapter.handle_message.call_args[0][0] + assert dispatched.source.thread_id == "222"