fix(gateway): preserve thread metadata edits

This commit is contained in:
Wooseong Kim 2026-04-23 19:50:05 +09:00
parent 6051fba9dc
commit fa8464e200
14 changed files with 141 additions and 16 deletions

View file

@ -1104,6 +1104,7 @@ class BasePlatformAdapter(ABC):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
""" """

View file

@ -1009,6 +1009,7 @@ class DingTalkAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit an AI Card by streaming updated content. """Edit an AI Card by streaming updated content.

View file

@ -1280,15 +1280,17 @@ class DiscordAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit a previously sent Discord message.""" """Edit a previously sent Discord message."""
if not self._client: if not self._client:
return SendResult(success=False, error="Not connected") return SendResult(success=False, error="Not connected")
try: try:
channel = self._client.get_channel(int(chat_id)) target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
channel = self._client.get_channel(int(target_chat_id))
if not channel: if not channel:
channel = await self._client.fetch_channel(int(chat_id)) channel = await self._client.fetch_channel(int(target_chat_id))
msg = await channel.fetch_message(int(message_id)) msg = await channel.fetch_message(int(message_id))
formatted = self.format_message(content) formatted = self.format_message(content)
if len(formatted) > self.MAX_MESSAGE_LENGTH: if len(formatted) > self.MAX_MESSAGE_LENGTH:
@ -1989,13 +1991,15 @@ class DiscordAdapter(BasePlatformAdapter):
if chat_id in self._typing_tasks: if chat_id in self._typing_tasks:
return return
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
async def _typing_loop() -> None: async def _typing_loop() -> None:
try: try:
while True: while True:
try: try:
route = discord.http.Route( route = discord.http.Route(
"POST", "/channels/{channel_id}/typing", "POST", "/channels/{channel_id}/typing",
channel_id=chat_id, channel_id=target_chat_id,
) )
await self._client.http.request(route) await self._client.http.request(route)
except asyncio.CancelledError: except asyncio.CancelledError:

View file

@ -1694,6 +1694,7 @@ class FeishuAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit a previously sent Feishu text/post message.""" """Edit a previously sent Feishu text/post message."""

View file

@ -825,7 +825,13 @@ class MatrixAdapter(BasePlatformAdapter):
async def edit_message( async def edit_message(
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False self,
chat_id: str,
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit an existing message (via m.replace).""" """Edit an existing message (via m.replace)."""

View file

@ -304,7 +304,13 @@ class MattermostAdapter(BasePlatformAdapter):
) )
async def edit_message( async def edit_message(
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False self,
chat_id: str,
message_id: str,
content: str,
*,
metadata=None,
finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit an existing post.""" """Edit an existing post."""
formatted = self.format_message(content) formatted = self.format_message(content)

View file

@ -328,6 +328,7 @@ class SlackAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit a previously sent Slack message.""" """Edit a previously sent Slack message."""

View file

@ -1127,6 +1127,7 @@ class TelegramAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit a previously sent Telegram message.""" """Edit a previously sent Telegram message."""

View file

@ -728,6 +728,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
message_id: str, message_id: str,
content: str, content: str,
*, *,
metadata=None,
finalize: bool = False, finalize: bool = False,
) -> SendResult: ) -> SendResult:
"""Edit a previously sent message via the WhatsApp bridge.""" """Edit a previously sent message via the WhatsApp bridge."""

View file

@ -9538,6 +9538,7 @@ class GatewayRunner:
chat_id=source.chat_id, chat_id=source.chat_id,
message_id=progress_msg_id, message_id=progress_msg_id,
content=full_text, content=full_text,
metadata=_progress_metadata,
) )
if not result.success: if not result.success:
_err = (getattr(result, "error", "") or "").lower() _err = (getattr(result, "error", "") or "").lower()
@ -9592,6 +9593,7 @@ class GatewayRunner:
chat_id=source.chat_id, chat_id=source.chat_id,
message_id=progress_msg_id, message_id=progress_msg_id,
content=full_text, content=full_text,
metadata=_progress_metadata,
) )
except Exception: except Exception:
pass pass

View file

@ -590,6 +590,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id, chat_id=self.chat_id,
message_id=self._message_id, message_id=self._message_id,
content=clean_text, content=clean_text,
metadata=self.metadata,
) )
if result.success: if result.success:
self._last_sent_text = clean_text self._last_sent_text = clean_text
@ -708,6 +709,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id, chat_id=self.chat_id,
message_id=self._message_id, message_id=self._message_id,
content=prefix, content=prefix,
metadata=self.metadata,
) )
self._last_sent_text = prefix self._last_sent_text = prefix
except Exception: except Exception:
@ -791,6 +793,7 @@ class GatewayStreamConsumer:
chat_id=self.chat_id, chat_id=self.chat_id,
message_id=self._message_id, message_id=self._message_id,
content=text, content=text,
metadata=self.metadata,
finalize=finalize, finalize=finalize,
) )
if result.success: if result.success:

View file

@ -1,5 +1,6 @@
"""Tests for Discord free-response defaults and mention gating.""" """Tests for Discord free-response defaults and mention gating."""
import asyncio
from datetime import datetime, timezone from datetime import datetime, timezone
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock, MagicMock
@ -82,6 +83,15 @@ class FakeThread:
self.topic = None self.topic = None
class FakeMessage:
def __init__(self, message_id: int = 1):
self.id = message_id
self.edits = []
async def edit(self, *, content: str):
self.edits.append(content)
@pytest.fixture @pytest.fixture
def adapter(monkeypatch): def adapter(monkeypatch):
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False) monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
@ -179,6 +189,55 @@ async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?" assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
@pytest.mark.asyncio
async def test_discord_edit_message_uses_thread_metadata(adapter):
parent_channel = FakeTextChannel(channel_id=123, name="general")
thread_channel = FakeThread(channel_id=456, name="planning", parent=parent_channel)
message = FakeMessage(message_id=789)
thread_channel.fetch_message = AsyncMock(return_value=message)
adapter._client = SimpleNamespace(
get_channel=MagicMock(side_effect=lambda cid: thread_channel if cid == 456 else parent_channel),
fetch_channel=AsyncMock(return_value=None),
)
result = await adapter.edit_message(
chat_id="123",
message_id="789",
content="updated in thread",
metadata={"thread_id": "456"},
)
assert result.success is True
thread_channel.fetch_message.assert_awaited_once_with(789)
assert message.edits == ["updated in thread"]
@pytest.mark.asyncio
async def test_discord_send_typing_uses_thread_metadata(adapter, monkeypatch):
recorded = []
class FakeRoute:
def __init__(self, _method, _path, *, channel_id):
self.channel_id = channel_id
async def fake_request(route):
recorded.append(route.channel_id)
task = adapter._typing_tasks.get("123")
if task:
task.cancel()
raise asyncio.CancelledError
monkeypatch.setattr(discord_platform.discord.http, "Route", FakeRoute, raising=False)
adapter._client = SimpleNamespace(http=SimpleNamespace(request=fake_request))
await adapter.send_typing("123", metadata={"thread_id": "456"})
task = adapter._typing_tasks["123"]
with pytest.raises(asyncio.CancelledError):
await task
assert recorded == ["456"]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch): async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true") monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")

View file

@ -38,12 +38,13 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
) )
return SendResult(success=True, message_id="progress-1") return SendResult(success=True, message_id="progress-1")
async def edit_message(self, chat_id, message_id, content) -> SendResult: async def edit_message(self, chat_id, message_id, content, metadata=None) -> SendResult:
self.edits.append( self.edits.append(
{ {
"chat_id": chat_id, "chat_id": chat_id,
"message_id": message_id, "message_id": message_id,
"content": content, "content": content,
"metadata": metadata,
} }
) )
return SendResult(success=True, message_id=message_id) return SendResult(success=True, message_id=message_id)
@ -110,6 +111,23 @@ class DelayedProgressAgent:
} }
class SlowProgressAgent:
def __init__(self, **kwargs):
self.tool_progress_callback = kwargs.get("tool_progress_callback")
self.tools = []
def run_conversation(self, message, conversation_history=None, task_id=None):
self.tool_progress_callback("tool.started", "terminal", "first command", {})
time.sleep(1.8)
self.tool_progress_callback("tool.started", "terminal", "second command", {})
time.sleep(1.8)
return {
"final_response": "done",
"messages": [],
"api_calls": 1,
}
class DelayedInterimAgent: class DelayedInterimAgent:
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.interim_assistant_callback = kwargs.get("interim_assistant_callback") self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
@ -160,7 +178,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv) monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
fake_run_agent = types.ModuleType("run_agent") fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = FakeAgent fake_run_agent.AIAgent = SlowProgressAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent) monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
@ -189,12 +207,13 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
assert adapter.sent == [ assert adapter.sent == [
{ {
"chat_id": "-1001", "chat_id": "-1001",
"content": '💻 terminal: "pwd"', "content": '💻 terminal: "first command"',
"reply_to": None, "reply_to": None,
"metadata": {"thread_id": "17585"}, "metadata": {"thread_id": "17585"},
} }
] ]
assert adapter.edits assert adapter.edits
assert any(call["metadata"] == {"thread_id": "17585"} for call in adapter.edits)
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing) assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)

View file

@ -134,15 +134,14 @@ class TestFinalizeCapabilityGate:
class TestEditMessageFinalizeSignature: class TestEditMessageFinalizeSignature:
"""Every concrete platform adapter must accept the ``finalize`` kwarg. """Every concrete platform adapter must accept shared edit kwargs.
stream_consumer._send_or_edit always passes ``finalize=`` to stream_consumer._send_or_edit always passes ``finalize=`` and shared
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An callers may pass ``metadata=`` to ``adapter.edit_message(...)``. An
adapter that overrides edit_message without accepting finalize raises adapter that overrides edit_message without accepting either kwarg
TypeError the first time streaming hits a segment break or final edit. raises TypeError at runtime. Guard the contract with an explicit
Guard the contract with an explicit signature check so it cannot signature check so it cannot silently regress existing tests use
silently regress existing tests use MagicMock which swallows any MagicMock which swallows any kwarg and cannot catch this.
kwarg and cannot catch this.
""" """
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -168,6 +167,10 @@ class TestEditMessageFinalizeSignature:
f"{class_name}.edit_message must accept 'finalize' kwarg; " f"{class_name}.edit_message must accept 'finalize' kwarg; "
f"stream_consumer._send_or_edit passes it unconditionally" f"stream_consumer._send_or_edit passes it unconditionally"
) )
assert "metadata" in params, (
f"{class_name}.edit_message must accept 'metadata' kwarg; "
f"shared edit call sites pass thread/topic routing context"
)
class TestSendOrEditMediaStripping: class TestSendOrEditMediaStripping:
@ -209,6 +212,23 @@ class TestSendOrEditMediaStripping:
edited_text = adapter.edit_message.call_args[1]["content"] edited_text = adapter.edit_message.call_args[1]["content"]
assert "MEDIA:" not in edited_text assert "MEDIA:" not in edited_text
@pytest.mark.asyncio
async def test_edit_preserves_metadata(self):
"""Edit call forwards stored metadata so thread routing survives streaming."""
adapter = MagicMock()
send_result = SimpleNamespace(success=True, message_id="msg_1")
edit_result = SimpleNamespace(success=True)
adapter.send = AsyncMock(return_value=send_result)
adapter.edit_message = AsyncMock(return_value=edit_result)
adapter.MAX_MESSAGE_LENGTH = 4096
consumer = GatewayStreamConsumer(adapter, "chat_123", metadata={"thread_id": "thread-42"})
await consumer._send_or_edit("Starting response...")
await consumer._send_or_edit("Updated response")
adapter.edit_message.assert_called_once()
assert adapter.edit_message.call_args[1]["metadata"] == {"thread_id": "thread-42"}
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_media_only_skips_send(self): async def test_media_only_skips_send(self):
"""If text is entirely MEDIA: tags, the send is skipped.""" """If text is entirely MEDIA: tags, the send is skipped."""