diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index db7603498..666796727 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -1104,6 +1104,7 @@ class BasePlatformAdapter(ABC): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """ diff --git a/gateway/platforms/dingtalk.py b/gateway/platforms/dingtalk.py index 3037e402b..c3263a7a9 100644 --- a/gateway/platforms/dingtalk.py +++ b/gateway/platforms/dingtalk.py @@ -1009,6 +1009,7 @@ class DingTalkAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit an AI Card by streaming updated content. diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index f741d45b5..9b7d49c6b 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -1280,15 +1280,17 @@ class DiscordAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit a previously sent Discord message.""" if not self._client: return SendResult(success=False, error="Not connected") try: - channel = self._client.get_channel(int(chat_id)) + target_chat_id = str((metadata or {}).get("thread_id") or chat_id) + channel = self._client.get_channel(int(target_chat_id)) if not channel: - channel = await self._client.fetch_channel(int(chat_id)) + channel = await self._client.fetch_channel(int(target_chat_id)) msg = await channel.fetch_message(int(message_id)) formatted = self.format_message(content) if len(formatted) > self.MAX_MESSAGE_LENGTH: @@ -1989,13 +1991,15 @@ class DiscordAdapter(BasePlatformAdapter): if chat_id in self._typing_tasks: return + target_chat_id = str((metadata or {}).get("thread_id") or chat_id) + async def _typing_loop() -> None: try: while True: try: route = discord.http.Route( "POST", "/channels/{channel_id}/typing", - channel_id=chat_id, + channel_id=target_chat_id, ) await self._client.http.request(route) except asyncio.CancelledError: diff --git a/gateway/platforms/feishu.py b/gateway/platforms/feishu.py index 718f01e99..fe2d9d420 100644 --- a/gateway/platforms/feishu.py +++ b/gateway/platforms/feishu.py @@ -1694,6 +1694,7 @@ class FeishuAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit a previously sent Feishu text/post message.""" diff --git a/gateway/platforms/matrix.py b/gateway/platforms/matrix.py index a5f9352b5..100eb0ff7 100644 --- a/gateway/platforms/matrix.py +++ b/gateway/platforms/matrix.py @@ -825,7 +825,13 @@ class MatrixAdapter(BasePlatformAdapter): async def edit_message( - self, chat_id: str, message_id: str, content: str, *, finalize: bool = False + self, + chat_id: str, + message_id: str, + content: str, + *, + metadata=None, + finalize: bool = False, ) -> SendResult: """Edit an existing message (via m.replace).""" diff --git a/gateway/platforms/mattermost.py b/gateway/platforms/mattermost.py index 0e6c9631d..d07e66828 100644 --- a/gateway/platforms/mattermost.py +++ b/gateway/platforms/mattermost.py @@ -304,7 +304,13 @@ class MattermostAdapter(BasePlatformAdapter): ) async def edit_message( - self, chat_id: str, message_id: str, content: str, *, finalize: bool = False + self, + chat_id: str, + message_id: str, + content: str, + *, + metadata=None, + finalize: bool = False, ) -> SendResult: """Edit an existing post.""" formatted = self.format_message(content) diff --git a/gateway/platforms/slack.py b/gateway/platforms/slack.py index 191689a5a..006d971bf 100644 --- a/gateway/platforms/slack.py +++ b/gateway/platforms/slack.py @@ -328,6 +328,7 @@ class SlackAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit a previously sent Slack message.""" diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index bec0d690a..4e5621a37 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -1127,6 +1127,7 @@ class TelegramAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit a previously sent Telegram message.""" diff --git a/gateway/platforms/whatsapp.py b/gateway/platforms/whatsapp.py index a82417a60..029c54c16 100644 --- a/gateway/platforms/whatsapp.py +++ b/gateway/platforms/whatsapp.py @@ -728,6 +728,7 @@ class WhatsAppAdapter(BasePlatformAdapter): message_id: str, content: str, *, + metadata=None, finalize: bool = False, ) -> SendResult: """Edit a previously sent message via the WhatsApp bridge.""" diff --git a/gateway/run.py b/gateway/run.py index db3f8b00d..33c380859 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -9538,6 +9538,7 @@ class GatewayRunner: chat_id=source.chat_id, message_id=progress_msg_id, content=full_text, + metadata=_progress_metadata, ) if not result.success: _err = (getattr(result, "error", "") or "").lower() @@ -9592,6 +9593,7 @@ class GatewayRunner: chat_id=source.chat_id, message_id=progress_msg_id, content=full_text, + metadata=_progress_metadata, ) except Exception: pass diff --git a/gateway/stream_consumer.py b/gateway/stream_consumer.py index 78e365712..658c2b286 100644 --- a/gateway/stream_consumer.py +++ b/gateway/stream_consumer.py @@ -590,6 +590,7 @@ class GatewayStreamConsumer: chat_id=self.chat_id, message_id=self._message_id, content=clean_text, + metadata=self.metadata, ) if result.success: self._last_sent_text = clean_text @@ -708,6 +709,7 @@ class GatewayStreamConsumer: chat_id=self.chat_id, message_id=self._message_id, content=prefix, + metadata=self.metadata, ) self._last_sent_text = prefix except Exception: @@ -791,6 +793,7 @@ class GatewayStreamConsumer: chat_id=self.chat_id, message_id=self._message_id, content=text, + metadata=self.metadata, finalize=finalize, ) if result.success: diff --git a/tests/gateway/test_discord_free_response.py b/tests/gateway/test_discord_free_response.py index f1ee99606..638ae2fb2 100644 --- a/tests/gateway/test_discord_free_response.py +++ b/tests/gateway/test_discord_free_response.py @@ -1,5 +1,6 @@ """Tests for Discord free-response defaults and mention gating.""" +import asyncio from datetime import datetime, timezone from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock @@ -82,6 +83,15 @@ class FakeThread: self.topic = None +class FakeMessage: + def __init__(self, message_id: int = 1): + self.id = message_id + self.edits = [] + + async def edit(self, *, content: str): + self.edits.append(content) + + @pytest.fixture def adapter(monkeypatch): monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False) @@ -179,6 +189,55 @@ async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?" +@pytest.mark.asyncio +async def test_discord_edit_message_uses_thread_metadata(adapter): + parent_channel = FakeTextChannel(channel_id=123, name="general") + thread_channel = FakeThread(channel_id=456, name="planning", parent=parent_channel) + message = FakeMessage(message_id=789) + thread_channel.fetch_message = AsyncMock(return_value=message) + adapter._client = SimpleNamespace( + get_channel=MagicMock(side_effect=lambda cid: thread_channel if cid == 456 else parent_channel), + fetch_channel=AsyncMock(return_value=None), + ) + + result = await adapter.edit_message( + chat_id="123", + message_id="789", + content="updated in thread", + metadata={"thread_id": "456"}, + ) + + assert result.success is True + thread_channel.fetch_message.assert_awaited_once_with(789) + assert message.edits == ["updated in thread"] + + +@pytest.mark.asyncio +async def test_discord_send_typing_uses_thread_metadata(adapter, monkeypatch): + recorded = [] + + class FakeRoute: + def __init__(self, _method, _path, *, channel_id): + self.channel_id = channel_id + + async def fake_request(route): + recorded.append(route.channel_id) + task = adapter._typing_tasks.get("123") + if task: + task.cancel() + raise asyncio.CancelledError + + monkeypatch.setattr(discord_platform.discord.http, "Route", FakeRoute, raising=False) + adapter._client = SimpleNamespace(http=SimpleNamespace(request=fake_request)) + + await adapter.send_typing("123", metadata={"thread_id": "456"}) + task = adapter._typing_tasks["123"] + with pytest.raises(asyncio.CancelledError): + await task + + assert recorded == ["456"] + + @pytest.mark.asyncio async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch): monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") diff --git a/tests/gateway/test_run_progress_topics.py b/tests/gateway/test_run_progress_topics.py index 59e9fa040..2cb4ce925 100644 --- a/tests/gateway/test_run_progress_topics.py +++ b/tests/gateway/test_run_progress_topics.py @@ -38,12 +38,13 @@ class ProgressCaptureAdapter(BasePlatformAdapter): ) return SendResult(success=True, message_id="progress-1") - async def edit_message(self, chat_id, message_id, content) -> SendResult: + async def edit_message(self, chat_id, message_id, content, metadata=None) -> SendResult: self.edits.append( { "chat_id": chat_id, "message_id": message_id, "content": content, + "metadata": metadata, } ) return SendResult(success=True, message_id=message_id) @@ -110,6 +111,23 @@ class DelayedProgressAgent: } +class SlowProgressAgent: + def __init__(self, **kwargs): + self.tool_progress_callback = kwargs.get("tool_progress_callback") + self.tools = [] + + def run_conversation(self, message, conversation_history=None, task_id=None): + self.tool_progress_callback("tool.started", "terminal", "first command", {}) + time.sleep(1.8) + self.tool_progress_callback("tool.started", "terminal", "second command", {}) + time.sleep(1.8) + return { + "final_response": "done", + "messages": [], + "api_calls": 1, + } + + class DelayedInterimAgent: def __init__(self, **kwargs): self.interim_assistant_callback = kwargs.get("interim_assistant_callback") @@ -160,7 +178,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) fake_run_agent = types.ModuleType("run_agent") - fake_run_agent.AIAgent = FakeAgent + fake_run_agent.AIAgent = SlowProgressAgent monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test @@ -189,12 +207,13 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa assert adapter.sent == [ { "chat_id": "-1001", - "content": '💻 terminal: "pwd"', + "content": '💻 terminal: "first command"', "reply_to": None, "metadata": {"thread_id": "17585"}, } ] assert adapter.edits + assert any(call["metadata"] == {"thread_id": "17585"} for call in adapter.edits) assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing) diff --git a/tests/gateway/test_stream_consumer.py b/tests/gateway/test_stream_consumer.py index 7ae587dad..47dd8e00c 100644 --- a/tests/gateway/test_stream_consumer.py +++ b/tests/gateway/test_stream_consumer.py @@ -134,15 +134,14 @@ class TestFinalizeCapabilityGate: class TestEditMessageFinalizeSignature: - """Every concrete platform adapter must accept the ``finalize`` kwarg. + """Every concrete platform adapter must accept shared edit kwargs. - stream_consumer._send_or_edit always passes ``finalize=`` to - ``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An - adapter that overrides edit_message without accepting finalize raises - TypeError the first time streaming hits a segment break or final edit. - Guard the contract with an explicit signature check so it cannot - silently regress — existing tests use MagicMock which swallows any - kwarg and cannot catch this. + stream_consumer._send_or_edit always passes ``finalize=`` and shared + callers may pass ``metadata=`` to ``adapter.edit_message(...)``. An + adapter that overrides edit_message without accepting either kwarg + raises TypeError at runtime. Guard the contract with an explicit + signature check so it cannot silently regress — existing tests use + MagicMock which swallows any kwarg and cannot catch this. """ @pytest.mark.parametrize( @@ -168,6 +167,10 @@ class TestEditMessageFinalizeSignature: f"{class_name}.edit_message must accept 'finalize' kwarg; " f"stream_consumer._send_or_edit passes it unconditionally" ) + assert "metadata" in params, ( + f"{class_name}.edit_message must accept 'metadata' kwarg; " + f"shared edit call sites pass thread/topic routing context" + ) class TestSendOrEditMediaStripping: @@ -209,6 +212,23 @@ class TestSendOrEditMediaStripping: edited_text = adapter.edit_message.call_args[1]["content"] assert "MEDIA:" not in edited_text + @pytest.mark.asyncio + async def test_edit_preserves_metadata(self): + """Edit call forwards stored metadata so thread routing survives streaming.""" + adapter = MagicMock() + send_result = SimpleNamespace(success=True, message_id="msg_1") + edit_result = SimpleNamespace(success=True) + adapter.send = AsyncMock(return_value=send_result) + adapter.edit_message = AsyncMock(return_value=edit_result) + adapter.MAX_MESSAGE_LENGTH = 4096 + + consumer = GatewayStreamConsumer(adapter, "chat_123", metadata={"thread_id": "thread-42"}) + await consumer._send_or_edit("Starting response...") + await consumer._send_or_edit("Updated response") + + adapter.edit_message.assert_called_once() + assert adapter.edit_message.call_args[1]["metadata"] == {"thread_id": "thread-42"} + @pytest.mark.asyncio async def test_media_only_skips_send(self): """If text is entirely MEDIA: tags, the send is skipped."""