From e5333e793c622a236c2639a351e115979587632a Mon Sep 17 00:00:00 2001 From: ChimingLiu Date: Fri, 17 Apr 2026 19:30:16 -0700 Subject: [PATCH] feat(discord): support forum channels --- gateway/channel_directory.py | 20 +- gateway/platforms/discord.py | 52 ++++ tests/gateway/test_channel_directory.py | 47 ++++ tests/gateway/test_discord_send.py | 119 ++++++++ tests/tools/test_send_message_tool.py | 360 ++++++++++++++++++++++++ tools/send_message_tool.py | 75 ++++- 6 files changed, 671 insertions(+), 2 deletions(-) diff --git a/gateway/channel_directory.py b/gateway/channel_directory.py index ae2beda9e..2489b718f 100644 --- a/gateway/channel_directory.py +++ b/gateway/channel_directory.py @@ -100,7 +100,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]: def _build_discord(adapter) -> List[Dict[str, str]]: - """Enumerate all text channels the Discord bot can see.""" + """Enumerate all text channels and forum channels the Discord bot can see.""" channels = [] client = getattr(adapter, "_client", None) if not client: @@ -119,6 +119,15 @@ def _build_discord(adapter) -> List[Dict[str, str]]: "guild": guild.name, "type": "channel", }) + # Forum channels (type 15) — creating a message auto-spawns a thread post. + forums = getattr(guild, "forum_channels", None) or [] + for ch in forums: + channels.append({ + "id": str(ch.id), + "name": ch.name, + "guild": guild.name, + "type": "forum", + }) # Also include DM-capable users we've interacted with is not # feasible via guild enumeration; those come from sessions. @@ -191,6 +200,15 @@ def load_directory() -> Dict[str, Any]: return {"updated_at": None, "platforms": {}} +def lookup_channel_type(platform_name: str, chat_id: str) -> Optional[str]: + """Return the channel ``type`` string (e.g. ``"channel"``, ``"forum"``) for *chat_id*, or *None* if unknown.""" + directory = load_directory() + for ch in directory.get("platforms", {}).get(platform_name, []): + if ch.get("id") == chat_id: + return ch.get("type") + return None + + def resolve_channel_name(platform_name: str, name: str) -> Optional[str]: """ Resolve a human-friendly channel name to a numeric ID. diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index a53908145..7367b8669 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -857,6 +857,9 @@ class DiscordAdapter(BasePlatformAdapter): When metadata contains a thread_id, the message is sent to that thread instead of the parent channel identified by chat_id. + + Forum channels (type 15) reject direct messages — a thread post is + created automatically. """ if not self._client: return SendResult(success=False, error="Not connected") @@ -882,6 +885,10 @@ class DiscordAdapter(BasePlatformAdapter): if not channel: return SendResult(success=False, error=f"Channel {chat_id} not found") + # Forum channels reject channel.send() — create a thread post instead. + if self._is_forum_parent(channel): + return await self._send_to_forum(channel, content) + # Format and split message if needed formatted = self.format_message(content) chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) @@ -945,6 +952,51 @@ class DiscordAdapter(BasePlatformAdapter): logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True) return SendResult(success=False, error=str(e)) + async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult: + """Create a thread post in a forum channel with the message as starter content. + + Forum channels (type 15) don't support direct messages. Instead we + POST to /channels/{forum_id}/threads with a thread name derived from + the first line of the message. + """ + from tools.send_message_tool import _derive_forum_thread_name + + formatted = self.format_message(content) + chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH) + + thread_name = _derive_forum_thread_name(content) + + starter_content = chunks[0] if chunks else thread_name + + try: + thread = await forum_channel.create_thread( + name=thread_name, + content=starter_content, + ) + except Exception as e: + logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e) + return SendResult(success=False, error=f"Forum thread creation failed: {e}") + + thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None) + thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", ""))) + starter_msg = getattr(thread, "message", None) + message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id + + # Send remaining chunks into the newly created thread. + message_ids = [message_id] + for chunk in chunks[1:]: + try: + msg = await thread_channel.send(content=chunk) + message_ids.append(str(msg.id)) + except Exception as e: + logger.warning("[%s] Failed to send follow-up chunk to forum thread %s: %s", self.name, thread_id, e) + + return SendResult( + success=True, + message_id=message_ids[0], + raw_response={"message_ids": message_ids, "thread_id": thread_id}, + ) + async def edit_message( self, chat_id: str, diff --git a/tests/gateway/test_channel_directory.py b/tests/gateway/test_channel_directory.py index 50d5b04b7..6c1b8fc73 100644 --- a/tests/gateway/test_channel_directory.py +++ b/tests/gateway/test_channel_directory.py @@ -7,6 +7,7 @@ from unittest.mock import patch from gateway.channel_directory import ( build_channel_directory, + lookup_channel_type, resolve_channel_name, format_directory_for_display, load_directory, @@ -285,3 +286,49 @@ class TestFormatDirectoryForDisplay: assert "Discord (Server1):" in result assert "Discord (Server2):" in result assert "discord:#general" in result + + +class TestLookupChannelType: + def _setup(self, tmp_path, platforms): + cache_file = _write_directory(tmp_path, platforms) + return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file) + + def test_forum_channel(self, tmp_path): + platforms = { + "discord": [ + {"id": "100", "name": "ideas", "guild": "Server1", "type": "forum"}, + ] + } + with self._setup(tmp_path, platforms): + assert lookup_channel_type("discord", "100") == "forum" + + def test_regular_channel(self, tmp_path): + platforms = { + "discord": [ + {"id": "200", "name": "general", "guild": "Server1", "type": "channel"}, + ] + } + with self._setup(tmp_path, platforms): + assert lookup_channel_type("discord", "200") == "channel" + + def test_unknown_chat_id_returns_none(self, tmp_path): + platforms = { + "discord": [ + {"id": "200", "name": "general", "guild": "Server1", "type": "channel"}, + ] + } + with self._setup(tmp_path, platforms): + assert lookup_channel_type("discord", "999") is None + + def test_unknown_platform_returns_none(self, tmp_path): + with self._setup(tmp_path, {}): + assert lookup_channel_type("discord", "100") is None + + def test_channel_without_type_key_returns_none(self, tmp_path): + platforms = { + "discord": [ + {"id": "300", "name": "general", "guild": "Server1"}, + ] + } + with self._setup(tmp_path, platforms): + assert lookup_channel_type("discord", "300") is None diff --git a/tests/gateway/test_discord_send.py b/tests/gateway/test_discord_send.py index 7d387cb08..a8b1e1529 100644 --- a/tests/gateway/test_discord_send.py +++ b/tests/gateway/test_discord_send.py @@ -157,3 +157,122 @@ async def test_send_does_not_retry_on_unrelated_errors(): # Only the first attempt happens — no reference-retry replay. assert channel.send.await_count == 1 assert send_calls[0]["reference"] is reference_obj + + +# --------------------------------------------------------------------------- +# Forum channel tests +# --------------------------------------------------------------------------- + +import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock + + +class TestIsForumParent: + def test_none_returns_false(self): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + assert adapter._is_forum_parent(None) is False + + def test_forum_channel_class_instance(self): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + forum_cls = getattr(_discord_mod, "ForumChannel", None) + if forum_cls is None: + # Re-create a type for the mock + forum_cls = type("ForumChannel", (), {}) + _discord_mod.ForumChannel = forum_cls + ch = forum_cls() + assert adapter._is_forum_parent(ch) is True + + def test_type_value_15(self): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + ch = SimpleNamespace(type=15) + assert adapter._is_forum_parent(ch) is True + + def test_regular_channel_returns_false(self): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + ch = SimpleNamespace(type=0) + assert adapter._is_forum_parent(ch) is False + + def test_thread_returns_false(self): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + ch = SimpleNamespace(type=11) # public thread + assert adapter._is_forum_parent(ch) is False + + +@pytest.mark.asyncio +async def test_send_to_forum_creates_thread_post(): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + + # thread object has no 'send' so _send_to_forum uses thread.thread + thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600))) + thread = SimpleNamespace( + id=555, + message=SimpleNamespace(id=500), + thread=thread_ch, + ) + forum_channel = _discord_mod.ForumChannel() + forum_channel.id = 999 + forum_channel.name = "ideas" + forum_channel.create_thread = AsyncMock(return_value=thread) + adapter._client = SimpleNamespace( + get_channel=lambda _chat_id: forum_channel, + fetch_channel=AsyncMock(), + ) + + result = await adapter.send("999", "Hello forum!") + + assert result.success is True + assert result.message_id == "500" + forum_channel.create_thread.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_to_forum_sends_remaining_chunks(): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + # Force a small max message length so the message splits + adapter.MAX_MESSAGE_LENGTH = 20 + + chunk_msg_1 = SimpleNamespace(id=500) + chunk_msg_2 = SimpleNamespace(id=501) + thread_ch = SimpleNamespace( + id=555, + send=AsyncMock(return_value=chunk_msg_2), + ) + # thread object has no 'send' so _send_to_forum uses thread.thread + thread = SimpleNamespace( + id=555, + message=chunk_msg_1, + thread=thread_ch, + ) + forum_channel = _discord_mod.ForumChannel() + forum_channel.id = 999 + forum_channel.name = "ideas" + forum_channel.create_thread = AsyncMock(return_value=thread) + adapter._client = SimpleNamespace( + get_channel=lambda _chat_id: forum_channel, + fetch_channel=AsyncMock(), + ) + + result = await adapter.send("999", "A" * 50) + + assert result.success is True + assert result.message_id == "500" + # Should have sent at least one follow-up chunk + assert thread_ch.send.await_count >= 1 + + +@pytest.mark.asyncio +async def test_send_to_forum_create_thread_failure(): + adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***")) + + forum_channel = _discord_mod.ForumChannel() + forum_channel.id = 999 + forum_channel.name = "ideas" + forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited")) + adapter._client = SimpleNamespace( + get_channel=lambda _chat_id: forum_channel, + fetch_channel=AsyncMock(), + ) + + result = await adapter.send("999", "Hello forum!") + + assert result.success is False + assert "rate limited" in result.error diff --git a/tests/tools/test_send_message_tool.py b/tests/tools/test_send_message_tool.py index 729a1fdec..a2db83f5e 100644 --- a/tests/tools/test_send_message_tool.py +++ b/tests/tools/test_send_message_tool.py @@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from gateway.config import Platform from tools.send_message_tool import ( + _derive_forum_thread_name, _parse_target_ref, _send_discord, _send_matrix_via_adapter, @@ -1234,3 +1235,362 @@ class TestSendMatrixUrlEncoding: put_url = mock_session.put.call_args[0][0] assert "%21HLOQwxYGgFPMPJUSNR%3Amatrix.org" in put_url assert "!HLOQwxYGgFPMPJUSNR:matrix.org" not in put_url + + +# --------------------------------------------------------------------------- +# Tests for _derive_forum_thread_name +# --------------------------------------------------------------------------- + + +class TestDeriveForumThreadName: + def test_single_line_message(self): + assert _derive_forum_thread_name("Hello world") == "Hello world" + + def test_multi_line_uses_first_line(self): + assert _derive_forum_thread_name("First line\nSecond line") == "First line" + + def test_strips_markdown_heading(self): + assert _derive_forum_thread_name("## My Heading") == "My Heading" + + def test_strips_multiple_hash_levels(self): + assert _derive_forum_thread_name("### Deep heading") == "Deep heading" + + def test_empty_message_falls_back_to_default(self): + assert _derive_forum_thread_name("") == "New Post" + + def test_whitespace_only_falls_back(self): + assert _derive_forum_thread_name(" \n ") == "New Post" + + def test_hash_only_falls_back(self): + assert _derive_forum_thread_name("###") == "New Post" + + def test_truncates_to_100_chars(self): + long_title = "A" * 200 + result = _derive_forum_thread_name(long_title) + assert len(result) == 100 + + def test_strips_whitespace_around_first_line(self): + assert _derive_forum_thread_name(" Title \nBody") == "Title" + + +# --------------------------------------------------------------------------- +# Tests for _send_discord with forum channel support +# --------------------------------------------------------------------------- + + +class TestSendDiscordForum: + """_send_discord creates thread posts for forum channels.""" + + @staticmethod + def _build_mock(response_status, response_data=None, response_text="error body"): + mock_resp = MagicMock() + mock_resp.status = response_status + mock_resp.json = AsyncMock(return_value=response_data or {}) + mock_resp.text = AsyncMock(return_value=response_text) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock(return_value=mock_resp) + mock_session.get = MagicMock(return_value=mock_resp) + + return mock_session, mock_resp + + def test_directory_forum_creates_thread(self): + """Directory says 'forum' — creates a thread post.""" + thread_data = { + "id": "t123", + "message": {"id": "m456"}, + } + mock_session, _ = self._build_mock(200, response_data=thread_data) + + with patch("aiohttp.ClientSession", return_value=mock_session), \ + patch("gateway.channel_directory.lookup_channel_type", return_value="forum"): + result = asyncio.run( + _send_discord("tok", "forum_ch", "Hello forum") + ) + + assert result["success"] is True + assert result["thread_id"] == "t123" + assert result["message_id"] == "m456" + # Should POST to threads endpoint, not messages + call_url = mock_session.post.call_args.args[0] + assert "/threads" in call_url + assert "/messages" not in call_url + + def test_directory_forum_skips_probe(self): + """When directory says 'forum', no GET probe is made.""" + thread_data = {"id": "t123", "message": {"id": "m456"}} + mock_session, _ = self._build_mock(200, response_data=thread_data) + + with patch("aiohttp.ClientSession", return_value=mock_session), \ + patch("gateway.channel_directory.lookup_channel_type", return_value="forum"): + asyncio.run( + _send_discord("tok", "forum_ch", "Hello") + ) + + # get() should never be called — directory resolved the type + mock_session.get.assert_not_called() + + def test_directory_channel_skips_forum(self): + """When directory says 'channel', sends via normal messages endpoint.""" + mock_session, _ = self._build_mock(200, response_data={"id": "msg1"}) + + with patch("aiohttp.ClientSession", return_value=mock_session), \ + patch("gateway.channel_directory.lookup_channel_type", return_value="channel"): + result = asyncio.run( + _send_discord("tok", "ch1", "Hello") + ) + + assert result["success"] is True + call_url = mock_session.post.call_args.args[0] + assert "/messages" in call_url + assert "/threads" not in call_url + + def test_directory_none_probes_and_detects_forum(self): + """When directory has no entry, probes GET /channels/{id} and detects type 15.""" + probe_resp = MagicMock() + probe_resp.status = 200 + probe_resp.json = AsyncMock(return_value={"type": 15}) + probe_resp.__aenter__ = AsyncMock(return_value=probe_resp) + probe_resp.__aexit__ = AsyncMock(return_value=None) + + thread_data = {"id": "t999", "message": {"id": "m888"}} + thread_resp = MagicMock() + thread_resp.status = 200 + thread_resp.json = AsyncMock(return_value=thread_data) + thread_resp.text = AsyncMock(return_value="") + thread_resp.__aenter__ = AsyncMock(return_value=thread_resp) + thread_resp.__aexit__ = AsyncMock(return_value=None) + + probe_session = MagicMock() + probe_session.__aenter__ = AsyncMock(return_value=probe_session) + probe_session.__aexit__ = AsyncMock(return_value=None) + probe_session.get = MagicMock(return_value=probe_resp) + + thread_session = MagicMock() + thread_session.__aenter__ = AsyncMock(return_value=thread_session) + thread_session.__aexit__ = AsyncMock(return_value=None) + thread_session.post = MagicMock(return_value=thread_resp) + + session_iter = iter([probe_session, thread_session]) + + with patch("aiohttp.ClientSession", side_effect=lambda **kw: next(session_iter)), \ + patch("gateway.channel_directory.lookup_channel_type", return_value=None): + result = asyncio.run( + _send_discord("tok", "forum_ch", "Hello probe") + ) + + assert result["success"] is True + assert result["thread_id"] == "t999" + + def test_directory_lookup_exception_falls_through_to_probe(self): + """When lookup_channel_type raises, falls through to API probe.""" + mock_session, _ = self._build_mock(200, response_data={"id": "msg1"}) + + with patch("aiohttp.ClientSession", return_value=mock_session), \ + patch("gateway.channel_directory.lookup_channel_type", side_effect=Exception("io error")): + result = asyncio.run( + _send_discord("tok", "ch1", "Hello") + ) + + assert result["success"] is True + # Falls through to probe (GET) + mock_session.get.assert_called_once() + + def test_forum_thread_creation_error(self): + """Forum thread creation returning non-200/201 returns an error dict.""" + mock_session, _ = self._build_mock(403, response_text="Forbidden") + + with patch("aiohttp.ClientSession", return_value=mock_session), \ + patch("gateway.channel_directory.lookup_channel_type", return_value="forum"): + result = asyncio.run( + _send_discord("tok", "forum_ch", "Hello") + ) + + assert "error" in result + assert "403" in result["error"] + + +# --------------------------------------------------------------------------- +# Discord media attachment support +# --------------------------------------------------------------------------- + + +class TestSendDiscordMedia: + """_send_discord uploads media files via multipart/form-data.""" + + @staticmethod + def _build_mock(response_status, response_data=None, response_text="error body"): + """Build a properly-structured aiohttp mock chain.""" + mock_resp = MagicMock() + mock_resp.status = response_status + mock_resp.json = AsyncMock(return_value=response_data or {"id": "msg123"}) + mock_resp.text = AsyncMock(return_value=response_text) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock(return_value=mock_resp) + + return mock_session, mock_resp + + def test_text_and_media_sends_both(self, tmp_path): + """Text message is sent first, then each media file as multipart.""" + img = tmp_path / "photo.png" + img.write_bytes(b"\x89PNG fake image data") + + mock_session, _ = self._build_mock(200, {"id": "msg999"}) + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "111", "hello", media_files=[(str(img), False)]) + ) + + assert result["success"] is True + assert result["message_id"] == "msg999" + # Two POSTs: one text JSON, one multipart upload + assert mock_session.post.call_count == 2 + + def test_media_only_skips_text_post(self, tmp_path): + """When message is empty and media is present, text POST is skipped.""" + img = tmp_path / "photo.png" + img.write_bytes(b"\x89PNG fake image data") + + mock_session, _ = self._build_mock(200, {"id": "media_only"}) + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "222", " ", media_files=[(str(img), False)]) + ) + + assert result["success"] is True + # Only one POST: the media upload (text was whitespace-only) + assert mock_session.post.call_count == 1 + + def test_missing_media_file_collected_as_warning(self): + """Non-existent media paths produce warnings but don't fail.""" + mock_session, _ = self._build_mock(200, {"id": "txt_ok"}) + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "333", "hello", media_files=[("/nonexistent/file.png", False)]) + ) + + assert result["success"] is True + assert "warnings" in result + assert any("not found" in w for w in result["warnings"]) + # Only the text POST was made, media was skipped + assert mock_session.post.call_count == 1 + + def test_media_upload_failure_collected_as_warning(self, tmp_path): + """Failed media upload becomes a warning, text still succeeds.""" + img = tmp_path / "photo.png" + img.write_bytes(b"\x89PNG fake image data") + + # First call (text) succeeds, second call (media) returns 413 + text_resp = MagicMock() + text_resp.status = 200 + text_resp.json = AsyncMock(return_value={"id": "txt_ok"}) + text_resp.__aenter__ = AsyncMock(return_value=text_resp) + text_resp.__aexit__ = AsyncMock(return_value=None) + + media_resp = MagicMock() + media_resp.status = 413 + media_resp.text = AsyncMock(return_value="Request Entity Too Large") + media_resp.__aenter__ = AsyncMock(return_value=media_resp) + media_resp.__aexit__ = AsyncMock(return_value=None) + + mock_session = MagicMock() + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + mock_session.post = MagicMock(side_effect=[text_resp, media_resp]) + + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "444", "hello", media_files=[(str(img), False)]) + ) + + assert result["success"] is True + assert result["message_id"] == "txt_ok" + assert "warnings" in result + assert any("413" in w for w in result["warnings"]) + + def test_no_text_no_media_returns_error(self): + """Empty text with no media returns error dict.""" + mock_session, _ = self._build_mock(200) + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "555", "", media_files=[]) + ) + + # Text is empty but media_files is empty, so text POST fires + # (the "skip text if media present" condition isn't met) + assert result["success"] is True + + def test_multiple_media_files_uploaded_separately(self, tmp_path): + """Each media file gets its own multipart POST.""" + img1 = tmp_path / "a.png" + img1.write_bytes(b"img1") + img2 = tmp_path / "b.jpg" + img2.write_bytes(b"img2") + + mock_session, _ = self._build_mock(200, {"id": "last"}) + with patch("aiohttp.ClientSession", return_value=mock_session): + result = asyncio.run( + _send_discord("tok", "666", "hi", media_files=[ + (str(img1), False), (str(img2), False) + ]) + ) + + assert result["success"] is True + # 1 text POST + 2 media POSTs = 3 + assert mock_session.post.call_count == 3 + + +# --------------------------------------------------------------------------- +# Tests for _send_to_platform with forum channel detection +# --------------------------------------------------------------------------- + + +class TestSendToPlatformDiscordForum: + """_send_to_platform delegates forum detection to _send_discord.""" + + def test_send_to_platform_discord_delegates_to_send_discord(self): + """Discord messages are routed through _send_discord, which handles forum detection.""" + send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) + + with patch("tools.send_message_tool._send_discord", send_mock): + result = asyncio.run( + _send_to_platform( + Platform.DISCORD, + SimpleNamespace(enabled=True, token="tok", extra={}), + "forum_ch", + "Hello forum", + ) + ) + + assert result["success"] is True + send_mock.assert_awaited_once_with( + "tok", "forum_ch", "Hello forum", media_files=[], thread_id=None, + ) + + def test_send_to_platform_discord_with_thread_id(self): + """Thread ID is still passed through when sending to Discord.""" + send_mock = AsyncMock(return_value={"success": True, "message_id": "1"}) + + with patch("tools.send_message_tool._send_discord", send_mock): + result = asyncio.run( + _send_to_platform( + Platform.DISCORD, + SimpleNamespace(enabled=True, token="tok", extra={}), + "ch1", + "Hello thread", + thread_id="17585", + ) + ) + + assert result["success"] is True + _, call_kwargs = send_mock.await_args + assert call_kwargs["thread_id"] == "17585" diff --git a/tools/send_message_tool.py b/tools/send_message_tool.py index bb2747686..281185104 100644 --- a/tools/send_message_tool.py +++ b/tools/send_message_tool.py @@ -685,6 +685,16 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No return _error(f"Telegram send failed: {e}") +def _derive_forum_thread_name(message: str) -> str: + """Derive a thread name from the first line of the message, capped at 100 chars.""" + first_line = message.strip().split("\n", 1)[0].strip() + # Strip common markdown heading prefixes + first_line = first_line.lstrip("#").strip() + if not first_line: + first_line = "New Post" + return first_line[:100] + + async def _send_discord(token, chat_id, message, thread_id=None, media_files=None): """Send a single message via Discord REST API (no websocket client needed). @@ -693,6 +703,12 @@ async def _send_discord(token, chat_id, message, thread_id=None, media_files=Non When thread_id is provided, the message is sent directly to that thread via the /channels/{thread_id}/messages endpoint. + Forum channels (type 15) reject POST /messages — auto-create a thread + post instead via POST /channels/{id}/threads. + + Channel type is resolved from the channel directory first; only falls + back to a GET /channels/{id} probe when the directory has no entry. + Media files are uploaded one-by-one via multipart/form-data after the text message is sent (same pattern as Telegram). """ @@ -704,16 +720,73 @@ async def _send_discord(token, chat_id, message, thread_id=None, media_files=Non from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp _proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY") _sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy) + headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"} + # Thread endpoint: Discord threads are channels; send directly to the thread ID. if thread_id: url = f"https://discord.com/api/v10/channels/{thread_id}/messages" else: + # Check if the target channel is a forum channel (type 15). + # Forum channels reject POST /messages — create a thread post instead. + # Try the channel directory first; fall back to an API probe only + # when the directory has no entry. + _channel_type = None + try: + from gateway.channel_directory import lookup_channel_type + _channel_type = lookup_channel_type("discord", chat_id) + except Exception: + pass + + if _channel_type == "forum": + is_forum = True + elif _channel_type is not None: + # Known non-forum type — skip the probe. + is_forum = False + else: + is_forum = False + try: + info_url = f"https://discord.com/api/v10/channels/{chat_id}" + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15), **_sess_kw) as info_sess: + async with info_sess.get(info_url, headers=headers, **_req_kw) as info_resp: + if info_resp.status == 200: + info = await info_resp.json() + is_forum = info.get("type") == 15 + except Exception: + logger.debug("Failed to probe channel type for %s", chat_id, exc_info=True) + + + if is_forum: + thread_name = _derive_forum_thread_name(message) + thread_url = f"https://discord.com/api/v10/channels/{chat_id}/threads" + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session: + async with session.post( + thread_url, + headers=headers, + json={ + "name": thread_name, + "message": {"content": message}, + }, + **_req_kw, + ) as resp: + if resp.status not in (200, 201): + body = await resp.text() + return _error(f"Discord forum thread creation error ({resp.status}): {body}") + data = await resp.json() + thread_id_created = data.get("id") + starter_msg_id = (data.get("message") or {}).get("id", thread_id_created) + return { + "success": True, + "platform": "discord", + "chat_id": chat_id, + "thread_id": thread_id_created, + "message_id": starter_msg_id, + } + url = f"https://discord.com/api/v10/channels/{chat_id}/messages" auth_headers = {"Authorization": f"Bot {token}"} media_files = media_files or [] last_data = None warnings = [] - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session: # Send text message (skip if empty and media is present) if message.strip() or not media_files: