Fix Telegram DM topic text batch keying

This commit is contained in:
LeonSGP43 2026-05-27 11:31:28 +08:00 committed by Teknium
parent 90f0f32eae
commit 5407d25599
2 changed files with 86 additions and 1 deletions

View file

@ -5029,12 +5029,36 @@ class TelegramAdapter(BasePlatformAdapter):
def _text_batch_key(self, event: MessageEvent) -> str:
"""Session-scoped key for text message batching."""
from gateway.session import build_session_key
source = self._normalize_text_batch_source(event)
return build_session_key(
event.source,
source,
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
)
def _normalize_text_batch_source(self, event: MessageEvent):
"""Apply runner-side Telegram DM topic recovery before batching."""
source = getattr(event, "source", None)
if source is None:
return source
runner = getattr(getattr(self, "_message_handler", None), "__self__", None)
recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None)
if not callable(recover_fn):
return source
try:
recovered = recover_fn(source)
except Exception:
logger.debug("telegram text batch recovery failed", exc_info=True)
return source
if recovered is None or str(recovered) == str(source.thread_id or ""):
return source
normalized = dataclasses.replace(source, thread_id=str(recovered))
try:
event.source = normalized
except Exception:
pass
return normalized
def _enqueue_text_event(self, event: MessageEvent) -> None:
"""Buffer a text event and reset the flush timer.

View file

@ -6,12 +6,14 @@ from the same session and aggregate them before dispatching.
"""
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import Platform, PlatformConfig
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
from gateway.session import build_session_key
def _make_adapter():
@ -119,3 +121,62 @@ class TestTextBatching:
assert len(adapter._pending_text_batches) == 0
assert len(adapter._pending_text_batch_tasks) == 0
@pytest.mark.asyncio
async def test_dm_topic_batching_recovers_thread_before_keying(self):
"""DM-topic text batches should use the recovered topic lane."""
adapter = _make_adapter()
class _Runner:
def _recover_telegram_topic_thread_id(self, source):
return "222" if str(source.thread_id or "") == "1" else None
async def _handle_message(self, _event):
return None
runner = _Runner()
adapter._message_handler = runner._handle_message
event = MessageEvent(
text="hello from DM topic",
message_type=MessageType.TEXT,
source=SessionSource(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
user_id="user-1",
thread_id="1",
),
)
adapter._enqueue_text_event(event)
recovered_key = build_session_key(
SimpleNamespace(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id="222",
),
group_sessions_per_user=True,
thread_sessions_per_user=False,
)
stale_key = build_session_key(
SimpleNamespace(
platform=Platform.TELEGRAM,
chat_id="12345",
chat_type="dm",
thread_id="1",
),
group_sessions_per_user=True,
thread_sessions_per_user=False,
)
assert recovered_key in adapter._pending_text_batches
assert stale_key not in adapter._pending_text_batches
assert event.source.thread_id == "222"
await asyncio.sleep(0.2)
adapter.handle_message.assert_called_once()
dispatched = adapter.handle_message.call_args[0][0]
assert dispatched.source.thread_id == "222"