mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
The blanket _(.+?)_ and __(.+?)__ patterns incorrectly consumed snake_case identifiers like send_as_bot and user_id. Add lookbehind/lookahead boundaries so underscores adjacent to alphanumeric characters are not treated as markdown formatting. Same fix already applied and tested in the CLI renderer; this addresses the gateway/platforms/helpers.py copy. Supersedes #15076.
264 lines
9.1 KiB
Python
264 lines
9.1 KiB
Python
"""Shared helper classes for gateway platform adapters.
|
|
|
|
Extracts common patterns that were duplicated across 5-7 adapters:
|
|
message deduplication, text batch aggregation, markdown stripping,
|
|
and thread participation tracking.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Dict, Optional
|
|
|
|
if TYPE_CHECKING:
|
|
from gateway.platforms.base import BasePlatformAdapter, MessageEvent
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ─── Message Deduplication ────────────────────────────────────────────────────
|
|
|
|
|
|
class MessageDeduplicator:
|
|
"""TTL-based message deduplication cache.
|
|
|
|
Replaces the identical ``_seen_messages`` / ``_is_duplicate()`` pattern
|
|
previously duplicated in discord, slack, dingtalk, wecom, weixin,
|
|
mattermost, and feishu adapters.
|
|
|
|
Usage::
|
|
|
|
self._dedup = MessageDeduplicator()
|
|
|
|
# In message handler:
|
|
if self._dedup.is_duplicate(msg_id):
|
|
return
|
|
"""
|
|
|
|
def __init__(self, max_size: int = 2000, ttl_seconds: float = 300):
|
|
self._seen: Dict[str, float] = {}
|
|
self._max_size = max_size
|
|
self._ttl = ttl_seconds
|
|
|
|
def is_duplicate(self, msg_id: str) -> bool:
|
|
"""Return True if *msg_id* was already seen within the TTL window."""
|
|
if not msg_id:
|
|
return False
|
|
now = time.time()
|
|
if msg_id in self._seen:
|
|
if now - self._seen[msg_id] < self._ttl:
|
|
return True
|
|
# Entry has expired — remove it and treat as new
|
|
del self._seen[msg_id]
|
|
self._seen[msg_id] = now
|
|
if len(self._seen) > self._max_size:
|
|
cutoff = now - self._ttl
|
|
self._seen = {k: v for k, v in self._seen.items() if v > cutoff}
|
|
return False
|
|
|
|
def clear(self):
|
|
"""Clear all tracked messages."""
|
|
self._seen.clear()
|
|
|
|
|
|
# ─── Text Batch Aggregation ──────────────────────────────────────────────────
|
|
|
|
|
|
class TextBatchAggregator:
|
|
"""Aggregates rapid-fire text events into single messages.
|
|
|
|
Replaces the ``_enqueue_text_event`` / ``_flush_text_batch`` pattern
|
|
previously duplicated in telegram, discord, matrix, wecom, and feishu.
|
|
|
|
Usage::
|
|
|
|
self._text_batcher = TextBatchAggregator(
|
|
handler=self._message_handler,
|
|
batch_delay=0.6,
|
|
split_threshold=1900,
|
|
)
|
|
|
|
# In message dispatch:
|
|
if msg_type == MessageType.TEXT and self._text_batcher.is_enabled():
|
|
self._text_batcher.enqueue(event, session_key)
|
|
return
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
handler,
|
|
*,
|
|
batch_delay: float = 0.6,
|
|
split_delay: float = 2.0,
|
|
split_threshold: int = 4000,
|
|
):
|
|
self._handler = handler
|
|
self._batch_delay = batch_delay
|
|
self._split_delay = split_delay
|
|
self._split_threshold = split_threshold
|
|
self._pending: Dict[str, "MessageEvent"] = {}
|
|
self._pending_tasks: Dict[str, asyncio.Task] = {}
|
|
|
|
def is_enabled(self) -> bool:
|
|
"""Return True if batching is active (delay > 0)."""
|
|
return self._batch_delay > 0
|
|
|
|
def enqueue(self, event: "MessageEvent", key: str) -> None:
|
|
"""Add *event* to the pending batch for *key*."""
|
|
chunk_len = len(event.text or "")
|
|
existing = self._pending.get(key)
|
|
if not existing:
|
|
event._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
|
self._pending[key] = event
|
|
else:
|
|
existing.text = f"{existing.text}\n{event.text}"
|
|
existing._last_chunk_len = chunk_len # type: ignore[attr-defined]
|
|
|
|
# Cancel prior flush timer, start a new one
|
|
prior = self._pending_tasks.get(key)
|
|
if prior and not prior.done():
|
|
prior.cancel()
|
|
self._pending_tasks[key] = asyncio.create_task(self._flush(key))
|
|
|
|
async def _flush(self, key: str) -> None:
|
|
"""Wait then dispatch the batched event for *key*."""
|
|
current_task = self._pending_tasks.get(key)
|
|
pending = self._pending.get(key)
|
|
last_len = getattr(pending, "_last_chunk_len", 0) if pending else 0
|
|
|
|
# Use longer delay when the last chunk looks like a split message
|
|
delay = self._split_delay if last_len >= self._split_threshold else self._batch_delay
|
|
await asyncio.sleep(delay)
|
|
|
|
event = self._pending.pop(key, None)
|
|
if event:
|
|
try:
|
|
await self._handler(event)
|
|
except Exception:
|
|
logger.exception("[TextBatchAggregator] Error dispatching batched event for %s", key)
|
|
|
|
if self._pending_tasks.get(key) is current_task:
|
|
self._pending_tasks.pop(key, None)
|
|
|
|
def cancel_all(self) -> None:
|
|
"""Cancel all pending flush tasks."""
|
|
for task in self._pending_tasks.values():
|
|
if not task.done():
|
|
task.cancel()
|
|
self._pending_tasks.clear()
|
|
self._pending.clear()
|
|
|
|
|
|
# ─── Markdown Stripping ──────────────────────────────────────────────────────
|
|
|
|
# Pre-compiled regexes for performance
|
|
_RE_BOLD = re.compile(r"\*\*(.+?)\*\*", re.DOTALL)
|
|
_RE_ITALIC_STAR = re.compile(r"\*(.+?)\*", re.DOTALL)
|
|
_RE_BOLD_UNDER = re.compile(r"(?<![a-zA-Z0-9])__(?=[^_\s])(.+?)(?<=[^_])__(?![a-zA-Z0-9])", re.DOTALL)
|
|
_RE_ITALIC_UNDER = re.compile(r"(?<![a-zA-Z0-9])_(?=[^_\s])(.+?)(?<=[^_])_(?![a-zA-Z0-9])", re.DOTALL)
|
|
_RE_CODE_BLOCK = re.compile(r"```[a-zA-Z0-9_+-]*\n?")
|
|
_RE_INLINE_CODE = re.compile(r"`(.+?)`")
|
|
_RE_HEADING = re.compile(r"^#{1,6}\s+", re.MULTILINE)
|
|
_RE_LINK = re.compile(r"\[([^\]]+)\]\([^\)]+\)")
|
|
_RE_MULTI_NEWLINE = re.compile(r"\n{3,}")
|
|
|
|
|
|
def strip_markdown(text: str) -> str:
|
|
"""Strip markdown formatting for plain-text platforms (SMS, iMessage, etc.).
|
|
|
|
Replaces the identical ``_strip_markdown()`` functions previously
|
|
duplicated in sms.py, bluebubbles.py, and feishu.py.
|
|
"""
|
|
text = _RE_BOLD.sub(r"\1", text)
|
|
text = _RE_ITALIC_STAR.sub(r"\1", text)
|
|
text = _RE_BOLD_UNDER.sub(r"\1", text)
|
|
text = _RE_ITALIC_UNDER.sub(r"\1", text)
|
|
text = _RE_CODE_BLOCK.sub("", text)
|
|
text = _RE_INLINE_CODE.sub(r"\1", text)
|
|
text = _RE_HEADING.sub("", text)
|
|
text = _RE_LINK.sub(r"\1", text)
|
|
text = _RE_MULTI_NEWLINE.sub("\n\n", text)
|
|
return text.strip()
|
|
|
|
|
|
# ─── Thread Participation Tracking ───────────────────────────────────────────
|
|
|
|
|
|
class ThreadParticipationTracker:
|
|
"""Persistent tracking of threads the bot has participated in.
|
|
|
|
Replaces the identical ``_load/_save_participated_threads`` +
|
|
``_mark_thread_participated`` pattern previously duplicated in
|
|
discord.py and matrix.py.
|
|
|
|
Usage::
|
|
|
|
self._threads = ThreadParticipationTracker("discord")
|
|
|
|
# Check membership:
|
|
if thread_id in self._threads:
|
|
...
|
|
|
|
# Mark participation:
|
|
self._threads.mark(thread_id)
|
|
"""
|
|
|
|
_MAX_TRACKED = 500
|
|
|
|
def __init__(self, platform_name: str, max_tracked: int = 500):
|
|
self._platform = platform_name
|
|
self._max_tracked = max_tracked
|
|
self._threads: set = self._load()
|
|
|
|
def _state_path(self) -> Path:
|
|
from hermes_constants import get_hermes_home
|
|
return get_hermes_home() / f"{self._platform}_threads.json"
|
|
|
|
def _load(self) -> set:
|
|
path = self._state_path()
|
|
if path.exists():
|
|
try:
|
|
return set(json.loads(path.read_text(encoding="utf-8")))
|
|
except Exception:
|
|
pass
|
|
return set()
|
|
|
|
def _save(self) -> None:
|
|
path = self._state_path()
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
thread_list = list(self._threads)
|
|
if len(thread_list) > self._max_tracked:
|
|
thread_list = thread_list[-self._max_tracked:]
|
|
self._threads = set(thread_list)
|
|
path.write_text(json.dumps(thread_list), encoding="utf-8")
|
|
|
|
def mark(self, thread_id: str) -> None:
|
|
"""Mark *thread_id* as participated and persist."""
|
|
if thread_id not in self._threads:
|
|
self._threads.add(thread_id)
|
|
self._save()
|
|
|
|
def __contains__(self, thread_id: str) -> bool:
|
|
return thread_id in self._threads
|
|
|
|
def clear(self) -> None:
|
|
self._threads.clear()
|
|
|
|
|
|
# ─── Phone Number Redaction ──────────────────────────────────────────────────
|
|
|
|
|
|
def redact_phone(phone: str) -> str:
|
|
"""Redact a phone number for logging, preserving country code and last 4.
|
|
|
|
Replaces the identical ``_redact_phone()`` functions in signal.py,
|
|
sms.py, and bluebubbles.py.
|
|
"""
|
|
if not phone:
|
|
return "<none>"
|
|
if len(phone) <= 8:
|
|
return phone[:2] + "****" + phone[-2:] if len(phone) > 4 else "****"
|
|
return phone[:4] + "****" + phone[-4:]
|