diff --git a/gateway/platforms/discord.py b/gateway/platforms/discord.py index 102e055ffc..fcd2cbc996 100644 --- a/gateway/platforms/discord.py +++ b/gateway/platforms/discord.py @@ -3078,6 +3078,7 @@ class DiscordAdapter(BasePlatformAdapter): async def send_update_prompt( self, chat_id: str, prompt: str, default: str = "", session_key: str = "", + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send an interactive button-based update prompt (Yes / No). @@ -3087,9 +3088,10 @@ class DiscordAdapter(BasePlatformAdapter): if not self._client or not DISCORD_AVAILABLE: return SendResult(success=False, error="Not connected") try: - channel = self._client.get_channel(int(chat_id)) + target_id = metadata.get("thread_id") if metadata and metadata.get("thread_id") else chat_id + channel = self._client.get_channel(int(target_id)) if not channel: - channel = await self._client.fetch_channel(int(chat_id)) + channel = await self._client.fetch_channel(int(target_id)) default_hint = f" (default: {default})" if default else "" embed = discord.Embed( diff --git a/gateway/platforms/telegram.py b/gateway/platforms/telegram.py index 307c6b89ab..3822cb72f8 100644 --- a/gateway/platforms/telegram.py +++ b/gateway/platforms/telegram.py @@ -1360,6 +1360,7 @@ class TelegramAdapter(BasePlatformAdapter): async def send_update_prompt( self, chat_id: str, prompt: str, default: str = "", session_key: str = "", + metadata: Optional[Dict[str, Any]] = None, ) -> SendResult: """Send an inline-keyboard update prompt (Yes / No buttons). @@ -1377,11 +1378,14 @@ class TelegramAdapter(BasePlatformAdapter): InlineKeyboardButton("✗ No", callback_data="update_prompt:n"), ] ]) + thread_id = self._metadata_thread_id(metadata) + message_thread_id = self._message_thread_id_for_send(thread_id) msg = await self._bot.send_message( chat_id=int(chat_id), text=text, parse_mode=ParseMode.MARKDOWN, reply_markup=keyboard, + message_thread_id=message_thread_id, **self._link_preview_kwargs(), ) return SendResult(success=True, message_id=str(msg.message_id)) diff --git a/gateway/run.py b/gateway/run.py index 7714ca99d8..4890ebe66f 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -9490,6 +9490,8 @@ class GatewayRunner: "session_key": session_key, "timestamp": datetime.now().isoformat(), } + if event.source.thread_id: + pending["thread_id"] = event.source.thread_id _tmp_pending = pending_path.with_suffix(".tmp") _tmp_pending.write_text(json.dumps(pending)) _tmp_pending.replace(pending_path) @@ -9575,6 +9577,7 @@ class GatewayRunner: adapter = None chat_id = None session_key = None + metadata = None for path in (claimed_path, pending_path): if path.exists(): try: @@ -9582,6 +9585,8 @@ class GatewayRunner: platform_str = pending.get("platform") chat_id = pending.get("chat_id") session_key = pending.get("session_key") + thread_id = pending.get("thread_id") + metadata = {"thread_id": thread_id} if thread_id else None if platform_str and chat_id: platform = Platform(platform_str) adapter = self.adapters.get(platform) @@ -9629,7 +9634,7 @@ class GatewayRunner: chunks = [clean[i:i + max_chunk] for i in range(0, len(clean), max_chunk)] for chunk in chunks: try: - await adapter.send(chat_id, f"```\n{chunk}\n```") + await adapter.send(chat_id, f"```\n{chunk}\n```", metadata=metadata) except Exception as e: logger.debug("Update stream send failed: %s", e) @@ -9652,9 +9657,13 @@ class GatewayRunner: exit_code_raw = exit_code_path.read_text().strip() or "1" exit_code = int(exit_code_raw) if exit_code == 0: - await adapter.send(chat_id, "✅ Hermes update finished.") + await adapter.send(chat_id, "✅ Hermes update finished.", metadata=metadata) else: - await adapter.send(chat_id, "❌ Hermes update failed (exit code {}).".format(exit_code)) + await adapter.send( + chat_id, + "❌ Hermes update failed (exit code {}).".format(exit_code), + metadata=metadata, + ) logger.info("Update finished (exit=%s), notified %s", exit_code, session_key) except Exception as e: logger.warning("Update final notification failed: %s", e) @@ -9704,6 +9713,7 @@ class GatewayRunner: prompt=prompt_text, default=default, session_key=session_key, + metadata=metadata, ) sent_buttons = True except Exception as btn_err: @@ -9715,7 +9725,8 @@ class GatewayRunner: f"⚕ **Update needs your input:**\n\n" f"{prompt_text}{default_hint}\n\n" f"Reply `/approve` (yes) or `/deny` (no), " - f"or type your answer directly." + f"or type your answer directly.", + metadata=metadata, ) self._update_prompt_pending[session_key] = True # Remove the prompt file so it isn't re-read on the @@ -9735,7 +9746,11 @@ class GatewayRunner: exit_code_path.write_text("124") await _flush_buffer() try: - await adapter.send(chat_id, "❌ Hermes update timed out after 30 minutes.") + await adapter.send( + chat_id, + "❌ Hermes update timed out after 30 minutes.", + metadata=metadata, + ) except Exception: pass for p in (pending_path, claimed_path, output_path, @@ -9777,6 +9792,7 @@ class GatewayRunner: pending = json.loads(claimed_path.read_text()) platform_str = pending.get("platform") chat_id = pending.get("chat_id") + thread_id = pending.get("thread_id") if not exit_code_path.exists(): logger.info("Update notification deferred: update still running") @@ -9798,6 +9814,7 @@ class GatewayRunner: adapter = self.adapters.get(platform) if adapter and chat_id: + metadata = {"thread_id": thread_id} if thread_id else None # Strip ANSI escape codes for clean display output = re.sub(r'\x1b\[[0-9;]*m', '', output).strip() if output: @@ -9812,7 +9829,7 @@ class GatewayRunner: msg = "✅ Hermes update finished successfully." else: msg = "❌ Hermes update failed. Check the gateway logs or run `hermes update` manually for details." - await adapter.send(chat_id, msg) + await adapter.send(chat_id, msg, metadata=metadata) logger.info( "Sent post-update notification to %s:%s (exit=%s)", platform_str, diff --git a/tests/gateway/test_update_command.py b/tests/gateway/test_update_command.py index 05be88c2c6..aa6240aa5b 100644 --- a/tests/gateway/test_update_command.py +++ b/tests/gateway/test_update_command.py @@ -17,13 +17,14 @@ from gateway.session import SessionSource def _make_event(text="/update", platform=Platform.TELEGRAM, - user_id="12345", chat_id="67890"): + user_id="12345", chat_id="67890", thread_id=None): """Build a MessageEvent for testing.""" source = SessionSource( platform=platform, user_id=user_id, chat_id=chat_id, user_name="testuser", + thread_id=thread_id, ) return MessageEvent(text=text, source=source) @@ -214,6 +215,34 @@ class TestHandleUpdateCommand: assert "timestamp" in data assert not (hermes_home / ".update_exit_code").exists() + @pytest.mark.asyncio + async def test_writes_pending_marker_with_thread_id(self, tmp_path): + """Persists thread_id so update notifications can route back to the thread.""" + runner = _make_runner() + event = _make_event( + platform=Platform.TELEGRAM, + chat_id="99999", + thread_id="777", + ) + + fake_root = tmp_path / "project" + fake_root.mkdir() + (fake_root / ".git").mkdir() + (fake_root / "gateway").mkdir() + (fake_root / "gateway" / "run.py").touch() + fake_file = str(fake_root / "gateway" / "run.py") + hermes_home = tmp_path / "hermes" + hermes_home.mkdir() + + with patch("gateway.run._hermes_home", hermes_home), \ + patch("gateway.run.__file__", fake_file), \ + patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/setsid"), \ + patch("subprocess.Popen"): + await runner._handle_update_command(event) + + data = json.loads((hermes_home / ".update_pending.json").read_text()) + assert data["thread_id"] == "777" + @pytest.mark.asyncio async def test_spawns_setsid(self, tmp_path): """Uses setsid when available.""" @@ -432,6 +461,31 @@ class TestSendUpdateNotification: assert call_args[0][0] == "67890" # chat_id assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower() + @pytest.mark.asyncio + async def test_sends_notification_with_thread_metadata(self, tmp_path): + """Final update notification preserves thread metadata when present.""" + runner = _make_runner() + hermes_home = tmp_path / "hermes" + hermes_home.mkdir() + + pending = { + "platform": "telegram", + "chat_id": "67890", + "thread_id": "777", + "user_id": "12345", + } + (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) + (hermes_home / ".update_output.txt").write_text("done") + (hermes_home / ".update_exit_code").write_text("0") + + mock_adapter = AsyncMock() + runner.adapters = {Platform.TELEGRAM: mock_adapter} + + with patch("gateway.run._hermes_home", hermes_home): + await runner._send_update_notification() + + assert mock_adapter.send.call_args.kwargs["metadata"] == {"thread_id": "777"} + @pytest.mark.asyncio async def test_strips_ansi_codes(self, tmp_path): """ANSI escape codes are removed from output.""" diff --git a/tests/gateway/test_update_streaming.py b/tests/gateway/test_update_streaming.py index 1020ea6c46..b78eaa3327 100644 --- a/tests/gateway/test_update_streaming.py +++ b/tests/gateway/test_update_streaming.py @@ -321,6 +321,58 @@ class TestWatchUpdateProgress: # Check session was marked as having pending prompt # (may be cleared by the time we check since update finished) + @pytest.mark.asyncio + async def test_prompt_forwarding_preserves_thread_metadata(self, tmp_path): + """Forwarded update prompts keep the originating thread/topic metadata.""" + runner = _make_runner() + hermes_home = tmp_path / "hermes" + hermes_home.mkdir() + + pending = { + "platform": "telegram", + "chat_id": "111", + "thread_id": "777", + "user_id": "222", + "session_key": "agent:main:telegram:group:111:777", + } + (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) + (hermes_home / ".update_output.txt").write_text("") + (hermes_home / ".update_prompt.json").write_text(json.dumps({ + "prompt": "Restore local changes? [Y/n]", + "default": "y", + "id": "threaded-prompt", + })) + + class _PromptCapableAdapter: + def __init__(self): + self.send = AsyncMock() + self.prompt_calls = AsyncMock() + + async def send_update_prompt(self, **kwargs): + return await self.prompt_calls(**kwargs) + + mock_adapter = _PromptCapableAdapter() + runner.adapters = {Platform.TELEGRAM: mock_adapter} + + async def finish_after_prompt(): + await asyncio.sleep(0.3) + (hermes_home / ".update_response").write_text("y") + await asyncio.sleep(0.2) + (hermes_home / ".update_exit_code").write_text("0") + + with patch("gateway.run._hermes_home", hermes_home): + task = asyncio.create_task(finish_after_prompt()) + await runner._watch_update_progress( + poll_interval=0.1, + stream_interval=0.2, + timeout=5.0, + ) + await task + + assert mock_adapter.prompt_calls.call_args.kwargs["metadata"] == { + "thread_id": "777" + } + @pytest.mark.asyncio async def test_cleans_up_on_completion(self, tmp_path): """All marker files are cleaned up when update finishes."""