From 11b0be8d15fc18f6a6741317cbb3f197ac7df989 Mon Sep 17 00:00:00 2001 From: konsisumer Date: Sat, 20 Jun 2026 22:01:44 +0200 Subject: [PATCH] fix(gateway): avoid Matrix pending invite boot loops --- plugins/platforms/matrix/adapter.py | 40 +++++++- tests/gateway/test_matrix.py | 147 +++++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 10 deletions(-) diff --git a/plugins/platforms/matrix/adapter.py b/plugins/platforms/matrix/adapter.py index 86dc729bd2a..14e640e1b7c 100644 --- a/plugins/platforms/matrix/adapter.py +++ b/plugins/platforms/matrix/adapter.py @@ -809,6 +809,7 @@ class MatrixAdapter(BasePlatformAdapter): self._client: Any = None # mautrix.client.Client self._crypto_db: Any = None # mautrix.util.async_db.Database self._sync_task: Optional[asyncio.Task] = None + self._invite_join_tasks: Dict[str, asyncio.Task] = {} self._closing = False self._startup_ts: float = 0.0 # Clock-skew detection: count grace-check drops that happen well @@ -1447,7 +1448,7 @@ class MatrixAdapter(BasePlatformAdapter): await self._dispatch_sync(sync_data) except Exception as exc: logger.warning("Matrix: initial sync event dispatch error: %s", exc) - await self._join_pending_invites(sync_data) + self._schedule_pending_invite_joins(sync_data) else: logger.warning( "Matrix: initial sync returned unexpected type %s", @@ -1479,6 +1480,14 @@ class MatrixAdapter(BasePlatformAdapter): except (asyncio.CancelledError, Exception): pass + invite_join_tasks = list(self._invite_join_tasks.values()) + for task in invite_join_tasks: + if not task.done(): + task.cancel() + if invite_join_tasks: + await asyncio.gather(*invite_join_tasks, return_exceptions=True) + self._invite_join_tasks.clear() + redaction_tasks = list(self._reaction_redaction_tasks) for task in redaction_tasks: if not task.done(): @@ -2217,7 +2226,10 @@ class MatrixAdapter(BasePlatformAdapter): await self._dispatch_sync(sync_data) except Exception as exc: logger.warning("Matrix: sync event dispatch error: %s", exc) - await self._join_pending_invites(sync_data) + self._schedule_pending_invite_joins(sync_data) + # Let freshly scheduled invite joins start before the next + # sync iteration without waiting for slow or stuck joins. + await asyncio.sleep(0) except asyncio.CancelledError: return @@ -2881,7 +2893,7 @@ class MatrixAdapter(BasePlatformAdapter): "Matrix: invited to %s — joining", room_id, ) - await self._join_room_by_id(room_id) + self._schedule_invite_join(room_id) async def _join_room_by_id(self, room_id: str) -> bool: """Join a room by ID and refresh local caches on success.""" @@ -2901,7 +2913,25 @@ class MatrixAdapter(BasePlatformAdapter): logger.warning("Matrix: error joining %s: %s", room_id, exc) return False - async def _join_pending_invites(self, sync_data: Dict[str, Any]) -> None: + def _schedule_invite_join(self, room_id: str) -> None: + """Schedule an invite join without blocking sync or gateway readiness.""" + if not room_id or room_id in self._joined_rooms: + return + existing = self._invite_join_tasks.get(room_id) + if existing and not existing.done(): + return + + async def _join_invite() -> None: + try: + await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0) + except asyncio.TimeoutError: + logger.warning("Matrix: timed out joining invite %s", room_id) + finally: + self._invite_join_tasks.pop(room_id, None) + + self._invite_join_tasks[room_id] = asyncio.create_task(_join_invite()) + + def _schedule_pending_invite_joins(self, sync_data: Dict[str, Any]) -> None: """Join rooms still present in rooms.invite after sync processing.""" rooms = sync_data.get("rooms", {}) if isinstance(sync_data, dict) else {} invites = rooms.get("invite", {}) @@ -2911,7 +2941,7 @@ class MatrixAdapter(BasePlatformAdapter): if room_id in self._joined_rooms: continue logger.info("Matrix: reconciling pending invite for %s", room_id) - await self._join_room_by_id(str(room_id)) + self._schedule_invite_join(str(room_id)) # ------------------------------------------------------------------ # Reactions (send, receive, processing lifecycle) diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py index ac78ecc526c..43aea0d0e39 100644 --- a/tests/gateway/test_matrix.py +++ b/tests/gateway/test_matrix.py @@ -1441,6 +1441,68 @@ class TestMatrixAccessTokenAuth: await adapter.disconnect() + @pytest.mark.asyncio + async def test_connect_does_not_wait_for_stuck_pending_invite(self): + """A stale pending invite must not keep the Matrix platform unready.""" + from plugins.platforms.matrix.adapter import MatrixAdapter + + config = PlatformConfig( + enabled=True, + token="syt_test_access_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + }, + ) + adapter = MatrixAdapter(config) + + fake_mautrix_mods = _make_fake_mautrix() + join_started = asyncio.Event() + + async def _stuck_join_room(*args, **kwargs): + join_started.set() + await asyncio.Event().wait() + + mock_client = MagicMock() + mock_client.mxid = "@bot:example.org" + mock_client.device_id = None + mock_client.state_store = MagicMock() + mock_client.sync_store = MagicMock() + mock_client.sync_store.put_next_batch = AsyncMock() + mock_client.crypto = None + mock_client.whoami = AsyncMock( + return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123") + ) + mock_client.sync = AsyncMock( + return_value={ + "rooms": { + "join": {}, + "invite": {"!dead:example.org": {}}, + }, + "next_batch": "s1234", + } + ) + mock_client.join_room = AsyncMock(side_effect=_stuck_join_room) + mock_client.add_event_handler = MagicMock() + mock_client.add_dispatcher = MagicMock() + mock_client.handle_sync = MagicMock(return_value=[]) + mock_client.api = MagicMock() + mock_client.api.token = "syt_test_access_token" + mock_client.api.session = MagicMock() + mock_client.api.session.close = AsyncMock() + fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client) + + with patch.dict("sys.modules", fake_mautrix_mods): + with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): + with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)): + assert await asyncio.wait_for(adapter.connect(), timeout=1) is True + + await asyncio.wait_for(join_started.wait(), timeout=1) + assert "!dead:example.org" in adapter._invite_join_tasks + + await adapter.disconnect() + assert adapter._invite_join_tasks == {} + class TestDeviceKeyReVerification: @pytest.mark.asyncio @@ -1858,6 +1920,10 @@ class TestMatrixSyncLoop: with patch.object(adapter, "_refresh_dm_cache", AsyncMock()): await adapter._sync_loop() + tasks = list(adapter._invite_join_tasks.values()) + if tasks: + await asyncio.gather(*tasks) + fake_client.join_room.assert_awaited_once() assert "!joined:example.org" in adapter._joined_rooms assert "!invited:example.org" in adapter._joined_rooms @@ -2123,6 +2189,47 @@ class TestMatrixSyncLoop: assert len(captured) == 1 + @pytest.mark.asyncio + async def test_pending_invite_join_does_not_block_sync_loop(self): + """Dead invite joins should not make sync look like a gateway failure.""" + adapter = _make_adapter() + adapter._closing = False + + async def _sync_once(**kwargs): + adapter._closing = True + return { + "rooms": { + "invite": {"!dead:example.org": {}}, + }, + "next_batch": "s1234", + } + + join_started = asyncio.Event() + + async def _stuck_join_room(*args, **kwargs): + join_started.set() + await asyncio.Event().wait() + + mock_sync_store = MagicMock() + mock_sync_store.get_next_batch = AsyncMock(return_value=None) + mock_sync_store.put_next_batch = AsyncMock() + + fake_client = MagicMock() + fake_client.sync = AsyncMock(side_effect=_sync_once) + fake_client.join_room = AsyncMock(side_effect=_stuck_join_room) + fake_client.sync_store = mock_sync_store + fake_client.handle_sync = MagicMock(return_value=[]) + adapter._client = fake_client + + await adapter._sync_loop() + await asyncio.wait_for(join_started.wait(), timeout=1) + + assert "!dead:example.org" not in adapter._joined_rooms + assert "!dead:example.org" in adapter._invite_join_tasks + fake_client.join_room.assert_awaited_once() + + await adapter.disconnect() + assert adapter._invite_join_tasks == {} class TestMatrixUploadAndSend: @pytest.mark.asyncio @@ -3270,6 +3377,7 @@ class TestMatrixImageOnlyMediaNormalization: @pytest.mark.asyncio async def test_external_media_download_rejects_oversized_content_length(self, monkeypatch): import aiohttp + import tools.url_safety as url_safety class _Content: async def iter_chunked(self, _size): @@ -3302,6 +3410,11 @@ class TestMatrixImageOnlyMediaNormalization: self.adapter._max_media_bytes = 10 monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session()) + monkeypatch.setattr( + url_safety, + "is_safe_url", + lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png", + ) with pytest.raises(ValueError, match="exceeds Matrix limit"): await self.adapter._download_external_media_with_cap( @@ -3311,6 +3424,7 @@ class TestMatrixImageOnlyMediaNormalization: @pytest.mark.asyncio async def test_external_media_download_rejects_oversized_stream(self, monkeypatch): import aiohttp + import tools.url_safety as url_safety class _Content: async def iter_chunked(self, _size): @@ -3345,6 +3459,11 @@ class TestMatrixImageOnlyMediaNormalization: self.adapter._max_media_bytes = 10 monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session()) + monkeypatch.setattr( + url_safety, + "is_safe_url", + lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png", + ) with pytest.raises(ValueError, match="exceeds Matrix limit"): await self.adapter._download_external_media_with_cap( @@ -3354,6 +3473,7 @@ class TestMatrixImageOnlyMediaNormalization: @pytest.mark.asyncio async def test_external_media_download_rejects_unsafe_redirect(self, monkeypatch): import aiohttp + import tools.url_safety as url_safety class _Content: async def iter_chunked(self, _size): @@ -3385,6 +3505,11 @@ class TestMatrixImageOnlyMediaNormalization: return _Response() monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session()) + monkeypatch.setattr( + url_safety, + "is_safe_url", + lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png", + ) with pytest.raises(ValueError, match="unsafe redirect"): await self.adapter._download_external_media_with_cap( @@ -3401,6 +3526,7 @@ class TestMatrixImageOnlyMediaNormalization: @pytest.mark.asyncio async def test_external_media_download_rejects_non_image_content_type(self, monkeypatch): import aiohttp + import tools.url_safety as url_safety class _Content: async def iter_chunked(self, _size): @@ -3432,6 +3558,7 @@ class TestMatrixImageOnlyMediaNormalization: return _Response() monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session()) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) with pytest.raises(ValueError, match="not an image"): await self.adapter._download_external_media_with_cap( @@ -3439,14 +3566,16 @@ class TestMatrixImageOnlyMediaNormalization: ) @pytest.mark.asyncio - async def test_send_image_failure_log_redacts_signed_url(self, caplog): + async def test_send_image_failure_log_redacts_signed_url(self, caplog, monkeypatch): from gateway.platforms.base import SendResult + import tools.url_safety as url_safety signed_url = "https://example.com/image.png?signature=secret-token#frag" self.adapter._download_external_media_with_cap = AsyncMock( side_effect=ValueError("download failed") ) self.adapter.send = AsyncMock(return_value=SendResult(success=True)) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) await self.adapter.send_image("!room:example.org", signed_url) @@ -3455,14 +3584,16 @@ class TestMatrixImageOnlyMediaNormalization: assert "#frag" not in caplog.text @pytest.mark.asyncio - async def test_send_image_failure_response_does_not_expose_signed_url_query(self): + async def test_send_image_failure_response_does_not_expose_signed_url_query(self, monkeypatch): from gateway.platforms.base import SendResult + import tools.url_safety as url_safety signed_url = "https://example.com/image.png?signature=secret-token" self.adapter._download_external_media_with_cap = AsyncMock( side_effect=ValueError("download failed") ) self.adapter.send = AsyncMock(return_value=SendResult(success=True)) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) await self.adapter.send_image("!room:example.org", signed_url) @@ -3473,14 +3604,16 @@ class TestMatrixImageOnlyMediaNormalization: assert "source URL was not shown" in sent_text @pytest.mark.asyncio - async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self): + async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self, monkeypatch): from gateway.platforms.base import SendResult + import tools.url_safety as url_safety signed_url = "https://example.com/image.png#fragment-secret" self.adapter._download_external_media_with_cap = AsyncMock( side_effect=ValueError("download failed") ) self.adapter.send = AsyncMock(return_value=SendResult(success=True)) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) await self.adapter.send_image("!room:example.org", signed_url) @@ -3491,14 +3624,16 @@ class TestMatrixImageOnlyMediaNormalization: assert "source URL was not shown" in sent_text @pytest.mark.asyncio - async def test_send_image_failure_response_preserves_caption(self): + async def test_send_image_failure_response_preserves_caption(self, monkeypatch): from gateway.platforms.base import SendResult + import tools.url_safety as url_safety signed_url = "https://example.com/image.png?signature=secret-token#fragment" self.adapter._download_external_media_with_cap = AsyncMock( side_effect=ValueError("download failed") ) self.adapter.send = AsyncMock(return_value=SendResult(success=True)) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) await self.adapter.send_image( "!room:example.org", @@ -3514,14 +3649,16 @@ class TestMatrixImageOnlyMediaNormalization: assert signed_url not in sent_text @pytest.mark.asyncio - async def test_send_image_failure_log_still_redacts_signed_url(self, caplog): + async def test_send_image_failure_log_still_redacts_signed_url(self, caplog, monkeypatch): from gateway.platforms.base import SendResult + import tools.url_safety as url_safety signed_url = "https://example.com/image.png?signature=secret-token#fragment" self.adapter._download_external_media_with_cap = AsyncMock( side_effect=ValueError("download failed") ) self.adapter.send = AsyncMock(return_value=SendResult(success=True)) + monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True) await self.adapter.send_image("!room:example.org", signed_url)