mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-29 06:31:32 +00:00
Fix Telegram DM topic text batch keying
This commit is contained in:
parent
90f0f32eae
commit
5407d25599
2 changed files with 86 additions and 1 deletions
|
|
@ -5029,12 +5029,36 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
def _text_batch_key(self, event: MessageEvent) -> str:
|
||||
"""Session-scoped key for text message batching."""
|
||||
from gateway.session import build_session_key
|
||||
source = self._normalize_text_batch_source(event)
|
||||
return build_session_key(
|
||||
event.source,
|
||||
source,
|
||||
group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True),
|
||||
thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _normalize_text_batch_source(self, event: MessageEvent):
|
||||
"""Apply runner-side Telegram DM topic recovery before batching."""
|
||||
source = getattr(event, "source", None)
|
||||
if source is None:
|
||||
return source
|
||||
runner = getattr(getattr(self, "_message_handler", None), "__self__", None)
|
||||
recover_fn = getattr(runner, "_recover_telegram_topic_thread_id", None)
|
||||
if not callable(recover_fn):
|
||||
return source
|
||||
try:
|
||||
recovered = recover_fn(source)
|
||||
except Exception:
|
||||
logger.debug("telegram text batch recovery failed", exc_info=True)
|
||||
return source
|
||||
if recovered is None or str(recovered) == str(source.thread_id or ""):
|
||||
return source
|
||||
normalized = dataclasses.replace(source, thread_id=str(recovered))
|
||||
try:
|
||||
event.source = normalized
|
||||
except Exception:
|
||||
pass
|
||||
return normalized
|
||||
|
||||
def _enqueue_text_event(self, event: MessageEvent) -> None:
|
||||
"""Buffer a text event and reset the flush timer.
|
||||
|
||||
|
|
|
|||
|
|
@ -6,12 +6,14 @@ from the same session and aggregate them before dispatching.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent, MessageType, SessionSource
|
||||
from gateway.session import build_session_key
|
||||
|
||||
|
||||
def _make_adapter():
|
||||
|
|
@ -119,3 +121,62 @@ class TestTextBatching:
|
|||
|
||||
assert len(adapter._pending_text_batches) == 0
|
||||
assert len(adapter._pending_text_batch_tasks) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_topic_batching_recovers_thread_before_keying(self):
|
||||
"""DM-topic text batches should use the recovered topic lane."""
|
||||
adapter = _make_adapter()
|
||||
|
||||
class _Runner:
|
||||
def _recover_telegram_topic_thread_id(self, source):
|
||||
return "222" if str(source.thread_id or "") == "1" else None
|
||||
|
||||
async def _handle_message(self, _event):
|
||||
return None
|
||||
|
||||
runner = _Runner()
|
||||
adapter._message_handler = runner._handle_message
|
||||
event = MessageEvent(
|
||||
text="hello from DM topic",
|
||||
message_type=MessageType.TEXT,
|
||||
source=SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
thread_id="1",
|
||||
),
|
||||
)
|
||||
|
||||
adapter._enqueue_text_event(event)
|
||||
|
||||
recovered_key = build_session_key(
|
||||
SimpleNamespace(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id="222",
|
||||
),
|
||||
group_sessions_per_user=True,
|
||||
thread_sessions_per_user=False,
|
||||
)
|
||||
stale_key = build_session_key(
|
||||
SimpleNamespace(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="12345",
|
||||
chat_type="dm",
|
||||
thread_id="1",
|
||||
),
|
||||
group_sessions_per_user=True,
|
||||
thread_sessions_per_user=False,
|
||||
)
|
||||
|
||||
assert recovered_key in adapter._pending_text_batches
|
||||
assert stale_key not in adapter._pending_text_batches
|
||||
assert event.source.thread_id == "222"
|
||||
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
adapter.handle_message.assert_called_once()
|
||||
dispatched = adapter.handle_message.call_args[0][0]
|
||||
assert dispatched.source.thread_id == "222"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue