hermes-agent/tests/gateway/test_restart_drain.py
Teknium e7475b1582 feat: auto-continue interrupted agent work after gateway restart (#4493)
When the gateway restarts mid-agent-work, the session transcript ends
on a tool result the agent never processed. Previously, the user had
to type 'continue' or use /retry (which replays from scratch, losing
all prior work).

Now, when the next user message arrives and the loaded history ends
with role='tool', a system note is prepended:

  [System note: Your previous turn was interrupted before you could
  process the last tool result(s). Please finish processing those
  results and summarize what was accomplished, then address the
  user's new message below.]

This is injected in _run_agent()'s run_sync closure, right before
calling agent.run_conversation(). The agent sees the full history
(including the pending tool results) and the system note, so it can
summarize what was accomplished and then handle the user's new input.

Design decisions:
- No new session flags or schema changes — purely detects trailing
  tool messages in the loaded history
- Works for any restart scenario (clean, crash, SIGTERM, drain timeout)
  as long as the session wasn't suspended (suspended = fresh start)
- The user's actual message is preserved after the note
- If the session WAS suspended (unclean shutdown), the old history is
  abandoned and the user starts fresh — no false auto-continue

Also updates the shutdown notification message from 'Use /retry after
restart to continue' to 'Send any message after restart to resume
where it left off' — which is now accurate.

Test plan:
- 6 new auto-continue tests (trailing tool detection, no false
  positives for assistant/user/empty history, multi-tool, message
  preservation)
- All 13 restart drain tests pass (updated /retry assertion)
2026-04-14 16:56:49 -07:00

244 lines
8.5 KiB
Python

import asyncio
import shutil
import subprocess
from unittest.mock import AsyncMock, MagicMock
import pytest
import gateway.run as gateway_run
from gateway.platforms.base import MessageEvent, MessageType
from gateway.restart import DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
from gateway.session import build_session_key
from tests.gateway.restart_test_helpers import make_restart_runner, make_restart_source
@pytest.mark.asyncio
async def test_restart_command_while_busy_requests_drain_without_interrupt(monkeypatch):
# Ensure INVOCATION_ID is NOT set — systemd sets this in service mode,
# which changes the restart call signature.
monkeypatch.delenv("INVOCATION_ID", raising=False)
runner, _adapter = make_restart_runner()
runner.request_restart = MagicMock(return_value=True)
event = MessageEvent(
text="/restart",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m1",
)
session_key = build_session_key(event.source)
running_agent = MagicMock()
runner._running_agents[session_key] = running_agent
result = await runner._handle_message(event)
assert result == "⏳ Draining 1 active agent(s) before restart..."
running_agent.interrupt.assert_not_called()
runner.request_restart.assert_called_once_with(detached=True, via_service=False)
@pytest.mark.asyncio
async def test_drain_queue_mode_queues_follow_up_without_interrupt():
runner, adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
runner._busy_input_mode = "queue"
event = MessageEvent(
text="follow up",
message_type=MessageType.TEXT,
source=make_restart_source(),
message_id="m2",
)
session_key = build_session_key(event.source)
adapter._active_sessions[session_key] = asyncio.Event()
await adapter.handle_message(event)
assert session_key in adapter._pending_messages
assert adapter._pending_messages[session_key].text == "follow up"
assert not adapter._active_sessions[session_key].is_set()
assert any("queued for the next turn" in message for message in adapter.sent)
@pytest.mark.asyncio
async def test_draining_rejects_new_session_messages():
runner, _adapter = make_restart_runner()
runner._draining = True
runner._restart_requested = True
event = MessageEvent(
text="hello",
message_type=MessageType.TEXT,
source=make_restart_source("fresh"),
message_id="m3",
)
result = await runner._handle_message(event)
assert result == "⏳ Gateway is restarting and is not accepting new work right now."
def test_load_busy_input_mode_prefers_env_then_config_then_default(tmp_path, monkeypatch):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_GATEWAY_BUSY_INPUT_MODE", raising=False)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
(tmp_path / "config.yaml").write_text(
"display:\n busy_input_mode: queue\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_busy_input_mode() == "queue"
monkeypatch.setenv("HERMES_GATEWAY_BUSY_INPUT_MODE", "interrupt")
assert gateway_run.GatewayRunner._load_busy_input_mode() == "interrupt"
def test_load_restart_drain_timeout_prefers_env_then_config_then_default(
tmp_path, monkeypatch, caplog
):
monkeypatch.setattr(gateway_run, "_hermes_home", tmp_path)
monkeypatch.delenv("HERMES_RESTART_DRAIN_TIMEOUT", raising=False)
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
(tmp_path / "config.yaml").write_text(
"agent:\n restart_drain_timeout: 12\n", encoding="utf-8"
)
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 12.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "7")
assert gateway_run.GatewayRunner._load_restart_drain_timeout() == 7.0
monkeypatch.setenv("HERMES_RESTART_DRAIN_TIMEOUT", "invalid")
assert (
gateway_run.GatewayRunner._load_restart_drain_timeout()
== DEFAULT_GATEWAY_RESTART_DRAIN_TIMEOUT
)
assert "Invalid restart_drain_timeout" in caplog.text
@pytest.mark.asyncio
async def test_request_restart_is_idempotent():
runner, _adapter = make_restart_runner()
runner.stop = AsyncMock()
assert runner.request_restart(detached=True, via_service=False) is True
first_task = next(iter(runner._background_tasks))
assert runner.request_restart(detached=True, via_service=False) is False
await first_task
runner.stop.assert_awaited_once_with(
restart=True, detached_restart=True, service_restart=False
)
@pytest.mark.asyncio
async def test_launch_detached_restart_command_uses_setsid(monkeypatch):
runner, _adapter = make_restart_runner()
popen_calls = []
monkeypatch.setattr(gateway_run, "_resolve_hermes_bin", lambda: ["/usr/bin/hermes"])
monkeypatch.setattr(gateway_run.os, "getpid", lambda: 321)
monkeypatch.setattr(shutil, "which", lambda cmd: "/usr/bin/setsid" if cmd == "setsid" else None)
def fake_popen(cmd, **kwargs):
popen_calls.append((cmd, kwargs))
return MagicMock()
monkeypatch.setattr(subprocess, "Popen", fake_popen)
await runner._launch_detached_restart_command()
assert len(popen_calls) == 1
cmd, kwargs = popen_calls[0]
assert cmd[:2] == ["/usr/bin/setsid", "bash"]
assert "gateway restart" in cmd[-1]
assert "kill -0 321" in cmd[-1]
assert kwargs["start_new_session"] is True
assert kwargs["stdout"] is subprocess.DEVNULL
assert kwargs["stderr"] is subprocess.DEVNULL
# ── Shutdown notification tests ──────────────────────────────────────
@pytest.mark.asyncio
async def test_shutdown_notification_sent_to_active_sessions():
"""Active sessions receive a notification when the gateway starts shutting down."""
runner, adapter = make_restart_runner()
source = make_restart_source(chat_id="999", chat_type="dm")
session_key = f"agent:main:telegram:dm:999"
runner._running_agents[session_key] = MagicMock()
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 1
assert "shutting down" in adapter.sent[0]
assert "interrupted" in adapter.sent[0]
@pytest.mark.asyncio
async def test_shutdown_notification_says_restarting_when_restart_requested():
"""When _restart_requested is True, the message says 'restarting' and mentions /retry."""
runner, adapter = make_restart_runner()
runner._restart_requested = True
session_key = "agent:main:telegram:dm:999"
runner._running_agents[session_key] = MagicMock()
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 1
assert "restarting" in adapter.sent[0]
assert "resume" in adapter.sent[0]
@pytest.mark.asyncio
async def test_shutdown_notification_deduplicates_per_chat():
"""Multiple sessions in the same chat only get one notification."""
runner, adapter = make_restart_runner()
# Two sessions (different users) in the same chat
runner._running_agents["agent:main:telegram:group:chat1:u1"] = MagicMock()
runner._running_agents["agent:main:telegram:group:chat1:u2"] = MagicMock()
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 1
@pytest.mark.asyncio
async def test_shutdown_notification_skipped_when_no_active_agents():
"""No notification is sent when there are no active agents."""
runner, adapter = make_restart_runner()
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 0
@pytest.mark.asyncio
async def test_shutdown_notification_ignores_pending_sentinels():
"""Pending sentinels (not-yet-started agents) don't trigger notifications."""
from gateway.run import _AGENT_PENDING_SENTINEL
runner, adapter = make_restart_runner()
runner._running_agents["agent:main:telegram:dm:999"] = _AGENT_PENDING_SENTINEL
await runner._notify_active_sessions_of_shutdown()
assert len(adapter.sent) == 0
@pytest.mark.asyncio
async def test_shutdown_notification_send_failure_does_not_block():
"""If sending a notification fails, the method still completes."""
runner, adapter = make_restart_runner()
adapter.send = AsyncMock(side_effect=Exception("network error"))
session_key = "agent:main:telegram:dm:999"
runner._running_agents[session_key] = MagicMock()
# Should not raise
await runner._notify_active_sessions_of_shutdown()