feat(gateway): auto-resume interrupted sessions after restart

This commit is contained in:
Kevin Yan 2026-05-06 16:18:20 -04:00 committed by Teknium
parent 233bfd3621
commit fad684b1f3
3 changed files with 190 additions and 1 deletions

View file

@ -2739,6 +2739,57 @@ class GatewayRunner:
task.add_done_callback(self._background_tasks.discard)
return True
def _schedule_resume_pending_sessions(self) -> int:
"""Auto-continue fresh restart-interrupted sessions after startup.
``resume_pending`` already preserves the transcript and injects the
recovery system note on the next user message. This method closes the
restart UX gap by synthesizing that next message once adapters are back
online, so users do not have to send a placeholder ping after restart.
"""
try:
entries = self.session_store.list_resume_pending(
window_secs=_auto_continue_freshness_window(),
allowed_reasons={"restart_timeout", "shutdown_timeout"},
)
except Exception as exc:
logger.warning("Failed to list resume-pending sessions: %s", exc)
return 0
scheduled = 0
for entry in entries:
source = getattr(entry, "origin", None)
platform = getattr(source, "platform", None)
adapter = self.adapters.get(platform) if platform is not None else None
if source is None or adapter is None:
logger.debug(
"Skipping auto-resume for %s: adapter unavailable for %s",
getattr(entry, "session_key", "?"),
getattr(platform, "value", platform),
)
continue
event = MessageEvent(
text=(
"[System note: The gateway restarted after interrupting "
"this session. Resume the previous turn now. Reconcile "
"the transcript first: if tool results are already present, "
"process them before taking new action; never claim work "
"completed unless it is visible in the transcript/tool output.]"
),
message_type=MessageType.TEXT,
source=source,
internal=True,
)
task = asyncio.create_task(adapter.handle_message(event))
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
scheduled += 1
if scheduled:
logger.info("Scheduled auto-resume for %d restart-interrupted session(s)", scheduled)
return scheduled
async def start(self) -> bool:
"""
Start the gateway and all configured platform adapters.
@ -3127,6 +3178,12 @@ class GatewayRunner:
skip_targets=skip_home_targets,
)
# Automatically continue fresh sessions that were interrupted by the
# previous gateway restart/shutdown. The resume_pending flag is cleared
# by the normal successful-turn path, so a failed auto-resume remains
# visible for manual recovery on the next user message.
self._schedule_resume_pending_sessions()
# Drain any recovered process watchers (from crash recovery checkpoint)
try:
from tools.process_registry import process_registry

View file

@ -1028,6 +1028,42 @@ class SessionStore:
self._save()
return True
def list_resume_pending(
self,
*,
window_secs: Optional[float] = None,
now: Optional[float] = None,
allowed_reasons: Optional[set[str]] = None,
) -> List[SessionEntry]:
"""Return fresh restart-interrupted sessions eligible for resume.
Only entries that still have an origin are returned; the gateway needs
that origin to route continuation back through the original
platform/chat/thread. ``suspended`` entries are excluded because
explicit suspension/stuck-loop escalation must win over resume.
"""
current = datetime.fromtimestamp(now) if now is not None else _now()
window = float(window_secs) if window_secs is not None else None
with self._lock:
self._ensure_loaded_locked()
entries = list(self._entries.values())
pending: List[SessionEntry] = []
for entry in entries:
if not entry.resume_pending or entry.suspended or entry.origin is None:
continue
if allowed_reasons is not None and entry.resume_reason not in allowed_reasons:
continue
if window is not None and window > 0:
marker = entry.last_resume_marked_at or entry.updated_at
if marker is not None and (current - marker).total_seconds() > window:
continue
pending.append(entry)
pending.sort(key=lambda entry: entry.last_resume_marked_at or entry.updated_at)
return pending
def prune_old_entries(self, max_age_days: int) -> int:
"""Drop SessionEntry records older than max_age_days.

View file

@ -33,7 +33,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from gateway.config import GatewayConfig, HomeChannel, Platform, PlatformConfig
from gateway.platforms.base import SendResult
from gateway.platforms.base import MessageEvent, MessageType, SendResult
from gateway.run import (
_auto_continue_freshness_window,
_coerce_gateway_timestamp,
@ -227,6 +227,30 @@ class TestSessionEntryResumeFields:
class TestMarkResumePending:
def test_list_resume_pending_returns_fresh_entries_with_origins(self, tmp_path):
store = _make_store(tmp_path)
fresh = store.get_or_create_session(_make_source(chat_id="fresh"))
stale = store.get_or_create_session(_make_source(chat_id="stale"))
missing_origin = store.get_or_create_session(_make_source(chat_id="missing-origin"))
suspended = store.get_or_create_session(_make_source(chat_id="suspended"))
store.mark_resume_pending(fresh.session_key, reason="restart_timeout")
store.mark_resume_pending(stale.session_key, reason="restart_timeout")
store.mark_resume_pending(missing_origin.session_key, reason="restart_timeout")
store.mark_resume_pending(suspended.session_key, reason="restart_timeout")
old = datetime.now() - timedelta(hours=3)
store._entries[stale.session_key].last_resume_marked_at = old
store._entries[missing_origin.session_key].origin = None
store._entries[suspended.session_key].suspended = True
pending = store.list_resume_pending(
window_secs=3600,
now=datetime.now().timestamp(),
allowed_reasons={"restart_timeout"},
)
assert [entry.session_key for entry in pending] == [fresh.session_key]
def test_marks_existing_session(self, tmp_path):
store = _make_store(tmp_path)
source = _make_source()
@ -910,6 +934,78 @@ async def test_drain_timeout_skips_pending_sentinel_sessions():
assert marked == {session_key_real}
# ---------------------------------------------------------------------------
# Gateway startup auto-resume
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_startup_auto_resume_schedules_fresh_pending_sessions():
"""Fresh resume_pending sessions should continue automatically after startup.
This closes the UX gap where restart recovery only happened if the user sent
another message after the gateway came back.
"""
runner, adapter = make_restart_runner()
source = make_restart_source(chat_id="resume-chat", thread_id="topic-1")
pending_entry = SessionEntry(
session_key="agent:main:telegram:group:resume-chat:topic-1",
session_id="sid",
created_at=datetime.now(),
updated_at=datetime.now(),
origin=source,
platform=Platform.TELEGRAM,
chat_type="group",
resume_pending=True,
resume_reason="restart_timeout",
last_resume_marked_at=datetime.now(),
)
runner.session_store.list_resume_pending = MagicMock(return_value=[pending_entry])
adapter.handle_message = AsyncMock()
scheduled = runner._schedule_resume_pending_sessions()
await asyncio.sleep(0)
assert scheduled == 1
runner.session_store.list_resume_pending.assert_called_once_with(
window_secs=_auto_continue_freshness_window(),
allowed_reasons={"restart_timeout", "shutdown_timeout"},
)
adapter.handle_message.assert_awaited_once()
event = adapter.handle_message.await_args.args[0]
assert isinstance(event, MessageEvent)
assert event.internal is True
assert event.message_type == MessageType.TEXT
assert event.source == source
assert event.text.startswith("[System note: The gateway restarted")
@pytest.mark.asyncio
async def test_startup_auto_resume_skips_when_adapter_unavailable():
runner, adapter = make_restart_runner()
source = make_restart_source(chat_id="resume-chat")
pending_entry = SessionEntry(
session_key="agent:main:telegram:dm:resume-chat",
session_id="sid",
created_at=datetime.now(),
updated_at=datetime.now(),
origin=source,
platform=Platform.TELEGRAM,
chat_type="dm",
resume_pending=True,
resume_reason="restart_timeout",
last_resume_marked_at=datetime.now(),
)
runner.session_store.list_resume_pending = MagicMock(return_value=[pending_entry])
runner.adapters = {}
adapter.handle_message = AsyncMock()
scheduled = runner._schedule_resume_pending_sessions()
assert scheduled == 0
adapter.handle_message.assert_not_called()
# ---------------------------------------------------------------------------
# Shutdown banner wording
# ---------------------------------------------------------------------------