diff --git a/gateway/run.py b/gateway/run.py index db3f8b00d5..733c3714d1 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -1681,8 +1681,7 @@ class GatewayRunner: action = "restarting" if self._restart_requested else "shutting down" hint = ( - "Your current task will be interrupted. " - "Send any message after restart and I'll try to resume where you left off." + "Your current task will be interrupted and automatically resumed after restart." if self._restart_requested else "Your current task will be interrupted." ) @@ -2028,29 +2027,32 @@ class GatewayRunner: except Exception as e: logger.warning("Process checkpoint recovery: %s", e) - # Suspend sessions that were active when the gateway last exited. - # This prevents stuck sessions from being blindly resumed on restart, - # which can create an unrecoverable loop (#7536). Suspended sessions - # auto-reset on the next incoming message, giving the user a clean start. + # Mark sessions that were active when the gateway last exited for + # automatic continuation. Previously these were suspended/reset to + # avoid stuck loops (#7536); stuck-loop protection still runs below, + # but the normal case now preserves the transcript and auto-injects a + # continuation turn after adapters reconnect. # - # SKIP suspension after a clean (graceful) shutdown — the previous + # SKIP marking after a clean (graceful) shutdown — the previous # process already drained active agents, so sessions aren't stuck. - # This prevents unwanted auto-resets after `hermes update`, - # `hermes gateway restart`, or `/restart`. + # This prevents unwanted auto-resumes after `hermes update`, + # `hermes gateway restart`, or `/restart` when no turn was interrupted. _clean_marker = _hermes_home / ".clean_shutdown" if _clean_marker.exists(): - logger.info("Previous gateway exited cleanly — skipping session suspension") + logger.info("Previous gateway exited cleanly — skipping in-flight session auto-resume marking") try: _clean_marker.unlink() except Exception: pass else: try: - suspended = self.session_store.suspend_recently_active() - if suspended: - logger.info("Suspended %d in-flight session(s) from previous run", suspended) + resumable = self.session_store.mark_recently_active_resume_pending( + reason="unexpected_restart" + ) + if resumable: + logger.info("Marked %d in-flight session(s) for automatic resume", resumable) except Exception as e: - logger.warning("Session suspension on startup failed: %s", e) + logger.warning("Session auto-resume marking on startup failed: %s", e) # Stuck-loop detection (#7536): if a session has been active across # 3+ consecutive restarts, it's probably stuck in a loop (the same @@ -2239,6 +2241,11 @@ class GatewayRunner: # Notify the chat that initiated /restart that the gateway is back. await self._send_restart_notification() + # Automatically continue sessions that were interrupted by a prior + # restart/shutdown after all adapters are connected. This is + # fire-and-forget so gateway startup is not blocked by long agent work. + self._schedule_auto_resume_pending_sessions() + # Drain any recovered process watchers (from crash recovery checkpoint) try: from tools.process_registry import process_registry @@ -2744,8 +2751,8 @@ class GatewayRunner: else: logger.info( "Skipping .clean_shutdown marker — drain timed out with " - "interrupted agents; next startup will suspend recently " - "active sessions." + "interrupted agents; next startup will mark recently " + "active sessions for auto-resume." ) # Track sessions that were active at shutdown for stuck-loop @@ -8203,6 +8210,87 @@ class GatewayRunner: finally: notify_path.unlink(missing_ok=True) + def _schedule_auto_resume_pending_sessions(self) -> None: + """Schedule automatic continuation for sessions marked resume_pending. + + ``resume_pending`` is persisted in sessions.json before a forced + shutdown, or marked on startup after an unclean exit. Once platform + adapters are connected, inject a synthetic internal text event through + the normal adapter pipeline so the final response is delivered to the + original chat/thread without requiring the user to send a nudge. + """ + try: + self.session_store._ensure_loaded() + candidates = [] + for session_key, entry in list(self.session_store._entries.items()): + if not getattr(entry, "resume_pending", False): + continue + if getattr(entry, "suspended", False): + continue + source = getattr(entry, "origin", None) + if source is None or source.platform not in self.adapters: + logger.debug( + "Auto-resume skipped for %s: no connected adapter/source", + session_key[:30], + ) + continue + if not getattr(source, "chat_id", None): + logger.debug( + "Auto-resume skipped for %s: missing chat_id", + session_key[:30], + ) + continue + if session_key in self._running_agents: + continue + candidates.append((session_key, entry, source)) + except Exception as e: + logger.warning("Auto-resume scan failed: %s", e) + return + + for session_key, entry, source in candidates: + task = asyncio.create_task( + self._auto_resume_pending_session(session_key, entry.session_id, source) + ) + self._background_tasks.add(task) + task.add_done_callback(self._background_tasks.discard) + + if candidates: + logger.info("Scheduled auto-resume for %d interrupted session(s)", len(candidates)) + + async def _auto_resume_pending_session(self, session_key: str, session_id: str, source: SessionSource) -> None: + await asyncio.sleep(1.0) + if not self._running: + return + if session_key in self._running_agents: + return + adapter = self.adapters.get(source.platform) + if not adapter: + return + text = ( + "[SYSTEM: The gateway restarted while this task was in progress. " + "Automatically continue the interrupted task now, without asking " + "the user for confirmation. Use the existing transcript and any " + "pending tool results, proceed with the next required steps, and " + "send the final report when done.]" + ) + event = MessageEvent( + text=text, + message_type=MessageType.TEXT, + source=source, + message_id=f"auto-resume:{session_id}", + internal=True, + ) + logger.info( + "Auto-resuming interrupted gateway session %s for %s:%s", + session_key[:30], + source.platform.value if source.platform else "unknown", + source.chat_id, + ) + try: + await adapter.handle_message(event) + except Exception as e: + logger.warning("Auto-resume failed for %s: %s", session_key[:30], e) + def _set_session_env(self, context: SessionContext) -> list: """Set session context variables for the current async task. diff --git a/gateway/session.py b/gateway/session.py index db90d31217..d7b91f5a94 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -1004,6 +1004,39 @@ class SessionStore: self._save() return count + def mark_recently_active_resume_pending( + self, + max_age_seconds: int = 120, + reason: str = "unexpected_restart", + ) -> int: + """Mark recently-active sessions for automatic resume on startup. + + Called after an unclean gateway exit. A session updated shortly + before the previous process disappeared was likely in-flight; instead + of forcing a fresh session, preserve its transcript and let the + gateway auto-inject a continuation turn after adapters reconnect. + + Explicitly suspended sessions are skipped so /stop and stuck-loop + escalation still win. + """ + from datetime import timedelta + + cutoff = _now() - timedelta(seconds=max_age_seconds) + count = 0 + with self._lock: + self._ensure_loaded_locked() + for entry in self._entries.values(): + if entry.suspended: + continue + if entry.updated_at >= cutoff and not entry.resume_pending: + entry.resume_pending = True + entry.resume_reason = reason + entry.last_resume_marked_at = _now() + count += 1 + if count: + self._save() + return count + def reset_session(self, session_key: str) -> Optional[SessionEntry]: """Force reset a session, creating a new session ID.""" db_end_session_id = None diff --git a/tests/gateway/test_auto_resume_pending_sessions.py b/tests/gateway/test_auto_resume_pending_sessions.py new file mode 100644 index 0000000000..8fe5896e40 --- /dev/null +++ b/tests/gateway/test_auto_resume_pending_sessions.py @@ -0,0 +1,94 @@ +"""Tests for gateway auto-resume of interrupted sessions after restart.""" + +import asyncio + +import pytest + +from gateway.config import GatewayConfig, Platform +from gateway.platforms.base import MessageEvent +from gateway.run import GatewayRunner +from gateway.session import SessionSource, SessionStore + + +class DummyAdapter: + def __init__(self): + self.events = [] + + async def handle_message(self, event: MessageEvent): + self.events.append(event) + return "ok" + + +def _source() -> SessionSource: + return SessionSource( + platform=Platform.TELEGRAM, + chat_id="chat-1", + user_id="user-1", + user_name="Maxim", + chat_type="dm", + ) + + +def _runner(tmp_path, store: SessionStore, adapter: DummyAdapter) -> GatewayRunner: + runner = object.__new__(GatewayRunner) + runner.session_store = store + runner.adapters = {Platform.TELEGRAM: adapter} + runner._running = True + runner._running_agents = {} + runner._background_tasks = set() + return runner + + +@pytest.mark.asyncio +async def test_schedule_auto_resume_injects_internal_event(tmp_path): + store = SessionStore(tmp_path, GatewayConfig(sessions_dir=tmp_path)) + source = _source() + entry = store.get_or_create_session(source) + assert store.mark_resume_pending(entry.session_key, reason="restart_timeout") + + adapter = DummyAdapter() + runner = _runner(tmp_path, store, adapter) + + runner._schedule_auto_resume_pending_sessions() + assert len(runner._background_tasks) == 1 + await asyncio.gather(*list(runner._background_tasks)) + + assert len(adapter.events) == 1 + event = adapter.events[0] + assert event.internal is True + assert event.source == source + assert event.message_id == f"auto-resume:{entry.session_id}" + assert "Automatically continue the interrupted task" in event.text + + +@pytest.mark.asyncio +async def test_schedule_auto_resume_skips_suspended_sessions(tmp_path): + store = SessionStore(tmp_path, GatewayConfig(sessions_dir=tmp_path)) + source = _source() + entry = store.get_or_create_session(source) + assert store.mark_resume_pending(entry.session_key, reason="restart_timeout") + entry.suspended = True + store._save() + + adapter = DummyAdapter() + runner = _runner(tmp_path, store, adapter) + + runner._schedule_auto_resume_pending_sessions() + assert len(runner._background_tasks) == 0 + assert adapter.events == [] + + +@pytest.mark.asyncio +async def test_schedule_auto_resume_skips_missing_adapter(tmp_path): + store = SessionStore(tmp_path, GatewayConfig(sessions_dir=tmp_path)) + source = _source() + entry = store.get_or_create_session(source) + assert store.mark_resume_pending(entry.session_key, reason="restart_timeout") + + adapter = DummyAdapter() + runner = _runner(tmp_path, store, adapter) + runner.adapters = {} + + runner._schedule_auto_resume_pending_sessions() + assert len(runner._background_tasks) == 0 + assert adapter.events == [] diff --git a/tests/gateway/test_clean_shutdown_marker.py b/tests/gateway/test_clean_shutdown_marker.py index 1a476bc49a..38760e8368 100644 --- a/tests/gateway/test_clean_shutdown_marker.py +++ b/tests/gateway/test_clean_shutdown_marker.py @@ -1,10 +1,12 @@ -"""Tests for the clean shutdown marker that prevents unwanted session auto-resets. +"""Tests for the clean shutdown marker and restart auto-resume. When the gateway shuts down gracefully (hermes update, gateway restart, /restart), -it writes a .clean_shutdown marker. On the next startup, if the marker exists, -suspend_recently_active() is skipped so users don't lose their sessions. +it writes a .clean_shutdown marker. On the next startup, if the marker exists, +recent sessions are not marked for auto-resume because no turn was interrupted. -After a crash (no marker), suspension still fires as a safety net for stuck sessions. +After an unclean exit (no marker), recent sessions are marked resume_pending so +the gateway can automatically continue them after adapters reconnect. Stuck-loop +escalation can still suspend repeatedly failing sessions. """ import os @@ -132,8 +134,8 @@ class TestCleanShutdownMarker: assert marker.exists(), ".clean_shutdown marker should exist after graceful stop" - def test_marker_skips_suspension_on_startup(self, tmp_path, monkeypatch): - """If .clean_shutdown exists, suspend_recently_active should NOT be called.""" + def test_marker_skips_auto_resume_marking_on_startup(self, tmp_path, monkeypatch): + """If .clean_shutdown exists, recent sessions should NOT be marked for resume.""" monkeypatch.setattr("gateway.run._hermes_home", tmp_path) # Create the marker @@ -149,20 +151,21 @@ class TestCleanShutdownMarker: # Simulate what start() does: if marker.exists(): marker.unlink() - # Should NOT call suspend_recently_active + # Should NOT call mark_recently_active_resume_pending else: - store.suspend_recently_active() + store.mark_recently_active_resume_pending() - # Session should NOT be suspended + # Session should NOT be suspended or marked for resume with store._lock: store._ensure_loaded_locked() for e in store._entries.values(): assert not e.suspended, "Session should NOT be suspended after clean shutdown" + assert not e.resume_pending, "Session should NOT auto-resume after clean shutdown" assert not marker.exists(), "Marker should be cleaned up" - def test_no_marker_triggers_suspension(self, tmp_path, monkeypatch): - """Without .clean_shutdown marker (crash), suspension should fire.""" + def test_no_marker_marks_recent_session_for_auto_resume(self, tmp_path, monkeypatch): + """Without .clean_shutdown marker (unclean exit), recent sessions auto-resume.""" monkeypatch.setattr("gateway.run._hermes_home", tmp_path) marker = tmp_path / ".clean_shutdown" @@ -178,13 +181,15 @@ class TestCleanShutdownMarker: if marker.exists(): marker.unlink() else: - store.suspend_recently_active() + store.mark_recently_active_resume_pending(reason="unexpected_restart") - # Session SHOULD be suspended (crash recovery) + # Session SHOULD be marked resume_pending, not suspended. with store._lock: store._ensure_loaded_locked() - suspended_count = sum(1 for e in store._entries.values() if e.suspended) - assert suspended_count == 1, "Session should be suspended after crash (no marker)" + entries = list(store._entries.values()) + assert sum(1 for e in entries if e.suspended) == 0 + assert sum(1 for e in entries if e.resume_pending) == 1 + assert entries[0].resume_reason == "unexpected_restart" def test_marker_written_on_restart_stop(self, tmp_path, monkeypatch): """stop(restart=True) should also write the marker."""