mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-30 01:41:43 +00:00
fix(gateway): preserve thread metadata edits
This commit is contained in:
parent
6051fba9dc
commit
fa8464e200
14 changed files with 141 additions and 16 deletions
|
|
@ -1104,6 +1104,7 @@ class BasePlatformAdapter(ABC):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1009,6 +1009,7 @@ class DingTalkAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit an AI Card by streaming updated content.
|
"""Edit an AI Card by streaming updated content.
|
||||||
|
|
|
||||||
|
|
@ -1280,15 +1280,17 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit a previously sent Discord message."""
|
"""Edit a previously sent Discord message."""
|
||||||
if not self._client:
|
if not self._client:
|
||||||
return SendResult(success=False, error="Not connected")
|
return SendResult(success=False, error="Not connected")
|
||||||
try:
|
try:
|
||||||
channel = self._client.get_channel(int(chat_id))
|
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
|
||||||
|
channel = self._client.get_channel(int(target_chat_id))
|
||||||
if not channel:
|
if not channel:
|
||||||
channel = await self._client.fetch_channel(int(chat_id))
|
channel = await self._client.fetch_channel(int(target_chat_id))
|
||||||
msg = await channel.fetch_message(int(message_id))
|
msg = await channel.fetch_message(int(message_id))
|
||||||
formatted = self.format_message(content)
|
formatted = self.format_message(content)
|
||||||
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
if len(formatted) > self.MAX_MESSAGE_LENGTH:
|
||||||
|
|
@ -1989,13 +1991,15 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||||
if chat_id in self._typing_tasks:
|
if chat_id in self._typing_tasks:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
target_chat_id = str((metadata or {}).get("thread_id") or chat_id)
|
||||||
|
|
||||||
async def _typing_loop() -> None:
|
async def _typing_loop() -> None:
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
route = discord.http.Route(
|
route = discord.http.Route(
|
||||||
"POST", "/channels/{channel_id}/typing",
|
"POST", "/channels/{channel_id}/typing",
|
||||||
channel_id=chat_id,
|
channel_id=target_chat_id,
|
||||||
)
|
)
|
||||||
await self._client.http.request(route)
|
await self._client.http.request(route)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
|
|
|
||||||
|
|
@ -1694,6 +1694,7 @@ class FeishuAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit a previously sent Feishu text/post message."""
|
"""Edit a previously sent Feishu text/post message."""
|
||||||
|
|
|
||||||
|
|
@ -825,7 +825,13 @@ class MatrixAdapter(BasePlatformAdapter):
|
||||||
|
|
||||||
|
|
||||||
async def edit_message(
|
async def edit_message(
|
||||||
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
message_id: str,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
metadata=None,
|
||||||
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit an existing message (via m.replace)."""
|
"""Edit an existing message (via m.replace)."""
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -304,7 +304,13 @@ class MattermostAdapter(BasePlatformAdapter):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def edit_message(
|
async def edit_message(
|
||||||
self, chat_id: str, message_id: str, content: str, *, finalize: bool = False
|
self,
|
||||||
|
chat_id: str,
|
||||||
|
message_id: str,
|
||||||
|
content: str,
|
||||||
|
*,
|
||||||
|
metadata=None,
|
||||||
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit an existing post."""
|
"""Edit an existing post."""
|
||||||
formatted = self.format_message(content)
|
formatted = self.format_message(content)
|
||||||
|
|
|
||||||
|
|
@ -328,6 +328,7 @@ class SlackAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit a previously sent Slack message."""
|
"""Edit a previously sent Slack message."""
|
||||||
|
|
|
||||||
|
|
@ -1127,6 +1127,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit a previously sent Telegram message."""
|
"""Edit a previously sent Telegram message."""
|
||||||
|
|
|
||||||
|
|
@ -728,6 +728,7 @@ class WhatsAppAdapter(BasePlatformAdapter):
|
||||||
message_id: str,
|
message_id: str,
|
||||||
content: str,
|
content: str,
|
||||||
*,
|
*,
|
||||||
|
metadata=None,
|
||||||
finalize: bool = False,
|
finalize: bool = False,
|
||||||
) -> SendResult:
|
) -> SendResult:
|
||||||
"""Edit a previously sent message via the WhatsApp bridge."""
|
"""Edit a previously sent message via the WhatsApp bridge."""
|
||||||
|
|
|
||||||
|
|
@ -9538,6 +9538,7 @@ class GatewayRunner:
|
||||||
chat_id=source.chat_id,
|
chat_id=source.chat_id,
|
||||||
message_id=progress_msg_id,
|
message_id=progress_msg_id,
|
||||||
content=full_text,
|
content=full_text,
|
||||||
|
metadata=_progress_metadata,
|
||||||
)
|
)
|
||||||
if not result.success:
|
if not result.success:
|
||||||
_err = (getattr(result, "error", "") or "").lower()
|
_err = (getattr(result, "error", "") or "").lower()
|
||||||
|
|
@ -9592,6 +9593,7 @@ class GatewayRunner:
|
||||||
chat_id=source.chat_id,
|
chat_id=source.chat_id,
|
||||||
message_id=progress_msg_id,
|
message_id=progress_msg_id,
|
||||||
content=full_text,
|
content=full_text,
|
||||||
|
metadata=_progress_metadata,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -590,6 +590,7 @@ class GatewayStreamConsumer:
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
content=clean_text,
|
content=clean_text,
|
||||||
|
metadata=self.metadata,
|
||||||
)
|
)
|
||||||
if result.success:
|
if result.success:
|
||||||
self._last_sent_text = clean_text
|
self._last_sent_text = clean_text
|
||||||
|
|
@ -708,6 +709,7 @@ class GatewayStreamConsumer:
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
content=prefix,
|
content=prefix,
|
||||||
|
metadata=self.metadata,
|
||||||
)
|
)
|
||||||
self._last_sent_text = prefix
|
self._last_sent_text = prefix
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -791,6 +793,7 @@ class GatewayStreamConsumer:
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
message_id=self._message_id,
|
message_id=self._message_id,
|
||||||
content=text,
|
content=text,
|
||||||
|
metadata=self.metadata,
|
||||||
finalize=finalize,
|
finalize=finalize,
|
||||||
)
|
)
|
||||||
if result.success:
|
if result.success:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""Tests for Discord free-response defaults and mention gating."""
|
"""Tests for Discord free-response defaults and mention gating."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
@ -82,6 +83,15 @@ class FakeThread:
|
||||||
self.topic = None
|
self.topic = None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMessage:
|
||||||
|
def __init__(self, message_id: int = 1):
|
||||||
|
self.id = message_id
|
||||||
|
self.edits = []
|
||||||
|
|
||||||
|
async def edit(self, *, content: str):
|
||||||
|
self.edits.append(content)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def adapter(monkeypatch):
|
def adapter(monkeypatch):
|
||||||
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
monkeypatch.setattr(discord_platform.discord, "DMChannel", FakeDMChannel, raising=False)
|
||||||
|
|
@ -179,6 +189,55 @@ async def test_discord_forum_threads_are_handled_as_threads(adapter, monkeypatch
|
||||||
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
|
assert event.source.chat_name == "Hermes Server / support-forum / Can Hermes reply here?"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discord_edit_message_uses_thread_metadata(adapter):
|
||||||
|
parent_channel = FakeTextChannel(channel_id=123, name="general")
|
||||||
|
thread_channel = FakeThread(channel_id=456, name="planning", parent=parent_channel)
|
||||||
|
message = FakeMessage(message_id=789)
|
||||||
|
thread_channel.fetch_message = AsyncMock(return_value=message)
|
||||||
|
adapter._client = SimpleNamespace(
|
||||||
|
get_channel=MagicMock(side_effect=lambda cid: thread_channel if cid == 456 else parent_channel),
|
||||||
|
fetch_channel=AsyncMock(return_value=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await adapter.edit_message(
|
||||||
|
chat_id="123",
|
||||||
|
message_id="789",
|
||||||
|
content="updated in thread",
|
||||||
|
metadata={"thread_id": "456"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
thread_channel.fetch_message.assert_awaited_once_with(789)
|
||||||
|
assert message.edits == ["updated in thread"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discord_send_typing_uses_thread_metadata(adapter, monkeypatch):
|
||||||
|
recorded = []
|
||||||
|
|
||||||
|
class FakeRoute:
|
||||||
|
def __init__(self, _method, _path, *, channel_id):
|
||||||
|
self.channel_id = channel_id
|
||||||
|
|
||||||
|
async def fake_request(route):
|
||||||
|
recorded.append(route.channel_id)
|
||||||
|
task = adapter._typing_tasks.get("123")
|
||||||
|
if task:
|
||||||
|
task.cancel()
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
monkeypatch.setattr(discord_platform.discord.http, "Route", FakeRoute, raising=False)
|
||||||
|
adapter._client = SimpleNamespace(http=SimpleNamespace(request=fake_request))
|
||||||
|
|
||||||
|
await adapter.send_typing("123", metadata={"thread_id": "456"})
|
||||||
|
task = adapter._typing_tasks["123"]
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await task
|
||||||
|
|
||||||
|
assert recorded == ["456"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
|
async def test_discord_can_still_require_mentions_when_enabled(adapter, monkeypatch):
|
||||||
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
monkeypatch.setenv("DISCORD_REQUIRE_MENTION", "true")
|
||||||
|
|
|
||||||
|
|
@ -38,12 +38,13 @@ class ProgressCaptureAdapter(BasePlatformAdapter):
|
||||||
)
|
)
|
||||||
return SendResult(success=True, message_id="progress-1")
|
return SendResult(success=True, message_id="progress-1")
|
||||||
|
|
||||||
async def edit_message(self, chat_id, message_id, content) -> SendResult:
|
async def edit_message(self, chat_id, message_id, content, metadata=None) -> SendResult:
|
||||||
self.edits.append(
|
self.edits.append(
|
||||||
{
|
{
|
||||||
"chat_id": chat_id,
|
"chat_id": chat_id,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
"content": content,
|
"content": content,
|
||||||
|
"metadata": metadata,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return SendResult(success=True, message_id=message_id)
|
return SendResult(success=True, message_id=message_id)
|
||||||
|
|
@ -110,6 +111,23 @@ class DelayedProgressAgent:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SlowProgressAgent:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.tool_progress_callback = kwargs.get("tool_progress_callback")
|
||||||
|
self.tools = []
|
||||||
|
|
||||||
|
def run_conversation(self, message, conversation_history=None, task_id=None):
|
||||||
|
self.tool_progress_callback("tool.started", "terminal", "first command", {})
|
||||||
|
time.sleep(1.8)
|
||||||
|
self.tool_progress_callback("tool.started", "terminal", "second command", {})
|
||||||
|
time.sleep(1.8)
|
||||||
|
return {
|
||||||
|
"final_response": "done",
|
||||||
|
"messages": [],
|
||||||
|
"api_calls": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DelayedInterimAgent:
|
class DelayedInterimAgent:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
self.interim_assistant_callback = kwargs.get("interim_assistant_callback")
|
||||||
|
|
@ -160,7 +178,7 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||||
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
monkeypatch.setitem(sys.modules, "dotenv", fake_dotenv)
|
||||||
|
|
||||||
fake_run_agent = types.ModuleType("run_agent")
|
fake_run_agent = types.ModuleType("run_agent")
|
||||||
fake_run_agent.AIAgent = FakeAgent
|
fake_run_agent.AIAgent = SlowProgressAgent
|
||||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||||
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
import tools.terminal_tool # noqa: F401 - register terminal emoji for this fake-agent test
|
||||||
|
|
||||||
|
|
@ -189,12 +207,13 @@ async def test_run_agent_progress_stays_in_originating_topic(monkeypatch, tmp_pa
|
||||||
assert adapter.sent == [
|
assert adapter.sent == [
|
||||||
{
|
{
|
||||||
"chat_id": "-1001",
|
"chat_id": "-1001",
|
||||||
"content": '💻 terminal: "pwd"',
|
"content": '💻 terminal: "first command"',
|
||||||
"reply_to": None,
|
"reply_to": None,
|
||||||
"metadata": {"thread_id": "17585"},
|
"metadata": {"thread_id": "17585"},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
assert adapter.edits
|
assert adapter.edits
|
||||||
|
assert any(call["metadata"] == {"thread_id": "17585"} for call in adapter.edits)
|
||||||
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)
|
assert all(call["metadata"] == {"thread_id": "17585"} for call in adapter.typing)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -134,15 +134,14 @@ class TestFinalizeCapabilityGate:
|
||||||
|
|
||||||
|
|
||||||
class TestEditMessageFinalizeSignature:
|
class TestEditMessageFinalizeSignature:
|
||||||
"""Every concrete platform adapter must accept the ``finalize`` kwarg.
|
"""Every concrete platform adapter must accept shared edit kwargs.
|
||||||
|
|
||||||
stream_consumer._send_or_edit always passes ``finalize=`` to
|
stream_consumer._send_or_edit always passes ``finalize=`` and shared
|
||||||
``adapter.edit_message(...)`` (see gateway/stream_consumer.py). An
|
callers may pass ``metadata=`` to ``adapter.edit_message(...)``. An
|
||||||
adapter that overrides edit_message without accepting finalize raises
|
adapter that overrides edit_message without accepting either kwarg
|
||||||
TypeError the first time streaming hits a segment break or final edit.
|
raises TypeError at runtime. Guard the contract with an explicit
|
||||||
Guard the contract with an explicit signature check so it cannot
|
signature check so it cannot silently regress — existing tests use
|
||||||
silently regress — existing tests use MagicMock which swallows any
|
MagicMock which swallows any kwarg and cannot catch this.
|
||||||
kwarg and cannot catch this.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -168,6 +167,10 @@ class TestEditMessageFinalizeSignature:
|
||||||
f"{class_name}.edit_message must accept 'finalize' kwarg; "
|
f"{class_name}.edit_message must accept 'finalize' kwarg; "
|
||||||
f"stream_consumer._send_or_edit passes it unconditionally"
|
f"stream_consumer._send_or_edit passes it unconditionally"
|
||||||
)
|
)
|
||||||
|
assert "metadata" in params, (
|
||||||
|
f"{class_name}.edit_message must accept 'metadata' kwarg; "
|
||||||
|
f"shared edit call sites pass thread/topic routing context"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSendOrEditMediaStripping:
|
class TestSendOrEditMediaStripping:
|
||||||
|
|
@ -209,6 +212,23 @@ class TestSendOrEditMediaStripping:
|
||||||
edited_text = adapter.edit_message.call_args[1]["content"]
|
edited_text = adapter.edit_message.call_args[1]["content"]
|
||||||
assert "MEDIA:" not in edited_text
|
assert "MEDIA:" not in edited_text
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_edit_preserves_metadata(self):
|
||||||
|
"""Edit call forwards stored metadata so thread routing survives streaming."""
|
||||||
|
adapter = MagicMock()
|
||||||
|
send_result = SimpleNamespace(success=True, message_id="msg_1")
|
||||||
|
edit_result = SimpleNamespace(success=True)
|
||||||
|
adapter.send = AsyncMock(return_value=send_result)
|
||||||
|
adapter.edit_message = AsyncMock(return_value=edit_result)
|
||||||
|
adapter.MAX_MESSAGE_LENGTH = 4096
|
||||||
|
|
||||||
|
consumer = GatewayStreamConsumer(adapter, "chat_123", metadata={"thread_id": "thread-42"})
|
||||||
|
await consumer._send_or_edit("Starting response...")
|
||||||
|
await consumer._send_or_edit("Updated response")
|
||||||
|
|
||||||
|
adapter.edit_message.assert_called_once()
|
||||||
|
assert adapter.edit_message.call_args[1]["metadata"] == {"thread_id": "thread-42"}
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_media_only_skips_send(self):
|
async def test_media_only_skips_send(self):
|
||||||
"""If text is entirely MEDIA: tags, the send is skipped."""
|
"""If text is entirely MEDIA: tags, the send is skipped."""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue