fix(gateway): avoid Matrix pending invite boot loops

This commit is contained in:
konsisumer 2026-06-20 22:01:44 +02:00 committed by Teknium
parent a1ac6baac4
commit 11b0be8d15
2 changed files with 177 additions and 10 deletions

View file

@ -809,6 +809,7 @@ class MatrixAdapter(BasePlatformAdapter):
self._client: Any = None # mautrix.client.Client
self._crypto_db: Any = None # mautrix.util.async_db.Database
self._sync_task: Optional[asyncio.Task] = None
self._invite_join_tasks: Dict[str, asyncio.Task] = {}
self._closing = False
self._startup_ts: float = 0.0
# Clock-skew detection: count grace-check drops that happen well
@ -1447,7 +1448,7 @@ class MatrixAdapter(BasePlatformAdapter):
await self._dispatch_sync(sync_data)
except Exception as exc:
logger.warning("Matrix: initial sync event dispatch error: %s", exc)
await self._join_pending_invites(sync_data)
self._schedule_pending_invite_joins(sync_data)
else:
logger.warning(
"Matrix: initial sync returned unexpected type %s",
@ -1479,6 +1480,14 @@ class MatrixAdapter(BasePlatformAdapter):
except (asyncio.CancelledError, Exception):
pass
invite_join_tasks = list(self._invite_join_tasks.values())
for task in invite_join_tasks:
if not task.done():
task.cancel()
if invite_join_tasks:
await asyncio.gather(*invite_join_tasks, return_exceptions=True)
self._invite_join_tasks.clear()
redaction_tasks = list(self._reaction_redaction_tasks)
for task in redaction_tasks:
if not task.done():
@ -2217,7 +2226,10 @@ class MatrixAdapter(BasePlatformAdapter):
await self._dispatch_sync(sync_data)
except Exception as exc:
logger.warning("Matrix: sync event dispatch error: %s", exc)
await self._join_pending_invites(sync_data)
self._schedule_pending_invite_joins(sync_data)
# Let freshly scheduled invite joins start before the next
# sync iteration without waiting for slow or stuck joins.
await asyncio.sleep(0)
except asyncio.CancelledError:
return
@ -2881,7 +2893,7 @@ class MatrixAdapter(BasePlatformAdapter):
"Matrix: invited to %s — joining",
room_id,
)
await self._join_room_by_id(room_id)
self._schedule_invite_join(room_id)
async def _join_room_by_id(self, room_id: str) -> bool:
"""Join a room by ID and refresh local caches on success."""
@ -2901,7 +2913,25 @@ class MatrixAdapter(BasePlatformAdapter):
logger.warning("Matrix: error joining %s: %s", room_id, exc)
return False
async def _join_pending_invites(self, sync_data: Dict[str, Any]) -> None:
def _schedule_invite_join(self, room_id: str) -> None:
"""Schedule an invite join without blocking sync or gateway readiness."""
if not room_id or room_id in self._joined_rooms:
return
existing = self._invite_join_tasks.get(room_id)
if existing and not existing.done():
return
async def _join_invite() -> None:
try:
await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0)
except asyncio.TimeoutError:
logger.warning("Matrix: timed out joining invite %s", room_id)
finally:
self._invite_join_tasks.pop(room_id, None)
self._invite_join_tasks[room_id] = asyncio.create_task(_join_invite())
def _schedule_pending_invite_joins(self, sync_data: Dict[str, Any]) -> None:
"""Join rooms still present in rooms.invite after sync processing."""
rooms = sync_data.get("rooms", {}) if isinstance(sync_data, dict) else {}
invites = rooms.get("invite", {})
@ -2911,7 +2941,7 @@ class MatrixAdapter(BasePlatformAdapter):
if room_id in self._joined_rooms:
continue
logger.info("Matrix: reconciling pending invite for %s", room_id)
await self._join_room_by_id(str(room_id))
self._schedule_invite_join(str(room_id))
# ------------------------------------------------------------------
# Reactions (send, receive, processing lifecycle)

View file

@ -1441,6 +1441,68 @@ class TestMatrixAccessTokenAuth:
await adapter.disconnect()
@pytest.mark.asyncio
async def test_connect_does_not_wait_for_stuck_pending_invite(self):
"""A stale pending invite must not keep the Matrix platform unready."""
from plugins.platforms.matrix.adapter import MatrixAdapter
config = PlatformConfig(
enabled=True,
token="syt_test_access_token",
extra={
"homeserver": "https://matrix.example.org",
"user_id": "@bot:example.org",
},
)
adapter = MatrixAdapter(config)
fake_mautrix_mods = _make_fake_mautrix()
join_started = asyncio.Event()
async def _stuck_join_room(*args, **kwargs):
join_started.set()
await asyncio.Event().wait()
mock_client = MagicMock()
mock_client.mxid = "@bot:example.org"
mock_client.device_id = None
mock_client.state_store = MagicMock()
mock_client.sync_store = MagicMock()
mock_client.sync_store.put_next_batch = AsyncMock()
mock_client.crypto = None
mock_client.whoami = AsyncMock(
return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")
)
mock_client.sync = AsyncMock(
return_value={
"rooms": {
"join": {},
"invite": {"!dead:example.org": {}},
},
"next_batch": "s1234",
}
)
mock_client.join_room = AsyncMock(side_effect=_stuck_join_room)
mock_client.add_event_handler = MagicMock()
mock_client.add_dispatcher = MagicMock()
mock_client.handle_sync = MagicMock(return_value=[])
mock_client.api = MagicMock()
mock_client.api.token = "syt_test_access_token"
mock_client.api.session = MagicMock()
mock_client.api.session.close = AsyncMock()
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
with patch.dict("sys.modules", fake_mautrix_mods):
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
assert await asyncio.wait_for(adapter.connect(), timeout=1) is True
await asyncio.wait_for(join_started.wait(), timeout=1)
assert "!dead:example.org" in adapter._invite_join_tasks
await adapter.disconnect()
assert adapter._invite_join_tasks == {}
class TestDeviceKeyReVerification:
@pytest.mark.asyncio
@ -1858,6 +1920,10 @@ class TestMatrixSyncLoop:
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
await adapter._sync_loop()
tasks = list(adapter._invite_join_tasks.values())
if tasks:
await asyncio.gather(*tasks)
fake_client.join_room.assert_awaited_once()
assert "!joined:example.org" in adapter._joined_rooms
assert "!invited:example.org" in adapter._joined_rooms
@ -2123,6 +2189,47 @@ class TestMatrixSyncLoop:
assert len(captured) == 1
@pytest.mark.asyncio
async def test_pending_invite_join_does_not_block_sync_loop(self):
"""Dead invite joins should not make sync look like a gateway failure."""
adapter = _make_adapter()
adapter._closing = False
async def _sync_once(**kwargs):
adapter._closing = True
return {
"rooms": {
"invite": {"!dead:example.org": {}},
},
"next_batch": "s1234",
}
join_started = asyncio.Event()
async def _stuck_join_room(*args, **kwargs):
join_started.set()
await asyncio.Event().wait()
mock_sync_store = MagicMock()
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
mock_sync_store.put_next_batch = AsyncMock()
fake_client = MagicMock()
fake_client.sync = AsyncMock(side_effect=_sync_once)
fake_client.join_room = AsyncMock(side_effect=_stuck_join_room)
fake_client.sync_store = mock_sync_store
fake_client.handle_sync = MagicMock(return_value=[])
adapter._client = fake_client
await adapter._sync_loop()
await asyncio.wait_for(join_started.wait(), timeout=1)
assert "!dead:example.org" not in adapter._joined_rooms
assert "!dead:example.org" in adapter._invite_join_tasks
fake_client.join_room.assert_awaited_once()
await adapter.disconnect()
assert adapter._invite_join_tasks == {}
class TestMatrixUploadAndSend:
@pytest.mark.asyncio
@ -3270,6 +3377,7 @@ class TestMatrixImageOnlyMediaNormalization:
@pytest.mark.asyncio
async def test_external_media_download_rejects_oversized_content_length(self, monkeypatch):
import aiohttp
import tools.url_safety as url_safety
class _Content:
async def iter_chunked(self, _size):
@ -3302,6 +3410,11 @@ class TestMatrixImageOnlyMediaNormalization:
self.adapter._max_media_bytes = 10
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
monkeypatch.setattr(
url_safety,
"is_safe_url",
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
)
with pytest.raises(ValueError, match="exceeds Matrix limit"):
await self.adapter._download_external_media_with_cap(
@ -3311,6 +3424,7 @@ class TestMatrixImageOnlyMediaNormalization:
@pytest.mark.asyncio
async def test_external_media_download_rejects_oversized_stream(self, monkeypatch):
import aiohttp
import tools.url_safety as url_safety
class _Content:
async def iter_chunked(self, _size):
@ -3345,6 +3459,11 @@ class TestMatrixImageOnlyMediaNormalization:
self.adapter._max_media_bytes = 10
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
monkeypatch.setattr(
url_safety,
"is_safe_url",
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
)
with pytest.raises(ValueError, match="exceeds Matrix limit"):
await self.adapter._download_external_media_with_cap(
@ -3354,6 +3473,7 @@ class TestMatrixImageOnlyMediaNormalization:
@pytest.mark.asyncio
async def test_external_media_download_rejects_unsafe_redirect(self, monkeypatch):
import aiohttp
import tools.url_safety as url_safety
class _Content:
async def iter_chunked(self, _size):
@ -3385,6 +3505,11 @@ class TestMatrixImageOnlyMediaNormalization:
return _Response()
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
monkeypatch.setattr(
url_safety,
"is_safe_url",
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
)
with pytest.raises(ValueError, match="unsafe redirect"):
await self.adapter._download_external_media_with_cap(
@ -3401,6 +3526,7 @@ class TestMatrixImageOnlyMediaNormalization:
@pytest.mark.asyncio
async def test_external_media_download_rejects_non_image_content_type(self, monkeypatch):
import aiohttp
import tools.url_safety as url_safety
class _Content:
async def iter_chunked(self, _size):
@ -3432,6 +3558,7 @@ class TestMatrixImageOnlyMediaNormalization:
return _Response()
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
with pytest.raises(ValueError, match="not an image"):
await self.adapter._download_external_media_with_cap(
@ -3439,14 +3566,16 @@ class TestMatrixImageOnlyMediaNormalization:
)
@pytest.mark.asyncio
async def test_send_image_failure_log_redacts_signed_url(self, caplog):
async def test_send_image_failure_log_redacts_signed_url(self, caplog, monkeypatch):
from gateway.platforms.base import SendResult
import tools.url_safety as url_safety
signed_url = "https://example.com/image.png?signature=secret-token#frag"
self.adapter._download_external_media_with_cap = AsyncMock(
side_effect=ValueError("download failed")
)
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
await self.adapter.send_image("!room:example.org", signed_url)
@ -3455,14 +3584,16 @@ class TestMatrixImageOnlyMediaNormalization:
assert "#frag" not in caplog.text
@pytest.mark.asyncio
async def test_send_image_failure_response_does_not_expose_signed_url_query(self):
async def test_send_image_failure_response_does_not_expose_signed_url_query(self, monkeypatch):
from gateway.platforms.base import SendResult
import tools.url_safety as url_safety
signed_url = "https://example.com/image.png?signature=secret-token"
self.adapter._download_external_media_with_cap = AsyncMock(
side_effect=ValueError("download failed")
)
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
await self.adapter.send_image("!room:example.org", signed_url)
@ -3473,14 +3604,16 @@ class TestMatrixImageOnlyMediaNormalization:
assert "source URL was not shown" in sent_text
@pytest.mark.asyncio
async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self):
async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self, monkeypatch):
from gateway.platforms.base import SendResult
import tools.url_safety as url_safety
signed_url = "https://example.com/image.png#fragment-secret"
self.adapter._download_external_media_with_cap = AsyncMock(
side_effect=ValueError("download failed")
)
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
await self.adapter.send_image("!room:example.org", signed_url)
@ -3491,14 +3624,16 @@ class TestMatrixImageOnlyMediaNormalization:
assert "source URL was not shown" in sent_text
@pytest.mark.asyncio
async def test_send_image_failure_response_preserves_caption(self):
async def test_send_image_failure_response_preserves_caption(self, monkeypatch):
from gateway.platforms.base import SendResult
import tools.url_safety as url_safety
signed_url = "https://example.com/image.png?signature=secret-token#fragment"
self.adapter._download_external_media_with_cap = AsyncMock(
side_effect=ValueError("download failed")
)
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
await self.adapter.send_image(
"!room:example.org",
@ -3514,14 +3649,16 @@ class TestMatrixImageOnlyMediaNormalization:
assert signed_url not in sent_text
@pytest.mark.asyncio
async def test_send_image_failure_log_still_redacts_signed_url(self, caplog):
async def test_send_image_failure_log_still_redacts_signed_url(self, caplog, monkeypatch):
from gateway.platforms.base import SendResult
import tools.url_safety as url_safety
signed_url = "https://example.com/image.png?signature=secret-token#fragment"
self.adapter._download_external_media_with_cap = AsyncMock(
side_effect=ValueError("download failed")
)
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
await self.adapter.send_image("!room:example.org", signed_url)