diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index 89203d48a2..055c6e65f2 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -273,6 +273,14 @@ class MatrixAdapter(BasePlatformAdapter): client.add_event_callback(self._on_room_message_media, nio.RoomMessageAudio) client.add_event_callback(self._on_room_message_media, nio.RoomMessageVideo) client.add_event_callback(self._on_room_message_media, nio.RoomMessageFile) + for encrypted_media_cls in ( + getattr(nio, "RoomEncryptedImage", None), + getattr(nio, "RoomEncryptedAudio", None), + getattr(nio, "RoomEncryptedVideo", None), + getattr(nio, "RoomEncryptedFile", None), + ): + if encrypted_media_cls is not None: + client.add_event_callback(self._on_room_message_media, encrypted_media_cls) client.add_event_callback(self._on_invite, nio.InviteMemberEvent) # If E2EE: handle encrypted events. @@ -1025,47 +1033,122 @@ class MatrixAdapter(BasePlatformAdapter): # Use the MIME type from the event's content info when available, # falling back to category-level MIME types for downstream matching # (gateway/run.py checks startswith("image/"), startswith("audio/"), etc.) - content_info = getattr(event, "content", {}) if isinstance(getattr(event, "content", None), dict) else {} - event_mimetype = (content_info.get("info") or {}).get("mimetype", "") + source_content = getattr(event, "source", {}).get("content", {}) + if not isinstance(source_content, dict): + source_content = {} + event_content = getattr(event, "content", {}) + if not isinstance(event_content, dict): + event_content = {} + content_info = event_content.get("info") if isinstance(event_content, dict) else {} + if not isinstance(content_info, dict) or not content_info: + content_info = source_content.get("info", {}) if isinstance(source_content, dict) else {} + event_mimetype = ( + (content_info.get("mimetype") if isinstance(content_info, dict) else None) + or getattr(event, "mimetype", "") + or "" + ) + # For encrypted media, the URL may be in file.url instead of event.url. + file_content = source_content.get("file", {}) if isinstance(source_content, dict) else {} + if not url and isinstance(file_content, dict): + url = file_content.get("url", "") or "" + if url and url.startswith("mxc://"): + http_url = self._mxc_to_http(url) + media_type = "application/octet-stream" msg_type = MessageType.DOCUMENT + is_encrypted_image = isinstance(event, getattr(nio, "RoomEncryptedImage", ())) + is_encrypted_audio = isinstance(event, getattr(nio, "RoomEncryptedAudio", ())) + is_encrypted_video = isinstance(event, getattr(nio, "RoomEncryptedVideo", ())) + is_encrypted_file = isinstance(event, getattr(nio, "RoomEncryptedFile", ())) + is_encrypted_media = any((is_encrypted_image, is_encrypted_audio, is_encrypted_video, is_encrypted_file)) is_voice_message = False - - if isinstance(event, nio.RoomMessageImage): + + if isinstance(event, nio.RoomMessageImage) or is_encrypted_image: msg_type = MessageType.PHOTO media_type = event_mimetype or "image/png" - elif isinstance(event, nio.RoomMessageAudio): - # Check for MSC3245 voice flag: org.matrix.msc3245.voice: {} - source_content = getattr(event, "source", {}).get("content", {}) + elif isinstance(event, nio.RoomMessageAudio) or is_encrypted_audio: if source_content.get("org.matrix.msc3245.voice") is not None: is_voice_message = True msg_type = MessageType.VOICE else: msg_type = MessageType.AUDIO media_type = event_mimetype or "audio/ogg" - elif isinstance(event, nio.RoomMessageVideo): + elif isinstance(event, nio.RoomMessageVideo) or is_encrypted_video: msg_type = MessageType.VIDEO media_type = event_mimetype or "video/mp4" elif event_mimetype: media_type = event_mimetype - # For images, download and cache locally so vision tools can access them. - # Matrix MXC URLs require authentication, so direct URL access fails. + # Cache media locally when downstream tools need a real file path: + # - photos (vision tools can't access MXC URLs) + # - voice messages (transcription tools need local files) + # - any encrypted media (HTTP fallback would point at ciphertext) cached_path = None - if msg_type == MessageType.PHOTO and url: + should_cache_locally = ( + msg_type == MessageType.PHOTO or is_voice_message or is_encrypted_media + ) + if should_cache_locally and url: try: - ext_map = { - "image/jpeg": ".jpg", "image/png": ".png", - "image/gif": ".gif", "image/webp": ".webp", - } - ext = ext_map.get(event_mimetype, ".jpg") - download_resp = await self._client.download(url) - if isinstance(download_resp, nio.DownloadResponse): - from gateway.platforms.base import cache_image_from_bytes - cached_path = cache_image_from_bytes(download_resp.body, ext=ext) - logger.info("[Matrix] Cached user image at %s", cached_path) + if is_voice_message: + download_resp = await self._client.download(mxc=url) + else: + download_resp = await self._client.download(url) + file_bytes = getattr(download_resp, "body", None) + if file_bytes is not None: + if is_encrypted_media: + from nio.crypto.attachments import decrypt_attachment + + hashes_value = getattr(event, "hashes", None) + if hashes_value is None and isinstance(file_content, dict): + hashes_value = file_content.get("hashes") + hash_value = hashes_value.get("sha256") if isinstance(hashes_value, dict) else None + + key_value = getattr(event, "key", None) + if key_value is None and isinstance(file_content, dict): + key_value = file_content.get("key") + if isinstance(key_value, dict): + key_value = key_value.get("k") + + iv_value = getattr(event, "iv", None) + if iv_value is None and isinstance(file_content, dict): + iv_value = file_content.get("iv") + + if key_value and hash_value and iv_value: + file_bytes = decrypt_attachment(file_bytes, key_value, hash_value, iv_value) + else: + logger.warning( + "[Matrix] Encrypted media event missing decryption metadata for %s", + event.event_id, + ) + file_bytes = None + + if file_bytes is not None: + from gateway.platforms.base import ( + cache_audio_from_bytes, + cache_document_from_bytes, + cache_image_from_bytes, + ) + + if msg_type == MessageType.PHOTO: + ext_map = { + "image/jpeg": ".jpg", + "image/png": ".png", + "image/gif": ".gif", + "image/webp": ".webp", + } + ext = ext_map.get(media_type, ".jpg") + cached_path = cache_image_from_bytes(file_bytes, ext=ext) + logger.info("[Matrix] Cached user image at %s", cached_path) + elif msg_type in (MessageType.AUDIO, MessageType.VOICE): + ext = Path(body or ("voice.ogg" if is_voice_message else "audio.ogg")).suffix or ".ogg" + cached_path = cache_audio_from_bytes(file_bytes, ext=ext) + else: + filename = body or ( + "video.mp4" if msg_type == MessageType.VIDEO else "document" + ) + cached_path = cache_document_from_bytes(file_bytes, filename) except Exception as e: - logger.warning("[Matrix] Failed to cache image: %s", e) + logger.warning("[Matrix] Failed to cache media: %s", e) is_dm = self._dm_rooms.get(room.room_id, False) if not is_dm and room.member_count == 2: @@ -1073,7 +1156,6 @@ class MatrixAdapter(BasePlatformAdapter): chat_type = "dm" if is_dm else "group" # Thread/reply detection. - source_content = getattr(event, "source", {}).get("content", {}) relates_to = source_content.get("m.relates_to", {}) thread_id = None if relates_to.get("rel_type") == "m.thread": @@ -1103,31 +1185,6 @@ class MatrixAdapter(BasePlatformAdapter): thread_id = event.event_id self._track_thread(thread_id) - # For voice messages, cache audio locally for transcription tools. - # Use the authenticated nio client to download (Matrix requires auth for media). - media_urls = [http_url] if http_url else None - media_types = [media_type] if http_url else None - - if is_voice_message and url and url.startswith("mxc://"): - try: - import nio - from gateway.platforms.base import cache_audio_from_bytes - - resp = await self._client.download(mxc=url) - if isinstance(resp, nio.MemoryDownloadResponse): - # Extract extension from mimetype or default to .ogg - ext = ".ogg" - if media_type and "/" in media_type: - subtype = media_type.split("/")[1] - ext = f".{subtype}" if subtype else ".ogg" - local_path = cache_audio_from_bytes(resp.body, ext) - media_urls = [local_path] - logger.debug("Matrix: cached voice message to %s", local_path) - else: - logger.warning("Matrix: failed to download voice: %s", getattr(resp, "message", resp)) - except Exception as e: - logger.warning("Matrix: failed to cache voice message, using HTTP URL: %s", e) - source = self.build_source( chat_id=room.room_id, chat_type=chat_type, @@ -1136,9 +1193,8 @@ class MatrixAdapter(BasePlatformAdapter): thread_id=thread_id, ) - # Use cached local path for images (voice messages already handled above). - if cached_path: - media_urls = [cached_path] + allow_http_fallback = bool(http_url) and not is_encrypted_media + media_urls = [cached_path] if cached_path else ([http_url] if allow_http_fallback else None) media_types = [media_type] if media_urls else None msg_event = MessageEvent( diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index 9912eef00b..5e2c7c3579 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -993,3 +993,358 @@ class TestMatrixKeyExportImport: # Should not have tried to export assert not hasattr(fake_client, "export_keys") or \ not fake_client.export_keys.called + + +# --------------------------------------------------------------------------- +# E2EE: Encrypted media +# --------------------------------------------------------------------------- + +class TestMatrixEncryptedMedia: + @pytest.mark.asyncio + async def test_connect_registers_callbacks_for_encrypted_media_events(self): + from gateway.platforms.matrix import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_te...oken", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + "encryption": True, + }, + ) + adapter = MatrixAdapter(config) + + class FakeWhoamiResponse: + def __init__(self, user_id, device_id): + self.user_id = user_id + self.device_id = device_id + + class FakeSyncResponse: + def __init__(self): + self.rooms = MagicMock(join={}) + + class FakeRoomMessageText: ... + class FakeRoomMessageImage: ... + class FakeRoomMessageAudio: ... + class FakeRoomMessageVideo: ... + class FakeRoomMessageFile: ... + class FakeRoomEncryptedImage: ... + class FakeRoomEncryptedAudio: ... + class FakeRoomEncryptedVideo: ... + class FakeRoomEncryptedFile: ... + class FakeInviteMemberEvent: ... + class FakeMegolmEvent: ... + + fake_client = MagicMock() + fake_client.whoami = AsyncMock(return_value=FakeWhoamiResponse("@bot:example.org", "DEV123")) + fake_client.sync = AsyncMock(return_value=FakeSyncResponse()) + fake_client.keys_upload = AsyncMock() + fake_client.keys_query = AsyncMock() + fake_client.keys_claim = AsyncMock() + fake_client.send_to_device_messages = AsyncMock(return_value=[]) + fake_client.get_users_for_key_claiming = MagicMock(return_value={}) + fake_client.close = AsyncMock() + fake_client.add_event_callback = MagicMock() + fake_client.rooms = {} + fake_client.account_data = {} + fake_client.olm = object() + fake_client.should_upload_keys = False + fake_client.should_query_keys = False + fake_client.should_claim_keys = False + fake_client.restore_login = MagicMock(side_effect=lambda u, d, t: None) + + fake_nio = MagicMock() + fake_nio.AsyncClient = MagicMock(return_value=fake_client) + fake_nio.WhoamiResponse = FakeWhoamiResponse + fake_nio.SyncResponse = FakeSyncResponse + fake_nio.LoginResponse = type("LoginResponse", (), {}) + fake_nio.RoomMessageText = FakeRoomMessageText + fake_nio.RoomMessageImage = FakeRoomMessageImage + fake_nio.RoomMessageAudio = FakeRoomMessageAudio + fake_nio.RoomMessageVideo = FakeRoomMessageVideo + fake_nio.RoomMessageFile = FakeRoomMessageFile + fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage + fake_nio.RoomEncryptedAudio = FakeRoomEncryptedAudio + fake_nio.RoomEncryptedVideo = FakeRoomEncryptedVideo + fake_nio.RoomEncryptedFile = FakeRoomEncryptedFile + fake_nio.InviteMemberEvent = FakeInviteMemberEvent + fake_nio.MegolmEvent = FakeMegolmEvent + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): + with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): + assert await adapter.connect() is True + + callback_classes = [call.args[1] for call in fake_client.add_event_callback.call_args_list] + assert FakeRoomEncryptedImage in callback_classes + assert FakeRoomEncryptedAudio in callback_classes + assert FakeRoomEncryptedVideo in callback_classes + assert FakeRoomEncryptedFile in callback_classes + + await adapter.disconnect() + + @pytest.mark.asyncio + async def test_on_room_message_media_decrypts_encrypted_image_and_passes_local_path(self): + from nio.crypto.attachments import encrypt_attachment + + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + adapter.handle_message = AsyncMock() + + plaintext = b"\x89PNG\r\n\x1a\n" + b"\x00" * 32 + ciphertext, keys = encrypt_attachment(plaintext) + + class FakeRoomEncryptedImage: + def __init__(self): + self.sender = "@alice:example.org" + self.event_id = "$img1" + self.server_timestamp = 0 + self.body = "screenshot.png" + self.url = "mxc://example.org/media123" + self.key = keys["key"]["k"] + self.hashes = keys["hashes"] + self.iv = keys["iv"] + self.mimetype = "image/png" + self.source = { + "content": { + "body": "screenshot.png", + "info": {"mimetype": "image/png"}, + "file": { + "url": self.url, + "key": keys["key"], + "hashes": keys["hashes"], + "iv": keys["iv"], + }, + } + } + + class FakeDownloadResponse: + def __init__(self, body): + self.body = body + + fake_client = MagicMock() + fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage + fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) + fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) + fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) + + room = MagicMock(room_id="!room:example.org", member_count=2, users={}) + event = FakeRoomEncryptedImage() + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch("gateway.platforms.base.cache_image_from_bytes", return_value="/tmp/cached-image.png") as cache_mock: + await adapter._on_room_message_media(room, event) + + cache_mock.assert_called_once_with(plaintext, ext=".png") + msg_event = adapter.handle_message.await_args.args[0] + assert msg_event.message_type.name == "PHOTO" + assert msg_event.media_urls == ["/tmp/cached-image.png"] + assert msg_event.media_types == ["image/png"] + + @pytest.mark.asyncio + async def test_on_room_message_media_decrypts_encrypted_voice_and_caches_audio(self): + from nio.crypto.attachments import encrypt_attachment + + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + adapter.handle_message = AsyncMock() + + plaintext = b"OggS" + b"\x00" * 32 + ciphertext, keys = encrypt_attachment(plaintext) + + class FakeRoomEncryptedAudio: + def __init__(self): + self.sender = "@alice:example.org" + self.event_id = "$voice1" + self.server_timestamp = 0 + self.body = "voice.ogg" + self.url = "mxc://example.org/voice123" + self.key = keys["key"]["k"] + self.hashes = keys["hashes"] + self.iv = keys["iv"] + self.mimetype = "audio/ogg" + self.source = { + "content": { + "body": "voice.ogg", + "info": {"mimetype": "audio/ogg"}, + "org.matrix.msc3245.voice": {}, + "file": { + "url": self.url, + "key": keys["key"], + "hashes": keys["hashes"], + "iv": keys["iv"], + }, + } + } + + class FakeDownloadResponse: + def __init__(self, body): + self.body = body + + fake_client = MagicMock() + fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.RoomEncryptedImage = type("RoomEncryptedImage", (), {}) + fake_nio.RoomEncryptedAudio = FakeRoomEncryptedAudio + fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) + fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) + + room = MagicMock(room_id="!room:example.org", member_count=2, users={}) + event = FakeRoomEncryptedAudio() + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch("gateway.platforms.base.cache_audio_from_bytes", return_value="/tmp/cached-voice.ogg") as cache_mock: + await adapter._on_room_message_media(room, event) + + cache_mock.assert_called_once_with(plaintext, ext=".ogg") + msg_event = adapter.handle_message.await_args.args[0] + assert msg_event.message_type.name == "VOICE" + assert msg_event.media_urls == ["/tmp/cached-voice.ogg"] + assert msg_event.media_types == ["audio/ogg"] + + @pytest.mark.asyncio + async def test_on_room_message_media_decrypts_encrypted_file_and_caches_document(self): + from nio.crypto.attachments import encrypt_attachment + + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + adapter.handle_message = AsyncMock() + + plaintext = b"hello from encrypted document" + ciphertext, keys = encrypt_attachment(plaintext) + + class FakeRoomEncryptedFile: + def __init__(self): + self.sender = "@alice:example.org" + self.event_id = "$file1" + self.server_timestamp = 0 + self.body = "notes.txt" + self.url = "mxc://example.org/file123" + self.key = keys["key"] + self.hashes = keys["hashes"] + self.iv = keys["iv"] + self.mimetype = "text/plain" + self.source = { + "content": { + "body": "notes.txt", + "info": {"mimetype": "text/plain"}, + "file": { + "url": self.url, + "key": keys["key"], + "hashes": keys["hashes"], + "iv": keys["iv"], + }, + } + } + + class FakeDownloadResponse: + def __init__(self, body): + self.body = body + + fake_client = MagicMock() + fake_client.download = AsyncMock(return_value=FakeDownloadResponse(ciphertext)) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.RoomEncryptedImage = type("RoomEncryptedImage", (), {}) + fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) + fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) + fake_nio.RoomEncryptedFile = FakeRoomEncryptedFile + + room = MagicMock(room_id="!room:example.org", member_count=2, users={}) + event = FakeRoomEncryptedFile() + + with patch.dict("sys.modules", {"nio": fake_nio}): + with patch("gateway.platforms.base.cache_document_from_bytes", return_value="/tmp/cached-notes.txt") as cache_mock: + await adapter._on_room_message_media(room, event) + + cache_mock.assert_called_once_with(plaintext, "notes.txt") + msg_event = adapter.handle_message.await_args.args[0] + assert msg_event.message_type.name == "DOCUMENT" + assert msg_event.media_urls == ["/tmp/cached-notes.txt"] + assert msg_event.media_types == ["text/plain"] + + @pytest.mark.asyncio + async def test_on_room_message_media_does_not_emit_ciphertext_url_when_encrypted_media_decryption_fails(self): + adapter = _make_adapter() + adapter._user_id = "@bot:example.org" + adapter._startup_ts = 0.0 + adapter._dm_rooms = {} + adapter.handle_message = AsyncMock() + + class FakeRoomEncryptedImage: + def __init__(self): + self.sender = "@alice:example.org" + self.event_id = "$img2" + self.server_timestamp = 0 + self.body = "broken.png" + self.url = "mxc://example.org/media999" + self.key = {"k": "broken"} + self.hashes = {"sha256": "broken"} + self.iv = "broken" + self.mimetype = "image/png" + self.source = { + "content": { + "body": "broken.png", + "info": {"mimetype": "image/png"}, + "file": { + "url": self.url, + "key": self.key, + "hashes": self.hashes, + "iv": self.iv, + }, + } + } + + class FakeDownloadResponse: + def __init__(self, body): + self.body = body + + fake_client = MagicMock() + fake_client.download = AsyncMock(return_value=FakeDownloadResponse(b"ciphertext")) + adapter._client = fake_client + + fake_nio = MagicMock() + fake_nio.RoomMessageImage = type("RoomMessageImage", (), {}) + fake_nio.RoomMessageAudio = type("RoomMessageAudio", (), {}) + fake_nio.RoomMessageVideo = type("RoomMessageVideo", (), {}) + fake_nio.RoomMessageFile = type("RoomMessageFile", (), {}) + fake_nio.RoomEncryptedImage = FakeRoomEncryptedImage + fake_nio.RoomEncryptedAudio = type("RoomEncryptedAudio", (), {}) + fake_nio.RoomEncryptedVideo = type("RoomEncryptedVideo", (), {}) + fake_nio.RoomEncryptedFile = type("RoomEncryptedFile", (), {}) + + room = MagicMock(room_id="!room:example.org", member_count=2, users={}) + event = FakeRoomEncryptedImage() + + with patch.dict("sys.modules", {"nio": fake_nio}): + await adapter._on_room_message_media(room, event) + + msg_event = adapter.handle_message.await_args.args[0] + assert not msg_event.media_urls + assert not msg_event.media_types