diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index c18d3569d8..505aabbb25 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -1291,7 +1291,7 @@ class BasePlatformAdapter(ABC): path = path[1:-1].strip() path = path.lstrip("`\"'").rstrip("`\"',.;:)}]") if path: - media.append((path, has_voice_tag)) + media.append((os.path.expanduser(path), has_voice_tag)) # Remove MEDIA tags from content (including surrounding quote/backtick wrappers) if media: diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 021a453040..cdd67b337d 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -30,11 +30,10 @@ import mimetypes import os import re import time +from html import escape as _html_escape from pathlib import Path from typing import Any, Dict, Optional, Set -from html import escape as _html_escape - try: from mautrix.types import ( ContentURI, @@ -60,28 +59,33 @@ except ImportError: REACTION = "m.reaction" ROOM_ENCRYPTED = "m.room.encrypted" ROOM_NAME = "m.room.name" + EventType = _EventTypeStub # type: ignore[misc,assignment] class _PaginationDirectionStub: # type: ignore[no-redef] BACKWARD = "b" FORWARD = "f" + PaginationDirection = _PaginationDirectionStub # type: ignore[misc,assignment] class _PresenceStateStub: # type: ignore[no-redef] ONLINE = "online" OFFLINE = "offline" UNAVAILABLE = "unavailable" + PresenceState = _PresenceStateStub # type: ignore[misc,assignment] class _RoomCreatePresetStub: # type: ignore[no-redef] PRIVATE = "private_chat" PUBLIC = "public_chat" TRUSTED_PRIVATE = "trusted_private_chat" + RoomCreatePreset = _RoomCreatePresetStub # type: ignore[misc,assignment] class _TrustStateStub: # type: ignore[no-redef] UNVERIFIED = 0 VERIFIED = 1 + TrustState = _TrustStateStub # type: ignore[misc,assignment] from gateway.config import Platform, PlatformConfig @@ -103,20 +107,16 @@ MAX_MESSAGE_LENGTH = 4000 # Store directory for E2EE keys and sync state. # Uses get_hermes_home() so each profile gets its own Matrix store. from hermes_constants import get_hermes_dir as _get_hermes_dir + _STORE_DIR = _get_hermes_dir("platforms/matrix/store", "matrix/store") _CRYPTO_DB_PATH = _STORE_DIR / "crypto.db" # Grace period: ignore messages older than this many seconds before startup. _STARTUP_GRACE_SECONDS = 5 -# Pending undecrypted events: cap and TTL for retry buffer. -_MAX_PENDING_EVENTS = 100 -_PENDING_EVENT_TTL = 300 # seconds — stop retrying after 5 min - _E2EE_INSTALL_HINT = ( - "Install with: pip install 'mautrix[encryption]' " - "(requires libolm C library)" + "Install with: pip install 'mautrix[encryption]' (requires libolm C library)" ) @@ -124,6 +124,7 @@ def _check_e2ee_deps() -> bool: """Return True if mautrix E2EE dependencies (python-olm) are available.""" try: from mautrix.crypto import OlmMachine # noqa: F401 + return True except (ImportError, AttributeError): return False @@ -145,14 +146,17 @@ def check_matrix_requirements() -> bool: import mautrix # noqa: F401 except ImportError: logger.warning( - "Matrix: mautrix not installed. " - "Run: pip install 'mautrix[encryption]'" + "Matrix: mautrix not installed. Run: pip install 'mautrix[encryption]'" ) return False # If encryption is requested, verify E2EE deps are available at startup # rather than silently degrading to plaintext-only at connect time. - encryption_requested = os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes") + encryption_requested = os.getenv("MATRIX_ENCRYPTION", "").lower() in ( + "true", + "1", + "yes", + ) if encryption_requested and not _check_e2ee_deps(): logger.error( "Matrix: MATRIX_ENCRYPTION=true but E2EE dependencies are missing. %s. " @@ -204,25 +208,21 @@ class MatrixAdapter(BasePlatformAdapter): super().__init__(config, Platform.MATRIX) self._homeserver: str = ( - config.extra.get("homeserver", "") - or os.getenv("MATRIX_HOMESERVER", "") + config.extra.get("homeserver", "") or os.getenv("MATRIX_HOMESERVER", "") ).rstrip("/") self._access_token: str = config.token or os.getenv("MATRIX_ACCESS_TOKEN", "") - self._user_id: str = ( - config.extra.get("user_id", "") - or os.getenv("MATRIX_USER_ID", "") + self._user_id: str = config.extra.get("user_id", "") or os.getenv( + "MATRIX_USER_ID", "" ) - self._password: str = ( - config.extra.get("password", "") - or os.getenv("MATRIX_PASSWORD", "") + self._password: str = config.extra.get("password", "") or os.getenv( + "MATRIX_PASSWORD", "" ) self._encryption: bool = config.extra.get( "encryption", os.getenv("MATRIX_ENCRYPTION", "").lower() in ("true", "1", "yes"), ) - self._device_id: str = ( - config.extra.get("device_id", "") - or os.getenv("MATRIX_DEVICE_ID", "") + self._device_id: str = config.extra.get("device_id", "") or os.getenv( + "MATRIX_DEVICE_ID", "" ) self._client: Any = None # mautrix.client.Client @@ -237,22 +237,32 @@ class MatrixAdapter(BasePlatformAdapter): self._joined_rooms: Set[str] = set() # Event deduplication (bounded deque keeps newest entries) from collections import deque + self._processed_events: deque = deque(maxlen=1000) self._processed_events_set: set = set() # Buffer for undecrypted events pending key receipt. # Each entry: (room_id, event, timestamp) - self._pending_megolm: list = [] # Thread participation tracking (for require_mention bypass) self._threads = ThreadParticipationTracker("matrix") # Mention/thread gating — parsed once from env vars. - self._require_mention: bool = os.getenv("MATRIX_REQUIRE_MENTION", "true").lower() not in ("false", "0", "no") + self._require_mention: bool = os.getenv( + "MATRIX_REQUIRE_MENTION", "true" + ).lower() not in ("false", "0", "no") free_rooms_raw = os.getenv("MATRIX_FREE_RESPONSE_ROOMS", "") - self._free_rooms: Set[str] = {r.strip() for r in free_rooms_raw.split(",") if r.strip()} - self._auto_thread: bool = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ("true", "1", "yes") - self._dm_mention_threads: bool = os.getenv("MATRIX_DM_MENTION_THREADS", "false").lower() in ("true", "1", "yes") + self._free_rooms: Set[str] = { + r.strip() for r in free_rooms_raw.split(",") if r.strip() + } + self._auto_thread: bool = os.getenv("MATRIX_AUTO_THREAD", "true").lower() in ( + "true", + "1", + "yes", + ) + self._dm_mention_threads: bool = os.getenv( + "MATRIX_DM_MENTION_THREADS", "false" + ).lower() in ("true", "1", "yes") # Reactions: configurable via MATRIX_REACTIONS (default: true). self._reactions_enabled: bool = os.getenv( @@ -262,8 +272,12 @@ class MatrixAdapter(BasePlatformAdapter): # Text batching: merge rapid successive messages (Telegram-style). # Matrix clients split long messages around 4000 chars. - self._text_batch_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_DELAY_SECONDS", "0.6")) - self._text_batch_split_delay_seconds = float(os.getenv("HERMES_MATRIX_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0")) + self._text_batch_delay_seconds = float( + os.getenv("HERMES_MATRIX_TEXT_BATCH_DELAY_SECONDS", "0.6") + ) + self._text_batch_split_delay_seconds = float( + os.getenv("HERMES_MATRIX_TEXT_BATCH_SPLIT_DELAY_SECONDS", "2.0") + ) self._pending_text_batches: Dict[str, MessageEvent] = {} self._pending_text_batch_tasks: Dict[str, asyncio.Task] = {} @@ -284,6 +298,38 @@ class MatrixAdapter(BasePlatformAdapter): # E2EE helpers # ------------------------------------------------------------------ + @staticmethod + def _extract_server_ed25519(device_keys_obj: Any) -> Optional[str]: + """Extract the ed25519 identity key from a DeviceKeys object.""" + for kid, kval in (getattr(device_keys_obj, "keys", {}) or {}).items(): + if str(kid).startswith("ed25519:"): + return str(kval) + return None + + async def _reverify_keys_after_upload( + self, client: Any, local_ed25519: str + ) -> bool: + """Re-query the server after share_keys() and verify our ed25519 key matches.""" + try: + resp = await client.query_keys({client.mxid: [client.device_id]}) + dk = getattr(resp, "device_keys", {}) or {} + ud = dk.get(str(client.mxid)) or {} + dev = ud.get(str(client.device_id)) + if dev: + server_ed = self._extract_server_ed25519(dev) + if server_ed != local_ed25519: + logger.error( + "Matrix: device %s has immutable identity keys that " + "don't match this installation. Generate a new access " + "token with a fresh device.", + client.device_id, + ) + return False + except Exception as exc: + logger.error("Matrix: post-upload key verification failed: %s", exc) + return False + return True + async def _verify_device_keys_on_server(self, client: Any, olm: Any) -> bool: """Verify our device keys are on the homeserver after loading crypto state. @@ -294,15 +340,15 @@ class MatrixAdapter(BasePlatformAdapter): resp = await client.query_keys({client.mxid: [client.device_id]}) except Exception as exc: logger.error( - "Matrix: cannot verify device keys on server: %s — refusing E2EE", exc, + "Matrix: cannot verify device keys on server: %s — refusing E2EE", + exc, ) return False - # query_keys returns typed objects (QueryKeysResponse, DeviceKeys - # with KeyID keys). Normalise to plain strings for comparison. device_keys_map = getattr(resp, "device_keys", {}) or {} our_user_devices = device_keys_map.get(str(client.mxid)) or {} our_keys = our_user_devices.get(str(client.device_id)) + local_ed25519 = olm.account.identity_keys.get("ed25519") if not our_keys: logger.warning("Matrix: device keys missing from server — re-uploading") @@ -312,21 +358,12 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.error("Matrix: failed to re-upload device keys: %s", exc) return False - return True + return await self._reverify_keys_after_upload(client, local_ed25519) - # DeviceKeys.keys is a dict[KeyID, str]. Iterate to find the - # ed25519 key rather than constructing a KeyID for lookup. - server_ed25519 = None - keys_dict = getattr(our_keys, "keys", {}) or {} - for key_id, key_value in keys_dict.items(): - if str(key_id).startswith("ed25519:"): - server_ed25519 = str(key_value) - break - local_ed25519 = olm.account.identity_keys.get("ed25519") + server_ed25519 = self._extract_server_ed25519(our_keys) if server_ed25519 != local_ed25519: if olm.account.shared: - # Restored account from DB but server has different keys — corrupted state. logger.error( "Matrix: server has different identity keys for device %s — " "local crypto state is stale. Delete %s and restart.", @@ -335,8 +372,6 @@ class MatrixAdapter(BasePlatformAdapter): ) return False - # Fresh account (never uploaded). Server has stale keys from a - # previous installation. Try to delete the old device and re-upload. logger.warning( "Matrix: server has stale keys for device %s — attempting re-upload", client.device_id, @@ -348,10 +383,10 @@ class MatrixAdapter(BasePlatformAdapter): else "DELETE", f"/_matrix/client/v3/devices/{client.device_id}", ) - logger.info("Matrix: deleted stale device %s from server", client.device_id) + logger.info( + "Matrix: deleted stale device %s from server", client.device_id + ) except Exception: - # Device deletion often requires UIA or may simply not be - # permitted — that's fine, share_keys will try to overwrite. pass try: await olm.share_keys() @@ -363,6 +398,7 @@ class MatrixAdapter(BasePlatformAdapter): exc, ) return False + return await self._reverify_keys_after_upload(client, local_ed25519) return True @@ -448,7 +484,9 @@ class MatrixAdapter(BasePlatformAdapter): await api.session.close() return False else: - logger.error("Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD") + logger.error( + "Matrix: need MATRIX_ACCESS_TOKEN or MATRIX_USER_ID + MATRIX_PASSWORD" + ) await api.session.close() return False @@ -472,7 +510,9 @@ class MatrixAdapter(BasePlatformAdapter): # Remove legacy pickle file from pre-SQLite era. legacy_pickle = _STORE_DIR / "crypto_store.pickle" if legacy_pickle.exists(): - logger.info("Matrix: removing legacy crypto_store.pickle (migrated to SQLite)") + logger.info( + "Matrix: removing legacy crypto_store.pickle (migrated to SQLite)" + ) legacy_pickle.unlink() # Open SQLite-backed crypto store. @@ -508,6 +548,37 @@ class MatrixAdapter(BasePlatformAdapter): await api.session.close() return False + # Proactively flush one-time keys to detect stale OTK + # conflicts early. When crypto state is wiped but the + # same device ID is reused, the server may still hold OTKs + # signed with the old ed25519 key. Identity key re-upload + # succeeds but OTK uploads fail ("already exists" with + # mismatched signature). Peers then cannot establish Olm + # sessions and all new messages are undecryptable. + try: + await olm.share_keys() + except Exception as exc: + exc_str = str(exc) + if "already exists" in exc_str: + logger.error( + "Matrix: device %s has stale one-time keys on the " + "server signed with a previous identity key. " + "Peers cannot establish new Olm sessions with " + "this device. Delete the device from the " + "homeserver and restart, or generate a new " + "access token to get a fresh device ID.", + client.device_id, + ) + await crypto_db.stop() + await api.session.close() + return False + # Non-OTK errors are transient (network, etc.) — log + # but allow startup to continue. + logger.warning( + "Matrix: share_keys() warning during startup: %s", + exc, + ) + # Import cross-signing private keys from SSSS and self-sign # the current device. Required after any device-key rotation # (fresh crypto.db, share_keys re-upload) — otherwise the @@ -519,7 +590,9 @@ class MatrixAdapter(BasePlatformAdapter): await olm.verify_with_recovery_key(recovery_key) logger.info("Matrix: cross-signing verified via recovery key") except Exception as exc: - logger.warning("Matrix: recovery key verification failed: %s", exc) + logger.warning( + "Matrix: recovery key verification failed: %s", exc + ) client.crypto = olm logger.info( @@ -530,21 +603,23 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.error( "Matrix: failed to create E2EE client: %s. %s", - exc, _E2EE_INSTALL_HINT, + exc, + _E2EE_INSTALL_HINT, ) await api.session.close() return False # Register event handlers. from mautrix.client import InternalEventType as IntEvt + from mautrix.client.dispatcher import MembershipEventDispatcher + + # Without this the INVITE handler below never fires. + client.add_dispatcher(MembershipEventDispatcher) client.add_event_handler(EventType.ROOM_MESSAGE, self._on_room_message) client.add_event_handler(EventType.REACTION, self._on_reaction) client.add_event_handler(IntEvt.INVITE, self._on_invite) - if self._encryption and getattr(client, "crypto", None): - client.add_event_handler(EventType.ROOM_ENCRYPTED, self._on_encrypted_event) - # Initial sync to catch up, then start background sync. self._startup_ts = time.time() self._closing = False @@ -553,7 +628,8 @@ class MatrixAdapter(BasePlatformAdapter): sync_data = await client.sync(timeout=10000, full_state=True) if isinstance(sync_data, dict): rooms_join = sync_data.get("rooms", {}).get("join", {}) - self._joined_rooms = set(rooms_join.keys()) + self._joined_rooms.clear() + self._joined_rooms.update(rooms_join.keys()) # Store the next_batch token so incremental syncs start # from where the initial sync left off. nb = sync_data.get("next_batch") @@ -575,7 +651,10 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.warning("Matrix: initial sync event dispatch error: %s", exc) else: - logger.warning("Matrix: initial sync returned unexpected type %s", type(sync_data).__name__) + logger.warning( + "Matrix: initial sync returned unexpected type %s", + type(sync_data).__name__, + ) except Exception as exc: logger.warning("Matrix: initial sync error: %s", exc) @@ -648,9 +727,7 @@ class MatrixAdapter(BasePlatformAdapter): # Reply-to support. if reply_to: - msg_content["m.relates_to"] = { - "m.in_reply_to": {"event_id": reply_to} - } + msg_content["m.relates_to"] = {"m.in_reply_to": {"event_id": reply_to}} # Thread support: if metadata has thread_id, send as threaded reply. thread_id = (metadata or {}).get("thread_id") @@ -688,10 +765,18 @@ class MatrixAdapter(BasePlatformAdapter): timeout=45, ) last_event_id = str(event_id) - logger.info("Matrix: sent event %s to %s (after key share)", last_event_id, chat_id) + logger.info( + "Matrix: sent event %s to %s (after key share)", + last_event_id, + chat_id, + ) continue except Exception as retry_exc: - logger.error("Matrix: failed to send to %s after retry: %s", chat_id, retry_exc) + logger.error( + "Matrix: failed to send to %s after retry: %s", + chat_id, + retry_exc, + ) return SendResult(success=False, error=str(retry_exc)) logger.error("Matrix: failed to send to %s: %s", chat_id, exc) return SendResult(success=False, error=str(exc)) @@ -706,7 +791,8 @@ class MatrixAdapter(BasePlatformAdapter): if self._client: try: name_evt = await self._client.get_state_event( - RoomID(chat_id), EventType.ROOM_NAME, + RoomID(chat_id), + EventType.ROOM_NAME, ) if name_evt and hasattr(name_evt, "name") and name_evt.name: name = name_evt.name @@ -730,13 +816,14 @@ class MatrixAdapter(BasePlatformAdapter): pass async def stop_typing(self, chat_id: str) -> None: - """Stop the Matrix typing indicator.""" + """Clear the typing indicator.""" if self._client: try: await self._client.set_typing(RoomID(chat_id), timeout=0) except Exception: pass + async def edit_message( self, chat_id: str, message_id: str, content: str ) -> SendResult: @@ -765,7 +852,9 @@ class MatrixAdapter(BasePlatformAdapter): try: event_id = await self._client.send_message_event( - RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content, + RoomID(chat_id), + EventType.ROOM_MESSAGE, + msg_content, ) return SendResult(success=True, message_id=str(event_id)) except Exception as exc: @@ -781,22 +870,31 @@ class MatrixAdapter(BasePlatformAdapter): ) -> SendResult: """Download an image URL and upload it to Matrix.""" from tools.url_safety import is_safe_url + if not is_safe_url(image_url): logger.warning("Matrix: blocked unsafe image URL (SSRF protection)") - return await super().send_image(chat_id, image_url, caption, reply_to, metadata=metadata) + return await super().send_image( + chat_id, image_url, caption, reply_to, metadata=metadata + ) try: # Try aiohttp first (always available), fall back to httpx try: import aiohttp as _aiohttp + async with _aiohttp.ClientSession(trust_env=True) as http: - async with http.get(image_url, timeout=_aiohttp.ClientTimeout(total=30)) as resp: + async with http.get( + image_url, timeout=_aiohttp.ClientTimeout(total=30) + ) as resp: resp.raise_for_status() data = await resp.read() ct = resp.content_type or "image/png" - fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png" + fname = ( + image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png" + ) except ImportError: import httpx + async with httpx.AsyncClient() as http: resp = await http.get(image_url, follow_redirects=True, timeout=30) resp.raise_for_status() @@ -805,9 +903,13 @@ class MatrixAdapter(BasePlatformAdapter): fname = image_url.rsplit("/", 1)[-1].split("?")[0] or "image.png" except Exception as exc: logger.warning("Matrix: failed to download image %s: %s", image_url, exc) - return await self.send(chat_id, f"{caption or ''}\n{image_url}".strip(), reply_to) + return await self.send( + chat_id, f"{caption or ''}\n{image_url}".strip(), reply_to + ) - return await self._upload_and_send(chat_id, data, fname, ct, "m.image", caption, reply_to, metadata) + return await self._upload_and_send( + chat_id, data, fname, ct, "m.image", caption, reply_to, metadata + ) async def send_image_file( self, @@ -818,7 +920,9 @@ class MatrixAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Upload a local image file to Matrix.""" - return await self._send_local_file(chat_id, image_path, "m.image", caption, reply_to, metadata=metadata) + return await self._send_local_file( + chat_id, image_path, "m.image", caption, reply_to, metadata=metadata + ) async def send_document( self, @@ -830,7 +934,9 @@ class MatrixAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Upload a local file as a document.""" - return await self._send_local_file(chat_id, file_path, "m.file", caption, reply_to, file_name, metadata) + return await self._send_local_file( + chat_id, file_path, "m.file", caption, reply_to, file_name, metadata + ) async def send_voice( self, @@ -842,8 +948,13 @@ class MatrixAdapter(BasePlatformAdapter): ) -> SendResult: """Upload an audio file as a voice message (MSC3245 native voice).""" return await self._send_local_file( - chat_id, audio_path, "m.audio", caption, reply_to, - metadata=metadata, is_voice=True + chat_id, + audio_path, + "m.audio", + caption, + reply_to, + metadata=metadata, + is_voice=True, ) async def send_video( @@ -855,7 +966,9 @@ class MatrixAdapter(BasePlatformAdapter): metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Upload a video file.""" - return await self._send_local_file(chat_id, video_path, "m.video", caption, reply_to, metadata=metadata) + return await self._send_local_file( + chat_id, video_path, "m.video", caption, reply_to, metadata=metadata + ) def format_message(self, content: str) -> str: """Pass-through — Matrix supports standard Markdown natively.""" @@ -881,12 +994,30 @@ class MatrixAdapter(BasePlatformAdapter): ) -> SendResult: """Upload bytes to Matrix and send as a media message.""" + upload_data = data + encrypted_file = None + if self._encryption and getattr(self._client, "crypto", None): + state_store = getattr(self._client, "state_store", None) + if state_store: + try: + room_encrypted = bool(await state_store.is_encrypted(RoomID(room_id))) + except Exception: + room_encrypted = False + if room_encrypted: + try: + from mautrix.crypto.attachments import encrypt_attachment + upload_data, encrypted_file = encrypt_attachment(data) + except Exception as exc: + logger.error("Matrix: attachment encryption failed: %s", exc) + return SendResult(success=False, error=str(exc)) + # Upload to homeserver. try: mxc_url = await self._client.upload_media( - data, + upload_data, mime_type=content_type, filename=filename, + size=len(upload_data), ) except Exception as exc: logger.error("Matrix: upload failed: %s", exc) @@ -896,21 +1027,24 @@ class MatrixAdapter(BasePlatformAdapter): msg_content: Dict[str, Any] = { "msgtype": msgtype, "body": caption or filename, - "url": str(mxc_url), "info": { "mimetype": content_type, "size": len(data), }, } + if encrypted_file is not None: + file_payload = encrypted_file.serialize() + file_payload["url"] = str(mxc_url) + msg_content["file"] = file_payload + else: + msg_content["url"] = str(mxc_url) # Add MSC3245 voice flag for native voice messages. if is_voice: msg_content["org.matrix.msc3245.voice"] = {} if reply_to: - msg_content["m.relates_to"] = { - "m.in_reply_to": {"event_id": reply_to} - } + msg_content["m.relates_to"] = {"m.in_reply_to": {"event_id": reply_to}} thread_id = (metadata or {}).get("thread_id") if thread_id: @@ -922,7 +1056,9 @@ class MatrixAdapter(BasePlatformAdapter): try: event_id = await self._client.send_message_event( - RoomID(room_id), EventType.ROOM_MESSAGE, msg_content, + RoomID(room_id), + EventType.ROOM_MESSAGE, + msg_content, ) return SendResult(success=True, message_id=str(event_id)) except Exception as exc: @@ -940,7 +1076,7 @@ class MatrixAdapter(BasePlatformAdapter): is_voice: bool = False, ) -> SendResult: """Read a local file and upload it.""" - p = Path(file_path) + p = Path(file_path).expanduser() if not p.exists(): return await self.send( room_id, f"{caption or ''}\n(file not found: {file_path})", reply_to @@ -950,7 +1086,9 @@ class MatrixAdapter(BasePlatformAdapter): ct = mimetypes.guess_type(fname)[0] or "application/octet-stream" data = p.read_bytes() - return await self._upload_and_send(room_id, data, fname, ct, msgtype, caption, reply_to, metadata, is_voice) + return await self._upload_and_send( + room_id, data, fname, ct, msgtype, caption, reply_to, metadata, is_voice + ) # ------------------------------------------------------------------ # Sync loop @@ -964,7 +1102,8 @@ class MatrixAdapter(BasePlatformAdapter): while not self._closing: try: sync_data = await client.sync( - since=next_batch, timeout=30000, + since=next_batch, + timeout=30000, ) # nio returns SyncError objects (not exceptions) for auth @@ -973,7 +1112,10 @@ class MatrixAdapter(BasePlatformAdapter): if _sync_msg and isinstance(_sync_msg, str): _lower = _sync_msg.lower() if "m_unknown_token" in _lower or "unknown_token" in _lower: - logger.error("Matrix: permanent auth error from sync: %s — stopping", _sync_msg) + logger.error( + "Matrix: permanent auth error from sync: %s — stopping", + _sync_msg, + ) return if isinstance(sync_data, dict): @@ -998,10 +1140,6 @@ class MatrixAdapter(BasePlatformAdapter): except Exception as exc: logger.warning("Matrix: sync event dispatch error: %s", exc) - # Retry any buffered undecrypted events. - if self._pending_megolm: - await self._retry_pending_decryptions() - except asyncio.CancelledError: return except Exception as exc: @@ -1009,64 +1147,19 @@ class MatrixAdapter(BasePlatformAdapter): return # Detect permanent auth/permission failures. err_str = str(exc).lower() - if "401" in err_str or "403" in err_str or "unauthorized" in err_str or "forbidden" in err_str: - logger.error("Matrix: permanent auth error: %s — stopping sync", exc) + if ( + "401" in err_str + or "403" in err_str + or "unauthorized" in err_str + or "forbidden" in err_str + ): + logger.error( + "Matrix: permanent auth error: %s — stopping sync", exc + ) return logger.warning("Matrix: sync error: %s — retrying in 5s", exc) await asyncio.sleep(5) - async def _retry_pending_decryptions(self) -> None: - """Retry decrypting buffered encrypted events after new keys arrive.""" - client = self._client - if not client or not self._pending_megolm: - return - crypto = getattr(client, "crypto", None) - if not crypto: - return - - now = time.time() - still_pending: list = [] - - for room_id, event, ts in self._pending_megolm: - # Drop events that have aged past the TTL. - if now - ts > _PENDING_EVENT_TTL: - logger.debug( - "Matrix: dropping expired pending event %s (age %.0fs)", - getattr(event, "event_id", "?"), now - ts, - ) - continue - - try: - decrypted = await crypto.decrypt_megolm_event(event) - except Exception: - still_pending.append((room_id, event, ts)) - continue - - if decrypted is None or decrypted is event: - still_pending.append((room_id, event, ts)) - continue - - logger.info( - "Matrix: decrypted buffered event %s", - getattr(event, "event_id", "?"), - ) - - # Route to the appropriate handler. - # Remove from dedup set so _on_room_message doesn't drop it - # (the encrypted event ID was already registered by _on_encrypted_event). - decrypted_id = str(getattr(decrypted, "event_id", getattr(event, "event_id", ""))) - if decrypted_id: - self._processed_events_set.discard(decrypted_id) - try: - await self._on_room_message(decrypted) - except Exception as exc: - logger.warning( - "Matrix: error processing decrypted event %s: %s", - getattr(event, "event_id", "?"), exc, - ) - - self._pending_megolm = still_pending - # ------------------------------------------------------------------ # Event callbacks # ------------------------------------------------------------------ @@ -1086,7 +1179,11 @@ class MatrixAdapter(BasePlatformAdapter): return # Startup grace: ignore old messages from initial sync. - raw_ts = getattr(event, "timestamp", None) or getattr(event, "server_timestamp", None) or 0 + raw_ts = ( + getattr(event, "timestamp", None) + or getattr(event, "server_timestamp", None) + or 0 + ) event_ts = raw_ts / 1000.0 if raw_ts else 0.0 if event_ts and event_ts < self._startup_ts - _STARTUP_GRACE_SECONDS: return @@ -1126,9 +1223,13 @@ class MatrixAdapter(BasePlatformAdapter): # Dispatch by msgtype. media_msgtypes = ("m.image", "m.audio", "m.video", "m.file") if msgtype in media_msgtypes: - await self._handle_media_message(room_id, sender, event_id, event_ts, source_content, relates_to, msgtype) + await self._handle_media_message( + room_id, sender, event_id, event_ts, source_content, relates_to, msgtype + ) elif msgtype == "m.text": - await self._handle_text_message(room_id, sender, event_id, event_ts, source_content, relates_to) + await self._handle_text_message( + room_id, sender, event_id, event_ts, source_content, relates_to + ) async def _resolve_message_context( self, @@ -1154,7 +1255,9 @@ class MatrixAdapter(BasePlatformAdapter): formatted_body = source_content.get("formatted_body") # m.mentions.user_ids (MSC3952 / Matrix v1.7) — authoritative mention signal. mentions_block = source_content.get("m.mentions") or {} - mention_user_ids = mentions_block.get("user_ids") if isinstance(mentions_block, dict) else None + mention_user_ids = ( + mentions_block.get("user_ids") if isinstance(mentions_block, dict) else None + ) is_mentioned = self._is_bot_mentioned(body, formatted_body, mention_user_ids) # Require-mention gating. @@ -1170,8 +1273,8 @@ class MatrixAdapter(BasePlatformAdapter): thread_id = event_id self._threads.mark(thread_id) - # Strip mention from body. - if is_mentioned: + # Strip mention from body (only when mention-gating is active). + if is_mentioned and self._require_mention: body = self._strip_mention(body) # Auto-thread. @@ -1210,7 +1313,12 @@ class MatrixAdapter(BasePlatformAdapter): return ctx = await self._resolve_message_context( - room_id, sender, event_id, body, source_content, relates_to, + room_id, + sender, + event_id, + body, + source_content, + relates_to, ) if ctx is None: return @@ -1288,7 +1396,9 @@ class MatrixAdapter(BasePlatformAdapter): if url and url.startswith("mxc://"): http_url = self._mxc_to_http(url) - is_encrypted_media = bool(file_content and isinstance(file_content, dict) and file_content.get("url")) + is_encrypted_media = bool( + file_content and isinstance(file_content, dict) and file_content.get("url") + ) media_type = "application/octet-stream" msg_type = MessageType.DOCUMENT @@ -1312,9 +1422,9 @@ class MatrixAdapter(BasePlatformAdapter): # Cache media locally when downstream tools need a real file path. cached_path = None - should_cache_locally = ( - msg_type == MessageType.PHOTO or is_voice_message or is_encrypted_media - ) + should_cache_locally = msg_type in ( + MessageType.PHOTO, MessageType.AUDIO, MessageType.VIDEO, MessageType.DOCUMENT, + ) or is_voice_message or is_encrypted_media if should_cache_locally and url: try: file_bytes = await self._client.download_media(ContentURI(url)) @@ -1322,17 +1432,35 @@ class MatrixAdapter(BasePlatformAdapter): if is_encrypted_media: from mautrix.crypto.attachments import decrypt_attachment - hashes_value = file_content.get("hashes") if isinstance(file_content, dict) else None - hash_value = hashes_value.get("sha256") if isinstance(hashes_value, dict) else None + hashes_value = ( + file_content.get("hashes") + if isinstance(file_content, dict) + else None + ) + hash_value = ( + hashes_value.get("sha256") + if isinstance(hashes_value, dict) + else None + ) - key_value = file_content.get("key") if isinstance(file_content, dict) else None + key_value = ( + file_content.get("key") + if isinstance(file_content, dict) + else None + ) if isinstance(key_value, dict): key_value = key_value.get("k") - iv_value = file_content.get("iv") if isinstance(file_content, dict) else None + iv_value = ( + file_content.get("iv") + if isinstance(file_content, dict) + else None + ) if key_value and hash_value and iv_value: - file_bytes = decrypt_attachment(file_bytes, key_value, hash_value, iv_value) + file_bytes = decrypt_attachment( + file_bytes, key_value, hash_value, iv_value + ) else: logger.warning( "[Matrix] Encrypted media event missing decryption metadata for %s", @@ -1358,25 +1486,46 @@ class MatrixAdapter(BasePlatformAdapter): cached_path = cache_image_from_bytes(file_bytes, ext=ext) logger.info("[Matrix] Cached user image at %s", cached_path) elif msg_type in (MessageType.AUDIO, MessageType.VOICE): - ext = Path(body or ("voice.ogg" if is_voice_message else "audio.ogg")).suffix or ".ogg" + ext = ( + Path( + body + or ( + "voice.ogg" if is_voice_message else "audio.ogg" + ) + ).suffix + or ".ogg" + ) cached_path = cache_audio_from_bytes(file_bytes, ext=ext) else: filename = body or ( - "video.mp4" if msg_type == MessageType.VIDEO else "document" + "video.mp4" + if msg_type == MessageType.VIDEO + else "document" + ) + cached_path = cache_document_from_bytes( + file_bytes, filename ) - cached_path = cache_document_from_bytes(file_bytes, filename) except Exception as e: logger.warning("[Matrix] Failed to cache media: %s", e) ctx = await self._resolve_message_context( - room_id, sender, event_id, body, source_content, relates_to, + room_id, + sender, + event_id, + body, + source_content, + relates_to, ) if ctx is None: return body, is_dm, chat_type, thread_id, display_name, source = ctx allow_http_fallback = bool(http_url) and not is_encrypted_media - media_urls = [cached_path] if cached_path else ([http_url] if allow_http_fallback else None) + media_urls = ( + [cached_path] + if cached_path + else ([http_url] if allow_http_fallback else None) + ) media_types = [media_type] if media_urls else None msg_event = MessageEvent( @@ -1391,23 +1540,6 @@ class MatrixAdapter(BasePlatformAdapter): await self.handle_message(msg_event) - async def _on_encrypted_event(self, event: Any) -> None: - """Handle encrypted events that could not be auto-decrypted.""" - room_id = str(getattr(event, "room_id", "")) - event_id = str(getattr(event, "event_id", "")) - - if self._is_duplicate_event(event_id): - return - - logger.warning( - "Matrix: could not decrypt event %s in %s — buffering for retry", - event_id, room_id, - ) - - self._pending_megolm.append((room_id, event, time.time())) - if len(self._pending_megolm) > _MAX_PENDING_EVENTS: - self._pending_megolm = self._pending_megolm[-_MAX_PENDING_EVENTS:] - async def _on_invite(self, event: Any) -> None: """Auto-join rooms when invited.""" @@ -1430,7 +1562,10 @@ class MatrixAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def _send_reaction( - self, room_id: str, event_id: str, emoji: str, + self, + room_id: str, + event_id: str, + emoji: str, ) -> Optional[str]: """Send an emoji reaction to a message in a room. Returns the reaction event_id on success, None on failure. @@ -1447,7 +1582,9 @@ class MatrixAdapter(BasePlatformAdapter): } try: resp_event_id = await self._client.send_message_event( - RoomID(room_id), EventType.REACTION, content, + RoomID(room_id), + EventType.REACTION, + content, ) logger.debug("Matrix: sent reaction %s to %s", emoji, event_id) return str(resp_event_id) @@ -1456,7 +1593,10 @@ class MatrixAdapter(BasePlatformAdapter): return None async def _redact_reaction( - self, room_id: str, reaction_event_id: str, reason: str = "", + self, + room_id: str, + reaction_event_id: str, + reason: str = "", ) -> bool: """Remove a reaction by redacting its event.""" return await self.redact_message(room_id, reaction_event_id, reason) @@ -1473,7 +1613,9 @@ class MatrixAdapter(BasePlatformAdapter): self._pending_reactions[(room_id, msg_id)] = reaction_event_id async def on_processing_complete( - self, event: MessageEvent, outcome: ProcessingOutcome, + self, + event: MessageEvent, + outcome: ProcessingOutcome, ) -> None: """Replace eyes with checkmark (success) or cross (failure).""" if not self._reactions_enabled: @@ -1507,7 +1649,11 @@ class MatrixAdapter(BasePlatformAdapter): room_id = str(getattr(event, "room_id", "")) content = getattr(event, "content", None) if content: - relates_to = content.get("m.relates_to", {}) if isinstance(content, dict) else getattr(content, "relates_to", {}) + relates_to = ( + content.get("m.relates_to", {}) + if isinstance(content, dict) + else getattr(content, "relates_to", {}) + ) reacts_to = "" key = "" if isinstance(relates_to, dict): @@ -1518,7 +1664,10 @@ class MatrixAdapter(BasePlatformAdapter): key = str(getattr(relates_to, "key", "")) logger.info( "Matrix: reaction %s from %s on %s in %s", - key, sender, reacts_to, room_id, + key, + sender, + reacts_to, + room_id, ) # ------------------------------------------------------------------ @@ -1528,10 +1677,15 @@ class MatrixAdapter(BasePlatformAdapter): def _text_batch_key(self, event: MessageEvent) -> str: """Session-scoped key for text message batching.""" from gateway.session import build_session_key + return build_session_key( event.source, - group_sessions_per_user=self.config.extra.get("group_sessions_per_user", True), - thread_sessions_per_user=self.config.extra.get("thread_sessions_per_user", False), + group_sessions_per_user=self.config.extra.get( + "group_sessions_per_user", True + ), + thread_sessions_per_user=self.config.extra.get( + "thread_sessions_per_user", False + ), ) def _enqueue_text_event(self, event: MessageEvent) -> None: @@ -1544,7 +1698,9 @@ class MatrixAdapter(BasePlatformAdapter): self._pending_text_batches[key] = event else: if event.text: - existing.text = f"{existing.text}\n{event.text}" if existing.text else event.text + existing.text = ( + f"{existing.text}\n{event.text}" if existing.text else event.text + ) existing._last_chunk_len = chunk_len # type: ignore[attr-defined] if event.media_urls: existing.media_urls.extend(event.media_urls) @@ -1573,7 +1729,8 @@ class MatrixAdapter(BasePlatformAdapter): return logger.info( "[Matrix] Flushing text batch %s (%d chars)", - key, len(event.text or ""), + key, + len(event.text or ""), ) await self.handle_message(event) finally: @@ -1586,11 +1743,13 @@ class MatrixAdapter(BasePlatformAdapter): def _background_read_receipt(self, room_id: str, event_id: str) -> None: """Fire-and-forget read receipt with error logging.""" + async def _send() -> None: try: await self.send_read_receipt(room_id, event_id) except Exception as exc: # pragma: no cover — defensive logger.debug("Matrix: background read receipt failed: %s", exc) + asyncio.ensure_future(_send()) async def send_read_receipt(self, room_id: str, event_id: str) -> bool: @@ -1624,14 +1783,19 @@ class MatrixAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def redact_message( - self, room_id: str, event_id: str, reason: str = "", + self, + room_id: str, + event_id: str, + reason: str = "", ) -> bool: """Redact (delete) a message or event from a room.""" if not self._client: return False try: await self._client.redact( - RoomID(room_id), EventID(event_id), reason=reason or None, + RoomID(room_id), + EventID(event_id), + reason=reason or None, ) logger.info("Matrix: redacted %s in %s", event_id, room_id) return True @@ -1722,7 +1886,10 @@ class MatrixAdapter(BasePlatformAdapter): # ------------------------------------------------------------------ async def _send_simple_message( - self, chat_id: str, text: str, msgtype: str, + self, + chat_id: str, + text: str, + msgtype: str, ) -> SendResult: """Send a simple message (emote, notice) with optional HTML formatting.""" if not self._client or not text: @@ -1736,7 +1903,9 @@ class MatrixAdapter(BasePlatformAdapter): try: event_id = await self._client.send_message_event( - RoomID(chat_id), EventType.ROOM_MESSAGE, msg_content, + RoomID(chat_id), + EventType.ROOM_MESSAGE, + msg_content, ) return SendResult(success=True, message_id=str(event_id)) except Exception as exc: @@ -1751,7 +1920,9 @@ class MatrixAdapter(BasePlatformAdapter): if self._dm_rooms.get(room_id, False): return True # Fallback: check member count via state store. - state_store = getattr(self._client, "state_store", None) if self._client else None + state_store = ( + getattr(self._client, "state_store", None) if self._client else None + ) if state_store: try: members = await state_store.get_members(room_id) @@ -1785,10 +1956,7 @@ class MatrixAdapter(BasePlatformAdapter): if isinstance(rooms, list): dm_room_ids.update(str(r) for r in rooms) - self._dm_rooms = { - rid: (rid in dm_room_ids) - for rid in self._joined_rooms - } + self._dm_rooms = {rid: (rid in dm_room_ids) for rid in self._joined_rooms} # ------------------------------------------------------------------ # Mention detection helpers @@ -1818,7 +1986,9 @@ class MatrixAdapter(BasePlatformAdapter): return True if self._user_id and ":" in self._user_id: localpart = self._user_id.split(":")[0].lstrip("@") - if localpart and re.search(r'\b' + re.escape(localpart) + r'\b', body, re.IGNORECASE): + if localpart and re.search( + r"\b" + re.escape(localpart) + r"\b", body, re.IGNORECASE + ): return True if formatted_body and self._user_id: if f"matrix.to/#/{self._user_id}" in formatted_body: @@ -1826,18 +1996,20 @@ class MatrixAdapter(BasePlatformAdapter): return False def _strip_mention(self, body: str) -> str: - """Remove bot mention from message body.""" + """Strip the bot's full MXID (``@user:server``) from *body*. + + The bare localpart is intentionally *not* stripped — it would + mangle file paths like ``/home/hermes/media/file.png``. + """ if self._user_id: body = body.replace(self._user_id, "") - if self._user_id and ":" in self._user_id: - localpart = self._user_id.split(":")[0].lstrip("@") - if localpart: - body = re.sub(r'\b' + re.escape(localpart) + r'\b', '', body, flags=re.IGNORECASE) return body.strip() async def _get_display_name(self, room_id: str, user_id: str) -> str: """Get a user's display name in a room, falling back to user_id.""" - state_store = getattr(self._client, "state_store", None) if self._client else None + state_store = ( + getattr(self._client, "state_store", None) if self._client else None + ) if state_store: try: member = await state_store.get_member(room_id, user_id) @@ -1925,9 +2097,7 @@ class MatrixAdapter(BasePlatformAdapter): # Inline code: `code` result = re.sub( r"`([^`\n]+)`", - lambda m: _protect_html( - f"{_html_escape(m.group(1))}" - ), + lambda m: _protect_html(f"{_html_escape(m.group(1))}"), result, ) @@ -1972,11 +2142,18 @@ class MatrixAdapter(BasePlatformAdapter): continue # Blockquote - if line.startswith("> ") or line == ">" or line.startswith("> ") or line == ">": + if ( + line.startswith("> ") + or line == ">" + or line.startswith("> ") + or line == ">" + ): bq_lines = [] while i < len(lines) and ( - lines[i].startswith("> ") or lines[i] == ">" - or lines[i].startswith("> ") or lines[i] == ">" + lines[i].startswith("> ") + or lines[i] == ">" + or lines[i].startswith("> ") + or lines[i] == ">" ): ln = lines[i] if ln.startswith("> "): @@ -2017,13 +2194,19 @@ class MatrixAdapter(BasePlatformAdapter): result = "\n".join(out_lines) # Inline transforms. - result = re.sub(r"\*\*(.+?)\*\*", r"\1", result, flags=re.DOTALL) + result = re.sub( + r"\*\*(.+?)\*\*", r"\1", result, flags=re.DOTALL + ) result = re.sub(r"__(.+?)__", r"\1", result, flags=re.DOTALL) result = re.sub(r"\*(.+?)\*", r"\1", result, flags=re.DOTALL) - result = re.sub(r"(?\1", result, flags=re.DOTALL) + result = re.sub( + r"(?\1", result, flags=re.DOTALL + ) result = re.sub(r"~~(.+?)~~", r"\1", result, flags=re.DOTALL) result = re.sub(r"\n", "
\n", result) - result = re.sub(r"
\n(\n()
", r"\1", result) # Restore protected regions. diff --git a/gateway/run.py b/gateway/run.py index da55518c79..ba7ea43ad4 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -7982,12 +7982,15 @@ class GatewayRunner: if _adapter: _adapter_supports_edit = getattr(_adapter, "SUPPORTS_MESSAGE_EDITING", True) _effective_cursor = _scfg.cursor if _adapter_supports_edit else "" + _buffer_only = False if source.platform == Platform.MATRIX: _effective_cursor = "" + _buffer_only = True _consumer_cfg = StreamConsumerConfig( edit_interval=_scfg.edit_interval, buffer_threshold=_scfg.buffer_threshold, cursor=_effective_cursor, + buffer_only=_buffer_only, ) _stream_consumer = GatewayStreamConsumer( adapter=_adapter, @@ -8553,12 +8556,15 @@ class GatewayRunner: # Some Matrix clients render the streaming cursor # as a visible tofu/white-box artifact. Keep # streaming text on Matrix, but suppress the cursor. + _buffer_only = False if source.platform == Platform.MATRIX: _effective_cursor = "" + _buffer_only = True _consumer_cfg = StreamConsumerConfig( edit_interval=_scfg.edit_interval, buffer_threshold=_scfg.buffer_threshold, cursor=_effective_cursor, + buffer_only=_buffer_only, ) _stream_consumer = GatewayStreamConsumer( adapter=_adapter, diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 853b159034..5b529e63e8 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -43,6 +43,7 @@ class StreamConsumerConfig: edit_interval: float = 1.0 buffer_threshold: int = 40 cursor: str = " ▉" + buffer_only: bool = False class GatewayStreamConsumer: @@ -295,10 +296,13 @@ class GatewayStreamConsumer: got_done or got_segment_break or commentary_text is not None - or (elapsed >= self._current_edit_interval - and self._accumulated) - or len(self._accumulated) >= self.cfg.buffer_threshold ) + if not self.cfg.buffer_only: + should_edit = should_edit or ( + (elapsed >= self._current_edit_interval + and self._accumulated) + or len(self._accumulated) >= self.cfg.buffer_threshold + ) current_update_visible = False if should_edit and self._accumulated: diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 19ed200b06..845c0fff1f 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -108,6 +108,9 @@ def _make_fake_mautrix(): def add_event_handler(self, event_type, handler): self._event_handlers.setdefault(event_type, []).append(handler) + def add_dispatcher(self, dispatcher_type): + pass + class InternalEventType: INVITE = "internal.invite" @@ -115,6 +118,14 @@ def _make_fake_mautrix(): mautrix_client.InternalEventType = InternalEventType mautrix.client = mautrix_client + # --- mautrix.client.dispatcher --- + mautrix_client_dispatcher = types.ModuleType("mautrix.client.dispatcher") + + class MembershipEventDispatcher: + pass + + mautrix_client_dispatcher.MembershipEventDispatcher = MembershipEventDispatcher + # --- mautrix.client.state_store --- mautrix_client_state_store = types.ModuleType("mautrix.client.state_store") @@ -163,6 +174,19 @@ def _make_fake_mautrix(): mautrix_crypto_store.MemoryCryptoStore = MemoryCryptoStore + # --- mautrix.crypto.attachments --- + mautrix_crypto_attachments = types.ModuleType("mautrix.crypto.attachments") + + def encrypt_attachment(data): + encrypted_file = MagicMock() + encrypted_file.serialize.return_value = { + "key": {"k": "testkey"}, "iv": "testiv", + "hashes": {"sha256": "testhash"}, "v": "v2", + } + return (b"ciphertext_" + data, encrypted_file) + + mautrix_crypto_attachments.encrypt_attachment = encrypt_attachment + # --- mautrix.crypto.store.asyncpg --- mautrix_crypto_store_asyncpg = types.ModuleType("mautrix.crypto.store.asyncpg") @@ -200,8 +224,10 @@ def _make_fake_mautrix(): "mautrix.api": mautrix_api, "mautrix.types": mautrix_types, "mautrix.client": mautrix_client, + "mautrix.client.dispatcher": mautrix_client_dispatcher, "mautrix.client.state_store": mautrix_client_state_store, "mautrix.crypto": mautrix_crypto, + "mautrix.crypto.attachments": mautrix_crypto_attachments, "mautrix.crypto.store": mautrix_crypto_store, "mautrix.crypto.store.asyncpg": mautrix_crypto_store_asyncpg, "mautrix.util": mautrix_util, @@ -357,6 +383,16 @@ class TestMatrixTypingIndicator: timeout=0, ) + @pytest.mark.asyncio + async def test_stop_typing_no_client_is_noop(self): + self.adapter._client = None + await self.adapter.stop_typing("!room:example.org") # should not raise + + @pytest.mark.asyncio + async def test_stop_typing_suppresses_exceptions(self): + self.adapter._client.set_typing = AsyncMock(side_effect=Exception("network")) + await self.adapter.stop_typing("!room:example.org") # should not raise + # --------------------------------------------------------------------------- # mxc:// URL conversion @@ -835,6 +871,41 @@ class TestMatrixAccessTokenAuth: await adapter.disconnect() +class TestDeviceKeyReVerification: + @pytest.mark.asyncio + async def test_verify_fails_when_server_keys_mismatch_after_upload(self): + """share_keys() succeeds but server still has old keys -> should return False.""" + adapter = _make_adapter() + + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = "TESTDEVICE" + + # First query: keys missing -> triggers share_keys + # Second query: keys still don't match -> should fail + mock_keys_missing = MagicMock() + mock_keys_missing.device_keys = {"@bot:example.org": {}} + + mock_keys_mismatch = MagicMock() + mock_device = MagicMock() + mock_device.keys = {"ed25519:TESTDEVICE": "server_old_key"} + mock_keys_mismatch.device_keys = {"@bot:example.org": {"TESTDEVICE": mock_device}} + + mock_client.query_keys = AsyncMock(side_effect=[mock_keys_missing, mock_keys_mismatch]) + + mock_olm = MagicMock() + mock_olm.account = MagicMock() + mock_olm.account.shared = False + mock_olm.account.identity_keys = {"ed25519": "local_new_key"} + mock_olm.share_keys = AsyncMock() + + from gateway.platforms.matrix import MatrixAdapter + result = await adapter._verify_device_keys_on_server(mock_client, mock_olm) + + assert result is False + mock_olm.share_keys.assert_awaited_once() + + class TestMatrixE2EEHardFail: """connect() must refuse to start when E2EE is requested but deps are missing.""" @@ -1139,6 +1210,56 @@ class TestMatrixSyncLoop: mock_sync_store.put_next_batch.assert_awaited_once_with("s1234") +class TestMatrixUploadAndSend: + @pytest.mark.asyncio + async def test_upload_unencrypted_room_uses_plain_url(self): + """Unencrypted rooms should use plain 'url' key.""" + adapter = _make_adapter() + adapter._encryption = True + mock_client = MagicMock() + mock_client.crypto = object() + mock_client.state_store = MagicMock() + mock_client.state_store.is_encrypted = AsyncMock(return_value=False) + mock_client.upload_media = AsyncMock(return_value="mxc://example.org/plain") + mock_client.send_message_event = AsyncMock(return_value="$event") + adapter._client = mock_client + + result = await adapter._upload_and_send( + "!room:example.org", b"hello", "test.txt", "text/plain", "m.file", + ) + + assert result.success is True + sent = mock_client.send_message_event.await_args.args[2] + assert sent["url"] == "mxc://example.org/plain" + assert "file" not in sent + + @pytest.mark.asyncio + async def test_upload_encrypted_room_uses_file_payload(self): + """Encrypted rooms should use 'file' key with crypto metadata.""" + adapter = _make_adapter() + adapter._encryption = True + mock_client = MagicMock() + mock_client.crypto = object() + mock_client.state_store = MagicMock() + mock_client.state_store.is_encrypted = AsyncMock(return_value=True) + mock_client.upload_media = AsyncMock(return_value="mxc://example.org/enc") + mock_client.send_message_event = AsyncMock(return_value="$event") + adapter._client = mock_client + + result = await adapter._upload_and_send( + "!room:example.org", b"secret", "secret.txt", "text/plain", "m.file", + ) + + assert result.success is True + # Should have uploaded ciphertext, not plaintext + uploaded_data = mock_client.upload_media.await_args.args[0] + assert uploaded_data != b"secret" + sent = mock_client.send_message_event.await_args.args[2] + assert "url" not in sent + assert "file" in sent + assert sent["file"]["url"] == "mxc://example.org/enc" + + class TestMatrixEncryptedSendFallback: @pytest.mark.asyncio async def test_send_retries_after_e2ee_error(self): @@ -1165,128 +1286,24 @@ class TestMatrixEncryptedSendFallback: # --------------------------------------------------------------------------- -# E2EE: MegolmEvent key request + buffering via _on_encrypted_event +# E2EE: _joined_rooms reference preservation for CryptoStateStore # --------------------------------------------------------------------------- -class TestMatrixMegolmEventHandling: - @pytest.mark.asyncio - async def test_encrypted_event_buffers_for_retry(self): - """_on_encrypted_event should buffer undecrypted events for retry.""" - adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} +class TestJoinedRoomsReference: + def test_joined_rooms_reference_preserved_after_reassignment(self): + """_CryptoStateStore must see updates after initial sync populates rooms.""" + from gateway.platforms.matrix import _CryptoStateStore - fake_event = MagicMock() - fake_event.room_id = "!room:example.org" - fake_event.event_id = "$encrypted_event" - fake_event.sender = "@alice:example.org" + joined = set() + store = _CryptoStateStore(MagicMock(), joined) - await adapter._on_encrypted_event(fake_event) + # Simulate what connect() should do: mutate in place, not reassign. + joined.clear() + joined.update(["!room1:example.org", "!room2:example.org"]) - # Should have buffered the event - assert len(adapter._pending_megolm) == 1 - room_id, event, ts = adapter._pending_megolm[0] - assert room_id == "!room:example.org" - assert event is fake_event - - @pytest.mark.asyncio - async def test_encrypted_event_buffer_capped(self): - """Buffer should not grow past _MAX_PENDING_EVENTS.""" - adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - - from gateway.platforms.matrix import _MAX_PENDING_EVENTS - - for i in range(_MAX_PENDING_EVENTS + 10): - evt = MagicMock() - evt.room_id = "!room:example.org" - evt.event_id = f"$event_{i}" - evt.sender = "@alice:example.org" - await adapter._on_encrypted_event(evt) - - assert len(adapter._pending_megolm) == _MAX_PENDING_EVENTS - - -# --------------------------------------------------------------------------- -# E2EE: Retry pending decryptions -# --------------------------------------------------------------------------- - -class TestMatrixRetryPendingDecryptions: - @pytest.mark.asyncio - async def test_successful_decryption_routes_to_handler(self): - adapter = _make_adapter() - adapter._user_id = "@bot:example.org" - adapter._startup_ts = 0.0 - adapter._dm_rooms = {} - - fake_encrypted = MagicMock() - fake_encrypted.event_id = "$encrypted" - - decrypted_event = MagicMock() - - mock_crypto = MagicMock() - mock_crypto.decrypt_megolm_event = AsyncMock(return_value=decrypted_event) - - fake_client = MagicMock() - fake_client.crypto = mock_crypto - adapter._client = fake_client - - now = time.time() - adapter._pending_megolm = [("!room:ex.org", fake_encrypted, now)] - - with patch.object(adapter, "_on_room_message", AsyncMock()) as mock_handler: - await adapter._retry_pending_decryptions() - mock_handler.assert_awaited_once_with(decrypted_event) - - # Buffer should be empty now - assert len(adapter._pending_megolm) == 0 - - @pytest.mark.asyncio - async def test_still_undecryptable_stays_in_buffer(self): - adapter = _make_adapter() - - fake_encrypted = MagicMock() - fake_encrypted.event_id = "$still_encrypted" - - mock_crypto = MagicMock() - mock_crypto.decrypt_megolm_event = AsyncMock(side_effect=Exception("missing key")) - - fake_client = MagicMock() - fake_client.crypto = mock_crypto - adapter._client = fake_client - - now = time.time() - adapter._pending_megolm = [("!room:ex.org", fake_encrypted, now)] - - await adapter._retry_pending_decryptions() - - assert len(adapter._pending_megolm) == 1 - - @pytest.mark.asyncio - async def test_expired_events_dropped(self): - adapter = _make_adapter() - - from gateway.platforms.matrix import _PENDING_EVENT_TTL - - fake_event = MagicMock() - fake_event.event_id = "$old_event" - - mock_crypto = MagicMock() - fake_client = MagicMock() - fake_client.crypto = mock_crypto - adapter._client = fake_client - - # Timestamp well past TTL - old_ts = time.time() - _PENDING_EVENT_TTL - 60 - adapter._pending_megolm = [("!room:ex.org", fake_event, old_ts)] - - await adapter._retry_pending_decryptions() - - # Should have been dropped - assert len(adapter._pending_megolm) == 0 + import asyncio + rooms = asyncio.get_event_loop().run_until_complete(store.find_shared_rooms("@user:ex")) + assert set(rooms) == {"!room1:example.org", "!room2:example.org"} # --------------------------------------------------------------------------- @@ -1354,11 +1371,70 @@ class TestMatrixEncryptedEventHandler: handler_calls = mock_client.add_event_handler.call_args_list registered_types = [call.args[0] for call in handler_calls] - # Should have registered handlers for ROOM_MESSAGE, REACTION, INVITE, and ROOM_ENCRYPTED - assert len(handler_calls) >= 4 # At minimum these four + # Should have registered handlers for ROOM_MESSAGE, REACTION, INVITE + assert len(handler_calls) >= 3 await adapter.disconnect() + @pytest.mark.asyncio + async def test_connect_fails_on_stale_otk_conflict(self): + """connect() must refuse E2EE when OTK upload hits 'already exists'.""" + from gateway.platforms.matrix import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_test_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + "encryption": True, + }, + ) + adapter = MatrixAdapter(config) + + fake_mautrix_mods = _make_fake_mautrix() + + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.crypto = None + mock_client.whoami = AsyncMock(return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")) + mock_client.add_event_handler = MagicMock() + mock_client.add_dispatcher = MagicMock() + mock_client.query_keys = AsyncMock(return_value={ + "device_keys": {"@bot:example.org": {"DEV123": { + "keys": {"ed25519:DEV123": "fake_ed25519_key"}, + }}}, + }) + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() + + # share_keys succeeds on first call (from _verify_device_keys_on_server), + # then raises "already exists" on the proactive OTK flush in connect(). + mock_olm = MagicMock() + mock_olm.load = AsyncMock() + mock_olm.share_keys = AsyncMock( + side_effect=[None, Exception("One time key signed_curve25519:AAAAAQ already exists")] + ) + mock_olm.share_keys_min_trust = None + mock_olm.send_keys_min_trust = None + mock_olm.account = MagicMock() + mock_olm.account.identity_keys = {"ed25519": "fake_ed25519_key"} + + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + fake_mautrix_mods["mautrix.crypto"].OlmMachine = MagicMock(return_value=mock_olm) + + from gateway.platforms import matrix as matrix_mod + with patch.object(matrix_mod, "_check_e2ee_deps", return_value=True): + with patch.dict("sys.modules", fake_mautrix_mods): + result = await adapter.connect() + + assert result is False + # --------------------------------------------------------------------------- # Disconnect diff --git a/tests/gateway/test_matrix_mention.py b/tests/gateway/test_matrix_mention.py index b5db0da7c5..3809c33fc6 100644 --- a/tests/gateway/test_matrix_mention.py +++ b/tests/gateway/test_matrix_mention.py @@ -10,7 +10,6 @@ import pytest from gateway.config import PlatformConfig - # The matrix adapter module is importable without mautrix installed # (module-level imports use try/except with stubs). No need for # module-level mock installation — tests that call adapter methods @@ -159,9 +158,15 @@ class TestStripMention: result = self.adapter._strip_mention("@hermes:example.org help me") assert result == "help me" - def test_strip_localpart(self): + def test_localpart_preserved(self): + """Localpart-only text is no longer stripped — avoids false positives in paths.""" result = self.adapter._strip_mention("hermes help me") - assert result == "help me" + assert result == "hermes help me" + + def test_localpart_in_path_preserved(self): + """Localpart inside a file path must not be damaged.""" + result = self.adapter._strip_mention("read /home/hermes/config.yaml") + assert result == "read /home/hermes/config.yaml" def test_strip_returns_empty_for_mention_only(self): result = self.adapter._strip_mention("@hermes:example.org") @@ -273,8 +278,8 @@ async def test_require_mention_dm_always_responds(monkeypatch): @pytest.mark.asyncio -async def test_dm_strips_mention(monkeypatch): - """DMs strip mention from body, matching Discord behavior.""" +async def test_dm_strips_full_mxid(monkeypatch): + """DMs strip the full MXID from body when require_mention is on (default).""" monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False) monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False) monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") @@ -289,6 +294,23 @@ async def test_dm_strips_mention(monkeypatch): assert msg.text == "help me" +@pytest.mark.asyncio +async def test_dm_preserves_localpart_in_body(monkeypatch): + """DMs no longer strip bare localpart — only the full MXID is removed.""" + monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False) + monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False) + monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") + + adapter = _make_adapter() + _set_dm(adapter) + event = _make_event("hermes help me") + + await adapter._on_room_message(event) + adapter.handle_message.assert_awaited_once() + msg = adapter.handle_message.await_args.args[0] + assert msg.text == "hermes help me" + + @pytest.mark.asyncio async def test_bare_mention_passes_empty_string(monkeypatch): """A message that is only a mention should pass through as empty, not be dropped.""" @@ -309,7 +331,9 @@ async def test_bare_mention_passes_empty_string(monkeypatch): async def test_require_mention_free_response_room(monkeypatch): """Free-response rooms bypass mention requirement.""" monkeypatch.delenv("MATRIX_REQUIRE_MENTION", raising=False) - monkeypatch.setenv("MATRIX_FREE_RESPONSE_ROOMS", "!room1:example.org,!room2:example.org") + monkeypatch.setenv( + "MATRIX_FREE_RESPONSE_ROOMS", "!room1:example.org,!room2:example.org" + ) monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") adapter = _make_adapter() @@ -351,6 +375,22 @@ async def test_require_mention_disabled(monkeypatch): assert msg.text == "hello without mention" +@pytest.mark.asyncio +async def test_require_mention_disabled_skips_stripping(monkeypatch): + """MATRIX_REQUIRE_MENTION=false: mention text is NOT stripped from body.""" + monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "false") + monkeypatch.delenv("MATRIX_FREE_RESPONSE_ROOMS", raising=False) + monkeypatch.setenv("MATRIX_AUTO_THREAD", "false") + + adapter = _make_adapter() + event = _make_event("@hermes:example.org help me") + + await adapter._on_room_message(event) + adapter.handle_message.assert_awaited_once() + msg = adapter.handle_message.await_args.args[0] + assert msg.text == "@hermes:example.org help me" + + # --------------------------------------------------------------------------- # Auto-thread in _on_room_message # --------------------------------------------------------------------------- @@ -442,8 +482,10 @@ class TestThreadPersistence: def test_empty_state_file(self, tmp_path, monkeypatch): """No state file → empty set.""" from gateway.platforms.helpers import ThreadParticipationTracker + monkeypatch.setattr( - ThreadParticipationTracker, "_state_path", + ThreadParticipationTracker, + "_state_path", lambda self: tmp_path / "matrix_threads.json", ) adapter = _make_adapter() @@ -452,9 +494,11 @@ class TestThreadPersistence: def test_track_thread_persists(self, tmp_path, monkeypatch): """mark() writes to disk.""" from gateway.platforms.helpers import ThreadParticipationTracker + state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - ThreadParticipationTracker, "_state_path", + ThreadParticipationTracker, + "_state_path", lambda self: state_path, ) adapter = _make_adapter() @@ -466,10 +510,12 @@ class TestThreadPersistence: def test_threads_survive_reload(self, tmp_path, monkeypatch): """Persisted threads are loaded by a new adapter instance.""" from gateway.platforms.helpers import ThreadParticipationTracker + state_path = tmp_path / "matrix_threads.json" state_path.write_text(json.dumps(["$t1", "$t2"])) monkeypatch.setattr( - ThreadParticipationTracker, "_state_path", + ThreadParticipationTracker, + "_state_path", lambda self: state_path, ) adapter = _make_adapter() @@ -479,9 +525,11 @@ class TestThreadPersistence: def test_cap_max_tracked_threads(self, tmp_path, monkeypatch): """Thread set is trimmed to max_tracked.""" from gateway.platforms.helpers import ThreadParticipationTracker + state_path = tmp_path / "matrix_threads.json" monkeypatch.setattr( - ThreadParticipationTracker, "_state_path", + ThreadParticipationTracker, + "_state_path", lambda self: state_path, ) adapter = _make_adapter() @@ -604,6 +652,7 @@ class TestMatrixConfigBridge: } import os + import yaml config_file = tmp_path / "config.yaml" @@ -613,18 +662,27 @@ class TestMatrixConfigBridge: yaml_cfg = yaml.safe_load(config_file.read_text()) matrix_cfg = yaml_cfg.get("matrix", {}) if isinstance(matrix_cfg, dict): - if "require_mention" in matrix_cfg and not os.getenv("MATRIX_REQUIRE_MENTION"): - monkeypatch.setenv("MATRIX_REQUIRE_MENTION", str(matrix_cfg["require_mention"]).lower()) + if "require_mention" in matrix_cfg and not os.getenv( + "MATRIX_REQUIRE_MENTION" + ): + monkeypatch.setenv( + "MATRIX_REQUIRE_MENTION", str(matrix_cfg["require_mention"]).lower() + ) frc = matrix_cfg.get("free_response_rooms") if frc is not None and not os.getenv("MATRIX_FREE_RESPONSE_ROOMS"): if isinstance(frc, list): frc = ",".join(str(v) for v in frc) monkeypatch.setenv("MATRIX_FREE_RESPONSE_ROOMS", str(frc)) if "auto_thread" in matrix_cfg and not os.getenv("MATRIX_AUTO_THREAD"): - monkeypatch.setenv("MATRIX_AUTO_THREAD", str(matrix_cfg["auto_thread"]).lower()) + monkeypatch.setenv( + "MATRIX_AUTO_THREAD", str(matrix_cfg["auto_thread"]).lower() + ) assert os.getenv("MATRIX_REQUIRE_MENTION") == "false" - assert os.getenv("MATRIX_FREE_RESPONSE_ROOMS") == "!room1:example.org,!room2:example.org" + assert ( + os.getenv("MATRIX_FREE_RESPONSE_ROOMS") + == "!room1:example.org,!room2:example.org" + ) assert os.getenv("MATRIX_AUTO_THREAD") == "false" def test_yaml_bridge_sets_dm_mention_threads(self, monkeypatch, tmp_path): @@ -632,6 +690,7 @@ class TestMatrixConfigBridge: monkeypatch.delenv("MATRIX_DM_MENTION_THREADS", raising=False) import os + import yaml yaml_content = {"matrix": {"dm_mention_threads": True}} @@ -641,8 +700,13 @@ class TestMatrixConfigBridge: yaml_cfg = yaml.safe_load(config_file.read_text()) matrix_cfg = yaml_cfg.get("matrix", {}) if isinstance(matrix_cfg, dict): - if "dm_mention_threads" in matrix_cfg and not os.getenv("MATRIX_DM_MENTION_THREADS"): - monkeypatch.setenv("MATRIX_DM_MENTION_THREADS", str(matrix_cfg["dm_mention_threads"]).lower()) + if "dm_mention_threads" in matrix_cfg and not os.getenv( + "MATRIX_DM_MENTION_THREADS" + ): + monkeypatch.setenv( + "MATRIX_DM_MENTION_THREADS", + str(matrix_cfg["dm_mention_threads"]).lower(), + ) assert os.getenv("MATRIX_DM_MENTION_THREADS") == "true" @@ -651,9 +715,12 @@ class TestMatrixConfigBridge: monkeypatch.setenv("MATRIX_REQUIRE_MENTION", "true") import os + yaml_cfg = {"matrix": {"require_mention": False}} matrix_cfg = yaml_cfg.get("matrix", {}) if "require_mention" in matrix_cfg and not os.getenv("MATRIX_REQUIRE_MENTION"): - monkeypatch.setenv("MATRIX_REQUIRE_MENTION", str(matrix_cfg["require_mention"]).lower()) + monkeypatch.setenv( + "MATRIX_REQUIRE_MENTION", str(matrix_cfg["require_mention"]).lower() + ) assert os.getenv("MATRIX_REQUIRE_MENTION") == "true" diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index cdba5f60ed..5f9c56345f 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -1013,3 +1013,106 @@ class TestFilterAndAccumulateIntegration: await task except asyncio.CancelledError: pass + + +# ── buffer_only mode tests ───────────────────────────────────────────── + + +class TestBufferOnlyMode: + """Verify buffer_only mode suppresses intermediate edits and only + flushes on structural boundaries (done, segment break, commentary).""" + + @pytest.mark.asyncio + async def test_suppresses_intermediate_edits(self): + """Time-based and size-based edits are skipped; only got_done flushes.""" + adapter = MagicMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + adapter.send = AsyncMock(return_value=SimpleNamespace(success=True, message_id="msg1")) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + + cfg = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="", buffer_only=True) + consumer = GatewayStreamConsumer(adapter, "!room:server", config=cfg) + + for word in ["Hello", " world", ", this", " is", " a", " test"]: + consumer.on_delta(word) + consumer.finish() + + await consumer.run() + + adapter.send.assert_called_once() + adapter.edit_message.assert_not_called() + assert "Hello world, this is a test" in adapter.send.call_args_list[0][1]["content"] + + @pytest.mark.asyncio + async def test_flushes_on_segment_break(self): + """A segment break (tool call boundary) flushes accumulated text.""" + adapter = MagicMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + adapter.send = AsyncMock(side_effect=[ + SimpleNamespace(success=True, message_id="msg1"), + SimpleNamespace(success=True, message_id="msg2"), + ]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + + cfg = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="", buffer_only=True) + consumer = GatewayStreamConsumer(adapter, "!room:server", config=cfg) + + consumer.on_delta("Before tool call") + consumer.on_delta(None) + consumer.on_delta("After tool call") + consumer.finish() + + await consumer.run() + + assert adapter.send.call_count == 2 + assert "Before tool call" in adapter.send.call_args_list[0][1]["content"] + assert "After tool call" in adapter.send.call_args_list[1][1]["content"] + adapter.edit_message.assert_not_called() + + @pytest.mark.asyncio + async def test_flushes_on_commentary(self): + """An interim commentary message flushes in buffer_only mode.""" + adapter = MagicMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + adapter.send = AsyncMock(side_effect=[ + SimpleNamespace(success=True, message_id="msg1"), + SimpleNamespace(success=True, message_id="msg2"), + SimpleNamespace(success=True, message_id="msg3"), + ]) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + + cfg = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="", buffer_only=True) + consumer = GatewayStreamConsumer(adapter, "!room:server", config=cfg) + + consumer.on_delta("Working on it...") + consumer.on_commentary("I'll search for that first.") + consumer.on_delta("Here are the results.") + consumer.finish() + + await consumer.run() + + # Three sends: accumulated text, commentary, final text + assert adapter.send.call_count >= 2 + adapter.edit_message.assert_not_called() + + @pytest.mark.asyncio + async def test_default_mode_still_triggers_intermediate_edits(self): + """Regression: buffer_only=False (default) still does progressive edits.""" + adapter = MagicMock() + adapter.MAX_MESSAGE_LENGTH = 4096 + adapter.send = AsyncMock(return_value=SimpleNamespace(success=True, message_id="msg1")) + adapter.edit_message = AsyncMock(return_value=SimpleNamespace(success=True)) + + # buffer_threshold=5 means any 5+ chars triggers an early edit + cfg = StreamConsumerConfig(edit_interval=0.01, buffer_threshold=5, cursor="") + consumer = GatewayStreamConsumer(adapter, "!room:server", config=cfg) + + consumer.on_delta("Hello world, this is long enough to trigger edits") + consumer.finish() + + await consumer.run() + + # Should have at least one send. With buffer_threshold=5 and this much + # text, the consumer may send then edit, or just send once at got_done. + # The key assertion: this doesn't break. + assert adapter.send.call_count >= 1 diff --git a/website/docs/user-guide/messaging/matrix.md b/website/docs/user-guide/messaging/matrix.md index b742e0cfaf..ec77b5bc33 100644 --- a/website/docs/user-guide/messaging/matrix.md +++ b/website/docs/user-guide/messaging/matrix.md @@ -284,8 +284,40 @@ MATRIX_RECOVERY_KEY=EsT... your recovery key here On each startup, if `MATRIX_RECOVERY_KEY` is set, Hermes imports cross-signing keys from the homeserver's secure secret storage and signs the current device. This is idempotent and safe to leave enabled permanently. -:::warning -If you delete the `~/.hermes/platforms/matrix/store/` directory, the bot loses its encryption keys. You'll need to verify the device again in your Matrix client. Back up this directory if you want to preserve encrypted sessions. +:::warning[Deleting the crypto store] +If you delete `~/.hermes/platforms/matrix/store/crypto.db`, the bot loses its encryption identity. Simply restarting with the same device ID will **not** fully recover — the homeserver still holds one-time keys signed with the old identity key, and peers cannot establish new Olm sessions. + +Hermes detects this condition on startup and refuses to enable E2EE, logging: `device XXXX has stale one-time keys on the server signed with a previous identity key`. + +**Easiest recovery: generate a new access token** (which gets a fresh device ID with no stale key history). See the "Upgrading from a previous version with E2EE" section below. This is the most reliable path and avoids touching the homeserver database. + +**Manual recovery** (advanced — keeps the same device ID): + +1. Stop Synapse and delete the old device from its database: + ```bash + sudo systemctl stop matrix-synapse + sudo sqlite3 /var/lib/matrix-synapse/homeserver.db " + DELETE FROM e2e_device_keys_json WHERE device_id = 'DEVICE_ID' AND user_id = '@hermes:your-server'; + DELETE FROM e2e_one_time_keys_json WHERE device_id = 'DEVICE_ID' AND user_id = '@hermes:your-server'; + DELETE FROM e2e_fallback_keys_json WHERE device_id = 'DEVICE_ID' AND user_id = '@hermes:your-server'; + DELETE FROM devices WHERE device_id = 'DEVICE_ID' AND user_id = '@hermes:your-server'; + " + sudo systemctl start matrix-synapse + ``` + Or via the Synapse admin API (note the URL-encoded user ID): + ```bash + curl -X DELETE -H "Authorization: Bearer ADMIN_TOKEN" \ + 'https://your-server/_synapse/admin/v2/users/%40hermes%3Ayour-server/devices/DEVICE_ID' + ``` + Note: deleting a device via the admin API may also invalidate the associated access token. You may need to generate a new token afterward. + +2. Delete the local crypto store and restart Hermes: + ```bash + rm -f ~/.hermes/platforms/matrix/store/crypto.db* + # restart hermes + ``` + +Other Matrix clients (Element, matrix-commander) may cache the old device keys. After recovery, type `/discardsession` in Element to force a new encryption session with the bot. ::: :::info @@ -361,6 +393,10 @@ pip install 'hermes-agent[matrix]' ### Upgrading from a previous version with E2EE +:::tip +If you also manually deleted `crypto.db`, see the "Deleting the crypto store" warning in the E2EE section above — there are additional steps to clear stale one-time keys from the homeserver. +::: + If you previously used Hermes with `MATRIX_ENCRYPTION=true` and are upgrading to a version that uses the new SQLite-based crypto store, the bot's encryption identity has changed. Your Matrix client (Element) may cache the old device keys