mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(discord): support forum channels
This commit is contained in:
parent
148459716c
commit
e5333e793c
6 changed files with 671 additions and 2 deletions
|
|
@ -100,7 +100,7 @@ def build_channel_directory(adapters: Dict[Any, Any]) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def _build_discord(adapter) -> List[Dict[str, str]]:
|
||||
"""Enumerate all text channels the Discord bot can see."""
|
||||
"""Enumerate all text channels and forum channels the Discord bot can see."""
|
||||
channels = []
|
||||
client = getattr(adapter, "_client", None)
|
||||
if not client:
|
||||
|
|
@ -119,6 +119,15 @@ def _build_discord(adapter) -> List[Dict[str, str]]:
|
|||
"guild": guild.name,
|
||||
"type": "channel",
|
||||
})
|
||||
# Forum channels (type 15) — creating a message auto-spawns a thread post.
|
||||
forums = getattr(guild, "forum_channels", None) or []
|
||||
for ch in forums:
|
||||
channels.append({
|
||||
"id": str(ch.id),
|
||||
"name": ch.name,
|
||||
"guild": guild.name,
|
||||
"type": "forum",
|
||||
})
|
||||
# Also include DM-capable users we've interacted with is not
|
||||
# feasible via guild enumeration; those come from sessions.
|
||||
|
||||
|
|
@ -191,6 +200,15 @@ def load_directory() -> Dict[str, Any]:
|
|||
return {"updated_at": None, "platforms": {}}
|
||||
|
||||
|
||||
def lookup_channel_type(platform_name: str, chat_id: str) -> Optional[str]:
|
||||
"""Return the channel ``type`` string (e.g. ``"channel"``, ``"forum"``) for *chat_id*, or *None* if unknown."""
|
||||
directory = load_directory()
|
||||
for ch in directory.get("platforms", {}).get(platform_name, []):
|
||||
if ch.get("id") == chat_id:
|
||||
return ch.get("type")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_channel_name(platform_name: str, name: str) -> Optional[str]:
|
||||
"""
|
||||
Resolve a human-friendly channel name to a numeric ID.
|
||||
|
|
|
|||
|
|
@ -857,6 +857,9 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
|
||||
When metadata contains a thread_id, the message is sent to that
|
||||
thread instead of the parent channel identified by chat_id.
|
||||
|
||||
Forum channels (type 15) reject direct messages — a thread post is
|
||||
created automatically.
|
||||
"""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
|
|
@ -882,6 +885,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if not channel:
|
||||
return SendResult(success=False, error=f"Channel {chat_id} not found")
|
||||
|
||||
# Forum channels reject channel.send() — create a thread post instead.
|
||||
if self._is_forum_parent(channel):
|
||||
return await self._send_to_forum(channel, content)
|
||||
|
||||
# Format and split message if needed
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
|
@ -945,6 +952,51 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
logger.error("[%s] Failed to send Discord message: %s", self.name, e, exc_info=True)
|
||||
return SendResult(success=False, error=str(e))
|
||||
|
||||
async def _send_to_forum(self, forum_channel: Any, content: str) -> SendResult:
|
||||
"""Create a thread post in a forum channel with the message as starter content.
|
||||
|
||||
Forum channels (type 15) don't support direct messages. Instead we
|
||||
POST to /channels/{forum_id}/threads with a thread name derived from
|
||||
the first line of the message.
|
||||
"""
|
||||
from tools.send_message_tool import _derive_forum_thread_name
|
||||
|
||||
formatted = self.format_message(content)
|
||||
chunks = self.truncate_message(formatted, self.MAX_MESSAGE_LENGTH)
|
||||
|
||||
thread_name = _derive_forum_thread_name(content)
|
||||
|
||||
starter_content = chunks[0] if chunks else thread_name
|
||||
|
||||
try:
|
||||
thread = await forum_channel.create_thread(
|
||||
name=thread_name,
|
||||
content=starter_content,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("[%s] Failed to create forum thread in %s: %s", self.name, forum_channel.id, e)
|
||||
return SendResult(success=False, error=f"Forum thread creation failed: {e}")
|
||||
|
||||
thread_channel = thread if hasattr(thread, "send") else getattr(thread, "thread", None)
|
||||
thread_id = str(getattr(thread_channel, "id", getattr(thread, "id", "")))
|
||||
starter_msg = getattr(thread, "message", None)
|
||||
message_id = str(getattr(starter_msg, "id", thread_id)) if starter_msg else thread_id
|
||||
|
||||
# Send remaining chunks into the newly created thread.
|
||||
message_ids = [message_id]
|
||||
for chunk in chunks[1:]:
|
||||
try:
|
||||
msg = await thread_channel.send(content=chunk)
|
||||
message_ids.append(str(msg.id))
|
||||
except Exception as e:
|
||||
logger.warning("[%s] Failed to send follow-up chunk to forum thread %s: %s", self.name, thread_id, e)
|
||||
|
||||
return SendResult(
|
||||
success=True,
|
||||
message_id=message_ids[0],
|
||||
raw_response={"message_ids": message_ids, "thread_id": thread_id},
|
||||
)
|
||||
|
||||
async def edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from unittest.mock import patch
|
|||
|
||||
from gateway.channel_directory import (
|
||||
build_channel_directory,
|
||||
lookup_channel_type,
|
||||
resolve_channel_name,
|
||||
format_directory_for_display,
|
||||
load_directory,
|
||||
|
|
@ -285,3 +286,49 @@ class TestFormatDirectoryForDisplay:
|
|||
assert "Discord (Server1):" in result
|
||||
assert "Discord (Server2):" in result
|
||||
assert "discord:#general" in result
|
||||
|
||||
|
||||
class TestLookupChannelType:
|
||||
def _setup(self, tmp_path, platforms):
|
||||
cache_file = _write_directory(tmp_path, platforms)
|
||||
return patch("gateway.channel_directory.DIRECTORY_PATH", cache_file)
|
||||
|
||||
def test_forum_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "100", "name": "ideas", "guild": "Server1", "type": "forum"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "100") == "forum"
|
||||
|
||||
def test_regular_channel(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "200") == "channel"
|
||||
|
||||
def test_unknown_chat_id_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "200", "name": "general", "guild": "Server1", "type": "channel"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "999") is None
|
||||
|
||||
def test_unknown_platform_returns_none(self, tmp_path):
|
||||
with self._setup(tmp_path, {}):
|
||||
assert lookup_channel_type("discord", "100") is None
|
||||
|
||||
def test_channel_without_type_key_returns_none(self, tmp_path):
|
||||
platforms = {
|
||||
"discord": [
|
||||
{"id": "300", "name": "general", "guild": "Server1"},
|
||||
]
|
||||
}
|
||||
with self._setup(tmp_path, platforms):
|
||||
assert lookup_channel_type("discord", "300") is None
|
||||
|
|
|
|||
|
|
@ -157,3 +157,122 @@ async def test_send_does_not_retry_on_unrelated_errors():
|
|||
# Only the first attempt happens — no reference-retry replay.
|
||||
assert channel.send.await_count == 1
|
||||
assert send_calls[0]["reference"] is reference_obj
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Forum channel tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
import discord as _discord_mod # noqa: E402 — imported after _ensure_discord_mock
|
||||
|
||||
|
||||
class TestIsForumParent:
|
||||
def test_none_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
assert adapter._is_forum_parent(None) is False
|
||||
|
||||
def test_forum_channel_class_instance(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
forum_cls = getattr(_discord_mod, "ForumChannel", None)
|
||||
if forum_cls is None:
|
||||
# Re-create a type for the mock
|
||||
forum_cls = type("ForumChannel", (), {})
|
||||
_discord_mod.ForumChannel = forum_cls
|
||||
ch = forum_cls()
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_type_value_15(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=15)
|
||||
assert adapter._is_forum_parent(ch) is True
|
||||
|
||||
def test_regular_channel_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=0)
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
def test_thread_returns_false(self):
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
ch = SimpleNamespace(type=11) # public thread
|
||||
assert adapter._is_forum_parent(ch) is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_creates_thread_post():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread_ch = SimpleNamespace(id=555, send=AsyncMock(return_value=SimpleNamespace(id=600)))
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=SimpleNamespace(id=500),
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
forum_channel.create_thread.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_sends_remaining_chunks():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
# Force a small max message length so the message splits
|
||||
adapter.MAX_MESSAGE_LENGTH = 20
|
||||
|
||||
chunk_msg_1 = SimpleNamespace(id=500)
|
||||
chunk_msg_2 = SimpleNamespace(id=501)
|
||||
thread_ch = SimpleNamespace(
|
||||
id=555,
|
||||
send=AsyncMock(return_value=chunk_msg_2),
|
||||
)
|
||||
# thread object has no 'send' so _send_to_forum uses thread.thread
|
||||
thread = SimpleNamespace(
|
||||
id=555,
|
||||
message=chunk_msg_1,
|
||||
thread=thread_ch,
|
||||
)
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(return_value=thread)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "A" * 50)
|
||||
|
||||
assert result.success is True
|
||||
assert result.message_id == "500"
|
||||
# Should have sent at least one follow-up chunk
|
||||
assert thread_ch.send.await_count >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_to_forum_create_thread_failure():
|
||||
adapter = DiscordAdapter(PlatformConfig(enabled=True, token="***"))
|
||||
|
||||
forum_channel = _discord_mod.ForumChannel()
|
||||
forum_channel.id = 999
|
||||
forum_channel.name = "ideas"
|
||||
forum_channel.create_thread = AsyncMock(side_effect=Exception("rate limited"))
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=lambda _chat_id: forum_channel,
|
||||
fetch_channel=AsyncMock(),
|
||||
)
|
||||
|
||||
result = await adapter.send("999", "Hello forum!")
|
||||
|
||||
assert result.success is False
|
||||
assert "rate limited" in result.error
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
from gateway.config import Platform
|
||||
from tools.send_message_tool import (
|
||||
_derive_forum_thread_name,
|
||||
_parse_target_ref,
|
||||
_send_discord,
|
||||
_send_matrix_via_adapter,
|
||||
|
|
@ -1234,3 +1235,362 @@ class TestSendMatrixUrlEncoding:
|
|||
put_url = mock_session.put.call_args[0][0]
|
||||
assert "%21HLOQwxYGgFPMPJUSNR%3Amatrix.org" in put_url
|
||||
assert "!HLOQwxYGgFPMPJUSNR:matrix.org" not in put_url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _derive_forum_thread_name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeriveForumThreadName:
|
||||
def test_single_line_message(self):
|
||||
assert _derive_forum_thread_name("Hello world") == "Hello world"
|
||||
|
||||
def test_multi_line_uses_first_line(self):
|
||||
assert _derive_forum_thread_name("First line\nSecond line") == "First line"
|
||||
|
||||
def test_strips_markdown_heading(self):
|
||||
assert _derive_forum_thread_name("## My Heading") == "My Heading"
|
||||
|
||||
def test_strips_multiple_hash_levels(self):
|
||||
assert _derive_forum_thread_name("### Deep heading") == "Deep heading"
|
||||
|
||||
def test_empty_message_falls_back_to_default(self):
|
||||
assert _derive_forum_thread_name("") == "New Post"
|
||||
|
||||
def test_whitespace_only_falls_back(self):
|
||||
assert _derive_forum_thread_name(" \n ") == "New Post"
|
||||
|
||||
def test_hash_only_falls_back(self):
|
||||
assert _derive_forum_thread_name("###") == "New Post"
|
||||
|
||||
def test_truncates_to_100_chars(self):
|
||||
long_title = "A" * 200
|
||||
result = _derive_forum_thread_name(long_title)
|
||||
assert len(result) == 100
|
||||
|
||||
def test_strips_whitespace_around_first_line(self):
|
||||
assert _derive_forum_thread_name(" Title \nBody") == "Title"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _send_discord with forum channel support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDiscordForum:
|
||||
"""_send_discord creates thread posts for forum channels."""
|
||||
|
||||
@staticmethod
|
||||
def _build_mock(response_status, response_data=None, response_text="error body"):
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = response_status
|
||||
mock_resp.json = AsyncMock(return_value=response_data or {})
|
||||
mock_resp.text = AsyncMock(return_value=response_text)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(return_value=mock_resp)
|
||||
mock_session.get = MagicMock(return_value=mock_resp)
|
||||
|
||||
return mock_session, mock_resp
|
||||
|
||||
def test_directory_forum_creates_thread(self):
|
||||
"""Directory says 'forum' — creates a thread post."""
|
||||
thread_data = {
|
||||
"id": "t123",
|
||||
"message": {"id": "m456"},
|
||||
}
|
||||
mock_session, _ = self._build_mock(200, response_data=thread_data)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello forum")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["thread_id"] == "t123"
|
||||
assert result["message_id"] == "m456"
|
||||
# Should POST to threads endpoint, not messages
|
||||
call_url = mock_session.post.call_args.args[0]
|
||||
assert "/threads" in call_url
|
||||
assert "/messages" not in call_url
|
||||
|
||||
def test_directory_forum_skips_probe(self):
|
||||
"""When directory says 'forum', no GET probe is made."""
|
||||
thread_data = {"id": "t123", "message": {"id": "m456"}}
|
||||
mock_session, _ = self._build_mock(200, response_data=thread_data)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello")
|
||||
)
|
||||
|
||||
# get() should never be called — directory resolved the type
|
||||
mock_session.get.assert_not_called()
|
||||
|
||||
def test_directory_channel_skips_forum(self):
|
||||
"""When directory says 'channel', sends via normal messages endpoint."""
|
||||
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="channel"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "ch1", "Hello")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
call_url = mock_session.post.call_args.args[0]
|
||||
assert "/messages" in call_url
|
||||
assert "/threads" not in call_url
|
||||
|
||||
def test_directory_none_probes_and_detects_forum(self):
|
||||
"""When directory has no entry, probes GET /channels/{id} and detects type 15."""
|
||||
probe_resp = MagicMock()
|
||||
probe_resp.status = 200
|
||||
probe_resp.json = AsyncMock(return_value={"type": 15})
|
||||
probe_resp.__aenter__ = AsyncMock(return_value=probe_resp)
|
||||
probe_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
thread_data = {"id": "t999", "message": {"id": "m888"}}
|
||||
thread_resp = MagicMock()
|
||||
thread_resp.status = 200
|
||||
thread_resp.json = AsyncMock(return_value=thread_data)
|
||||
thread_resp.text = AsyncMock(return_value="")
|
||||
thread_resp.__aenter__ = AsyncMock(return_value=thread_resp)
|
||||
thread_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
probe_session = MagicMock()
|
||||
probe_session.__aenter__ = AsyncMock(return_value=probe_session)
|
||||
probe_session.__aexit__ = AsyncMock(return_value=None)
|
||||
probe_session.get = MagicMock(return_value=probe_resp)
|
||||
|
||||
thread_session = MagicMock()
|
||||
thread_session.__aenter__ = AsyncMock(return_value=thread_session)
|
||||
thread_session.__aexit__ = AsyncMock(return_value=None)
|
||||
thread_session.post = MagicMock(return_value=thread_resp)
|
||||
|
||||
session_iter = iter([probe_session, thread_session])
|
||||
|
||||
with patch("aiohttp.ClientSession", side_effect=lambda **kw: next(session_iter)), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value=None):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello probe")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["thread_id"] == "t999"
|
||||
|
||||
def test_directory_lookup_exception_falls_through_to_probe(self):
|
||||
"""When lookup_channel_type raises, falls through to API probe."""
|
||||
mock_session, _ = self._build_mock(200, response_data={"id": "msg1"})
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", side_effect=Exception("io error")):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "ch1", "Hello")
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# Falls through to probe (GET)
|
||||
mock_session.get.assert_called_once()
|
||||
|
||||
def test_forum_thread_creation_error(self):
|
||||
"""Forum thread creation returning non-200/201 returns an error dict."""
|
||||
mock_session, _ = self._build_mock(403, response_text="Forbidden")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session), \
|
||||
patch("gateway.channel_directory.lookup_channel_type", return_value="forum"):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "forum_ch", "Hello")
|
||||
)
|
||||
|
||||
assert "error" in result
|
||||
assert "403" in result["error"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Discord media attachment support
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendDiscordMedia:
|
||||
"""_send_discord uploads media files via multipart/form-data."""
|
||||
|
||||
@staticmethod
|
||||
def _build_mock(response_status, response_data=None, response_text="error body"):
|
||||
"""Build a properly-structured aiohttp mock chain."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status = response_status
|
||||
mock_resp.json = AsyncMock(return_value=response_data or {"id": "msg123"})
|
||||
mock_resp.text = AsyncMock(return_value=response_text)
|
||||
mock_resp.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(return_value=mock_resp)
|
||||
|
||||
return mock_session, mock_resp
|
||||
|
||||
def test_text_and_media_sends_both(self, tmp_path):
|
||||
"""Text message is sent first, then each media file as multipart."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "msg999"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "111", "hello", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["message_id"] == "msg999"
|
||||
# Two POSTs: one text JSON, one multipart upload
|
||||
assert mock_session.post.call_count == 2
|
||||
|
||||
def test_media_only_skips_text_post(self, tmp_path):
|
||||
"""When message is empty and media is present, text POST is skipped."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "media_only"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "222", " ", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# Only one POST: the media upload (text was whitespace-only)
|
||||
assert mock_session.post.call_count == 1
|
||||
|
||||
def test_missing_media_file_collected_as_warning(self):
|
||||
"""Non-existent media paths produce warnings but don't fail."""
|
||||
mock_session, _ = self._build_mock(200, {"id": "txt_ok"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "333", "hello", media_files=[("/nonexistent/file.png", False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert "warnings" in result
|
||||
assert any("not found" in w for w in result["warnings"])
|
||||
# Only the text POST was made, media was skipped
|
||||
assert mock_session.post.call_count == 1
|
||||
|
||||
def test_media_upload_failure_collected_as_warning(self, tmp_path):
|
||||
"""Failed media upload becomes a warning, text still succeeds."""
|
||||
img = tmp_path / "photo.png"
|
||||
img.write_bytes(b"\x89PNG fake image data")
|
||||
|
||||
# First call (text) succeeds, second call (media) returns 413
|
||||
text_resp = MagicMock()
|
||||
text_resp.status = 200
|
||||
text_resp.json = AsyncMock(return_value={"id": "txt_ok"})
|
||||
text_resp.__aenter__ = AsyncMock(return_value=text_resp)
|
||||
text_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
media_resp = MagicMock()
|
||||
media_resp.status = 413
|
||||
media_resp.text = AsyncMock(return_value="Request Entity Too Large")
|
||||
media_resp.__aenter__ = AsyncMock(return_value=media_resp)
|
||||
media_resp.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
mock_session.post = MagicMock(side_effect=[text_resp, media_resp])
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "444", "hello", media_files=[(str(img), False)])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["message_id"] == "txt_ok"
|
||||
assert "warnings" in result
|
||||
assert any("413" in w for w in result["warnings"])
|
||||
|
||||
def test_no_text_no_media_returns_error(self):
|
||||
"""Empty text with no media returns error dict."""
|
||||
mock_session, _ = self._build_mock(200)
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "555", "", media_files=[])
|
||||
)
|
||||
|
||||
# Text is empty but media_files is empty, so text POST fires
|
||||
# (the "skip text if media present" condition isn't met)
|
||||
assert result["success"] is True
|
||||
|
||||
def test_multiple_media_files_uploaded_separately(self, tmp_path):
|
||||
"""Each media file gets its own multipart POST."""
|
||||
img1 = tmp_path / "a.png"
|
||||
img1.write_bytes(b"img1")
|
||||
img2 = tmp_path / "b.jpg"
|
||||
img2.write_bytes(b"img2")
|
||||
|
||||
mock_session, _ = self._build_mock(200, {"id": "last"})
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = asyncio.run(
|
||||
_send_discord("tok", "666", "hi", media_files=[
|
||||
(str(img1), False), (str(img2), False)
|
||||
])
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
# 1 text POST + 2 media POSTs = 3
|
||||
assert mock_session.post.call_count == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _send_to_platform with forum channel detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendToPlatformDiscordForum:
|
||||
"""_send_to_platform delegates forum detection to _send_discord."""
|
||||
|
||||
def test_send_to_platform_discord_delegates_to_send_discord(self):
|
||||
"""Discord messages are routed through _send_discord, which handles forum detection."""
|
||||
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", send_mock):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"forum_ch",
|
||||
"Hello forum",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
send_mock.assert_awaited_once_with(
|
||||
"tok", "forum_ch", "Hello forum", media_files=[], thread_id=None,
|
||||
)
|
||||
|
||||
def test_send_to_platform_discord_with_thread_id(self):
|
||||
"""Thread ID is still passed through when sending to Discord."""
|
||||
send_mock = AsyncMock(return_value={"success": True, "message_id": "1"})
|
||||
|
||||
with patch("tools.send_message_tool._send_discord", send_mock):
|
||||
result = asyncio.run(
|
||||
_send_to_platform(
|
||||
Platform.DISCORD,
|
||||
SimpleNamespace(enabled=True, token="tok", extra={}),
|
||||
"ch1",
|
||||
"Hello thread",
|
||||
thread_id="17585",
|
||||
)
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
_, call_kwargs = send_mock.await_args
|
||||
assert call_kwargs["thread_id"] == "17585"
|
||||
|
|
|
|||
|
|
@ -685,6 +685,16 @@ async def _send_telegram(token, chat_id, message, media_files=None, thread_id=No
|
|||
return _error(f"Telegram send failed: {e}")
|
||||
|
||||
|
||||
def _derive_forum_thread_name(message: str) -> str:
|
||||
"""Derive a thread name from the first line of the message, capped at 100 chars."""
|
||||
first_line = message.strip().split("\n", 1)[0].strip()
|
||||
# Strip common markdown heading prefixes
|
||||
first_line = first_line.lstrip("#").strip()
|
||||
if not first_line:
|
||||
first_line = "New Post"
|
||||
return first_line[:100]
|
||||
|
||||
|
||||
async def _send_discord(token, chat_id, message, thread_id=None, media_files=None):
|
||||
"""Send a single message via Discord REST API (no websocket client needed).
|
||||
|
||||
|
|
@ -693,6 +703,12 @@ async def _send_discord(token, chat_id, message, thread_id=None, media_files=Non
|
|||
When thread_id is provided, the message is sent directly to that thread
|
||||
via the /channels/{thread_id}/messages endpoint.
|
||||
|
||||
Forum channels (type 15) reject POST /messages — auto-create a thread
|
||||
post instead via POST /channels/{id}/threads.
|
||||
|
||||
Channel type is resolved from the channel directory first; only falls
|
||||
back to a GET /channels/{id} probe when the directory has no entry.
|
||||
|
||||
Media files are uploaded one-by-one via multipart/form-data after the
|
||||
text message is sent (same pattern as Telegram).
|
||||
"""
|
||||
|
|
@ -704,16 +720,73 @@ async def _send_discord(token, chat_id, message, thread_id=None, media_files=Non
|
|||
from gateway.platforms.base import resolve_proxy_url, proxy_kwargs_for_aiohttp
|
||||
_proxy = resolve_proxy_url(platform_env_var="DISCORD_PROXY")
|
||||
_sess_kw, _req_kw = proxy_kwargs_for_aiohttp(_proxy)
|
||||
headers = {"Authorization": f"Bot {token}", "Content-Type": "application/json"}
|
||||
|
||||
# Thread endpoint: Discord threads are channels; send directly to the thread ID.
|
||||
if thread_id:
|
||||
url = f"https://discord.com/api/v10/channels/{thread_id}/messages"
|
||||
else:
|
||||
# Check if the target channel is a forum channel (type 15).
|
||||
# Forum channels reject POST /messages — create a thread post instead.
|
||||
# Try the channel directory first; fall back to an API probe only
|
||||
# when the directory has no entry.
|
||||
_channel_type = None
|
||||
try:
|
||||
from gateway.channel_directory import lookup_channel_type
|
||||
_channel_type = lookup_channel_type("discord", chat_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if _channel_type == "forum":
|
||||
is_forum = True
|
||||
elif _channel_type is not None:
|
||||
# Known non-forum type — skip the probe.
|
||||
is_forum = False
|
||||
else:
|
||||
is_forum = False
|
||||
try:
|
||||
info_url = f"https://discord.com/api/v10/channels/{chat_id}"
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=15), **_sess_kw) as info_sess:
|
||||
async with info_sess.get(info_url, headers=headers, **_req_kw) as info_resp:
|
||||
if info_resp.status == 200:
|
||||
info = await info_resp.json()
|
||||
is_forum = info.get("type") == 15
|
||||
except Exception:
|
||||
logger.debug("Failed to probe channel type for %s", chat_id, exc_info=True)
|
||||
|
||||
|
||||
if is_forum:
|
||||
thread_name = _derive_forum_thread_name(message)
|
||||
thread_url = f"https://discord.com/api/v10/channels/{chat_id}/threads"
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session:
|
||||
async with session.post(
|
||||
thread_url,
|
||||
headers=headers,
|
||||
json={
|
||||
"name": thread_name,
|
||||
"message": {"content": message},
|
||||
},
|
||||
**_req_kw,
|
||||
) as resp:
|
||||
if resp.status not in (200, 201):
|
||||
body = await resp.text()
|
||||
return _error(f"Discord forum thread creation error ({resp.status}): {body}")
|
||||
data = await resp.json()
|
||||
thread_id_created = data.get("id")
|
||||
starter_msg_id = (data.get("message") or {}).get("id", thread_id_created)
|
||||
return {
|
||||
"success": True,
|
||||
"platform": "discord",
|
||||
"chat_id": chat_id,
|
||||
"thread_id": thread_id_created,
|
||||
"message_id": starter_msg_id,
|
||||
}
|
||||
|
||||
url = f"https://discord.com/api/v10/channels/{chat_id}/messages"
|
||||
auth_headers = {"Authorization": f"Bot {token}"}
|
||||
media_files = media_files or []
|
||||
last_data = None
|
||||
warnings = []
|
||||
|
||||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=30), **_sess_kw) as session:
|
||||
# Send text message (skip if empty and media is present)
|
||||
if message.strip() or not media_files:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue