mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-02 12:13:05 +00:00
fix(gateway): avoid Matrix pending invite boot loops
This commit is contained in:
parent
a1ac6baac4
commit
11b0be8d15
2 changed files with 177 additions and 10 deletions
|
|
@ -809,6 +809,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
self._client: Any = None # mautrix.client.Client
|
||||
self._crypto_db: Any = None # mautrix.util.async_db.Database
|
||||
self._sync_task: Optional[asyncio.Task] = None
|
||||
self._invite_join_tasks: Dict[str, asyncio.Task] = {}
|
||||
self._closing = False
|
||||
self._startup_ts: float = 0.0
|
||||
# Clock-skew detection: count grace-check drops that happen well
|
||||
|
|
@ -1447,7 +1448,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
await self._dispatch_sync(sync_data)
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: initial sync event dispatch error: %s", exc)
|
||||
await self._join_pending_invites(sync_data)
|
||||
self._schedule_pending_invite_joins(sync_data)
|
||||
else:
|
||||
logger.warning(
|
||||
"Matrix: initial sync returned unexpected type %s",
|
||||
|
|
@ -1479,6 +1480,14 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
invite_join_tasks = list(self._invite_join_tasks.values())
|
||||
for task in invite_join_tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
if invite_join_tasks:
|
||||
await asyncio.gather(*invite_join_tasks, return_exceptions=True)
|
||||
self._invite_join_tasks.clear()
|
||||
|
||||
redaction_tasks = list(self._reaction_redaction_tasks)
|
||||
for task in redaction_tasks:
|
||||
if not task.done():
|
||||
|
|
@ -2217,7 +2226,10 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
await self._dispatch_sync(sync_data)
|
||||
except Exception as exc:
|
||||
logger.warning("Matrix: sync event dispatch error: %s", exc)
|
||||
await self._join_pending_invites(sync_data)
|
||||
self._schedule_pending_invite_joins(sync_data)
|
||||
# Let freshly scheduled invite joins start before the next
|
||||
# sync iteration without waiting for slow or stuck joins.
|
||||
await asyncio.sleep(0)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
|
@ -2881,7 +2893,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
"Matrix: invited to %s — joining",
|
||||
room_id,
|
||||
)
|
||||
await self._join_room_by_id(room_id)
|
||||
self._schedule_invite_join(room_id)
|
||||
|
||||
async def _join_room_by_id(self, room_id: str) -> bool:
|
||||
"""Join a room by ID and refresh local caches on success."""
|
||||
|
|
@ -2901,7 +2913,25 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
logger.warning("Matrix: error joining %s: %s", room_id, exc)
|
||||
return False
|
||||
|
||||
async def _join_pending_invites(self, sync_data: Dict[str, Any]) -> None:
|
||||
def _schedule_invite_join(self, room_id: str) -> None:
|
||||
"""Schedule an invite join without blocking sync or gateway readiness."""
|
||||
if not room_id or room_id in self._joined_rooms:
|
||||
return
|
||||
existing = self._invite_join_tasks.get(room_id)
|
||||
if existing and not existing.done():
|
||||
return
|
||||
|
||||
async def _join_invite() -> None:
|
||||
try:
|
||||
await asyncio.wait_for(self._join_room_by_id(room_id), timeout=45.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Matrix: timed out joining invite %s", room_id)
|
||||
finally:
|
||||
self._invite_join_tasks.pop(room_id, None)
|
||||
|
||||
self._invite_join_tasks[room_id] = asyncio.create_task(_join_invite())
|
||||
|
||||
def _schedule_pending_invite_joins(self, sync_data: Dict[str, Any]) -> None:
|
||||
"""Join rooms still present in rooms.invite after sync processing."""
|
||||
rooms = sync_data.get("rooms", {}) if isinstance(sync_data, dict) else {}
|
||||
invites = rooms.get("invite", {})
|
||||
|
|
@ -2911,7 +2941,7 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
if room_id in self._joined_rooms:
|
||||
continue
|
||||
logger.info("Matrix: reconciling pending invite for %s", room_id)
|
||||
await self._join_room_by_id(str(room_id))
|
||||
self._schedule_invite_join(str(room_id))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reactions (send, receive, processing lifecycle)
|
||||
|
|
|
|||
|
|
@ -1441,6 +1441,68 @@ class TestMatrixAccessTokenAuth:
|
|||
|
||||
await adapter.disconnect()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connect_does_not_wait_for_stuck_pending_invite(self):
|
||||
"""A stale pending invite must not keep the Matrix platform unready."""
|
||||
from plugins.platforms.matrix.adapter import MatrixAdapter
|
||||
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
token="syt_test_access_token",
|
||||
extra={
|
||||
"homeserver": "https://matrix.example.org",
|
||||
"user_id": "@bot:example.org",
|
||||
},
|
||||
)
|
||||
adapter = MatrixAdapter(config)
|
||||
|
||||
fake_mautrix_mods = _make_fake_mautrix()
|
||||
join_started = asyncio.Event()
|
||||
|
||||
async def _stuck_join_room(*args, **kwargs):
|
||||
join_started.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.mxid = "@bot:example.org"
|
||||
mock_client.device_id = None
|
||||
mock_client.state_store = MagicMock()
|
||||
mock_client.sync_store = MagicMock()
|
||||
mock_client.sync_store.put_next_batch = AsyncMock()
|
||||
mock_client.crypto = None
|
||||
mock_client.whoami = AsyncMock(
|
||||
return_value=MagicMock(user_id="@bot:example.org", device_id="DEV123")
|
||||
)
|
||||
mock_client.sync = AsyncMock(
|
||||
return_value={
|
||||
"rooms": {
|
||||
"join": {},
|
||||
"invite": {"!dead:example.org": {}},
|
||||
},
|
||||
"next_batch": "s1234",
|
||||
}
|
||||
)
|
||||
mock_client.join_room = AsyncMock(side_effect=_stuck_join_room)
|
||||
mock_client.add_event_handler = MagicMock()
|
||||
mock_client.add_dispatcher = MagicMock()
|
||||
mock_client.handle_sync = MagicMock(return_value=[])
|
||||
mock_client.api = MagicMock()
|
||||
mock_client.api.token = "syt_test_access_token"
|
||||
mock_client.api.session = MagicMock()
|
||||
mock_client.api.session.close = AsyncMock()
|
||||
fake_mautrix_mods["mautrix.client"].Client = MagicMock(return_value=mock_client)
|
||||
|
||||
with patch.dict("sys.modules", fake_mautrix_mods):
|
||||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
with patch.object(adapter, "_sync_loop", AsyncMock(return_value=None)):
|
||||
assert await asyncio.wait_for(adapter.connect(), timeout=1) is True
|
||||
|
||||
await asyncio.wait_for(join_started.wait(), timeout=1)
|
||||
assert "!dead:example.org" in adapter._invite_join_tasks
|
||||
|
||||
await adapter.disconnect()
|
||||
assert adapter._invite_join_tasks == {}
|
||||
|
||||
|
||||
class TestDeviceKeyReVerification:
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -1858,6 +1920,10 @@ class TestMatrixSyncLoop:
|
|||
with patch.object(adapter, "_refresh_dm_cache", AsyncMock()):
|
||||
await adapter._sync_loop()
|
||||
|
||||
tasks = list(adapter._invite_join_tasks.values())
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
fake_client.join_room.assert_awaited_once()
|
||||
assert "!joined:example.org" in adapter._joined_rooms
|
||||
assert "!invited:example.org" in adapter._joined_rooms
|
||||
|
|
@ -2123,6 +2189,47 @@ class TestMatrixSyncLoop:
|
|||
|
||||
assert len(captured) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pending_invite_join_does_not_block_sync_loop(self):
|
||||
"""Dead invite joins should not make sync look like a gateway failure."""
|
||||
adapter = _make_adapter()
|
||||
adapter._closing = False
|
||||
|
||||
async def _sync_once(**kwargs):
|
||||
adapter._closing = True
|
||||
return {
|
||||
"rooms": {
|
||||
"invite": {"!dead:example.org": {}},
|
||||
},
|
||||
"next_batch": "s1234",
|
||||
}
|
||||
|
||||
join_started = asyncio.Event()
|
||||
|
||||
async def _stuck_join_room(*args, **kwargs):
|
||||
join_started.set()
|
||||
await asyncio.Event().wait()
|
||||
|
||||
mock_sync_store = MagicMock()
|
||||
mock_sync_store.get_next_batch = AsyncMock(return_value=None)
|
||||
mock_sync_store.put_next_batch = AsyncMock()
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.sync = AsyncMock(side_effect=_sync_once)
|
||||
fake_client.join_room = AsyncMock(side_effect=_stuck_join_room)
|
||||
fake_client.sync_store = mock_sync_store
|
||||
fake_client.handle_sync = MagicMock(return_value=[])
|
||||
adapter._client = fake_client
|
||||
|
||||
await adapter._sync_loop()
|
||||
await asyncio.wait_for(join_started.wait(), timeout=1)
|
||||
|
||||
assert "!dead:example.org" not in adapter._joined_rooms
|
||||
assert "!dead:example.org" in adapter._invite_join_tasks
|
||||
fake_client.join_room.assert_awaited_once()
|
||||
|
||||
await adapter.disconnect()
|
||||
assert adapter._invite_join_tasks == {}
|
||||
|
||||
class TestMatrixUploadAndSend:
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -3270,6 +3377,7 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
@pytest.mark.asyncio
|
||||
async def test_external_media_download_rejects_oversized_content_length(self, monkeypatch):
|
||||
import aiohttp
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
class _Content:
|
||||
async def iter_chunked(self, _size):
|
||||
|
|
@ -3302,6 +3410,11 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
|
||||
self.adapter._max_media_bytes = 10
|
||||
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
|
||||
monkeypatch.setattr(
|
||||
url_safety,
|
||||
"is_safe_url",
|
||||
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="exceeds Matrix limit"):
|
||||
await self.adapter._download_external_media_with_cap(
|
||||
|
|
@ -3311,6 +3424,7 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
@pytest.mark.asyncio
|
||||
async def test_external_media_download_rejects_oversized_stream(self, monkeypatch):
|
||||
import aiohttp
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
class _Content:
|
||||
async def iter_chunked(self, _size):
|
||||
|
|
@ -3345,6 +3459,11 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
|
||||
self.adapter._max_media_bytes = 10
|
||||
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
|
||||
monkeypatch.setattr(
|
||||
url_safety,
|
||||
"is_safe_url",
|
||||
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="exceeds Matrix limit"):
|
||||
await self.adapter._download_external_media_with_cap(
|
||||
|
|
@ -3354,6 +3473,7 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
@pytest.mark.asyncio
|
||||
async def test_external_media_download_rejects_unsafe_redirect(self, monkeypatch):
|
||||
import aiohttp
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
class _Content:
|
||||
async def iter_chunked(self, _size):
|
||||
|
|
@ -3385,6 +3505,11 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
return _Response()
|
||||
|
||||
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
|
||||
monkeypatch.setattr(
|
||||
url_safety,
|
||||
"is_safe_url",
|
||||
lambda candidate, **_kwargs: str(candidate) == "https://example.com/image.png",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="unsafe redirect"):
|
||||
await self.adapter._download_external_media_with_cap(
|
||||
|
|
@ -3401,6 +3526,7 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
@pytest.mark.asyncio
|
||||
async def test_external_media_download_rejects_non_image_content_type(self, monkeypatch):
|
||||
import aiohttp
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
class _Content:
|
||||
async def iter_chunked(self, _size):
|
||||
|
|
@ -3432,6 +3558,7 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
return _Response()
|
||||
|
||||
monkeypatch.setattr(aiohttp, "ClientSession", lambda **_kwargs: _Session())
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
with pytest.raises(ValueError, match="not an image"):
|
||||
await self.adapter._download_external_media_with_cap(
|
||||
|
|
@ -3439,14 +3566,16 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_failure_log_redacts_signed_url(self, caplog):
|
||||
async def test_send_image_failure_log_redacts_signed_url(self, caplog, monkeypatch):
|
||||
from gateway.platforms.base import SendResult
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
signed_url = "https://example.com/image.png?signature=secret-token#frag"
|
||||
self.adapter._download_external_media_with_cap = AsyncMock(
|
||||
side_effect=ValueError("download failed")
|
||||
)
|
||||
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
await self.adapter.send_image("!room:example.org", signed_url)
|
||||
|
||||
|
|
@ -3455,14 +3584,16 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
assert "#frag" not in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_failure_response_does_not_expose_signed_url_query(self):
|
||||
async def test_send_image_failure_response_does_not_expose_signed_url_query(self, monkeypatch):
|
||||
from gateway.platforms.base import SendResult
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
signed_url = "https://example.com/image.png?signature=secret-token"
|
||||
self.adapter._download_external_media_with_cap = AsyncMock(
|
||||
side_effect=ValueError("download failed")
|
||||
)
|
||||
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
await self.adapter.send_image("!room:example.org", signed_url)
|
||||
|
||||
|
|
@ -3473,14 +3604,16 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
assert "source URL was not shown" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self):
|
||||
async def test_send_image_failure_response_does_not_expose_signed_url_fragment(self, monkeypatch):
|
||||
from gateway.platforms.base import SendResult
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
signed_url = "https://example.com/image.png#fragment-secret"
|
||||
self.adapter._download_external_media_with_cap = AsyncMock(
|
||||
side_effect=ValueError("download failed")
|
||||
)
|
||||
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
await self.adapter.send_image("!room:example.org", signed_url)
|
||||
|
||||
|
|
@ -3491,14 +3624,16 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
assert "source URL was not shown" in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_failure_response_preserves_caption(self):
|
||||
async def test_send_image_failure_response_preserves_caption(self, monkeypatch):
|
||||
from gateway.platforms.base import SendResult
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
signed_url = "https://example.com/image.png?signature=secret-token#fragment"
|
||||
self.adapter._download_external_media_with_cap = AsyncMock(
|
||||
side_effect=ValueError("download failed")
|
||||
)
|
||||
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
await self.adapter.send_image(
|
||||
"!room:example.org",
|
||||
|
|
@ -3514,14 +3649,16 @@ class TestMatrixImageOnlyMediaNormalization:
|
|||
assert signed_url not in sent_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_image_failure_log_still_redacts_signed_url(self, caplog):
|
||||
async def test_send_image_failure_log_still_redacts_signed_url(self, caplog, monkeypatch):
|
||||
from gateway.platforms.base import SendResult
|
||||
import tools.url_safety as url_safety
|
||||
|
||||
signed_url = "https://example.com/image.png?signature=secret-token#fragment"
|
||||
self.adapter._download_external_media_with_cap = AsyncMock(
|
||||
side_effect=ValueError("download failed")
|
||||
)
|
||||
self.adapter.send = AsyncMock(return_value=SendResult(success=True))
|
||||
monkeypatch.setattr(url_safety, "is_safe_url", lambda *_args, **_kwargs: True)
|
||||
|
||||
await self.adapter.send_image("!room:example.org", signed_url)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue