mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-30 11:52:04 +00:00
fix(matrix): record DM rooms in m.direct on invite to prevent group misclassification
Rebase onto plugins/platforms/matrix/adapter.py (code moved from gateway/platforms/matrix.py). Same logic: _on_invite checks is_direct on invite events and calls _record_dm_room to persist in m.direct account data. Fixes #44679
This commit is contained in:
parent
fde1c8570f
commit
14baeefe1d
2 changed files with 325 additions and 5 deletions
|
|
@ -2885,15 +2885,27 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
await self.handle_message(msg_event)
|
||||
|
||||
async def _on_invite(self, event: Any) -> None:
|
||||
"""Auto-join rooms when invited."""
|
||||
"""Auto-join rooms when invited, recording DM rooms in m.direct."""
|
||||
|
||||
room_id = str(getattr(event, "room_id", ""))
|
||||
content = getattr(event, "content", None)
|
||||
is_direct = bool(getattr(content, "is_direct", False))
|
||||
inviter = str(getattr(event, "sender", ""))
|
||||
|
||||
logger.info(
|
||||
"Matrix: invited to %s — joining",
|
||||
"Matrix: invited to %s — joining (is_direct=%s)",
|
||||
room_id,
|
||||
is_direct,
|
||||
)
|
||||
# When the invite declares this as a DM, record it in m.direct after
|
||||
# the (non-blocking) join completes so that _resolve_room_identity
|
||||
# treats it correctly even when the bot account has no prior DM
|
||||
# history. The join itself stays off the sync path.
|
||||
self._schedule_invite_join(
|
||||
room_id,
|
||||
is_direct=is_direct and bool(inviter),
|
||||
inviter=inviter,
|
||||
)
|
||||
self._schedule_invite_join(room_id)
|
||||
|
||||
async def _join_room_by_id(self, room_id: str) -> bool:
|
||||
"""Join a room by ID and refresh local caches on success."""
|
||||
|
|
@ -2913,7 +2925,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
logger.warning("Matrix: error joining %s: %s", room_id, exc)
|
||||
return False
|
||||
|
||||
def _schedule_invite_join(self, room_id: str) -> None:
|
||||
def _schedule_invite_join(
|
||||
self,
|
||||
room_id: str,
|
||||
*,
|
||||
is_direct: bool = False,
|
||||
inviter: str = "",
|
||||
) -> None:
|
||||
"""Schedule an invite join without blocking sync or gateway readiness."""
|
||||
if not room_id or room_id in self._joined_rooms:
|
||||
return
|
||||
|
|
@ -2923,7 +2941,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
|
||||
async def _join_invite() -> None:
|
||||
try:
|
||||
await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0)
|
||||
joined = await asyncio.wait_for(
|
||||
self._join_room_by_id(room_id), timeout=45.0
|
||||
)
|
||||
# Persist the DM signal from the invite once the join lands,
|
||||
# so m.direct is authoritative even on a fresh bot account.
|
||||
if joined and is_direct and inviter:
|
||||
await self._record_dm_room(room_id, inviter)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Matrix: timed out joining invite %s", room_id)
|
||||
finally:
|
||||
|
|
@ -3779,6 +3803,47 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
self._room_identities.clear()
|
||||
self._room_identity_cached_at.clear()
|
||||
|
||||
async def _record_dm_room(self, room_id: str, inviter: str) -> None:
|
||||
"""Persist a room as DM in m.direct account data after an invite.
|
||||
|
||||
When the bot account has never been used for DMs, ``m.direct`` is
|
||||
absent (404). This method fetches the current mapping (if any),
|
||||
appends *room_id* under the *inviter*'s entry, and writes it back
|
||||
so that subsequent ``_refresh_dm_cache`` calls treat the room as a
|
||||
DM without requiring manual ``m.direct`` setup.
|
||||
"""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
dm_data: Dict[str, list] = {}
|
||||
try:
|
||||
resp = await self._client.get_account_data("m.direct")
|
||||
if hasattr(resp, "content") and isinstance(resp.content, dict):
|
||||
dm_data = resp.content
|
||||
elif isinstance(resp, dict):
|
||||
dm_data = resp
|
||||
except Exception:
|
||||
pass # m.direct doesn't exist yet — start fresh
|
||||
|
||||
rooms_for_user = dm_data.get(inviter, [])
|
||||
if not isinstance(rooms_for_user, list):
|
||||
rooms_for_user = []
|
||||
if room_id not in rooms_for_user:
|
||||
rooms_for_user.append(room_id)
|
||||
dm_data[inviter] = rooms_for_user
|
||||
try:
|
||||
await self._client.set_account_data("m.direct", dm_data)
|
||||
logger.info(
|
||||
"Matrix: recorded %s as DM room (inviter=%s)", room_id, inviter
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: failed to update m.direct: %s", exc)
|
||||
|
||||
# Update local cache so _resolve_room_identity sees it immediately.
|
||||
self._dm_rooms[room_id] = True
|
||||
self._room_identities.pop(room_id, None)
|
||||
self._room_identity_cached_at.pop(room_id, None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Mention detection helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
|
|
|||
255
tests/gateway/test_matrix_dm_invite_recording.py
Normal file
255
tests/gateway/test_matrix_dm_invite_recording.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""Tests for Matrix DM room recording on invite (issue #44679).
|
||||
|
||||
When the bot's Matrix account has no ``m.direct`` account data (common for
|
||||
accounts created solely for Hermes), DM rooms are silently treated as groups.
|
||||
This tests the fix that records DM rooms in ``m.direct`` when the invite
|
||||
event carries ``is_direct: true``.
|
||||
"""
|
||||
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import PlatformConfig
|
||||
|
||||
|
||||
def _make_adapter(tmp_path=None):
|
||||
"""Create a MatrixAdapter with mocked config."""
|
||||
from plugins.platforms.matrix.adapter import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@hermes:example.org",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
adapter._text_batch_delay_seconds = 0
|
||||
adapter.handle_message = AsyncMock()
|
||||
adapter._startup_ts = time.time() - 10
|
||||
return adapter
|
||||
|
||||
|
||||
def _make_invite_event(
|
||||
room_id="!dm_room:example.org",
|
||||
sender="@alice:example.org",
|
||||
is_direct=True,
|
||||
):
|
||||
"""Create a fake invite event with is_direct in content."""
|
||||
content = SimpleNamespace(is_direct=is_direct)
|
||||
return SimpleNamespace(
|
||||
room_id=room_id,
|
||||
sender=sender,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _on_invite DM recording
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOnInviteRecordsDM:
|
||||
"""_on_invite schedules a join that records the DM when is_direct is True.
|
||||
|
||||
The join itself is non-blocking (``_schedule_invite_join`` spawns a task),
|
||||
so these tests drive ``_on_invite`` and then await the scheduled task to
|
||||
observe its side effects.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def _drain_invite_tasks(adapter):
|
||||
"""Await any tasks _schedule_invite_join spawned."""
|
||||
tasks = list(adapter._invite_join_tasks.values())
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dm_invite_records_room(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._join_room_by_id = AsyncMock(return_value=True)
|
||||
adapter._record_dm_room = AsyncMock()
|
||||
|
||||
event = _make_invite_event(is_direct=True, sender="@alice:example.org")
|
||||
await adapter._on_invite(event)
|
||||
await self._drain_invite_tasks(adapter)
|
||||
|
||||
adapter._join_room_by_id.assert_awaited_once_with("!dm_room:example.org")
|
||||
adapter._record_dm_room.assert_awaited_once_with(
|
||||
"!dm_room:example.org", "@alice:example.org"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_dm_invite_does_not_record(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._join_room_by_id = AsyncMock(return_value=True)
|
||||
adapter._record_dm_room = AsyncMock()
|
||||
|
||||
event = _make_invite_event(is_direct=False)
|
||||
await adapter._on_invite(event)
|
||||
await self._drain_invite_tasks(adapter)
|
||||
|
||||
adapter._join_room_by_id.assert_awaited_once()
|
||||
adapter._record_dm_room.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_is_direct_does_not_record(self):
|
||||
"""Invite events without is_direct attribute should not trigger recording."""
|
||||
adapter = _make_adapter()
|
||||
adapter._join_room_by_id = AsyncMock(return_value=True)
|
||||
adapter._record_dm_room = AsyncMock()
|
||||
|
||||
event = SimpleNamespace(
|
||||
room_id="!room:example.org",
|
||||
sender="@alice:example.org",
|
||||
content=SimpleNamespace(), # no is_direct attr
|
||||
)
|
||||
await adapter._on_invite(event)
|
||||
await self._drain_invite_tasks(adapter)
|
||||
|
||||
adapter._record_dm_room.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_join_failure_does_not_record(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._join_room_by_id = AsyncMock(return_value=False)
|
||||
adapter._record_dm_room = AsyncMock()
|
||||
|
||||
event = _make_invite_event(is_direct=True)
|
||||
await adapter._on_invite(event)
|
||||
await self._drain_invite_tasks(adapter)
|
||||
|
||||
adapter._record_dm_room.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_inviter_does_not_record(self):
|
||||
adapter = _make_adapter()
|
||||
adapter._join_room_by_id = AsyncMock(return_value=True)
|
||||
adapter._record_dm_room = AsyncMock()
|
||||
|
||||
event = SimpleNamespace(
|
||||
room_id="!room:example.org",
|
||||
sender="",
|
||||
content=SimpleNamespace(is_direct=True),
|
||||
)
|
||||
await adapter._on_invite(event)
|
||||
await self._drain_invite_tasks(adapter)
|
||||
|
||||
adapter._record_dm_room.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _record_dm_room
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecordDMRoom:
|
||||
"""_record_dm_room should update m.direct account data and local cache."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_creates_m_direct_when_absent(self):
|
||||
"""When m.direct doesn't exist (404), creates it from scratch."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.get_account_data = AsyncMock(side_effect=Exception("M_NOT_FOUND"))
|
||||
adapter._client.set_account_data = AsyncMock()
|
||||
|
||||
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
|
||||
|
||||
adapter._client.set_account_data.assert_awaited_once_with(
|
||||
"m.direct", {"@alice:example.org": ["!new:example.org"]}
|
||||
)
|
||||
assert adapter._dm_rooms.get("!new:example.org") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_appends_to_existing_m_direct(self):
|
||||
"""When m.direct exists with other rooms, appends the new room."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
existing_data = {"@bob:example.org": ["!old:example.org"]}
|
||||
adapter._client.get_account_data = AsyncMock(return_value=existing_data)
|
||||
adapter._client.set_account_data = AsyncMock()
|
||||
|
||||
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
|
||||
|
||||
expected = {
|
||||
"@bob:example.org": ["!old:example.org"],
|
||||
"@alice:example.org": ["!new:example.org"],
|
||||
}
|
||||
adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_duplicate_room_in_m_direct(self):
|
||||
"""If room is already in m.direct, does not append again."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
existing_data = {"@alice:example.org": ["!room:example.org"]}
|
||||
adapter._client.get_account_data = AsyncMock(return_value=existing_data)
|
||||
adapter._client.set_account_data = AsyncMock()
|
||||
|
||||
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
|
||||
|
||||
adapter._client.set_account_data.assert_not_awaited()
|
||||
assert adapter._dm_rooms.get("!room:example.org") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_failure_is_handled_gracefully(self):
|
||||
"""If set_account_data fails, local cache is still updated."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.get_account_data = AsyncMock(side_effect=Exception("not found"))
|
||||
adapter._client.set_account_data = AsyncMock(
|
||||
side_effect=Exception("M_FORBIDDEN")
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
|
||||
|
||||
# Local cache updated despite server error
|
||||
assert adapter._dm_rooms.get("!room:example.org") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clears_room_identity_cache(self):
|
||||
"""After recording a DM, room identity cache should be invalidated."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
adapter._client.get_account_data = AsyncMock(side_effect=Exception("404"))
|
||||
adapter._client.set_account_data = AsyncMock()
|
||||
|
||||
adapter._room_identities["!room:example.org"] = "stale"
|
||||
adapter._room_identity_cached_at["!room:example.org"] = time.monotonic()
|
||||
|
||||
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
|
||||
|
||||
assert "!room:example.org" not in adapter._room_identities
|
||||
assert "!room:example.org" not in adapter._room_identity_cached_at
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_is_noop(self):
|
||||
"""If _client is None, does nothing."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = None
|
||||
|
||||
# Should not raise
|
||||
await adapter._record_dm_room("!room:example.org", "@alice:example.org")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_m_direct_response_with_content_attr(self):
|
||||
"""get_account_data may return an object with .content attribute."""
|
||||
adapter = _make_adapter()
|
||||
adapter._client = MagicMock()
|
||||
resp = SimpleNamespace(content={"@bob:example.org": ["!old:example.org"]})
|
||||
adapter._client.get_account_data = AsyncMock(return_value=resp)
|
||||
adapter._client.set_account_data = AsyncMock()
|
||||
|
||||
await adapter._record_dm_room("!new:example.org", "@alice:example.org")
|
||||
|
||||
expected = {
|
||||
"@bob:example.org": ["!old:example.org"],
|
||||
"@alice:example.org": ["!new:example.org"],
|
||||
}
|
||||
adapter._client.set_account_data.assert_awaited_once_with("m.direct", expected)
|
||||
Loading…
Add table
Add a link
Reference in a new issue