diff --git a/plugins/platforms/matrix/adapter.py b/plugins/platforms/matrix/adapter.py index 14e640e1b7c..428b6b6afe0 100644 --- a/plugins/platforms/matrix/adapter.py +++ b/plugins/platforms/matrix/adapter.py @@ -2885,15 +2885,27 @@ class MatrixAdapter(BasePlatformAdapter): await self.handle_message(msg_event) async def _on_invite(self, event: Any) -> None: - """Auto-join rooms when invited.""" + """Auto-join rooms when invited, recording DM rooms in m.direct.""" room_id = str(getattr(event, "room_id", "")) + content = getattr(event, "content", None) + is_direct = bool(getattr(content, "is_direct", False)) + inviter = str(getattr(event, "sender", "")) logger.info( - "Matrix: invited to %s — joining", + "Matrix: invited to %s — joining (is_direct=%s)", room_id, + is_direct, + ) + # When the invite declares this as a DM, record it in m.direct after + # the (non-blocking) join completes so that _resolve_room_identity + # treats it correctly even when the bot account has no prior DM + # history. The join itself stays off the sync path. + self._schedule_invite_join( + room_id, + is_direct=is_direct and bool(inviter), + inviter=inviter, ) - self._schedule_invite_join(room_id) async def _join_room_by_id(self, room_id: str) -> bool: """Join a room by ID and refresh local caches on success.""" @@ -2913,7 +2925,13 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: error joining %s: %s", room_id, exc) return False - def _schedule_invite_join(self, room_id: str) -> None: + def _schedule_invite_join( + self, + room_id: str, + *, + is_direct: bool = False, + inviter: str = "", + ) -> None: """Schedule an invite join without blocking sync or gateway readiness.""" if not room_id or room_id in self._joined_rooms: return @@ -2923,7 +2941,13 @@ class MatrixAdapter(BasePlatformAdapter): async def _join_invite() -> None: try: - await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0) + joined = await asyncio.wait_for( + self._join_room_by_id(room_id), timeout=45.0 + ) + # Persist the DM signal from the invite once the join lands, + # so m.direct is authoritative even on a fresh bot account. + if joined and is_direct and inviter: + await self._record_dm_room(room_id, inviter) except asyncio.TimeoutError: logger.warning("Matrix: timed out joining invite %s", room_id) finally: @@ -3779,6 +3803,47 @@ class MatrixAdapter(BasePlatformAdapter): self._room_identities.clear() self._room_identity_cached_at.clear() + async def _record_dm_room(self, room_id: str, inviter: str) -> None: + """Persist a room as DM in m.direct account data after an invite. + + When the bot account has never been used for DMs, ``m.direct`` is + absent (404). This method fetches the current mapping (if any), + appends *room_id* under the *inviter*'s entry, and writes it back + so that subsequent ``_refresh_dm_cache`` calls treat the room as a + DM without requiring manual ``m.direct`` setup. + """ + if not self._client: + return + + dm_data: Dict[str, list] = {} + try: + resp = await self._client.get_account_data("m.direct") + if hasattr(resp, "content") and isinstance(resp.content, dict): + dm_data = resp.content + elif isinstance(resp, dict): + dm_data = resp + except Exception: + pass # m.direct doesn't exist yet — start fresh + + rooms_for_user = dm_data.get(inviter, []) + if not isinstance(rooms_for_user, list): + rooms_for_user = [] + if room_id not in rooms_for_user: + rooms_for_user.append(room_id) + dm_data[inviter] = rooms_for_user + try: + await self._client.set_account_data("m.direct", dm_data) + logger.info( + "Matrix: recorded %s as DM room (inviter=%s)", room_id, inviter + ) + except Exception as exc: + logger.warning("Matrix: failed to update m.direct: %s", exc) + + # Update local cache so _resolve_room_identity sees it immediately. + self._dm_rooms[room_id] = True + self._room_identities.pop(room_id, None) + self._room_identity_cached_at.pop(room_id, None) + # ------------------------------------------------------------------ # Mention detection helpers # ------------------------------------------------------------------ diff --git a/tests/gateway/test_matrix_dm_invite_recording.py b/tests/gateway/test_matrix_dm_invite_recording.py new file mode 100644 index 00000000000..77d9ae56bf1 --- /dev/null +++ b/tests/gateway/test_matrix_dm_invite_recording.py @@ -0,0 +1,255 @@ +"""Tests for Matrix DM room recording on invite (issue #44679). + +When the bot's Matrix account has no ``m.direct`` account data (common for +accounts created solely for Hermes), DM rooms are silently treated as groups. +This tests the fix that records DM rooms in ``m.direct`` when the invite +event carries ``is_direct: true``. +""" + +import time +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from gateway.config import PlatformConfig + + +def _make_adapter(tmp_path=None): + """Create a MatrixAdapter with mocked config.""" + from plugins.platforms.matrix.adapter import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_test_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@hermes:example.org", + }, + ) + adapter = MatrixAdapter(config) + adapter._text_batch_delay_seconds = 0 + adapter.handle_message = AsyncMock() + adapter._startup_ts = time.time() - 10 + return adapter + + +def _make_invite_event( + room_id="!dm_room:example.org", + sender="@alice:example.org", + is_direct=True, +): + """Create a fake invite event with is_direct in content.""" + content = SimpleNamespace(is_direct=is_direct) + return SimpleNamespace( + room_id=room_id, + sender=sender, + content=content, + ) + + +# --------------------------------------------------------------------------- +# _on_invite DM recording +# --------------------------------------------------------------------------- + + +class TestOnInviteRecordsDM: + """_on_invite schedules a join that records the DM when is_direct is True. + + The join itself is non-blocking (``_schedule_invite_join`` spawns a task), + so these tests drive ``_on_invite`` and then await the scheduled task to + observe its side effects. + """ + + @staticmethod + async def _drain_invite_tasks(adapter): + """Await any tasks _schedule_invite_join spawned.""" + tasks = list(adapter._invite_join_tasks.values()) + for task in tasks: + await task + + @pytest.mark.asyncio + async def test_dm_invite_records_room(self): + adapter = _make_adapter() + adapter._join_room_by_id = AsyncMock(return_value=True) + adapter._record_dm_room = AsyncMock() + + event = _make_invite_event(is_direct=True, sender="@alice:example.org") + await adapter._on_invite(event) + await self._drain_invite_tasks(adapter) + + adapter._join_room_by_id.assert_awaited_once_with("!dm_room:example.org") + adapter._record_dm_room.assert_awaited_once_with( + "!dm_room:example.org", "@alice:example.org" + ) + + @pytest.mark.asyncio + async def test_non_dm_invite_does_not_record(self): + adapter = _make_adapter() + adapter._join_room_by_id = AsyncMock(return_value=True) + adapter._record_dm_room = AsyncMock() + + event = _make_invite_event(is_direct=False) + await adapter._on_invite(event) + await self._drain_invite_tasks(adapter) + + adapter._join_room_by_id.assert_awaited_once() + adapter._record_dm_room.assert_not_awaited() + + @pytest.mark.asyncio + async def test_missing_is_direct_does_not_record(self): + """Invite events without is_direct attribute should not trigger recording.""" + adapter = _make_adapter() + adapter._join_room_by_id = AsyncMock(return_value=True) + adapter._record_dm_room = AsyncMock() + + event = SimpleNamespace( + room_id="!room:example.org", + sender="@alice:example.org", + content=SimpleNamespace(), # no is_direct attr + ) + await adapter._on_invite(event) + await self._drain_invite_tasks(adapter) + + adapter._record_dm_room.assert_not_awaited() + + @pytest.mark.asyncio + async def test_join_failure_does_not_record(self): + adapter = _make_adapter() + adapter._join_room_by_id = AsyncMock(return_value=False) + adapter._record_dm_room = AsyncMock() + + event = _make_invite_event(is_direct=True) + await adapter._on_invite(event) + await self._drain_invite_tasks(adapter) + + adapter._record_dm_room.assert_not_awaited() + + @pytest.mark.asyncio + async def test_empty_inviter_does_not_record(self): + adapter = _make_adapter() + adapter._join_room_by_id = AsyncMock(return_value=True) + adapter._record_dm_room = AsyncMock() + + event = SimpleNamespace( + room_id="!room:example.org", + sender="", + content=SimpleNamespace(is_direct=True), + ) + await adapter._on_invite(event) + await self._drain_invite_tasks(adapter) + + adapter._record_dm_room.assert_not_awaited() + + +# --------------------------------------------------------------------------- +# _record_dm_room +# --------------------------------------------------------------------------- + + +class TestRecordDMRoom: + """_record_dm_room should update m.direct account data and local cache.""" + + @pytest.mark.asyncio + async def test_creates_m_direct_when_absent(self): + """When m.direct doesn't exist (404), creates it from scratch.""" + adapter = _make_adapter() + adapter._client = MagicMock() + adapter._client.get_account_data = AsyncMock(side_effect=Exception("M_NOT_FOUND")) + adapter._client.set_account_data = AsyncMock() + + await adapter._record_dm_room("!new:example.org", "@alice:example.org") + + adapter._client.set_account_data.assert_awaited_once_with( + "m.direct", {"@alice:example.org": ["!new:example.org"]} + ) + assert adapter._dm_rooms.get("!new:example.org") is True + + @pytest.mark.asyncio + async def test_appends_to_existing_m_direct(self): + """When m.direct exists with other rooms, appends the new room.""" + adapter = _make_adapter() + adapter._client = MagicMock() + existing_data = {"@bob:example.org": ["!old:example.org"]} + adapter._client.get_account_data = AsyncMock(return_value=existing_data) + adapter._client.set_account_data = AsyncMock() + + await adapter._record_dm_room("!new:example.org", "@alice:example.org") + + expected = { + "@bob:example.org": ["!old:example.org"], + "@alice:example.org": ["!new:example.org"], + } + adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected) + + @pytest.mark.asyncio + async def test_no_duplicate_room_in_m_direct(self): + """If room is already in m.direct, does not append again.""" + adapter = _make_adapter() + adapter._client = MagicMock() + existing_data = {"@alice:example.org": ["!room:example.org"]} + adapter._client.get_account_data = AsyncMock(return_value=existing_data) + adapter._client.set_account_data = AsyncMock() + + await adapter._record_dm_room("!room:example.org", "@alice:example.org") + + adapter._client.set_account_data.assert_not_awaited() + assert adapter._dm_rooms.get("!room:example.org") is True + + @pytest.mark.asyncio + async def test_set_failure_is_handled_gracefully(self): + """If set_account_data fails, local cache is still updated.""" + adapter = _make_adapter() + adapter._client = MagicMock() + adapter._client.get_account_data = AsyncMock(side_effect=Exception("not found")) + adapter._client.set_account_data = AsyncMock( + side_effect=Exception("M_FORBIDDEN") + ) + + # Should not raise + await adapter._record_dm_room("!room:example.org", "@alice:example.org") + + # Local cache updated despite server error + assert adapter._dm_rooms.get("!room:example.org") is True + + @pytest.mark.asyncio + async def test_clears_room_identity_cache(self): + """After recording a DM, room identity cache should be invalidated.""" + adapter = _make_adapter() + adapter._client = MagicMock() + adapter._client.get_account_data = AsyncMock(side_effect=Exception("404")) + adapter._client.set_account_data = AsyncMock() + + adapter._room_identities["!room:example.org"] = "stale" + adapter._room_identity_cached_at["!room:example.org"] = time.monotonic() + + await adapter._record_dm_room("!room:example.org", "@alice:example.org") + + assert "!room:example.org" not in adapter._room_identities + assert "!room:example.org" not in adapter._room_identity_cached_at + + @pytest.mark.asyncio + async def test_no_client_is_noop(self): + """If _client is None, does nothing.""" + adapter = _make_adapter() + adapter._client = None + + # Should not raise + await adapter._record_dm_room("!room:example.org", "@alice:example.org") + + @pytest.mark.asyncio + async def test_m_direct_response_with_content_attr(self): + """get_account_data may return an object with .content attribute.""" + adapter = _make_adapter() + adapter._client = MagicMock() + resp = SimpleNamespace(content={"@bob:example.org": ["!old:example.org"]}) + adapter._client.get_account_data = AsyncMock(return_value=resp) + adapter._client.set_account_data = AsyncMock() + + await adapter._record_dm_room("!new:example.org", "@alice:example.org") + + expected = { + "@bob:example.org": ["!old:example.org"], + "@alice:example.org": ["!new:example.org"], + } + adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected)