fix(matrix): record DM rooms in m.direct on invite to prevent group misclassification

Rebase onto plugins/platforms/matrix/adapter.py (code moved from
gateway/platforms/matrix.py). Same logic: _on_invite checks is_direct
on invite events and calls _record_dm_room to persist in m.direct
account data.

Fixes #44679
This commit is contained in:
liuhao1024 2026-06-21 08:29:01 +08:00 committed by Teknium
parent fde1c8570f
commit 14baeefe1d
2 changed files with 325 additions and 5 deletions

View file

@ -2885,15 +2885,27 @@ class MatrixAdapter(BasePlatformAdapter):
await self.handle_message(msg_event)
async def _on_invite(self, event: Any) -> None:
"""Auto-join rooms when invited."""
"""Auto-join rooms when invited, recording DM rooms in m.direct."""
room_id = str(getattr(event, "room_id", ""))
content = getattr(event, "content", None)
is_direct = bool(getattr(content, "is_direct", False))
inviter = str(getattr(event, "sender", ""))
logger.info(
"Matrix: invited to %s — joining",
"Matrix: invited to %s — joining (is_direct=%s)",
room_id,
is_direct,
)
# When the invite declares this as a DM, record it in m.direct after
# the (non-blocking) join completes so that _resolve_room_identity
# treats it correctly even when the bot account has no prior DM
# history. The join itself stays off the sync path.
self._schedule_invite_join(
room_id,
is_direct=is_direct and bool(inviter),
inviter=inviter,
)
self._schedule_invite_join(room_id)
async def _join_room_by_id(self, room_id: str) -> bool:
"""Join a room by ID and refresh local caches on success."""
@ -2913,7 +2925,13 @@ class MatrixAdapter(BasePlatformAdapter):
logger.warning("Matrix: error joining %s: %s", room_id, exc)
return False
def _schedule_invite_join(self, room_id: str) -> None:
def _schedule_invite_join(
self,
room_id: str,
*,
is_direct: bool = False,
inviter: str = "",
) -> None:
"""Schedule an invite join without blocking sync or gateway readiness."""
if not room_id or room_id in self._joined_rooms:
return
@ -2923,7 +2941,13 @@ class MatrixAdapter(BasePlatformAdapter):
async def _join_invite() -> None:
try:
await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0)
joined = await asyncio.wait_for(
self._join_room_by_id(room_id), timeout=45.0
)
# Persist the DM signal from the invite once the join lands,
# so m.direct is authoritative even on a fresh bot account.
if joined and is_direct and inviter:
await self._record_dm_room(room_id, inviter)
except asyncio.TimeoutError:
logger.warning("Matrix: timed out joining invite %s", room_id)
finally:
@ -3779,6 +3803,47 @@ class MatrixAdapter(BasePlatformAdapter):
self._room_identities.clear()
self._room_identity_cached_at.clear()
async def _record_dm_room(self, room_id: str, inviter: str) -> None:
"""Persist a room as DM in m.direct account data after an invite.
When the bot account has never been used for DMs, ``m.direct`` is
absent (404). This method fetches the current mapping (if any),
appends *room_id* under the *inviter*'s entry, and writes it back
so that subsequent ``_refresh_dm_cache`` calls treat the room as a
DM without requiring manual ``m.direct`` setup.
"""
if not self._client:
return
dm_data: Dict[str, list] = {}
try:
resp = await self._client.get_account_data("m.direct")
if hasattr(resp, "content") and isinstance(resp.content, dict):
dm_data = resp.content
elif isinstance(resp, dict):
dm_data = resp
except Exception:
pass # m.direct doesn't exist yet — start fresh
rooms_for_user = dm_data.get(inviter, [])
if not isinstance(rooms_for_user, list):
rooms_for_user = []
if room_id not in rooms_for_user:
rooms_for_user.append(room_id)
dm_data[inviter] = rooms_for_user
try:
await self._client.set_account_data("m.direct", dm_data)
logger.info(
"Matrix: recorded %s as DM room (inviter=%s)", room_id, inviter
)
except Exception as exc:
logger.warning("Matrix: failed to update m.direct: %s", exc)
# Update local cache so _resolve_room_identity sees it immediately.
self._dm_rooms[room_id] = True
self._room_identities.pop(room_id, None)
self._room_identity_cached_at.pop(room_id, None)
# ------------------------------------------------------------------
# Mention detection helpers
# ------------------------------------------------------------------

View file

@ -0,0 +1,255 @@
"""Tests for Matrix DM room recording on invite (issue #44679).
When the bot's Matrix account has no ``m.direct`` account data (common for
accounts created solely for Hermes), DM rooms are silently treated as groups.
This tests the fix that records DM rooms in ``m.direct`` when the invite
event carries ``is_direct: true``.
"""
import time
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import PlatformConfig
def _make_adapter(tmp_path=None):
"""Create a MatrixAdapter with mocked config."""
from plugins.platforms.matrix.adapter import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@hermes:example.org",
},
)
adapter = MatrixAdapter(config)
adapter._text_batch_delay_seconds = 0
adapter.handle_message = AsyncMock()
adapter._startup_ts = time.time() - 10
return adapter
def _make_invite_event(
room_id="!dm_room:example.org",
sender="@alice:example.org",
is_direct=True,
):
"""Create a fake invite event with is_direct in content."""
content = SimpleNamespace(is_direct=is_direct)
return SimpleNamespace(
room_id=room_id,
sender=sender,
content=content,
)
# ---------------------------------------------------------------------------
# _on_invite DM recording
# ---------------------------------------------------------------------------
class TestOnInviteRecordsDM:
"""_on_invite schedules a join that records the DM when is_direct is True.
The join itself is non-blocking (``_schedule_invite_join`` spawns a task),
so these tests drive ``_on_invite`` and then await the scheduled task to
observe its side effects.
"""
@staticmethod
async def _drain_invite_tasks(adapter):
"""Await any tasks _schedule_invite_join spawned."""
tasks = list(adapter._invite_join_tasks.values())
for task in tasks:
await task
@pytest.mark.asyncio
async def test_dm_invite_records_room(self):
adapter = _make_adapter()
adapter._join_room_by_id = AsyncMock(return_value=True)
adapter._record_dm_room = AsyncMock()
event = _make_invite_event(is_direct=True, sender="@alice:example.org")
await adapter._on_invite(event)
await self._drain_invite_tasks(adapter)
adapter._join_room_by_id.assert_awaited_once_with("!dm_room:example.org")
adapter._record_dm_room.assert_awaited_once_with(
"!dm_room:example.org", "@alice:example.org"
)
@pytest.mark.asyncio
async def test_non_dm_invite_does_not_record(self):
adapter = _make_adapter()
adapter._join_room_by_id = AsyncMock(return_value=True)
adapter._record_dm_room = AsyncMock()
event = _make_invite_event(is_direct=False)
await adapter._on_invite(event)
await self._drain_invite_tasks(adapter)
adapter._join_room_by_id.assert_awaited_once()
adapter._record_dm_room.assert_not_awaited()
@pytest.mark.asyncio
async def test_missing_is_direct_does_not_record(self):
"""Invite events without is_direct attribute should not trigger recording."""
adapter = _make_adapter()
adapter._join_room_by_id = AsyncMock(return_value=True)
adapter._record_dm_room = AsyncMock()
event = SimpleNamespace(
room_id="!room:example.org",
sender="@alice:example.org",
content=SimpleNamespace(), # no is_direct attr
)
await adapter._on_invite(event)
await self._drain_invite_tasks(adapter)
adapter._record_dm_room.assert_not_awaited()
@pytest.mark.asyncio
async def test_join_failure_does_not_record(self):
adapter = _make_adapter()
adapter._join_room_by_id = AsyncMock(return_value=False)
adapter._record_dm_room = AsyncMock()
event = _make_invite_event(is_direct=True)
await adapter._on_invite(event)
await self._drain_invite_tasks(adapter)
adapter._record_dm_room.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_inviter_does_not_record(self):
adapter = _make_adapter()
adapter._join_room_by_id = AsyncMock(return_value=True)
adapter._record_dm_room = AsyncMock()
event = SimpleNamespace(
room_id="!room:example.org",
sender="",
content=SimpleNamespace(is_direct=True),
)
await adapter._on_invite(event)
await self._drain_invite_tasks(adapter)
adapter._record_dm_room.assert_not_awaited()
# ---------------------------------------------------------------------------
# _record_dm_room
# ---------------------------------------------------------------------------
class TestRecordDMRoom:
"""_record_dm_room should update m.direct account data and local cache."""
@pytest.mark.asyncio
async def test_creates_m_direct_when_absent(self):
"""When m.direct doesn't exist (404), creates it from scratch."""
adapter = _make_adapter()
adapter._client = MagicMock()
adapter._client.get_account_data = AsyncMock(side_effect=Exception("M_NOT_FOUND"))
adapter._client.set_account_data = AsyncMock()
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
adapter._client.set_account_data.assert_awaited_once_with(
"m.direct", {"@alice:example.org": ["!new:example.org"]}
)
assert adapter._dm_rooms.get("!new:example.org") is True
@pytest.mark.asyncio
async def test_appends_to_existing_m_direct(self):
"""When m.direct exists with other rooms, appends the new room."""
adapter = _make_adapter()
adapter._client = MagicMock()
existing_data = {"@bob:example.org": ["!old:example.org"]}
adapter._client.get_account_data = AsyncMock(return_value=existing_data)
adapter._client.set_account_data = AsyncMock()
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
expected = {
"@bob:example.org": ["!old:example.org"],
"@alice:example.org": ["!new:example.org"],
}
adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected)
@pytest.mark.asyncio
async def test_no_duplicate_room_in_m_direct(self):
"""If room is already in m.direct, does not append again."""
adapter = _make_adapter()
adapter._client = MagicMock()
existing_data = {"@alice:example.org": ["!room:example.org"]}
adapter._client.get_account_data = AsyncMock(return_value=existing_data)
adapter._client.set_account_data = AsyncMock()
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
adapter._client.set_account_data.assert_not_awaited()
assert adapter._dm_rooms.get("!room:example.org") is True
@pytest.mark.asyncio
async def test_set_failure_is_handled_gracefully(self):
"""If set_account_data fails, local cache is still updated."""
adapter = _make_adapter()
adapter._client = MagicMock()
adapter._client.get_account_data = AsyncMock(side_effect=Exception("not found"))
adapter._client.set_account_data = AsyncMock(
side_effect=Exception("M_FORBIDDEN")
)
# Should not raise
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
# Local cache updated despite server error
assert adapter._dm_rooms.get("!room:example.org") is True
@pytest.mark.asyncio
async def test_clears_room_identity_cache(self):
"""After recording a DM, room identity cache should be invalidated."""
adapter = _make_adapter()
adapter._client = MagicMock()
adapter._client.get_account_data = AsyncMock(side_effect=Exception("404"))
adapter._client.set_account_data = AsyncMock()
adapter._room_identities["!room:example.org"] = "stale"
adapter._room_identity_cached_at["!room:example.org"] = time.monotonic()
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
assert "!room:example.org" not in adapter._room_identities
assert "!room:example.org" not in adapter._room_identity_cached_at
@pytest.mark.asyncio
async def test_no_client_is_noop(self):
"""If _client is None, does nothing."""
adapter = _make_adapter()
adapter._client = None
# Should not raise
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
@pytest.mark.asyncio
async def test_m_direct_response_with_content_attr(self):
"""get_account_data may return an object with .content attribute."""
adapter = _make_adapter()
adapter._client = MagicMock()
resp = SimpleNamespace(content={"@bob:example.org": ["!old:example.org"]})
adapter._client.get_account_data = AsyncMock(return_value=resp)
adapter._client.set_account_data = AsyncMock()
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
expected = {
"@bob:example.org": ["!old:example.org"],
"@alice:example.org": ["!new:example.org"],
}
adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected)