From 8ad5e98f8d433e6e302355c77e22931f7a047eea Mon Sep 17 00:00:00 2001 From: simbam99 Date: Fri, 1 May 2026 20:16:11 +0300 Subject: [PATCH] fix(gateway): preserve pending update prompts across restarts --- gateway/run.py | 13 +++-- tests/gateway/test_update_streaming.py | 80 +++++++++++++++++++++++++- 2 files changed, 87 insertions(+), 6 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index ebfd2731fe..5e2163e830 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -5020,10 +5020,12 @@ class GatewayRunner: response_text = raw if response_text: response_path = _hermes_home / ".update_response" + prompt_path = _hermes_home / ".update_prompt.json" try: tmp = response_path.with_suffix(".tmp") tmp.write_text(response_text) tmp.replace(response_path) + prompt_path.unlink(missing_ok=True) except OSError as e: logger.warning("Failed to write update response: %s", e) return f"✗ Failed to send response to update process: {e}" @@ -5038,10 +5040,12 @@ class GatewayRunner: # The slash command then falls through to normal dispatch. if _recognized_cmd: response_path = _hermes_home / ".update_response" + prompt_path = _hermes_home / ".update_prompt.json" try: tmp = response_path.with_suffix(".tmp") tmp.write_text("") tmp.replace(response_path) + prompt_path.unlink(missing_ok=True) logger.info( "Recognized /%s during pending update prompt for %s; " "cancelled prompt with default and dispatching command", @@ -11488,12 +11492,13 @@ class GatewayRunner: f"or type your answer directly.", metadata=metadata, ) + # Keep the prompt marker on disk until the user + # answers. If the gateway restarts mid-prompt, the + # next watcher can recover by re-forwarding it from + # disk. Duplicate sends in the same process are + # still suppressed by _update_prompt_pending. self._update_prompt_pending[session_key] = True - # Remove the prompt file so it isn't re-read on the - # next poll cycle. The update process only needs # .update_response to continue — it doesn't re-check - # .update_prompt.json while waiting. - prompt_path.unlink(missing_ok=True) logger.info("Forwarded update prompt to %s: %s", session_key, prompt_text[:80]) except (json.JSONDecodeError, OSError) as e: logger.debug("Failed to read update prompt: %s", e) diff --git a/tests/gateway/test_update_streaming.py b/tests/gateway/test_update_streaming.py index b78eaa3327..36923bc5f0 100644 --- a/tests/gateway/test_update_streaming.py +++ b/tests/gateway/test_update_streaming.py @@ -459,8 +459,9 @@ class TestWatchUpdateProgress: async def test_prompt_forwarded_only_once(self, tmp_path): """Regression: prompt must not be re-sent on every poll cycle. - Before the fix, the watcher never deleted .update_prompt.json after - forwarding, causing the same prompt to be sent every poll_interval. + The in-memory pending flag should suppress duplicate sends within a + single watcher process even when the prompt marker stays on disk for + restart recovery. """ runner = _make_runner() hermes_home = tmp_path / "hermes" @@ -505,6 +506,75 @@ class TestWatchUpdateProgress: f"All sends: {all_sent}" ) + @pytest.mark.asyncio + async def test_prompt_is_recovered_after_watcher_restart(self, tmp_path): + """A forwarded prompt stays on disk until answered so a new watcher can recover it.""" + hermes_home = tmp_path / "hermes" + hermes_home.mkdir() + + pending = { + "platform": "telegram", + "chat_id": "111", + "user_id": "222", + "session_key": "agent:main:telegram:dm:111", + } + prompt = { + "prompt": "Restore local changes? [Y/n]", + "default": "y", + "id": "restart-recover", + } + (hermes_home / ".update_pending.json").write_text(json.dumps(pending)) + (hermes_home / ".update_output.txt").write_text("") + (hermes_home / ".update_prompt.json").write_text(json.dumps(prompt)) + + runner1 = _make_runner() + adapter1 = AsyncMock() + runner1.adapters = {Platform.TELEGRAM: adapter1} + + with patch("gateway.run._hermes_home", hermes_home): + watch1 = asyncio.create_task( + runner1._watch_update_progress( + poll_interval=0.05, + stream_interval=0.1, + timeout=10.0, + ) + ) + for _ in range(40): + if adapter1.send.call_count: + break + await asyncio.sleep(0.05) + + assert adapter1.send.call_count == 1 + assert (hermes_home / ".update_prompt.json").exists() + + watch1.cancel() + with pytest.raises(asyncio.CancelledError): + await watch1 + + runner2 = _make_runner() + adapter2 = AsyncMock() + runner2.adapters = {Platform.TELEGRAM: adapter2} + + async def respond_and_finish(): + await asyncio.sleep(0.2) + (hermes_home / ".update_response").write_text("y") + await asyncio.sleep(0.2) + (hermes_home / ".update_exit_code").write_text("0") + + finisher = asyncio.create_task(respond_and_finish()) + await runner2._watch_update_progress( + poll_interval=0.05, + stream_interval=0.1, + timeout=10.0, + ) + await finisher + + prompt_sends = [ + str(call) for call in adapter2.send.call_args_list + if "Restore local changes" in str(call) + ] + assert len(prompt_sends) == 1 + # --------------------------------------------------------------------------- # Message interception for update prompts @@ -525,6 +595,7 @@ class TestUpdatePromptInterception: # The session key uses the full format from build_session_key session_key = "agent:main:telegram:dm:67890" runner._update_prompt_pending[session_key] = True + (hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"})) # Mock authorization and _session_key_for_source runner._is_user_authorized = MagicMock(return_value=True) @@ -538,6 +609,7 @@ class TestUpdatePromptInterception: response_path = hermes_home / ".update_response" assert response_path.exists() assert response_path.read_text() == "y" + assert not (hermes_home / ".update_prompt.json").exists() # Should clear the pending flag assert session_key not in runner._update_prompt_pending @@ -560,6 +632,7 @@ class TestUpdatePromptInterception: runner._is_user_authorized = MagicMock(return_value=True) runner._session_key_for_source = MagicMock(return_value=session_key) runner._handle_reset_command = AsyncMock(return_value="reset ok") + (hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"})) with patch("gateway.run._hermes_home", hermes_home): result = await runner._handle_message(event) @@ -572,6 +645,7 @@ class TestUpdatePromptInterception: response_path = hermes_home / ".update_response" assert response_path.exists() assert response_path.read_text() == "" + assert not (hermes_home / ".update_prompt.json").exists() # Pending flag is cleared so stray future input won't be # re-intercepted for a prompt that is no longer outstanding. assert session_key not in runner._update_prompt_pending @@ -588,6 +662,7 @@ class TestUpdatePromptInterception: runner._update_prompt_pending[session_key] = True runner._is_user_authorized = MagicMock(return_value=True) runner._session_key_for_source = MagicMock(return_value=session_key) + (hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"})) with patch("gateway.run._hermes_home", hermes_home): result = await runner._handle_message(event) @@ -595,6 +670,7 @@ class TestUpdatePromptInterception: response_path = hermes_home / ".update_response" assert response_path.exists() assert response_path.read_text() == "/foobarbaz" + assert not (hermes_home / ".update_prompt.json").exists() assert "Sent" in (result or "") assert session_key not in runner._update_prompt_pending