mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-13 03:52:00 +00:00
fix(gateway): preserve pending update prompts across restarts
This commit is contained in:
parent
2785355750
commit
8ad5e98f8d
2 changed files with 87 additions and 6 deletions
|
|
@ -5020,10 +5020,12 @@ class GatewayRunner:
|
||||||
response_text = raw
|
response_text = raw
|
||||||
if response_text:
|
if response_text:
|
||||||
response_path = _hermes_home / ".update_response"
|
response_path = _hermes_home / ".update_response"
|
||||||
|
prompt_path = _hermes_home / ".update_prompt.json"
|
||||||
try:
|
try:
|
||||||
tmp = response_path.with_suffix(".tmp")
|
tmp = response_path.with_suffix(".tmp")
|
||||||
tmp.write_text(response_text)
|
tmp.write_text(response_text)
|
||||||
tmp.replace(response_path)
|
tmp.replace(response_path)
|
||||||
|
prompt_path.unlink(missing_ok=True)
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logger.warning("Failed to write update response: %s", e)
|
logger.warning("Failed to write update response: %s", e)
|
||||||
return f"✗ Failed to send response to update process: {e}"
|
return f"✗ Failed to send response to update process: {e}"
|
||||||
|
|
@ -5038,10 +5040,12 @@ class GatewayRunner:
|
||||||
# The slash command then falls through to normal dispatch.
|
# The slash command then falls through to normal dispatch.
|
||||||
if _recognized_cmd:
|
if _recognized_cmd:
|
||||||
response_path = _hermes_home / ".update_response"
|
response_path = _hermes_home / ".update_response"
|
||||||
|
prompt_path = _hermes_home / ".update_prompt.json"
|
||||||
try:
|
try:
|
||||||
tmp = response_path.with_suffix(".tmp")
|
tmp = response_path.with_suffix(".tmp")
|
||||||
tmp.write_text("")
|
tmp.write_text("")
|
||||||
tmp.replace(response_path)
|
tmp.replace(response_path)
|
||||||
|
prompt_path.unlink(missing_ok=True)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Recognized /%s during pending update prompt for %s; "
|
"Recognized /%s during pending update prompt for %s; "
|
||||||
"cancelled prompt with default and dispatching command",
|
"cancelled prompt with default and dispatching command",
|
||||||
|
|
@ -11488,12 +11492,13 @@ class GatewayRunner:
|
||||||
f"or type your answer directly.",
|
f"or type your answer directly.",
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
)
|
)
|
||||||
|
# Keep the prompt marker on disk until the user
|
||||||
|
# answers. If the gateway restarts mid-prompt, the
|
||||||
|
# next watcher can recover by re-forwarding it from
|
||||||
|
# disk. Duplicate sends in the same process are
|
||||||
|
# still suppressed by _update_prompt_pending.
|
||||||
self._update_prompt_pending[session_key] = True
|
self._update_prompt_pending[session_key] = True
|
||||||
# Remove the prompt file so it isn't re-read on the
|
|
||||||
# next poll cycle. The update process only needs
|
|
||||||
# .update_response to continue — it doesn't re-check
|
# .update_response to continue — it doesn't re-check
|
||||||
# .update_prompt.json while waiting.
|
|
||||||
prompt_path.unlink(missing_ok=True)
|
|
||||||
logger.info("Forwarded update prompt to %s: %s", session_key, prompt_text[:80])
|
logger.info("Forwarded update prompt to %s: %s", session_key, prompt_text[:80])
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
except (json.JSONDecodeError, OSError) as e:
|
||||||
logger.debug("Failed to read update prompt: %s", e)
|
logger.debug("Failed to read update prompt: %s", e)
|
||||||
|
|
|
||||||
|
|
@ -459,8 +459,9 @@ class TestWatchUpdateProgress:
|
||||||
async def test_prompt_forwarded_only_once(self, tmp_path):
|
async def test_prompt_forwarded_only_once(self, tmp_path):
|
||||||
"""Regression: prompt must not be re-sent on every poll cycle.
|
"""Regression: prompt must not be re-sent on every poll cycle.
|
||||||
|
|
||||||
Before the fix, the watcher never deleted .update_prompt.json after
|
The in-memory pending flag should suppress duplicate sends within a
|
||||||
forwarding, causing the same prompt to be sent every poll_interval.
|
single watcher process even when the prompt marker stays on disk for
|
||||||
|
restart recovery.
|
||||||
"""
|
"""
|
||||||
runner = _make_runner()
|
runner = _make_runner()
|
||||||
hermes_home = tmp_path / "hermes"
|
hermes_home = tmp_path / "hermes"
|
||||||
|
|
@ -505,6 +506,75 @@ class TestWatchUpdateProgress:
|
||||||
f"All sends: {all_sent}"
|
f"All sends: {all_sent}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_is_recovered_after_watcher_restart(self, tmp_path):
|
||||||
|
"""A forwarded prompt stays on disk until answered so a new watcher can recover it."""
|
||||||
|
hermes_home = tmp_path / "hermes"
|
||||||
|
hermes_home.mkdir()
|
||||||
|
|
||||||
|
pending = {
|
||||||
|
"platform": "telegram",
|
||||||
|
"chat_id": "111",
|
||||||
|
"user_id": "222",
|
||||||
|
"session_key": "agent:main:telegram:dm:111",
|
||||||
|
}
|
||||||
|
prompt = {
|
||||||
|
"prompt": "Restore local changes? [Y/n]",
|
||||||
|
"default": "y",
|
||||||
|
"id": "restart-recover",
|
||||||
|
}
|
||||||
|
(hermes_home / ".update_pending.json").write_text(json.dumps(pending))
|
||||||
|
(hermes_home / ".update_output.txt").write_text("")
|
||||||
|
(hermes_home / ".update_prompt.json").write_text(json.dumps(prompt))
|
||||||
|
|
||||||
|
runner1 = _make_runner()
|
||||||
|
adapter1 = AsyncMock()
|
||||||
|
runner1.adapters = {Platform.TELEGRAM: adapter1}
|
||||||
|
|
||||||
|
with patch("gateway.run._hermes_home", hermes_home):
|
||||||
|
watch1 = asyncio.create_task(
|
||||||
|
runner1._watch_update_progress(
|
||||||
|
poll_interval=0.05,
|
||||||
|
stream_interval=0.1,
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(40):
|
||||||
|
if adapter1.send.call_count:
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
assert adapter1.send.call_count == 1
|
||||||
|
assert (hermes_home / ".update_prompt.json").exists()
|
||||||
|
|
||||||
|
watch1.cancel()
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await watch1
|
||||||
|
|
||||||
|
runner2 = _make_runner()
|
||||||
|
adapter2 = AsyncMock()
|
||||||
|
runner2.adapters = {Platform.TELEGRAM: adapter2}
|
||||||
|
|
||||||
|
async def respond_and_finish():
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
(hermes_home / ".update_response").write_text("y")
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
(hermes_home / ".update_exit_code").write_text("0")
|
||||||
|
|
||||||
|
finisher = asyncio.create_task(respond_and_finish())
|
||||||
|
await runner2._watch_update_progress(
|
||||||
|
poll_interval=0.05,
|
||||||
|
stream_interval=0.1,
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
await finisher
|
||||||
|
|
||||||
|
prompt_sends = [
|
||||||
|
str(call) for call in adapter2.send.call_args_list
|
||||||
|
if "Restore local changes" in str(call)
|
||||||
|
]
|
||||||
|
assert len(prompt_sends) == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Message interception for update prompts
|
# Message interception for update prompts
|
||||||
|
|
@ -525,6 +595,7 @@ class TestUpdatePromptInterception:
|
||||||
# The session key uses the full format from build_session_key
|
# The session key uses the full format from build_session_key
|
||||||
session_key = "agent:main:telegram:dm:67890"
|
session_key = "agent:main:telegram:dm:67890"
|
||||||
runner._update_prompt_pending[session_key] = True
|
runner._update_prompt_pending[session_key] = True
|
||||||
|
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
|
||||||
|
|
||||||
# Mock authorization and _session_key_for_source
|
# Mock authorization and _session_key_for_source
|
||||||
runner._is_user_authorized = MagicMock(return_value=True)
|
runner._is_user_authorized = MagicMock(return_value=True)
|
||||||
|
|
@ -538,6 +609,7 @@ class TestUpdatePromptInterception:
|
||||||
response_path = hermes_home / ".update_response"
|
response_path = hermes_home / ".update_response"
|
||||||
assert response_path.exists()
|
assert response_path.exists()
|
||||||
assert response_path.read_text() == "y"
|
assert response_path.read_text() == "y"
|
||||||
|
assert not (hermes_home / ".update_prompt.json").exists()
|
||||||
# Should clear the pending flag
|
# Should clear the pending flag
|
||||||
assert session_key not in runner._update_prompt_pending
|
assert session_key not in runner._update_prompt_pending
|
||||||
|
|
||||||
|
|
@ -560,6 +632,7 @@ class TestUpdatePromptInterception:
|
||||||
runner._is_user_authorized = MagicMock(return_value=True)
|
runner._is_user_authorized = MagicMock(return_value=True)
|
||||||
runner._session_key_for_source = MagicMock(return_value=session_key)
|
runner._session_key_for_source = MagicMock(return_value=session_key)
|
||||||
runner._handle_reset_command = AsyncMock(return_value="reset ok")
|
runner._handle_reset_command = AsyncMock(return_value="reset ok")
|
||||||
|
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
|
||||||
|
|
||||||
with patch("gateway.run._hermes_home", hermes_home):
|
with patch("gateway.run._hermes_home", hermes_home):
|
||||||
result = await runner._handle_message(event)
|
result = await runner._handle_message(event)
|
||||||
|
|
@ -572,6 +645,7 @@ class TestUpdatePromptInterception:
|
||||||
response_path = hermes_home / ".update_response"
|
response_path = hermes_home / ".update_response"
|
||||||
assert response_path.exists()
|
assert response_path.exists()
|
||||||
assert response_path.read_text() == ""
|
assert response_path.read_text() == ""
|
||||||
|
assert not (hermes_home / ".update_prompt.json").exists()
|
||||||
# Pending flag is cleared so stray future input won't be
|
# Pending flag is cleared so stray future input won't be
|
||||||
# re-intercepted for a prompt that is no longer outstanding.
|
# re-intercepted for a prompt that is no longer outstanding.
|
||||||
assert session_key not in runner._update_prompt_pending
|
assert session_key not in runner._update_prompt_pending
|
||||||
|
|
@ -588,6 +662,7 @@ class TestUpdatePromptInterception:
|
||||||
runner._update_prompt_pending[session_key] = True
|
runner._update_prompt_pending[session_key] = True
|
||||||
runner._is_user_authorized = MagicMock(return_value=True)
|
runner._is_user_authorized = MagicMock(return_value=True)
|
||||||
runner._session_key_for_source = MagicMock(return_value=session_key)
|
runner._session_key_for_source = MagicMock(return_value=session_key)
|
||||||
|
(hermes_home / ".update_prompt.json").write_text(json.dumps({"prompt": "test"}))
|
||||||
|
|
||||||
with patch("gateway.run._hermes_home", hermes_home):
|
with patch("gateway.run._hermes_home", hermes_home):
|
||||||
result = await runner._handle_message(event)
|
result = await runner._handle_message(event)
|
||||||
|
|
@ -595,6 +670,7 @@ class TestUpdatePromptInterception:
|
||||||
response_path = hermes_home / ".update_response"
|
response_path = hermes_home / ".update_response"
|
||||||
assert response_path.exists()
|
assert response_path.exists()
|
||||||
assert response_path.read_text() == "/foobarbaz"
|
assert response_path.read_text() == "/foobarbaz"
|
||||||
|
assert not (hermes_home / ".update_prompt.json").exists()
|
||||||
assert "Sent" in (result or "")
|
assert "Sent" in (result or "")
|
||||||
assert session_key not in runner._update_prompt_pending
|
assert session_key not in runner._update_prompt_pending
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue