fix(gateway): preserve thread routing for /update progress and prompts

This commit is contained in:
Yukipukii1 2026-04-29 16:55:12 +03:00 committed by Teknium
parent f48ba47d1e
commit 25cbe3e1d6
5 changed files with 138 additions and 9 deletions

View file

@ -3078,6 +3078,7 @@ class DiscordAdapter(BasePlatformAdapter):
async def send_update_prompt(
self, chat_id: str, prompt: str, default: str = "",
session_key: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send an interactive button-based update prompt (Yes / No).
@ -3087,9 +3088,10 @@ class DiscordAdapter(BasePlatformAdapter):
if not self._client or not DISCORD_AVAILABLE:
return SendResult(success=False, error="Not connected")
try:
channel = self._client.get_channel(int(chat_id))
target_id = metadata.get("thread_id") if metadata and metadata.get("thread_id") else chat_id
channel = self._client.get_channel(int(target_id))
if not channel:
channel = await self._client.fetch_channel(int(chat_id))
channel = await self._client.fetch_channel(int(target_id))
default_hint = f" (default: {default})" if default else ""
embed = discord.Embed(

View file

@ -1360,6 +1360,7 @@ class TelegramAdapter(BasePlatformAdapter):
async def send_update_prompt(
self, chat_id: str, prompt: str, default: str = "",
session_key: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> SendResult:
"""Send an inline-keyboard update prompt (Yes / No buttons).
@ -1377,11 +1378,14 @@ class TelegramAdapter(BasePlatformAdapter):
InlineKeyboardButton("✗ No", callback_data="update_prompt:n"),
]
])
thread_id = self._metadata_thread_id(metadata)
message_thread_id = self._message_thread_id_for_send(thread_id)
msg = await self._bot.send_message(
chat_id=int(chat_id),
text=text,
parse_mode=ParseMode.MARKDOWN,
reply_markup=keyboard,
message_thread_id=message_thread_id,
**self._link_preview_kwargs(),
)
return SendResult(success=True, message_id=str(msg.message_id))

View file

@ -9490,6 +9490,8 @@ class GatewayRunner:
"session_key": session_key,
"timestamp": datetime.now().isoformat(),
}
if event.source.thread_id:
pending["thread_id"] = event.source.thread_id
_tmp_pending = pending_path.with_suffix(".tmp")
_tmp_pending.write_text(json.dumps(pending))
_tmp_pending.replace(pending_path)
@ -9575,6 +9577,7 @@ class GatewayRunner:
adapter = None
chat_id = None
session_key = None
metadata = None
for path in (claimed_path, pending_path):
if path.exists():
try:
@ -9582,6 +9585,8 @@ class GatewayRunner:
platform_str = pending.get("platform")
chat_id = pending.get("chat_id")
session_key = pending.get("session_key")
thread_id = pending.get("thread_id")
metadata = {"thread_id": thread_id} if thread_id else None
if platform_str and chat_id:
platform = Platform(platform_str)
adapter = self.adapters.get(platform)
@ -9629,7 +9634,7 @@ class GatewayRunner:
chunks = [clean[i:i + max_chunk] for i in range(0, len(clean), max_chunk)]
for chunk in chunks:
try:
await adapter.send(chat_id, f"```\n{chunk}\n```")
await adapter.send(chat_id, f"```\n{chunk}\n```", metadata=metadata)
except Exception as e:
logger.debug("Update stream send failed: %s", e)
@ -9652,9 +9657,13 @@ class GatewayRunner:
exit_code_raw = exit_code_path.read_text().strip() or "1"
exit_code = int(exit_code_raw)
if exit_code == 0:
await adapter.send(chat_id, "✅ Hermes update finished.")
await adapter.send(chat_id, "✅ Hermes update finished.", metadata=metadata)
else:
await adapter.send(chat_id, "❌ Hermes update failed (exit code {}).".format(exit_code))
await adapter.send(
chat_id,
"❌ Hermes update failed (exit code {}).".format(exit_code),
metadata=metadata,
)
logger.info("Update finished (exit=%s), notified %s", exit_code, session_key)
except Exception as e:
logger.warning("Update final notification failed: %s", e)
@ -9704,6 +9713,7 @@ class GatewayRunner:
prompt=prompt_text,
default=default,
session_key=session_key,
metadata=metadata,
)
sent_buttons = True
except Exception as btn_err:
@ -9715,7 +9725,8 @@ class GatewayRunner:
f"⚕ **Update needs your input:**\n\n"
f"{prompt_text}{default_hint}\n\n"
f"Reply `/approve` (yes) or `/deny` (no), "
f"or type your answer directly."
f"or type your answer directly.",
metadata=metadata,
)
self._update_prompt_pending[session_key] = True
# Remove the prompt file so it isn't re-read on the
@ -9735,7 +9746,11 @@ class GatewayRunner:
exit_code_path.write_text("124")
await _flush_buffer()
try:
await adapter.send(chat_id, "❌ Hermes update timed out after 30 minutes.")
await adapter.send(
chat_id,
"❌ Hermes update timed out after 30 minutes.",
metadata=metadata,
)
except Exception:
pass
for p in (pending_path, claimed_path, output_path,
@ -9777,6 +9792,7 @@ class GatewayRunner:
pending = json.loads(claimed_path.read_text())
platform_str = pending.get("platform")
chat_id = pending.get("chat_id")
thread_id = pending.get("thread_id")
if not exit_code_path.exists():
logger.info("Update notification deferred: update still running")
@ -9798,6 +9814,7 @@ class GatewayRunner:
adapter = self.adapters.get(platform)
if adapter and chat_id:
metadata = {"thread_id": thread_id} if thread_id else None
# Strip ANSI escape codes for clean display
output = re.sub(r'\x1b\[[0-9;]*m', '', output).strip()
if output:
@ -9812,7 +9829,7 @@ class GatewayRunner:
msg = "✅ Hermes update finished successfully."
else:
msg = "❌ Hermes update failed. Check the gateway logs or run `hermes update` manually for details."
await adapter.send(chat_id, msg)
await adapter.send(chat_id, msg, metadata=metadata)
logger.info(
"Sent post-update notification to %s:%s (exit=%s)",
platform_str,

View file

@ -17,13 +17,14 @@ from gateway.session import SessionSource
def _make_event(text="/update", platform=Platform.TELEGRAM,
user_id="12345", chat_id="67890"):
user_id="12345", chat_id="67890", thread_id=None):
"""Build a MessageEvent for testing."""
source = SessionSource(
platform=platform,
user_id=user_id,
chat_id=chat_id,
user_name="testuser",
thread_id=thread_id,
)
return MessageEvent(text=text, source=source)
@ -214,6 +215,34 @@ class TestHandleUpdateCommand:
assert "timestamp" in data
assert not (hermes_home / ".update_exit_code").exists()
@pytest.mark.asyncio
async def test_writes_pending_marker_with_thread_id(self, tmp_path):
"""Persists thread_id so update notifications can route back to the thread."""
runner = _make_runner()
event = _make_event(
platform=Platform.TELEGRAM,
chat_id="99999",
thread_id="777",
)
fake_root = tmp_path / "project"
fake_root.mkdir()
(fake_root / ".git").mkdir()
(fake_root / "gateway").mkdir()
(fake_root / "gateway" / "run.py").touch()
fake_file = str(fake_root / "gateway" / "run.py")
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
with patch("gateway.run._hermes_home", hermes_home), \
patch("gateway.run.__file__", fake_file), \
patch("shutil.which", side_effect=lambda x: "/usr/bin/hermes" if x == "hermes" else "/usr/bin/setsid"), \
patch("subprocess.Popen"):
await runner._handle_update_command(event)
data = json.loads((hermes_home / ".update_pending.json").read_text())
assert data["thread_id"] == "777"
@pytest.mark.asyncio
async def test_spawns_setsid(self, tmp_path):
"""Uses setsid when available."""
@ -432,6 +461,31 @@ class TestSendUpdateNotification:
assert call_args[0][0] == "67890" # chat_id
assert "Update complete" in call_args[0][1] or "update finished" in call_args[0][1].lower()
@pytest.mark.asyncio
async def test_sends_notification_with_thread_metadata(self, tmp_path):
"""Final update notification preserves thread metadata when present."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {
"platform": "telegram",
"chat_id": "67890",
"thread_id": "777",
"user_id": "12345",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("done")
(hermes_home / ".update_exit_code").write_text("0")
mock_adapter = AsyncMock()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
with patch("gateway.run._hermes_home", hermes_home):
await runner._send_update_notification()
assert mock_adapter.send.call_args.kwargs["metadata"] == {"thread_id": "777"}
@pytest.mark.asyncio
async def test_strips_ansi_codes(self, tmp_path):
"""ANSI escape codes are removed from output."""

View file

@ -321,6 +321,58 @@ class TestWatchUpdateProgress:
# Check session was marked as having pending prompt
# (may be cleared by the time we check since update finished)
@pytest.mark.asyncio
async def test_prompt_forwarding_preserves_thread_metadata(self, tmp_path):
"""Forwarded update prompts keep the originating thread/topic metadata."""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {
"platform": "telegram",
"chat_id": "111",
"thread_id": "777",
"user_id": "222",
"session_key": "agent:main:telegram:group:111:777",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("")
(hermes_home / ".update_prompt.json").write_text(json.dumps({
"prompt": "Restore local changes? [Y/n]",
"default": "y",
"id": "threaded-prompt",
}))
class _PromptCapableAdapter:
def __init__(self):
self.send = AsyncMock()
self.prompt_calls = AsyncMock()
async def send_update_prompt(self, **kwargs):
return await self.prompt_calls(**kwargs)
mock_adapter = _PromptCapableAdapter()
runner.adapters = {Platform.TELEGRAM: mock_adapter}
async def finish_after_prompt():
await asyncio.sleep(0.3)
(hermes_home / ".update_response").write_text("y")
await asyncio.sleep(0.2)
(hermes_home / ".update_exit_code").write_text("0")
with patch("gateway.run._hermes_home", hermes_home):
task = asyncio.create_task(finish_after_prompt())
await runner._watch_update_progress(
poll_interval=0.1,
stream_interval=0.2,
timeout=5.0,
)
await task
assert mock_adapter.prompt_calls.call_args.kwargs["metadata"] == {
"thread_id": "777"
}
@pytest.mark.asyncio
async def test_cleans_up_on_completion(self, tmp_path):
"""All marker files are cleaned up when update finishes."""