diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 91e360e7f4c..766f3541aa5 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -472,6 +472,7 @@ def is_host_excluded_by_no_proxy(hostname: str, no_proxy_value: str | None = Non return False +import dataclasses from dataclasses import dataclass, field from datetime import datetime from pathlib import Path @@ -1561,6 +1562,10 @@ class BasePlatformAdapter(ABC): self.config = config self.platform = platform self._message_handler: Optional[MessageHandler] = None + # Optional hook (e.g. Telegram DM topic recovery) that rewrites + # ``event.source.thread_id`` before session keying. Returns the + # corrected thread_id or None to leave the source untouched. + self._topic_recovery_fn: Optional[Callable[[Any], Optional[str]]] = None self._running = False self._fatal_error_code: Optional[str] = None self._fatal_error_message: Optional[str] = None @@ -1816,6 +1821,40 @@ class BasePlatformAdapter(ABC): """ self._message_handler = handler + def set_topic_recovery_fn( + self, + fn: Optional[Callable[[Any], Optional[str]]], + ) -> None: + """Install a thread_id-recovery hook (Telegram DM topic mode). + + The hook is called with ``event.source`` before session keying; + a non-None return value replaces ``source.thread_id``. Pass + ``None`` to clear the hook. + """ + # Guard against subclasses that initialize via ``object.__new__`` in + # tests and never run ``BasePlatformAdapter.__init__``. + self._topic_recovery_fn = fn # type: ignore[attr-defined] + + def _apply_topic_recovery(self, event: MessageEvent) -> None: + """Rewrite ``event.source.thread_id`` in place if the hook returns one.""" + recover = getattr(self, "_topic_recovery_fn", None) + if recover is None: + return + source = getattr(event, "source", None) + if source is None: + return + try: + recovered = recover(source) + except Exception: + logger.debug("topic recovery hook failed", exc_info=True) + return + if recovered is None or str(recovered) == str(source.thread_id or ""): + return + try: + event.source = dataclasses.replace(source, thread_id=str(recovered)) + except Exception: + logger.debug("topic recovery rewrite failed", exc_info=True) + def set_busy_session_handler(self, handler: Optional[Callable[[MessageEvent, str], Awaitable[bool]]]) -> None: """Set an optional handler for messages arriving during active sessions.""" self._busy_session_handler = handler @@ -3332,7 +3371,12 @@ class BasePlatformAdapter(ABC): return coerce_plaintext_gateway_command(event) - + + # Rewrite ``event.source.thread_id`` via the installed recovery hook + # (Telegram DM topic mode) so the session key, guard checks, and + # downstream delivery all agree on the same lane. + self._apply_topic_recovery(event) + session_key = build_session_key( event.source, group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 624c83cbc3e..daaf3fb4d1d 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -5027,38 +5027,20 @@ class TelegramAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ def _text_batch_key(self, event: MessageEvent) -> str: - """Session-scoped key for text message batching.""" + """Session-scoped key for text message batching. + + Applies the installed topic-recovery hook first so DM-topic batches + coalesce on (and dispatch to) the recovered lane rather than the + raw inbound ``message_thread_id`` Telegram may have attached. + """ from gateway.session import build_session_key - source = self._normalize_text_batch_source(event) + self._apply_topic_recovery(event) return build_session_key( - source, + event.source, group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False), ) - def _normalize_text_batch_source(self, event: MessageEvent): - """Apply runner-side Telegram DM topic recovery before batching.""" - source = getattr(event, "source", None) - if source is None: - return source - runner = getattr(getattr(self, "_message_handler", None), "__self__", None) - recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None) - if not callable(recover_fn): - return source - try: - recovered = recover_fn(source) - except Exception: - logger.debug("telegram text batch recovery failed", exc_info=True) - return source - if recovered is None or str(recovered) == str(source.thread_id or ""): - return source - normalized = dataclasses.replace(source, thread_id=str(recovered)) - try: - event.source = normalized - except Exception: - pass - return normalized - def _enqueue_text_event(self, event: MessageEvent) -> None: """Buffer a text event and reset the flush timer. diff --git a/gateway/run.py b/gateway/run.py index f59aa4109b4..e30845affa2 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -4159,6 +4159,7 @@ class GatewayRunner: adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) adapter.set_busy_session_handler(self._handle_active_session_busy_message) + adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id) adapter._busy_text_mode = self._busy_text_mode # Try to connect @@ -5864,6 +5865,7 @@ class GatewayRunner: adapter.set_fatal_error_handler(self._handle_adapter_fatal_error) adapter.set_session_store(self.session_store) adapter.set_busy_session_handler(self._handle_active_session_busy_message) + adapter.set_topic_recovery_fn(self._recover_telegram_topic_thread_id) adapter._busy_text_mode = self._busy_text_mode success = await self._connect_adapter_with_timeout(adapter, platform) diff --git a/tests/gateway/test_telegram_text_batching.py b/tests/gateway/test_telegram_text_batching.py index e68a679a3eb..4dd99f780fb 100644 --- a/tests/gateway/test_telegram_text_batching.py +++ b/tests/gateway/test_telegram_text_batching.py @@ -126,16 +126,9 @@ class TestTextBatching: async def test_dm_topic_batching_recovers_thread_before_keying(self): """DM-topic text batches should use the recovered topic lane.""" adapter = _make_adapter() - - class _Runner: - def _recover_telegram_topic_thread_id(self, source): - return "222" if str(source.thread_id or "") == "1" else None - - async def _handle_message(self, _event): - return None - - runner = _Runner() - adapter._message_handler = runner._handle_message + adapter.set_topic_recovery_fn( + lambda source: "222" if str(source.thread_id or "") == "1" else None + ) event = MessageEvent( text="hello from DM topic", message_type=MessageType.TEXT, @@ -150,29 +143,20 @@ class TestTextBatching: adapter._enqueue_text_event(event) - recovered_key = build_session_key( - SimpleNamespace( - platform=Platform.TELEGRAM, - chat_id="12345", - chat_type="dm", - thread_id="222", - ), - group_sessions_per_user=True, - thread_sessions_per_user=False, - ) - stale_key = build_session_key( - SimpleNamespace( - platform=Platform.TELEGRAM, - chat_id="12345", - chat_type="dm", - thread_id="1", - ), - group_sessions_per_user=True, - thread_sessions_per_user=False, - ) + def _key(thread_id: str) -> str: + return build_session_key( + SimpleNamespace( + platform=Platform.TELEGRAM, + chat_id="12345", + chat_type="dm", + thread_id=thread_id, + ), + group_sessions_per_user=True, + thread_sessions_per_user=False, + ) - assert recovered_key in adapter._pending_text_batches - assert stale_key not in adapter._pending_text_batches + assert _key("222") in adapter._pending_text_batches + assert _key("1") not in adapter._pending_text_batches assert event.source.thread_id == "222" await asyncio.sleep(0.2)