"""Shared helper classes for gateway platform adapters. Extracts common patterns that were duplicated across 5-7 adapters: message deduplication, text batch aggregation, markdown stripping, and thread participation tracking. """ import asyncio import json import logging import re import time from pathlib import Path from typing import TYPE_CHECKING, Dict, Optional if TYPE_CHECKING: from gateway.platforms.base import BasePlatformAdapter, MessageEvent logger = logging.getLogger(__name__) # ─── Message Deduplication ──────────────────────────────────────────────────── class MessageDeduplicator: """TTL-based message deduplication cache. Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern previously duplicated in discord, slack, dingtalk, wecom, weixin, mattermost, and feishu adapters. Usage:: self._dedup = MessageDeduplicator() # In message handler: if self._dedup.is_duplicate(msg_id): return """ def __init__(self, max_size: int = 2000, ttl_seconds: float = 300): self._seen: Dict[str, float] = {} self._max_size = max_size self._ttl = ttl_seconds def is_duplicate(self, msg_id: str) -> bool: """Return True if *msg_id* was already seen within the TTL window.""" if not msg_id: return False now = time.time() if msg_id in self._seen: if now - self._seen[msg_id] < self._ttl: return True # Entry has expired — remove it and treat as new del self._seen[msg_id] self._seen[msg_id] = now if len(self._seen) > self._max_size: cutoff = now - self._ttl self._seen = {k: v for k, v in self._seen.items() if v > cutoff} return False def clear(self): """Clear all tracked messages.""" self._seen.clear() # ─── Text Batch Aggregation ────────────────────────────────────────────────── class TextBatchAggregator: """Aggregates rapid-fire text events into single messages. Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern previously duplicated in telegram, discord, matrix, wecom, and feishu. Usage:: self._text_batcher = TextBatchAggregator( handler=self._message_handler, batch_delay=0.6, split_threshold=1900, ) # In message dispatch: if msg_type == MessageType.TEXT and self._text_batcher.is_enabled(): self._text_batcher.enqueue(event, session_key) return """ def __init__( self, handler, *, batch_delay: float = 0.6, split_delay: float = 2.0, split_threshold: int = 4000, ): self._handler = handler self._batch_delay = batch_delay self._split_delay = split_delay self._split_threshold = split_threshold self._pending: Dict[str, "MessageEvent"] = {} self._pending_tasks: Dict[str, asyncio.Task] = {} def is_enabled(self) -> bool: """Return True if batching is active (delay > 0).""" return self._batch_delay > 0 def enqueue(self, event: "MessageEvent", key: str) -> None: """Add *event* to the pending batch for *key*.""" chunk_len = len(event.text or "") existing = self._pending.get(key) if not existing: event._last_chunk_len = chunk_len # type: ignore[attr-defined] self._pending[key] = event else: existing.text = f"{existing.text}\n{event.text}" existing._last_chunk_len = chunk_len # type: ignore[attr-defined] # Cancel prior flush timer, start a new one prior = self._pending_tasks.get(key) if prior and not prior.done(): prior.cancel() self._pending_tasks[key] = asyncio.create_task(self._flush(key)) async def _flush(self, key: str) -> None: """Wait then dispatch the batched event for *key*.""" current_task = self._pending_tasks.get(key) pending = self._pending.get(key) last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0 # Use longer delay when the last chunk looks like a split message delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay await asyncio.sleep(delay) event = self._pending.pop(key, None) if event: try: await self._handler(event) except Exception: logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key) if self._pending_tasks.get(key) is current_task: self._pending_tasks.pop(key, None) def cancel_all(self) -> None: """Cancel all pending flush tasks.""" for task in self._pending_tasks.values(): if not task.done(): task.cancel() self._pending_tasks.clear() self._pending.clear() # ─── Markdown Stripping ────────────────────────────────────────────────────── # Pre-compiled regexes for performance _RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL) _RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL) _RE_BOLD_UNDER = re.compile(r"__(.+?)__", re.DOTALL) _RE_ITALIC_UNDER = re.compile(r"_(.+?)_", re.DOTALL) _RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?") _RE_INLINE_CODE = re.compile(r"`(.+?)`") _RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE) _RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)") _RE_MULTI_NEWLINE = re.compile(r"\n{3,}") def strip_markdown(text: str) -> str: """Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.). Replaces the identical ``_strip_markdown()`` functions previously duplicated in sms.py, bluebubbles.py, and feishu.py. """ text = _RE_BOLD.sub(r"\1", text) text = _RE_ITALIC_STAR.sub(r"\1", text) text = _RE_BOLD_UNDER.sub(r"\1", text) text = _RE_ITALIC_UNDER.sub(r"\1", text) text = _RE_CODE_BLOCK.sub("", text) text = _RE_INLINE_CODE.sub(r"\1", text) text = _RE_HEADING.sub("", text) text = _RE_LINK.sub(r"\1", text) text = _RE_MULTI_NEWLINE.sub("\n\n", text) return text.strip() # ─── Thread Participation Tracking ─────────────────────────────────────────── class ThreadParticipationTracker: """Persistent tracking of threads the bot has participated in. Replaces the identical ``_load/_save_participated_threads`` + ``_mark_thread_participated`` pattern previously duplicated in discord.py and matrix.py. Usage:: self._threads = ThreadParticipationTracker("discord") # Check membership: if thread_id in self._threads: ... # Mark participation: self._threads.mark(thread_id) """ _MAX_TRACKED = 500 def __init__(self, platform_name: str, max_tracked: int = 500): self._platform = platform_name self._max_tracked = max_tracked self._threads: set = self._load() def _state_path(self) -> Path: from hermes_constants import get_hermes_home return get_hermes_home() / f"{self._platform}_threads.json" def _load(self) -> set: path = self._state_path() if path.exists(): try: return set(json.loads(path.read_text(encoding="utf-8"))) except Exception: pass return set() def _save(self) -> None: path = self._state_path() path.parent.mkdir(parents=True, exist_ok=True) thread_list = list(self._threads) if len(thread_list) > self._max_tracked: thread_list = thread_list[-self._max_tracked:] self._threads = set(thread_list) path.write_text(json.dumps(thread_list), encoding="utf-8") def mark(self, thread_id: str) -> None: """Mark *thread_id* as participated and persist.""" if thread_id not in self._threads: self._threads.add(thread_id) self._save() def __contains__(self, thread_id: str) -> bool: return thread_id in self._threads def clear(self) -> None: self._threads.clear() # ─── Phone Number Redaction ────────────────────────────────────────────────── def redact_phone(phone: str) -> str: """Redact a phone number for logging, preserving country code and last 4. Replaces the identical ``_redact_phone()`` functions in signal.py, sms.py, and bluebubbles.py. """ if not phone: return "" if len(phone) <= 8: return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****" return phone[:4] + "****" + phone[-4:]