fix(gateway): preserve thread metadata edits

This commit is contained in:
Wooseong Kim 2026-04-23 19:50:05 +09:00
parent 6051fba9dc
commit fa8464e200
14 changed files with 141 additions and 16 deletions

View file

@ -1104,6 +1104,7 @@ class BasePlatformAdapter(ABC):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""

View file

@ -1009,6 +1009,7 @@ class DingTalkAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit an AI Card by streaming updated content.

View file

@ -1280,15 +1280,17 @@ class DiscordAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit a previously sent Discord message."""
if not self._client:
return SendResult(success=False, error="Not connected")
try:
channel = self._client.get_channel(int(chat_id))
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
channel = self._client.get_channel(int(target_chat_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
channel = await self._client.fetch_channel(int(target_chat_id))
msg = await channel.fetch_message(int(message_id))
formatted = self.format_message(content)
if len(formatted) > self.MAX_MESSAGE_LENGTH:
@ -1989,13 +1991,15 @@ class DiscordAdapter(BasePlatformAdapter):
if chat_id in self._typing_tasks:
return
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
async def _typing_loop() -> None:
try:
while True:
try:
route = discord.http.Route(
"POST", "/channels/{channel_id}/typing",
channel_id=chat_id,
channel_id=target_chat_id,
)
await self._client.http.request(route)
except asyncio.CancelledError:

View file

@ -1694,6 +1694,7 @@ class FeishuAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit a previously sent Feishu text/post message."""

View file

@ -825,7 +825,13 @@ class MatrixAdapter(BasePlatformAdapter):
async def edit_message(
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
self,
chat_id: str,
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit an existing message (via m.replace)."""

View file

@ -304,7 +304,13 @@ class MattermostAdapter(BasePlatformAdapter):
)
async def edit_message(
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
self,
chat_id: str,
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit an existing post."""
formatted = self.format_message(content)

View file

@ -328,6 +328,7 @@ class SlackAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit a previously sent Slack message."""

View file

@ -1127,6 +1127,7 @@ class TelegramAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit a previously sent Telegram message."""

View file

@ -728,6 +728,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult:
"""Edit a previously sent message via the WhatsApp bridge."""

View file

@ -9538,6 +9538,7 @@ class GatewayRunner:
chat_id=source.chat_id,
message_id=progress_msg_id,
content=full_text,
metadata=_progress_metadata,
)
if not result.success:
_err = (getattr(result, "error", "") or "").lower()
@ -9592,6 +9593,7 @@ class GatewayRunner:
chat_id=source.chat_id,
message_id=progress_msg_id,
content=full_text,
metadata=_progress_metadata,
)
except Exception:
pass

View file

@ -590,6 +590,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id,
message_id=self._message_id,
content=clean_text,
metadata=self.metadata,
)
if result.success:
self._last_sent_text = clean_text
@ -708,6 +709,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id,
message_id=self._message_id,
content=prefix,
metadata=self.metadata,
)
self._last_sent_text = prefix
except Exception:
@ -791,6 +793,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id,
message_id=self._message_id,
content=text,
metadata=self.metadata,
finalize=finalize,
)
if result.success:

View file

@ -1,5 +1,6 @@
"""Tests for Discord free-response defaults and mention gating."""
import asyncio
from datetime import datetime, timezone
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
@ -82,6 +83,15 @@ class FakeThread:
self.topic = None
class FakeMessage:
def __init__(self, message_id: int = 1):
self.id = message_id
self.edits = []
async def edit(self, *, content: str):
self.edits.append(content)
@pytest.fixture
def adapter(monkeypatch):
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
@ -179,6 +189,55 @@ async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
@pytest.mark.asyncio
async def test_discord_edit_message_uses_thread_metadata(adapter):
parent_channel = FakeTextChannel(channel_id=123, name="general")
thread_channel = FakeThread(channel_id=456, name="planning", parent=parent_channel)
message = FakeMessage(message_id=789)
thread_channel.fetch_message = AsyncMock(return_value=message)
adapter._client = SimpleNamespace(
get_channel=MagicMock(side_effect=lambda cid: thread_channel if cid == 456 else parent_channel),
fetch_channel=AsyncMock(return_value=None),
)
result = await adapter.edit_message(
chat_id="123",
message_id="789",
content="updated in thread",
metadata={"thread_id": "456"},
)
assert result.success is True
thread_channel.fetch_message.assert_awaited_once_with(789)
assert message.edits == ["updated in thread"]
@pytest.mark.asyncio
async def test_discord_send_typing_uses_thread_metadata(adapter, monkeypatch):
recorded = []
class FakeRoute:
def __init__(self, _method, _path, *, channel_id):
self.channel_id = channel_id
async def fake_request(route):
recorded.append(route.channel_id)
task = adapter._typing_tasks.get("123")
if task:
task.cancel()
raise asyncio.CancelledError
monkeypatch.setattr(discord_platform.discord.http, "Route", FakeRoute, raising=False)
adapter._client = SimpleNamespace(http=SimpleNamespace(request=fake_request))
await adapter.send_typing("123", metadata={"thread_id": "456"})
task = adapter._typing_tasks["123"]
with pytest.raises(asyncio.CancelledError):
await task
assert recorded == ["456"]
@pytest.mark.asyncio
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")

View file

@ -38,12 +38,13 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
)
return SendResult(success=True, message_id="progress-1")
async def edit_message(self, chat_id, message_id, content) -> SendResult:
async def edit_message(self, chat_id, message_id, content, metadata=None) -> SendResult:
self.edits.append(
{
"chat_id": chat_id,
"message_id": message_id,
"content": content,
"metadata": metadata,
}
)
return SendResult(success=True, message_id=message_id)
@ -110,6 +111,23 @@ class DelayedProgressAgent:
}
class SlowProgressAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", "first command", {})
time.sleep(1.8)
self.tool_progress_callback("tool.started", "terminal", "second command", {})
time.sleep(1.8)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class DelayedInterimAgent:
def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
@ -160,7 +178,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = FakeAgent
fake_run_agent.AIAgent = SlowProgressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
@ -189,12 +207,13 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
assert adapter.sent == [
{
"chat_id": "-1001",
"content": '💻 terminal: "pwd"',
"content": '💻 terminal: "first command"',
"reply_to": None,
"metadata": {"thread_id": "17585"},
}
]
assert adapter.edits
assert any(call["metadata"] == {"thread_id": "17585"} for call in adapter.edits)
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)

View file

@ -134,15 +134,14 @@ class TestFinalizeCapabilityGate:
class TestEditMessageFinalizeSignature:
"""Every concrete platform adapter must accept the ``finalize`` kwarg.
"""Every concrete platform adapter must accept shared edit kwargs.
stream_consumer._send_or_edit always passes ``finalize=`` to
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An
adapter that overrides edit_message without accepting finalize raises
TypeError the first time streaming hits a segment break or final edit.
Guard the contract with an explicit signature check so it cannot
silently regress existing tests use MagicMock which swallows any
kwarg and cannot catch this.
stream_consumer._send_or_edit always passes ``finalize=`` and shared
callers may pass ``metadata=`` to ``adapter.edit_message(...)``. An
adapter that overrides edit_message without accepting either kwarg
raises TypeError at runtime. Guard the contract with an explicit
signature check so it cannot silently regress existing tests use
MagicMock which swallows any kwarg and cannot catch this.
"""
@pytest.mark.parametrize(
@ -168,6 +167,10 @@ class TestEditMessageFinalizeSignature:
f"{class_name}.edit_message must accept 'finalize' kwarg; "
f"stream_consumer._send_or_edit passes it unconditionally"
)
assert "metadata" in params, (
f"{class_name}.edit_message must accept 'metadata' kwarg; "
f"shared edit call sites pass thread/topic routing context"
)
class TestSendOrEditMediaStripping:
@ -209,6 +212,23 @@ class TestSendOrEditMediaStripping:
edited_text = adapter.edit_message.call_args[1]["content"]
assert "MEDIA:" not in edited_text
@pytest.mark.asyncio
async def test_edit_preserves_metadata(self):
"""Edit call forwards stored metadata so thread routing survives streaming."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
edit_result = SimpleNamespace(success=True)
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=edit_result)
adapter.MAX_MESSAGE_LENGTH = 4096
consumer = GatewayStreamConsumer(adapter, "chat_123", metadata={"thread_id": "thread-42"})
await consumer._send_or_edit("Starting response...")
await consumer._send_or_edit("Updated response")
adapter.edit_message.assert_called_once()
assert adapter.edit_message.call_args[1]["metadata"] == {"thread_id": "thread-42"}
@pytest.mark.asyncio
async def test_media_only_skips_send(self):
"""If text is entirely MEDIA: tags, the send is skipped."""