diff --git a/tests/gateway/test_matrix.py b/tests/gateway/test_matrix.py new file mode 100644 index 000000000..31e59caeb --- /dev/null +++ b/tests/gateway/test_matrix.py @@ -0,0 +1,448 @@ +"""Tests for Matrix platform adapter.""" +import json +import re +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Platform & Config +# --------------------------------------------------------------------------- + +class TestMatrixPlatformEnum: + def test_matrix_enum_exists(self): + assert Platform.MATRIX.value == "matrix" + + def test_matrix_in_platform_list(self): + platforms = [p.value for p in Platform] + assert "matrix" in platforms + + +class TestMatrixConfigLoading: + def test_apply_env_overrides_with_access_token(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX in config.platforms + mc = config.platforms[Platform.MATRIX] + assert mc.enabled is True + assert mc.token == "syt_abc123" + assert mc.extra.get("homeserver") == "https://matrix.example.org" + + def test_apply_env_overrides_with_password(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.setenv("MATRIX_PASSWORD", "secret123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_USER_ID", "@bot:example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX in config.platforms + mc = config.platforms[Platform.MATRIX] + assert mc.enabled is True + assert mc.extra.get("password") == "secret123" + assert mc.extra.get("user_id") == "@bot:example.org" + + def test_matrix_not_loaded_without_creds(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.delenv("MATRIX_PASSWORD", raising=False) + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATRIX not in config.platforms + + def test_matrix_encryption_flag(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_ENCRYPTION", "true") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("encryption") is True + + def test_matrix_encryption_default_off(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.delenv("MATRIX_ENCRYPTION", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("encryption") is False + + def test_matrix_home_room(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_HOME_ROOM", "!room123:example.org") + monkeypatch.setenv("MATRIX_HOME_ROOM_NAME", "Bot Room") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + home = config.get_home_channel(Platform.MATRIX) + assert home is not None + assert home.chat_id == "!room123:example.org" + assert home.name == "Bot Room" + + def test_matrix_user_id_stored_in_extra(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_abc123") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + monkeypatch.setenv("MATRIX_USER_ID", "@hermes:example.org") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + mc = config.platforms[Platform.MATRIX] + assert mc.extra.get("user_id") == "@hermes:example.org" + + +# --------------------------------------------------------------------------- +# Adapter helpers +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Create a MatrixAdapter with mocked config.""" + from gateway.platforms.matrix import MatrixAdapter + config = PlatformConfig( + enabled=True, + token="syt_test_token", + extra={ + "homeserver": "https://matrix.example.org", + "user_id": "@bot:example.org", + }, + ) + adapter = MatrixAdapter(config) + return adapter + + +# --------------------------------------------------------------------------- +# mxc:// URL conversion +# --------------------------------------------------------------------------- + +class TestMatrixMxcToHttp: + def setup_method(self): + self.adapter = _make_adapter() + + def test_basic_mxc_conversion(self): + """mxc://server/media_id should become an authenticated HTTP URL.""" + mxc = "mxc://matrix.org/abc123" + result = self.adapter._mxc_to_http(mxc) + assert result == "https://matrix.example.org/_matrix/client/v1/media/download/matrix.org/abc123" + + def test_mxc_with_different_server(self): + """mxc:// from a different server should still use our homeserver.""" + mxc = "mxc://other.server/media456" + result = self.adapter._mxc_to_http(mxc) + assert result.startswith("https://matrix.example.org/") + assert "other.server/media456" in result + + def test_non_mxc_url_passthrough(self): + """Non-mxc URLs should be returned unchanged.""" + url = "https://example.com/image.png" + assert self.adapter._mxc_to_http(url) == url + + def test_mxc_uses_client_v1_endpoint(self): + """Should use /_matrix/client/v1/media/download/ not the deprecated path.""" + mxc = "mxc://example.com/test123" + result = self.adapter._mxc_to_http(mxc) + assert "/_matrix/client/v1/media/download/" in result + assert "/_matrix/media/v3/download/" not in result + + +# --------------------------------------------------------------------------- +# DM detection +# --------------------------------------------------------------------------- + +class TestMatrixDmDetection: + def setup_method(self): + self.adapter = _make_adapter() + + def test_room_in_m_direct_is_dm(self): + """A room listed in m.direct should be detected as DM.""" + self.adapter._joined_rooms = {"!dm_room:ex.org", "!group_room:ex.org"} + self.adapter._dm_rooms = { + "!dm_room:ex.org": True, + "!group_room:ex.org": False, + } + + assert self.adapter._dm_rooms.get("!dm_room:ex.org") is True + assert self.adapter._dm_rooms.get("!group_room:ex.org") is False + + def test_unknown_room_not_in_cache(self): + """Unknown rooms should not be in the DM cache.""" + self.adapter._dm_rooms = {} + assert self.adapter._dm_rooms.get("!unknown:ex.org") is None + + @pytest.mark.asyncio + async def test_refresh_dm_cache_with_m_direct(self): + """_refresh_dm_cache should populate _dm_rooms from m.direct data.""" + self.adapter._joined_rooms = {"!room_a:ex.org", "!room_b:ex.org", "!room_c:ex.org"} + + mock_client = MagicMock() + mock_resp = MagicMock() + mock_resp.content = { + "@alice:ex.org": ["!room_a:ex.org"], + "@bob:ex.org": ["!room_b:ex.org"], + } + mock_client.get_account_data = AsyncMock(return_value=mock_resp) + self.adapter._client = mock_client + + await self.adapter._refresh_dm_cache() + + assert self.adapter._dm_rooms["!room_a:ex.org"] is True + assert self.adapter._dm_rooms["!room_b:ex.org"] is True + assert self.adapter._dm_rooms["!room_c:ex.org"] is False + + +# --------------------------------------------------------------------------- +# Reply fallback stripping +# --------------------------------------------------------------------------- + +class TestMatrixReplyFallbackStripping: + """Test that Matrix reply fallback lines ('> ' prefix) are stripped.""" + + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._user_id = "@bot:example.org" + self.adapter._startup_ts = 0.0 + self.adapter._dm_rooms = {} + self.adapter._message_handler = AsyncMock() + + def _strip_fallback(self, body: str, has_reply: bool = True) -> str: + """Simulate the reply fallback stripping logic from _on_room_message.""" + reply_to = "some_event_id" if has_reply else None + if reply_to and body.startswith("> "): + lines = body.split("\n") + stripped = [] + past_fallback = False + for line in lines: + if not past_fallback: + if line.startswith("> ") or line == ">": + continue + if line == "": + past_fallback = True + continue + past_fallback = True + stripped.append(line) + body = "\n".join(stripped) if stripped else body + return body + + def test_simple_reply_fallback(self): + body = "> <@alice:ex.org> Original message\n\nActual reply" + result = self._strip_fallback(body) + assert result == "Actual reply" + + def test_multiline_reply_fallback(self): + body = "> <@alice:ex.org> Line 1\n> Line 2\n\nMy response" + result = self._strip_fallback(body) + assert result == "My response" + + def test_no_reply_fallback_preserved(self): + body = "Just a normal message" + result = self._strip_fallback(body, has_reply=False) + assert result == "Just a normal message" + + def test_quote_without_reply_preserved(self): + """'> ' lines without a reply_to context should be preserved.""" + body = "> This is a blockquote" + result = self._strip_fallback(body, has_reply=False) + assert result == "> This is a blockquote" + + def test_empty_fallback_separator(self): + """The blank line between fallback and actual content should be stripped.""" + body = "> <@alice:ex.org> hi\n>\n\nResponse" + result = self._strip_fallback(body) + assert result == "Response" + + def test_multiline_response_after_fallback(self): + body = "> <@alice:ex.org> Original\n\nLine 1\nLine 2\nLine 3" + result = self._strip_fallback(body) + assert result == "Line 1\nLine 2\nLine 3" + + +# --------------------------------------------------------------------------- +# Thread detection +# --------------------------------------------------------------------------- + +class TestMatrixThreadDetection: + def test_thread_id_from_m_relates_to(self): + """m.relates_to with rel_type=m.thread should extract the event_id.""" + relates_to = { + "rel_type": "m.thread", + "event_id": "$thread_root_event", + "is_falling_back": True, + "m.in_reply_to": {"event_id": "$some_event"}, + } + # Simulate the extraction logic from _on_room_message + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id == "$thread_root_event" + + def test_no_thread_for_reply(self): + """m.in_reply_to without m.thread should not set thread_id.""" + relates_to = { + "m.in_reply_to": {"event_id": "$reply_event"}, + } + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + def test_no_thread_for_edit(self): + """m.replace relation should not set thread_id.""" + relates_to = { + "rel_type": "m.replace", + "event_id": "$edited_event", + } + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + def test_empty_relates_to(self): + """Empty m.relates_to should not set thread_id.""" + relates_to = {} + thread_id = None + if relates_to.get("rel_type") == "m.thread": + thread_id = relates_to.get("event_id") + assert thread_id is None + + +# --------------------------------------------------------------------------- +# Format message +# --------------------------------------------------------------------------- + +class TestMatrixFormatMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_image_markdown_stripped(self): + """![alt](url) should be converted to just the URL.""" + result = self.adapter.format_message("![cat](https://img.example.com/cat.png)") + assert result == "https://img.example.com/cat.png" + + def test_regular_markdown_preserved(self): + """Standard markdown should be preserved (Matrix supports it).""" + content = "**bold** and *italic* and `code`" + assert self.adapter.format_message(content) == content + + def test_plain_text_unchanged(self): + content = "Hello, world!" + assert self.adapter.format_message(content) == content + + def test_multiple_images_stripped(self): + content = "![a](http://a.com/1.png) and ![b](http://b.com/2.png)" + result = self.adapter.format_message(content) + assert "![" not in result + assert "http://a.com/1.png" in result + assert "http://b.com/2.png" in result + + +# --------------------------------------------------------------------------- +# Markdown to HTML conversion +# --------------------------------------------------------------------------- + +class TestMatrixMarkdownToHtml: + def setup_method(self): + self.adapter = _make_adapter() + + def test_bold_conversion(self): + """**bold** should produce tags.""" + result = self.adapter._markdown_to_html("**bold**") + assert "" in result or "" in result + assert "bold" in result + + def test_italic_conversion(self): + """*italic* should produce tags.""" + result = self.adapter._markdown_to_html("*italic*") + assert "" in result or "" in result + + def test_inline_code(self): + """`code` should produce tags.""" + result = self.adapter._markdown_to_html("`code`") + assert "" in result + + def test_plain_text_returns_html(self): + """Plain text should still be returned (possibly with
or

).""" + result = self.adapter._markdown_to_html("Hello world") + assert "Hello world" in result + + +# --------------------------------------------------------------------------- +# Helper: display name extraction +# --------------------------------------------------------------------------- + +class TestMatrixDisplayName: + def setup_method(self): + self.adapter = _make_adapter() + + def test_get_display_name_from_room_users(self): + """Should get display name from room's users dict.""" + mock_room = MagicMock() + mock_user = MagicMock() + mock_user.display_name = "Alice" + mock_room.users = {"@alice:ex.org": mock_user} + + name = self.adapter._get_display_name(mock_room, "@alice:ex.org") + assert name == "Alice" + + def test_get_display_name_fallback_to_localpart(self): + """Should extract localpart from @user:server format.""" + mock_room = MagicMock() + mock_room.users = {} + + name = self.adapter._get_display_name(mock_room, "@bob:example.org") + assert name == "bob" + + def test_get_display_name_no_room(self): + """Should handle None room gracefully.""" + name = self.adapter._get_display_name(None, "@charlie:ex.org") + assert name == "charlie" + + +# --------------------------------------------------------------------------- +# Requirements check +# --------------------------------------------------------------------------- + +class TestMatrixRequirements: + def test_check_requirements_with_token(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test") + monkeypatch.setenv("MATRIX_HOMESERVER", "https://matrix.example.org") + from gateway.platforms.matrix import check_matrix_requirements + try: + import nio # noqa: F401 + assert check_matrix_requirements() is True + except ImportError: + assert check_matrix_requirements() is False + + def test_check_requirements_without_creds(self, monkeypatch): + monkeypatch.delenv("MATRIX_ACCESS_TOKEN", raising=False) + monkeypatch.delenv("MATRIX_PASSWORD", raising=False) + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + from gateway.platforms.matrix import check_matrix_requirements + assert check_matrix_requirements() is False + + def test_check_requirements_without_homeserver(self, monkeypatch): + monkeypatch.setenv("MATRIX_ACCESS_TOKEN", "syt_test") + monkeypatch.delenv("MATRIX_HOMESERVER", raising=False) + from gateway.platforms.matrix import check_matrix_requirements + assert check_matrix_requirements() is False diff --git a/tests/gateway/test_mattermost.py b/tests/gateway/test_mattermost.py new file mode 100644 index 000000000..6b0fbd899 --- /dev/null +++ b/tests/gateway/test_mattermost.py @@ -0,0 +1,574 @@ +"""Tests for Mattermost platform adapter.""" +import json +import time +import pytest +from unittest.mock import MagicMock, patch, AsyncMock + +from gateway.config import Platform, PlatformConfig + + +# --------------------------------------------------------------------------- +# Platform & Config +# --------------------------------------------------------------------------- + +class TestMattermostPlatformEnum: + def test_mattermost_enum_exists(self): + assert Platform.MATTERMOST.value == "mattermost" + + def test_mattermost_in_platform_list(self): + platforms = [p.value for p in Platform] + assert "mattermost" in platforms + + +class TestMattermostConfigLoading: + def test_apply_env_overrides_mattermost(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST in config.platforms + mc = config.platforms[Platform.MATTERMOST] + assert mc.enabled is True + assert mc.token == "mm-tok-abc123" + assert mc.extra.get("url") == "https://mm.example.com" + + def test_mattermost_not_loaded_without_token(self, monkeypatch): + monkeypatch.delenv("MATTERMOST_TOKEN", raising=False) + monkeypatch.delenv("MATTERMOST_URL", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST not in config.platforms + + def test_connected_platforms_includes_mattermost(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + connected = config.get_connected_platforms() + assert Platform.MATTERMOST in connected + + def test_mattermost_home_channel(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + monkeypatch.setenv("MATTERMOST_HOME_CHANNEL", "ch_abc123") + monkeypatch.setenv("MATTERMOST_HOME_CHANNEL_NAME", "General") + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + home = config.get_home_channel(Platform.MATTERMOST) + assert home is not None + assert home.chat_id == "ch_abc123" + assert home.name == "General" + + def test_mattermost_url_warning_without_url(self, monkeypatch): + """MATTERMOST_TOKEN set but MATTERMOST_URL missing should still load.""" + monkeypatch.setenv("MATTERMOST_TOKEN", "mm-tok-abc123") + monkeypatch.delenv("MATTERMOST_URL", raising=False) + + from gateway.config import GatewayConfig, _apply_env_overrides + config = GatewayConfig() + _apply_env_overrides(config) + + assert Platform.MATTERMOST in config.platforms + assert config.platforms[Platform.MATTERMOST].extra.get("url") == "" + + +# --------------------------------------------------------------------------- +# Adapter format / truncate +# --------------------------------------------------------------------------- + +def _make_adapter(): + """Create a MattermostAdapter with mocked config.""" + from gateway.platforms.mattermost import MattermostAdapter + config = PlatformConfig( + enabled=True, + token="test-token", + extra={"url": "https://mm.example.com"}, + ) + adapter = MattermostAdapter(config) + return adapter + + +class TestMattermostFormatMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_image_markdown_to_url(self): + """![alt](url) should be converted to just the URL.""" + result = self.adapter.format_message("![cat](https://img.example.com/cat.png)") + assert result == "https://img.example.com/cat.png" + + def test_image_markdown_strips_alt_text(self): + result = self.adapter.format_message("Here: ![my image](https://x.com/a.jpg) done") + assert "![" not in result + assert "https://x.com/a.jpg" in result + + def test_regular_markdown_preserved(self): + """Regular markdown (bold, italic, code) should be kept as-is.""" + content = "**bold** and *italic* and `code`" + assert self.adapter.format_message(content) == content + + def test_regular_links_preserved(self): + """Non-image links should be preserved.""" + content = "[click](https://example.com)" + assert self.adapter.format_message(content) == content + + def test_plain_text_unchanged(self): + content = "Hello, world!" + assert self.adapter.format_message(content) == content + + def test_multiple_images(self): + content = "![a](http://a.com/1.png) text ![b](http://b.com/2.png)" + result = self.adapter.format_message(content) + assert "![" not in result + assert "http://a.com/1.png" in result + assert "http://b.com/2.png" in result + + +class TestMattermostTruncateMessage: + def setup_method(self): + self.adapter = _make_adapter() + + def test_short_message_single_chunk(self): + msg = "Hello, world!" + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) == 1 + assert chunks[0] == msg + + def test_long_message_splits(self): + msg = "a " * 2500 # 5000 chars + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) >= 2 + for chunk in chunks: + assert len(chunk) <= 4000 + + def test_custom_max_length(self): + msg = "Hello " * 20 + chunks = self.adapter.truncate_message(msg, max_length=50) + assert all(len(c) <= 50 for c in chunks) + + def test_exactly_at_limit(self): + msg = "x" * 4000 + chunks = self.adapter.truncate_message(msg, 4000) + assert len(chunks) == 1 + + +# --------------------------------------------------------------------------- +# Send +# --------------------------------------------------------------------------- + +class TestMattermostSend: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._session = MagicMock() + + @pytest.mark.asyncio + async def test_send_calls_api_post(self): + """send() should POST to /api/v4/posts with channel_id and message.""" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post123"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Hello!") + + assert result.success is True + assert result.message_id == "post123" + + # Verify post was called with correct URL + call_args = self.adapter._session.post.call_args + assert "/api/v4/posts" in call_args[0][0] + # Verify payload + payload = call_args[1]["json"] + assert payload["channel_id"] == "channel_1" + assert payload["message"] == "Hello!" + + @pytest.mark.asyncio + async def test_send_empty_content_succeeds(self): + """Empty content should return success without calling the API.""" + result = await self.adapter.send("channel_1", "") + assert result.success is True + + @pytest.mark.asyncio + async def test_send_with_thread_reply(self): + """When reply_mode is 'thread', reply_to should become root_id.""" + self.adapter._reply_mode = "thread" + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post456"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post") + + assert result.success is True + payload = self.adapter._session.post.call_args[1]["json"] + assert payload["root_id"] == "root_post" + + @pytest.mark.asyncio + async def test_send_without_thread_no_root_id(self): + """When reply_mode is 'off', reply_to should NOT set root_id.""" + self.adapter._reply_mode = "off" + + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.json = AsyncMock(return_value={"id": "post789"}) + mock_resp.text = AsyncMock(return_value="") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Reply!", reply_to="root_post") + + assert result.success is True + payload = self.adapter._session.post.call_args[1]["json"] + assert "root_id" not in payload + + @pytest.mark.asyncio + async def test_send_api_failure(self): + """When API returns error, send should return failure.""" + mock_resp = AsyncMock() + mock_resp.status = 500 + mock_resp.json = AsyncMock(return_value={}) + mock_resp.text = AsyncMock(return_value="Internal Server Error") + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + self.adapter._session.post = MagicMock(return_value=mock_resp) + + result = await self.adapter.send("channel_1", "Hello!") + + assert result.success is False + + +# --------------------------------------------------------------------------- +# WebSocket event parsing +# --------------------------------------------------------------------------- + +class TestMattermostWebSocketParsing: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._bot_user_id = "bot_user_id" + # Mock handle_message to capture the MessageEvent without processing + self.adapter.handle_message = AsyncMock() + + @pytest.mark.asyncio + async def test_parse_posted_event(self): + """'posted' events should extract message from double-encoded post JSON.""" + post_data = { + "id": "post_abc", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Hello from Matrix!", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), # double-encoded JSON string + "channel_type": "O", + "sender_name": "@alice", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.text == "Hello from Matrix!" + assert msg_event.message_id == "post_abc" + + @pytest.mark.asyncio + async def test_ignore_own_messages(self): + """Messages from the bot's own user_id should be ignored.""" + post_data = { + "id": "post_self", + "user_id": "bot_user_id", # same as bot + "channel_id": "chan_456", + "message": "Bot echo", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_ignore_non_posted_events(self): + """Non-'posted' events should be ignored.""" + event = { + "event": "typing", + "data": {"user_id": "user_123"}, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_ignore_system_posts(self): + """Posts with a 'type' field (system messages) should be ignored.""" + post_data = { + "id": "sys_post", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "user joined", + "type": "system_join_channel", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + @pytest.mark.asyncio + async def test_channel_type_mapping(self): + """channel_type 'D' should map to 'dm'.""" + post_data = { + "id": "post_dm", + "user_id": "user_123", + "channel_id": "chan_dm", + "message": "DM message", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "D", + "sender_name": "@bob", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.source.chat_type == "dm" + + @pytest.mark.asyncio + async def test_thread_id_from_root_id(self): + """Post with root_id should have thread_id set.""" + post_data = { + "id": "post_reply", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Thread reply", + "root_id": "root_post_123", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.called + msg_event = self.adapter.handle_message.call_args[0][0] + assert msg_event.source.thread_id == "root_post_123" + + @pytest.mark.asyncio + async def test_invalid_post_json_ignored(self): + """Invalid JSON in data.post should be silently ignored.""" + event = { + "event": "posted", + "data": { + "post": "not-valid-json{{{", + "channel_type": "O", + }, + } + + await self.adapter._handle_ws_event(event) + assert not self.adapter.handle_message.called + + +# --------------------------------------------------------------------------- +# File upload (send_image) +# --------------------------------------------------------------------------- + +class TestMattermostFileUpload: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._session = MagicMock() + + @pytest.mark.asyncio + async def test_send_image_downloads_and_uploads(self): + """send_image should download the URL, upload via /api/v4/files, then post.""" + # Mock the download (GET) + mock_dl_resp = AsyncMock() + mock_dl_resp.status = 200 + mock_dl_resp.read = AsyncMock(return_value=b"\x89PNG\x00fake-image-data") + mock_dl_resp.content_type = "image/png" + mock_dl_resp.__aenter__ = AsyncMock(return_value=mock_dl_resp) + mock_dl_resp.__aexit__ = AsyncMock(return_value=False) + + # Mock the upload (POST to /files) + mock_upload_resp = AsyncMock() + mock_upload_resp.status = 200 + mock_upload_resp.json = AsyncMock(return_value={ + "file_infos": [{"id": "file_abc123"}] + }) + mock_upload_resp.text = AsyncMock(return_value="") + mock_upload_resp.__aenter__ = AsyncMock(return_value=mock_upload_resp) + mock_upload_resp.__aexit__ = AsyncMock(return_value=False) + + # Mock the post (POST to /posts) + mock_post_resp = AsyncMock() + mock_post_resp.status = 200 + mock_post_resp.json = AsyncMock(return_value={"id": "post_with_file"}) + mock_post_resp.text = AsyncMock(return_value="") + mock_post_resp.__aenter__ = AsyncMock(return_value=mock_post_resp) + mock_post_resp.__aexit__ = AsyncMock(return_value=False) + + # Route calls: first GET (download), then POST (upload), then POST (create post) + self.adapter._session.get = MagicMock(return_value=mock_dl_resp) + post_call_count = 0 + original_post_returns = [mock_upload_resp, mock_post_resp] + + def post_side_effect(*args, **kwargs): + nonlocal post_call_count + resp = original_post_returns[min(post_call_count, len(original_post_returns) - 1)] + post_call_count += 1 + return resp + + self.adapter._session.post = MagicMock(side_effect=post_side_effect) + + result = await self.adapter.send_image( + "channel_1", "https://img.example.com/cat.png", caption="A cat" + ) + + assert result.success is True + assert result.message_id == "post_with_file" + + +# --------------------------------------------------------------------------- +# Dedup cache +# --------------------------------------------------------------------------- + +class TestMattermostDedup: + def setup_method(self): + self.adapter = _make_adapter() + self.adapter._bot_user_id = "bot_user_id" + # Mock handle_message to capture calls without processing + self.adapter.handle_message = AsyncMock() + + @pytest.mark.asyncio + async def test_duplicate_post_ignored(self): + """The same post_id within the TTL window should be ignored.""" + post_data = { + "id": "post_dup", + "user_id": "user_123", + "channel_id": "chan_456", + "message": "Hello!", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + + # First time: should process + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.call_count == 1 + + # Second time (same post_id): should be deduped + await self.adapter._handle_ws_event(event) + assert self.adapter.handle_message.call_count == 1 # still 1 + + @pytest.mark.asyncio + async def test_different_post_ids_both_processed(self): + """Different post IDs should both be processed.""" + for i, pid in enumerate(["post_a", "post_b"]): + post_data = { + "id": pid, + "user_id": "user_123", + "channel_id": "chan_456", + "message": f"Message {i}", + } + event = { + "event": "posted", + "data": { + "post": json.dumps(post_data), + "channel_type": "O", + "sender_name": "@alice", + }, + } + await self.adapter._handle_ws_event(event) + + assert self.adapter.handle_message.call_count == 2 + + def test_prune_seen_clears_expired(self): + """_prune_seen should remove entries older than _SEEN_TTL.""" + now = time.time() + # Fill with enough expired entries to trigger pruning + for i in range(self.adapter._SEEN_MAX + 10): + self.adapter._seen_posts[f"old_{i}"] = now - 600 # 10 min ago + + # Add a fresh one + self.adapter._seen_posts["fresh"] = now + + self.adapter._prune_seen() + + # Old entries should be pruned, fresh one kept + assert "fresh" in self.adapter._seen_posts + assert len(self.adapter._seen_posts) < self.adapter._SEEN_MAX + + def test_seen_cache_tracks_post_ids(self): + """Posts are tracked in _seen_posts dict.""" + self.adapter._seen_posts["test_post"] = time.time() + assert "test_post" in self.adapter._seen_posts + + +# --------------------------------------------------------------------------- +# Requirements check +# --------------------------------------------------------------------------- + +class TestMattermostRequirements: + def test_check_requirements_with_token_and_url(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "test-token") + monkeypatch.setenv("MATTERMOST_URL", "https://mm.example.com") + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is True + + def test_check_requirements_without_token(self, monkeypatch): + monkeypatch.delenv("MATTERMOST_TOKEN", raising=False) + monkeypatch.delenv("MATTERMOST_URL", raising=False) + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is False + + def test_check_requirements_without_url(self, monkeypatch): + monkeypatch.setenv("MATTERMOST_TOKEN", "test-token") + monkeypatch.delenv("MATTERMOST_URL", raising=False) + from gateway.platforms.mattermost import check_mattermost_requirements + assert check_mattermost_requirements() is False