mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-07 02:51:50 +00:00
fix(gateway): preserve thread routing for /update progress and prompts
This commit is contained in:
parent
f48ba47d1e
commit
25cbe3e1d6
5 changed files with 138 additions and 9 deletions
|
|
@ -3078,6 +3078,7 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
async def send_update_prompt(
|
||||
self, chat_id: str, prompt: str, default: str = "",
|
||||
session_key: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an interactive button-based update prompt (Yes / No).
|
||||
|
||||
|
|
@ -3087,9 +3088,10 @@ class DiscordAdapter(BasePlatformAdapter):
|
|||
if not self._client or not DISCORD_AVAILABLE:
|
||||
return SendResult(success=False, error="Not connected")
|
||||
try:
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
target_id = metadata.get("thread_id") if metadata and metadata.get("thread_id") else chat_id
|
||||
channel = self._client.get_channel(int(target_id))
|
||||
if not channel:
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
channel = await self._client.fetch_channel(int(target_id))
|
||||
|
||||
default_hint = f" (default: {default})" if default else ""
|
||||
embed = discord.Embed(
|
||||
|
|
|
|||
|
|
@ -1360,6 +1360,7 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
async def send_update_prompt(
|
||||
self, chat_id: str, prompt: str, default: str = "",
|
||||
session_key: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> SendResult:
|
||||
"""Send an inline-keyboard update prompt (Yes / No buttons).
|
||||
|
||||
|
|
@ -1377,11 +1378,14 @@ class TelegramAdapter(BasePlatformAdapter):
|
|||
InlineKeyboardButton("✗ No", callback_data="update_prompt:n"),
|
||||
]
|
||||
])
|
||||
thread_id = self._metadata_thread_id(metadata)
|
||||
message_thread_id = self._message_thread_id_for_send(thread_id)
|
||||
msg = await self._bot.send_message(
|
||||
chat_id=int(chat_id),
|
||||
text=text,
|
||||
parse_mode=ParseMode.MARKDOWN,
|
||||
reply_markup=keyboard,
|
||||
message_thread_id=message_thread_id,
|
||||
**self._link_preview_kwargs(),
|
||||
)
|
||||
return SendResult(success=True, message_id=str(msg.message_id))
|
||||
|
|
|
|||
|
|
@ -9490,6 +9490,8 @@ class GatewayRunner:
|
|||
"session_key": session_key,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
if event.source.thread_id:
|
||||
pending["thread_id"] = event.source.thread_id
|
||||
_tmp_pending = pending_path.with_suffix(".tmp")
|
||||
_tmp_pending.write_text(json.dumps(pending))
|
||||
_tmp_pending.replace(pending_path)
|
||||
|
|
@ -9575,6 +9577,7 @@ class GatewayRunner:
|
|||
adapter = None
|
||||
chat_id = None
|
||||
session_key = None
|
||||
metadata = None
|
||||
for path in (claimed_path, pending_path):
|
||||
if path.exists():
|
||||
try:
|
||||
|
|
@ -9582,6 +9585,8 @@ class GatewayRunner:
|
|||
platform_str = pending.get("platform")
|
||||
chat_id = pending.get("chat_id")
|
||||
session_key = pending.get("session_key")
|
||||
thread_id = pending.get("thread_id")
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
if platform_str and chat_id:
|
||||
platform = Platform(platform_str)
|
||||
adapter = self.adapters.get(platform)
|
||||
|
|
@ -9629,7 +9634,7 @@ class GatewayRunner:
|
|||
chunks = [clean[i:i + max_chunk] for i in range(0, len(clean), max_chunk)]
|
||||
for chunk in chunks:
|
||||
try:
|
||||
await adapter.send(chat_id, f"```\n{chunk}\n```")
|
||||
await adapter.send(chat_id, f"```\n{chunk}\n```", metadata=metadata)
|
||||
except Exception as e:
|
||||
logger.debug("Update stream send failed: %s", e)
|
||||
|
||||
|
|
@ -9652,9 +9657,13 @@ class GatewayRunner:
|
|||
exit_code_raw = exit_code_path.read_text().strip() or "1"
|
||||
exit_code = int(exit_code_raw)
|
||||
if exit_code == 0:
|
||||
await adapter.send(chat_id, "✅ Hermes update finished.")
|
||||
await adapter.send(chat_id, "✅ Hermes update finished.", metadata=metadata)
|
||||
else:
|
||||
await adapter.send(chat_id, "❌ Hermes update failed (exit code {}).".format(exit_code))
|
||||
await adapter.send(
|
||||
chat_id,
|
||||
"❌ Hermes update failed (exit code {}).".format(exit_code),
|
||||
metadata=metadata,
|
||||
)
|
||||
logger.info("Update finished (exit=%s), notified %s", exit_code, session_key)
|
||||
except Exception as e:
|
||||
logger.warning("Update final notification failed: %s", e)
|
||||
|
|
@ -9704,6 +9713,7 @@ class GatewayRunner:
|
|||
prompt=prompt_text,
|
||||
default=default,
|
||||
session_key=session_key,
|
||||
metadata=metadata,
|
||||
)
|
||||
sent_buttons = True
|
||||
except Exception as btn_err:
|
||||
|
|
@ -9715,7 +9725,8 @@ class GatewayRunner:
|
|||
f"⚕ **Update needs your input:**\n\n"
|
||||
f"{prompt_text}{default_hint}\n\n"
|
||||
f"Reply `/approve` (yes) or `/deny` (no), "
|
||||
f"or type your answer directly."
|
||||
f"or type your answer directly.",
|
||||
metadata=metadata,
|
||||
)
|
||||
self._update_prompt_pending[session_key] = True
|
||||
# Remove the prompt file so it isn't re-read on the
|
||||
|
|
@ -9735,7 +9746,11 @@ class GatewayRunner:
|
|||
exit_code_path.write_text("124")
|
||||
await _flush_buffer()
|
||||
try:
|
||||
await adapter.send(chat_id, "❌ Hermes update timed out after 30 minutes.")
|
||||
await adapter.send(
|
||||
chat_id,
|
||||
"❌ Hermes update timed out after 30 minutes.",
|
||||
metadata=metadata,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
for p in (pending_path, claimed_path, output_path,
|
||||
|
|
@ -9777,6 +9792,7 @@ class GatewayRunner:
|
|||
pending = json.loads(claimed_path.read_text())
|
||||
platform_str = pending.get("platform")
|
||||
chat_id = pending.get("chat_id")
|
||||
thread_id = pending.get("thread_id")
|
||||
|
||||
if not exit_code_path.exists():
|
||||
logger.info("Update notification deferred: update still running")
|
||||
|
|
@ -9798,6 +9814,7 @@ class GatewayRunner:
|
|||
adapter = self.adapters.get(platform)
|
||||
|
||||
if adapter and chat_id:
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
# Strip ANSI escape codes for clean display
|
||||
output = re.sub(r'\x1b\[[0-9;]*m', '', output).strip()
|
||||
if output:
|
||||
|
|
@ -9812,7 +9829,7 @@ class GatewayRunner:
|
|||
msg = "✅ Hermes update finished successfully."
|
||||
else:
|
||||
msg = "❌ Hermes update failed. Check the gateway logs or run `hermes update` manually for details."
|
||||
await adapter.send(chat_id, msg)
|
||||
await adapter.send(chat_id, msg, metadata=metadata)
|
||||
logger.info(
|
||||
"Sent post-update notification to %s:%s (exit=%s)",
|
||||
platform_str,
|
||||
|
|
|
|||
|
|
@ -17,13 +17,14 @@ from gateway.session import SessionSource
|
|||
|
||||
|
||||
def _make_event(text="/update", platform=Platform.TELEGRAM,
|
||||
user_id="12345", chat_id="67890"):
|
||||
user_id="12345", chat_id="67890", thread_id=None):
|
||||
"""Build a MessageEvent for testing."""
|
||||
source = SessionSource(
|
||||
platform=platform,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
user_name="testuser",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return MessageEvent(text=text, source=source)
|
||||
|
||||
|
|
@ -214,6 +215,34 @@ class TestHandleUpdateCommand:
|
|||
assert "timestamp" in data
|
||||
assert not (hermes_home / ".update_exit_code").exists()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writes_pending_marker_with_thread_id(self, tmp_path):
|
||||
"""Persists thread_id so update notifications can route back to the thread."""
|
||||
runner = _make_runner()
|
||||
event = _make_event(
|
||||
platform=Platform.TELEGRAM,
|
||||
chat_id="99999",
|
||||
thread_id="777",
|
||||
)
|
||||
|
||||
fake_root = tmp_path / "project"
|
||||
fake_root.mkdir()
|
||||
(fake_root / ".git").mkdir()
|
||||
(fake_root / "gateway").mkdir()
|
||||
(fake_root / "gateway" / "run.py").touch()
|
||||
fake_file = str(fake_root / "gateway" / "run.py")
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home), \
|
||||
patch("gateway.run.__file__", fake_file), \
|
||||
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/setsid"), \
|
||||
patch("subprocess.Popen"):
|
||||
await runner._handle_update_command(event)
|
||||
|
||||
data = json.loads((hermes_home / ".update_pending.json").read_text())
|
||||
assert data["thread_id"] == "777"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spawns_setsid(self, tmp_path):
|
||||
"""Uses setsid when available."""
|
||||
|
|
@ -432,6 +461,31 @@ class TestSendUpdateNotification:
|
|||
assert call_args[0][0] == "67890" # chat_id
|
||||
assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sends_notification_with_thread_metadata(self, tmp_path):
|
||||
"""Final update notification preserves thread metadata when present."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {
|
||||
"platform": "telegram",
|
||||
"chat_id": "67890",
|
||||
"thread_id": "777",
|
||||
"user_id": "12345",
|
||||
}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text("done")
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
mock_adapter = AsyncMock()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
await runner._send_update_notification()
|
||||
|
||||
assert mock_adapter.send.call_args.kwargs["metadata"] == {"thread_id": "777"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_strips_ansi_codes(self, tmp_path):
|
||||
"""ANSI escape codes are removed from output."""
|
||||
|
|
|
|||
|
|
@ -321,6 +321,58 @@ class TestWatchUpdateProgress:
|
|||
# Check session was marked as having pending prompt
|
||||
# (may be cleared by the time we check since update finished)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_forwarding_preserves_thread_metadata(self, tmp_path):
|
||||
"""Forwarded update prompts keep the originating thread/topic metadata."""
|
||||
runner = _make_runner()
|
||||
hermes_home = tmp_path / "hermes"
|
||||
hermes_home.mkdir()
|
||||
|
||||
pending = {
|
||||
"platform": "telegram",
|
||||
"chat_id": "111",
|
||||
"thread_id": "777",
|
||||
"user_id": "222",
|
||||
"session_key": "agent:main:telegram:group:111:777",
|
||||
}
|
||||
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||
(hermes_home / ".update_output.txt").write_text("")
|
||||
(hermes_home / ".update_prompt.json").write_text(json.dumps({
|
||||
"prompt": "Restore local changes? [Y/n]",
|
||||
"default": "y",
|
||||
"id": "threaded-prompt",
|
||||
}))
|
||||
|
||||
class _PromptCapableAdapter:
|
||||
def __init__(self):
|
||||
self.send = AsyncMock()
|
||||
self.prompt_calls = AsyncMock()
|
||||
|
||||
async def send_update_prompt(self, **kwargs):
|
||||
return await self.prompt_calls(**kwargs)
|
||||
|
||||
mock_adapter = _PromptCapableAdapter()
|
||||
runner.adapters = {Platform.TELEGRAM: mock_adapter}
|
||||
|
||||
async def finish_after_prompt():
|
||||
await asyncio.sleep(0.3)
|
||||
(hermes_home / ".update_response").write_text("y")
|
||||
await asyncio.sleep(0.2)
|
||||
(hermes_home / ".update_exit_code").write_text("0")
|
||||
|
||||
with patch("gateway.run._hermes_home", hermes_home):
|
||||
task = asyncio.create_task(finish_after_prompt())
|
||||
await runner._watch_update_progress(
|
||||
poll_interval=0.1,
|
||||
stream_interval=0.2,
|
||||
timeout=5.0,
|
||||
)
|
||||
await task
|
||||
|
||||
assert mock_adapter.prompt_calls.call_args.kwargs["metadata"] == {
|
||||
"thread_id": "777"
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleans_up_on_completion(self, tmp_path):
|
||||
"""All marker files are cleaned up when update finishes."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue