diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 87a765882a..0d44309126 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -35,6 +35,19 @@ from typing import Any, Dict, Optional, Set from html import escape as _html_escape +from mautrix.types import ( + ContentURI, + EventID, + EventType, + PaginationDirection, + PresenceState, + RoomCreatePreset, + RoomID, + SyncToken, + TrustState, + UserID, +) + from gateway.config import Platform, PlatformConfig from gateway.platforms.base import ( BasePlatformAdapter, @@ -54,6 +67,7 @@ MAX_MESSAGE_LENGTH = 4000 # Uses get_hermes_home() so each profile gets its own Matrix store. from hermes_constants import get_hermes_dir as _get_hermes_dir _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") +_CRYPTO_PICKLE_PATH = _STORE_DIR / "crypto_store.pickle" # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 @@ -169,12 +183,17 @@ class MatrixAdapter(BasePlatformAdapter): self._bot_participated_threads: set = self._load_participated_threads() self._MAX_TRACKED_THREADS = 500 + # Mention/thread gating — parsed once from env vars. + self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") + free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") + self._free_rooms: Set[str] = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} + self._auto_thread: bool = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") + self._dm_mention_threads: bool = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes") + # Reactions: configurable via MATRIX_REACTIONS (default: true). self._reactions_enabled: bool = os.getenv( "MATRIX_REACTIONS", "true" ).lower() not in ("false", "0", "no") - # Tracks the reaction event_id for in-progress (eyes) reactions. - # Key: (room_id, message_event_id) → reaction_event_id (for the eyes reaction). self._pending_reactions: dict[tuple[str, str], str] = {} # Text batching: merge rapid successive messages (Telegram-style). @@ -206,7 +225,6 @@ class MatrixAdapter(BasePlatformAdapter): from mautrix.api import HTTPAPI from mautrix.client import Client from mautrix.client.state_store import MemoryStateStore, MemorySyncStore - from mautrix.types import EventType, UserID if not self._homeserver: logger.error("Matrix: homeserver URL not configured") @@ -262,6 +280,7 @@ class MatrixAdapter(BasePlatformAdapter): "Matrix: whoami failed — check MATRIX_ACCESS_TOKEN and MATRIX_HOMESERVER: %s", exc, ) + await api.session.close() return False elif self._password and self._user_id: try: @@ -271,15 +290,16 @@ class MatrixAdapter(BasePlatformAdapter): device_name="Hermes Agent", device_id=self._device_id or None, ) - # login() stores the token automatically. if resp and hasattr(resp, "device_id"): client.device_id = resp.device_id logger.info("Matrix: logged in as %s", self._user_id) except Exception as exc: logger.error("Matrix: login failed — %s", exc) + await api.session.close() return False else: logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD") + await api.session.close() return False # Set up E2EE if requested. @@ -298,7 +318,7 @@ class MatrixAdapter(BasePlatformAdapter): crypto_store = MemoryCryptoStore() # Restore persisted crypto state from a previous run. - pickle_path = _STORE_DIR / "crypto_store.pickle" + pickle_path = _CRYPTO_PICKLE_PATH if pickle_path.exists(): try: import pickle @@ -314,7 +334,6 @@ class MatrixAdapter(BasePlatformAdapter): # Set trust policy: accept unverified devices so senders # share Megolm session keys with us automatically. - from mautrix.types import TrustState olm.share_keys_min_trust = TrustState.UNVERIFIED olm.send_keys_min_trust = TrustState.UNVERIFIED @@ -392,7 +411,7 @@ class MatrixAdapter(BasePlatformAdapter): import pickle crypto_store = self._client.crypto.crypto_store _STORE_DIR.mkdir(parents=True, exist_ok=True) - pickle_path = _STORE_DIR / "crypto_store.pickle" + pickle_path = _CRYPTO_PICKLE_PATH with open(pickle_path, "wb") as f: pickle.dump(crypto_store, f) logger.info("Matrix: persisted E2EE crypto store to %s", pickle_path) @@ -416,7 +435,6 @@ class MatrixAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a message to a Matrix room.""" - from mautrix.types import EventType, RoomID if not content: return SendResult(success=True) @@ -492,30 +510,12 @@ class MatrixAdapter(BasePlatformAdapter): async def get_chat_info(self, chat_id: str) -> Dict[str, Any]: """Return room name and type (dm/group).""" name = chat_id - chat_type = "group" + chat_type = "dm" if await self._is_dm_room(chat_id) else "group" if self._client: - # Try state store for member count. - state_store = getattr(self._client, "state_store", None) - if state_store: - try: - members = await state_store.get_members( - chat_id, - ) - if members and len(members) == 2: - chat_type = "dm" - except Exception: - pass - - # Use DM cache. - if self._dm_rooms.get(chat_id, False): - chat_type = "dm" - - # Try to get room name from state. try: - from mautrix.types import EventType as ET, RoomID name_evt = await self._client.get_state_event( - RoomID(chat_id), ET.ROOM_NAME, + RoomID(chat_id), EventType.ROOM_NAME, ) if name_evt and hasattr(name_evt, "name") and name_evt.name: name = name_evt.name @@ -534,7 +534,6 @@ class MatrixAdapter(BasePlatformAdapter): """Send a typing indicator.""" if self._client: try: - from mautrix.types import RoomID await self._client.set_typing(RoomID(chat_id), timeout=30000) except Exception: pass @@ -543,7 +542,6 @@ class MatrixAdapter(BasePlatformAdapter): self, chat_id: str, message_id: str, content: str ) -> SendResult: """Edit an existing message (via m.replace).""" - from mautrix.types import EventType, RoomID formatted = self.format_message(content) msg_content: Dict[str, Any] = { @@ -683,7 +681,6 @@ class MatrixAdapter(BasePlatformAdapter): is_voice: bool = False, ) -> SendResult: """Upload bytes to Matrix and send as a media message.""" - from mautrix.types import EventType, RoomID # Upload to homeserver. try: @@ -866,10 +863,8 @@ class MatrixAdapter(BasePlatformAdapter): return # Startup grace: ignore old messages from initial sync. - event_ts = getattr(event, "timestamp", 0) / 1000.0 if getattr(event, "timestamp", 0) else 0 - # Also check server_timestamp for compatibility. - if not event_ts: - event_ts = getattr(event, "server_timestamp", 0) / 1000.0 if getattr(event, "server_timestamp", 0) else 0 + raw_ts = getattr(event, "timestamp", None) or getattr(event, "server_timestamp", None) or 0 + event_ts = raw_ts / 1000.0 if raw_ts else 0.0 if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: return @@ -907,6 +902,68 @@ class MatrixAdapter(BasePlatformAdapter): elif msgtype in ("m.text", "m.notice"): await self._handle_text_message(room_id, sender, event_id, event_ts, source_content, relates_to) + async def _resolve_message_context( + self, + room_id: str, + sender: str, + event_id: str, + body: str, + source_content: dict, + relates_to: dict, + ) -> Optional[tuple]: + """Shared mention/thread/DM gating for text and media handlers. + + Returns (body, is_dm, chat_type, thread_id, display_name, source) + or None if the message should be dropped (mention gating). + """ + is_dm = await self._is_dm_room(room_id) + chat_type = "dm" if is_dm else "group" + + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + + formatted_body = source_content.get("formatted_body") + is_mentioned = self._is_bot_mentioned(body, formatted_body) + + # Require-mention gating. + if not is_dm: + is_free_room = room_id in self._free_rooms + in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) + if self._require_mention and not is_free_room and not in_bot_thread: + if not is_mentioned: + return None + + # DM mention-thread. + if is_dm and not thread_id and self._dm_mention_threads and is_mentioned: + thread_id = event_id + self._track_thread(thread_id) + + # Strip mention from body. + if is_mentioned: + body = self._strip_mention(body) + + # Auto-thread. + if not is_dm and not thread_id and self._auto_thread: + thread_id = event_id + self._track_thread(thread_id) + + display_name = await self._get_display_name(room_id, sender) + source = self.build_source( + chat_id=room_id, + chat_type=chat_type, + user_id=sender, + user_name=display_name, + thread_id=thread_id, + ) + + if thread_id: + self._track_thread(thread_id) + + self._background_read_receipt(room_id, event_id) + + return body, is_dm, chat_type, thread_id, display_name, source + async def _handle_text_message( self, room_id: str, @@ -921,45 +978,12 @@ class MatrixAdapter(BasePlatformAdapter): if not body: return - # Determine chat type. - is_dm = await self._is_dm_room(room_id) - chat_type = "dm" if is_dm else "group" - - # Thread support. - thread_id = None - if relates_to.get("rel_type") == "m.thread": - thread_id = relates_to.get("event_id") - - # Require-mention gating. - if not is_dm: - free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") - free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} - require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - is_free_room = room_id in free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) - - formatted_body = source_content.get("formatted_body") - if require_mention and not is_free_room and not in_bot_thread: - if not self._is_bot_mentioned(body, formatted_body): - return - - # DM mention-thread. - if is_dm and not thread_id: - dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes") - if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")): - thread_id = event_id - self._track_thread(thread_id) - - # Strip mention from body. - if self._is_bot_mentioned(body, source_content.get("formatted_body")): - body = self._strip_mention(body) - - # Auto-thread. - if not is_dm and not thread_id: - auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - if auto_thread: - thread_id = event_id - self._track_thread(thread_id) + ctx = await self._resolve_message_context( + room_id, sender, event_id, body, source_content, relates_to, + ) + if ctx is None: + return + body, is_dm, chat_type, thread_id, display_name, source = ctx # Reply-to detection. reply_to = None @@ -983,20 +1007,10 @@ class MatrixAdapter(BasePlatformAdapter): stripped.append(line) body = "\n".join(stripped) if stripped else body - # Message type. msg_type = MessageType.TEXT if body.startswith(("!", "/")): msg_type = MessageType.COMMAND - display_name = await self._get_display_name(room_id, sender) - source = self.build_source( - chat_id=room_id, - chat_type=chat_type, - user_id=sender, - user_name=display_name, - thread_id=thread_id, - ) - msg_event = MessageEvent( text=body, message_type=msg_type, @@ -1006,13 +1020,6 @@ class MatrixAdapter(BasePlatformAdapter): reply_to_message_id=reply_to, ) - if thread_id: - self._track_thread(thread_id) - - # Acknowledge receipt (fire-and-forget). - self._background_read_receipt(room_id, event_id) - - # Batch plain text messages — commands dispatch immediately. if msg_type == MessageType.TEXT and self._text_batch_delay_seconds > 0: self._enqueue_text_event(msg_event) else: @@ -1079,7 +1086,6 @@ class MatrixAdapter(BasePlatformAdapter): ) if should_cache_locally and url: try: - from mautrix.types import ContentURI file_bytes = await self._client.download_media(ContentURI(url)) if file_bytes is not None: if is_encrypted_media: @@ -1131,53 +1137,12 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as e: logger.warning("[Matrix] Failed to cache media: %s", e) - is_dm = await self._is_dm_room(room_id) - chat_type = "dm" if is_dm else "group" - - # Thread/reply detection. - thread_id = None - if relates_to.get("rel_type") == "m.thread": - thread_id = relates_to.get("event_id") - - # Require-mention gating (media messages). - if not is_dm: - free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") - free_rooms = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} - require_mention = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") - is_free_room = room_id in free_rooms - in_bot_thread = bool(thread_id and thread_id in self._bot_participated_threads) - - if require_mention and not is_free_room and not in_bot_thread: - formatted_body = source_content.get("formatted_body") - if not self._is_bot_mentioned(body, formatted_body): - return - - # DM mention-thread. - if is_dm and not thread_id: - dm_mention_threads = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes") - if dm_mention_threads and self._is_bot_mentioned(body, source_content.get("formatted_body")): - thread_id = event_id - self._track_thread(thread_id) - - # Strip mention from body. - if self._is_bot_mentioned(body, source_content.get("formatted_body")): - body = self._strip_mention(body) - - # Auto-thread. - if not is_dm and not thread_id: - auto_thread = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - if auto_thread: - thread_id = event_id - self._track_thread(thread_id) - - display_name = await self._get_display_name(room_id, sender) - source = self.build_source( - chat_id=room_id, - chat_type=chat_type, - user_id=sender, - user_name=display_name, - thread_id=thread_id, + ctx = await self._resolve_message_context( + room_id, sender, event_id, body, source_content, relates_to, ) + if ctx is None: + return + body, is_dm, chat_type, thread_id, display_name, source = ctx allow_http_fallback = bool(http_url) and not is_encrypted_media media_urls = [cached_path] if cached_path else ([http_url] if allow_http_fallback else None) @@ -1193,11 +1158,6 @@ class MatrixAdapter(BasePlatformAdapter): media_types=media_types, ) - if thread_id: - self._track_thread(thread_id) - - self._background_read_receipt(room_id, event_id) - await self.handle_message(msg_event) async def _on_encrypted_event(self, event: Any) -> None: @@ -1219,7 +1179,6 @@ class MatrixAdapter(BasePlatformAdapter): async def _on_invite(self, event: Any) -> None: """Auto-join rooms when invited.""" - from mautrix.types import RoomID room_id = str(getattr(event, "room_id", "")) @@ -1245,7 +1204,6 @@ class MatrixAdapter(BasePlatformAdapter): """Send an emoji reaction to a message in a room. Returns the reaction event_id on success, None on failure. """ - from mautrix.types import EventType, RoomID if not self._client: return None @@ -1409,7 +1367,6 @@ class MatrixAdapter(BasePlatformAdapter): if not self._client: return False try: - from mautrix.types import EventID, RoomID await self._client.set_read_markers( RoomID(room_id), fully_read_event=EventID(event_id), @@ -1432,7 +1389,6 @@ class MatrixAdapter(BasePlatformAdapter): if not self._client: return False try: - from mautrix.types import EventID, RoomID await self._client.redact( RoomID(room_id), EventID(event_id), reason=reason or None, ) @@ -1456,7 +1412,6 @@ class MatrixAdapter(BasePlatformAdapter): if not self._client: return [] try: - from mautrix.types import PaginationDirection, RoomID, SyncToken resp = await self._client.get_messages( RoomID(room_id), direction=PaginationDirection.BACKWARD, @@ -1505,7 +1460,6 @@ class MatrixAdapter(BasePlatformAdapter): if not self._client: return None try: - from mautrix.types import RoomCreatePreset, UserID preset_enum = { "private_chat": RoomCreatePreset.PRIVATE, "public_chat": RoomCreatePreset.PUBLIC, @@ -1532,7 +1486,6 @@ class MatrixAdapter(BasePlatformAdapter): if not self._client: return False try: - from mautrix.types import RoomID, UserID await self._client.invite_user(RoomID(room_id), UserID(user_id)) logger.info("Matrix: invited %s to %s", user_id, room_id) return True @@ -1554,7 +1507,6 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: invalid presence state %r", state) return False try: - from mautrix.types import PresenceState presence_map = { "online": PresenceState.ONLINE, "offline": PresenceState.OFFLINE, @@ -1574,19 +1526,14 @@ class MatrixAdapter(BasePlatformAdapter): # Emote & notice message types # ------------------------------------------------------------------ - async def send_emote( - self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, + async def _send_simple_message( + self, chat_id: str, text: str, msgtype: str, ) -> SendResult: - """Send an emote message (/me style action).""" - from mautrix.types import EventType, RoomID - + """Send a simple message (emote, notice) with optional HTML formatting.""" if not self._client or not text: return SendResult(success=False, error="No client or empty text") - msg_content: Dict[str, Any] = { - "msgtype": "m.emote", - "body": text, - } + msg_content: Dict[str, Any] = {"msgtype": msgtype, "body": text} html = self._markdown_to_html(text) if html and html != text: msg_content["format"] = "org.matrix.custom.html" @@ -1600,31 +1547,17 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: return SendResult(success=False, error=str(exc)) + async def send_emote( + self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, + ) -> SendResult: + """Send an emote message (/me style action).""" + return await self._send_simple_message(chat_id, text, "m.emote") + async def send_notice( self, chat_id: str, text: str, metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send a notice message (bot-appropriate, non-alerting).""" - from mautrix.types import EventType, RoomID - - if not self._client or not text: - return SendResult(success=False, error="No client or empty text") - - msg_content: Dict[str, Any] = { - "msgtype": "m.notice", - "body": text, - } - html = self._markdown_to_html(text) - if html and html != text: - msg_content["format"] = "org.matrix.custom.html" - msg_content["formatted_body"] = html - - try: - event_id = await self._client.send_message_event( - RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content, - ) - return SendResult(success=True, message_id=str(event_id)) - except Exception as exc: - return SendResult(success=False, error=str(exc)) + return await self._send_simple_message(chat_id, text, "m.notice") # ------------------------------------------------------------------ # Helpers