mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): preserve thread metadata edits
This commit is contained in:
parent
6051fba9dc
commit
fa8464e200
14 changed files with 141 additions and 16 deletions
|
|
@ -1104,6 +1104,7 @@ class BasePlatformAdapter(ABC):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1009,6 +1009,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit an AI Card by streaming updated content.
|
||||
|
|
|
|||
|
|
@ -1280,15 +1280,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Discord message."""
|
||||
if not self._client:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
|
||||
channel = self._client.get_channel(int(target_chat_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
channel = await self._client.fetch_channel(int(target_chat_id))
|
||||
msg = await channel.fetch_message(int(message_id))
|
||||
formatted = self.format_message(content)
|
||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||
|
|
@ -1989,13 +1991,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if chat_id in self._typing_tasks:
|
||||
return
|
||||
|
||||
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
|
||||
|
||||
async def _typing_loop() -> None:
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
route = discord.http.Route(
|
||||
"POST", "/channels/{channel_id}/typing",
|
||||
channel_id=chat_id,
|
||||
channel_id=target_chat_id,
|
||||
)
|
||||
await self._client.http.request(route)
|
||||
except asyncio.CancelledError:
|
||||
|
|
|
|||
|
|
@ -1694,6 +1694,7 @@ class FeishuAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Feishu text/post message."""
|
||||
|
|
|
|||
|
|
@ -825,7 +825,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
|||
|
||||
|
||||
async def edit_message(
|
||||
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit an existing message (via m.replace)."""
|
||||
|
||||
|
|
|
|||
|
|
@ -304,7 +304,13 @@ class MattermostAdapter(BasePlatformAdapter):
|
|||
)
|
||||
|
||||
async def edit_message(
|
||||
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit an existing post."""
|
||||
formatted = self.format_message(content)
|
||||
|
|
|
|||
|
|
@ -328,6 +328,7 @@ class SlackAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Slack message."""
|
||||
|
|
|
|||
|
|
@ -1127,6 +1127,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent Telegram message."""
|
||||
|
|
|
|||
|
|
@ -728,6 +728,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
|||
message_id: str,
|
||||
content: str,
|
||||
*,
|
||||
metadata=None,
|
||||
finalize: bool = False,
|
||||
) -> SendResult:
|
||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||
|
|
|
|||
|
|
@ -9538,6 +9538,7 @@ class GatewayRunner:
|
|||
chat_id=source.chat_id,
|
||||
message_id=progress_msg_id,
|
||||
content=full_text,
|
||||
metadata=_progress_metadata,
|
||||
)
|
||||
if not result.success:
|
||||
_err = (getattr(result, "error", "") or "").lower()
|
||||
|
|
@ -9592,6 +9593,7 @@ class GatewayRunner:
|
|||
chat_id=source.chat_id,
|
||||
message_id=progress_msg_id,
|
||||
content=full_text,
|
||||
metadata=_progress_metadata,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -590,6 +590,7 @@ class GatewayStreamConsumer:
|
|||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=clean_text,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
if result.success:
|
||||
self._last_sent_text = clean_text
|
||||
|
|
@ -708,6 +709,7 @@ class GatewayStreamConsumer:
|
|||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=prefix,
|
||||
metadata=self.metadata,
|
||||
)
|
||||
self._last_sent_text = prefix
|
||||
except Exception:
|
||||
|
|
@ -791,6 +793,7 @@ class GatewayStreamConsumer:
|
|||
chat_id=self.chat_id,
|
||||
message_id=self._message_id,
|
||||
content=text,
|
||||
metadata=self.metadata,
|
||||
finalize=finalize,
|
||||
)
|
||||
if result.success:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Tests for Discord free-response defaults and mention gating."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
|
@ -82,6 +83,15 @@ class FakeThread:
|
|||
self.topic = None
|
||||
|
||||
|
||||
class FakeMessage:
|
||||
def __init__(self, message_id: int = 1):
|
||||
self.id = message_id
|
||||
self.edits = []
|
||||
|
||||
async def edit(self, *, content: str):
|
||||
self.edits.append(content)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def adapter(monkeypatch):
|
||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||
|
|
@ -179,6 +189,55 @@ async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch
|
|||
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_edit_message_uses_thread_metadata(adapter):
|
||||
parent_channel = FakeTextChannel(channel_id=123, name="general")
|
||||
thread_channel = FakeThread(channel_id=456, name="planning", parent=parent_channel)
|
||||
message = FakeMessage(message_id=789)
|
||||
thread_channel.fetch_message = AsyncMock(return_value=message)
|
||||
adapter._client = SimpleNamespace(
|
||||
get_channel=MagicMock(side_effect=lambda cid: thread_channel if cid == 456 else parent_channel),
|
||||
fetch_channel=AsyncMock(return_value=None),
|
||||
)
|
||||
|
||||
result = await adapter.edit_message(
|
||||
chat_id="123",
|
||||
message_id="789",
|
||||
content="updated in thread",
|
||||
metadata={"thread_id": "456"},
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
thread_channel.fetch_message.assert_awaited_once_with(789)
|
||||
assert message.edits == ["updated in thread"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_send_typing_uses_thread_metadata(adapter, monkeypatch):
|
||||
recorded = []
|
||||
|
||||
class FakeRoute:
|
||||
def __init__(self, _method, _path, *, channel_id):
|
||||
self.channel_id = channel_id
|
||||
|
||||
async def fake_request(route):
|
||||
recorded.append(route.channel_id)
|
||||
task = adapter._typing_tasks.get("123")
|
||||
if task:
|
||||
task.cancel()
|
||||
raise asyncio.CancelledError
|
||||
|
||||
monkeypatch.setattr(discord_platform.discord.http, "Route", FakeRoute, raising=False)
|
||||
adapter._client = SimpleNamespace(http=SimpleNamespace(request=fake_request))
|
||||
|
||||
await adapter.send_typing("123", metadata={"thread_id": "456"})
|
||||
task = adapter._typing_tasks["123"]
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await task
|
||||
|
||||
assert recorded == ["456"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
|
||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||
|
|
|
|||
|
|
@ -38,12 +38,13 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
|||
)
|
||||
return SendResult(success=True, message_id="progress-1")
|
||||
|
||||
async def edit_message(self, chat_id, message_id, content) -> SendResult:
|
||||
async def edit_message(self, chat_id, message_id, content, metadata=None) -> SendResult:
|
||||
self.edits.append(
|
||||
{
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"content": content,
|
||||
"metadata": metadata,
|
||||
}
|
||||
)
|
||||
return SendResult(success=True, message_id=message_id)
|
||||
|
|
@ -110,6 +111,23 @@ class DelayedProgressAgent:
|
|||
}
|
||||
|
||||
|
||||
class SlowProgressAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||
time.sleep(1.8)
|
||||
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||
time.sleep(1.8)
|
||||
return {
|
||||
"final_response": "done",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
class DelayedInterimAgent:
|
||||
def __init__(self, **kwargs):
|
||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||
|
|
@ -160,7 +178,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = FakeAgent
|
||||
fake_run_agent.AIAgent = SlowProgressAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||
|
||||
|
|
@ -189,12 +207,13 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
|||
assert adapter.sent == [
|
||||
{
|
||||
"chat_id": "-1001",
|
||||
"content": '💻 terminal: "pwd"',
|
||||
"content": '💻 terminal: "first command"',
|
||||
"reply_to": None,
|
||||
"metadata": {"thread_id": "17585"},
|
||||
}
|
||||
]
|
||||
assert adapter.edits
|
||||
assert any(call["metadata"] == {"thread_id": "17585"} for call in adapter.edits)
|
||||
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -134,15 +134,14 @@ class TestFinalizeCapabilityGate:
|
|||
|
||||
|
||||
class TestEditMessageFinalizeSignature:
|
||||
"""Every concrete platform adapter must accept the ``finalize`` kwarg.
|
||||
"""Every concrete platform adapter must accept shared edit kwargs.
|
||||
|
||||
stream_consumer._send_or_edit always passes ``finalize=`` to
|
||||
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An
|
||||
adapter that overrides edit_message without accepting finalize raises
|
||||
TypeError the first time streaming hits a segment break or final edit.
|
||||
Guard the contract with an explicit signature check so it cannot
|
||||
silently regress — existing tests use MagicMock which swallows any
|
||||
kwarg and cannot catch this.
|
||||
stream_consumer._send_or_edit always passes ``finalize=`` and shared
|
||||
callers may pass ``metadata=`` to ``adapter.edit_message(...)``. An
|
||||
adapter that overrides edit_message without accepting either kwarg
|
||||
raises TypeError at runtime. Guard the contract with an explicit
|
||||
signature check so it cannot silently regress — existing tests use
|
||||
MagicMock which swallows any kwarg and cannot catch this.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
@ -168,6 +167,10 @@ class TestEditMessageFinalizeSignature:
|
|||
f"{class_name}.edit_message must accept 'finalize' kwarg; "
|
||||
f"stream_consumer._send_or_edit passes it unconditionally"
|
||||
)
|
||||
assert "metadata" in params, (
|
||||
f"{class_name}.edit_message must accept 'metadata' kwarg; "
|
||||
f"shared edit call sites pass thread/topic routing context"
|
||||
)
|
||||
|
||||
|
||||
class TestSendOrEditMediaStripping:
|
||||
|
|
@ -209,6 +212,23 @@ class TestSendOrEditMediaStripping:
|
|||
edited_text = adapter.edit_message.call_args[1]["content"]
|
||||
assert "MEDIA:" not in edited_text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_edit_preserves_metadata(self):
|
||||
"""Edit call forwards stored metadata so thread routing survives streaming."""
|
||||
adapter = MagicMock()
|
||||
send_result = SimpleNamespace(success=True, message_id="msg_1")
|
||||
edit_result = SimpleNamespace(success=True)
|
||||
adapter.send = AsyncMock(return_value=send_result)
|
||||
adapter.edit_message = AsyncMock(return_value=edit_result)
|
||||
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||
|
||||
consumer = GatewayStreamConsumer(adapter, "chat_123", metadata={"thread_id": "thread-42"})
|
||||
await consumer._send_or_edit("Starting response...")
|
||||
await consumer._send_or_edit("Updated response")
|
||||
|
||||
adapter.edit_message.assert_called_once()
|
||||
assert adapter.edit_message.call_args[1]["metadata"] == {"thread_id": "thread-42"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_media_only_skips_send(self):
|
||||
"""If text is entirely MEDIA: tags, the send is skipped."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue