fix(gateway): preserve pending update prompts across restarts

This commit is contained in:
simbam99 2026-05-01 20:16:11 +03:00 committed by Teknium
parent 2785355750
commit 8ad5e98f8d
2 changed files with 87 additions and 6 deletions

View file

@ -5020,10 +5020,12 @@ class GatewayRunner:
response_text = raw
if response_text:
response_path = _hermes_home / ".update_response"
prompt_path = _hermes_home / ".update_prompt.json"
try:
tmp = response_path.with_suffix(".tmp")
tmp.write_text(response_text)
tmp.replace(response_path)
prompt_path.unlink(missing_ok=True)
except OSError as e:
logger.warning("Failed to write update response: %s", e)
return f"✗ Failed to send response to update process: {e}"
@ -5038,10 +5040,12 @@ class GatewayRunner:
# The slash command then falls through to normal dispatch.
if _recognized_cmd:
response_path = _hermes_home / ".update_response"
prompt_path = _hermes_home / ".update_prompt.json"
try:
tmp = response_path.with_suffix(".tmp")
tmp.write_text("")
tmp.replace(response_path)
prompt_path.unlink(missing_ok=True)
logger.info(
"Recognized /%s during pending update prompt for %s; "
"cancelled prompt with default and dispatching command",
@ -11488,12 +11492,13 @@ class GatewayRunner:
f"or type your answer directly.",
metadata=metadata,
)
# Keep the prompt marker on disk until the user
# answers. If the gateway restarts mid-prompt, the
# next watcher can recover by re-forwarding it from
# disk. Duplicate sends in the same process are
# still suppressed by _update_prompt_pending.
self._update_prompt_pending[session_key] = True
# Remove the prompt file so it isn't re-read on the
# next poll cycle. The update process only needs
# .update_response to continue — it doesn't re-check
# .update_prompt.json while waiting.
prompt_path.unlink(missing_ok=True)
logger.info("Forwarded update prompt to %s: %s", session_key, prompt_text[:80])
except (json.JSONDecodeError, OSError) as e:
logger.debug("Failed to read update prompt: %s", e)

View file

@ -459,8 +459,9 @@ class TestWatchUpdateProgress:
async def test_prompt_forwarded_only_once(self, tmp_path):
"""Regression: prompt must not be re-sent on every poll cycle.
Before the fix, the watcher never deleted .update_prompt.json after
forwarding, causing the same prompt to be sent every poll_interval.
The in-memory pending flag should suppress duplicate sends within a
single watcher process even when the prompt marker stays on disk for
restart recovery.
"""
runner = _make_runner()
hermes_home = tmp_path / "hermes"
@ -505,6 +506,75 @@ class TestWatchUpdateProgress:
f"All sends: {all_sent}"
)
@pytest.mark.asyncio
async def test_prompt_is_recovered_after_watcher_restart(self, tmp_path):
"""A forwarded prompt stays on disk until answered so a new watcher can recover it."""
hermes_home = tmp_path / "hermes"
hermes_home.mkdir()
pending = {
"platform": "telegram",
"chat_id": "111",
"user_id": "222",
"session_key": "agent:main:telegram:dm:111",
}
prompt = {
"prompt": "Restore local changes? [Y/n]",
"default": "y",
"id": "restart-recover",
}
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
(hermes_home / ".update_output.txt").write_text("")
(hermes_home / ".update_prompt.json").write_text(json.dumps(prompt))
runner1 = _make_runner()
adapter1 = AsyncMock()
runner1.adapters = {Platform.TELEGRAM: adapter1}
with patch("gateway.run._hermes_home", hermes_home):
watch1 = asyncio.create_task(
runner1._watch_update_progress(
poll_interval=0.05,
stream_interval=0.1,
timeout=10.0,
)
)
for _ in range(40):
if adapter1.send.call_count:
break
await asyncio.sleep(0.05)
assert adapter1.send.call_count == 1
assert (hermes_home / ".update_prompt.json").exists()
watch1.cancel()
with pytest.raises(asyncio.CancelledError):
await watch1
runner2 = _make_runner()
adapter2 = AsyncMock()
runner2.adapters = {Platform.TELEGRAM: adapter2}
async def respond_and_finish():
await asyncio.sleep(0.2)
(hermes_home / ".update_response").write_text("y")
await asyncio.sleep(0.2)
(hermes_home / ".update_exit_code").write_text("0")
finisher = asyncio.create_task(respond_and_finish())
await runner2._watch_update_progress(
poll_interval=0.05,
stream_interval=0.1,
timeout=10.0,
)
await finisher
prompt_sends = [
str(call) for call in adapter2.send.call_args_list
if "Restore local changes" in str(call)
]
assert len(prompt_sends) == 1
# ---------------------------------------------------------------------------
# Message interception for update prompts
@ -525,6 +595,7 @@ class TestUpdatePromptInterception:
# The session key uses the full format from build_session_key
session_key = "agent:main:telegram:dm:67890"
runner._update_prompt_pending[session_key] = True
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
# Mock authorization and _session_key_for_source
runner._is_user_authorized = MagicMock(return_value=True)
@ -538,6 +609,7 @@ class TestUpdatePromptInterception:
response_path = hermes_home / ".update_response"
assert response_path.exists()
assert response_path.read_text() == "y"
assert not (hermes_home / ".update_prompt.json").exists()
# Should clear the pending flag
assert session_key not in runner._update_prompt_pending
@ -560,6 +632,7 @@ class TestUpdatePromptInterception:
runner._is_user_authorized = MagicMock(return_value=True)
runner._session_key_for_source = MagicMock(return_value=session_key)
runner._handle_reset_command = AsyncMock(return_value="reset ok")
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
with patch("gateway.run._hermes_home", hermes_home):
result = await runner._handle_message(event)
@ -572,6 +645,7 @@ class TestUpdatePromptInterception:
response_path = hermes_home / ".update_response"
assert response_path.exists()
assert response_path.read_text() == ""
assert not (hermes_home / ".update_prompt.json").exists()
# Pending flag is cleared so stray future input won't be
# re-intercepted for a prompt that is no longer outstanding.
assert session_key not in runner._update_prompt_pending
@ -588,6 +662,7 @@ class TestUpdatePromptInterception:
runner._update_prompt_pending[session_key] = True
runner._is_user_authorized = MagicMock(return_value=True)
runner._session_key_for_source = MagicMock(return_value=session_key)
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
with patch("gateway.run._hermes_home", hermes_home):
result = await runner._handle_message(event)
@ -595,6 +670,7 @@ class TestUpdatePromptInterception:
response_path = hermes_home / ".update_response"
assert response_path.exists()
assert response_path.read_text() == "/foobarbaz"
assert not (hermes_home / ".update_prompt.json").exists()
assert "Sent" in (result or "")
assert session_key not in runner._update_prompt_pending