diff --git a/gateway/run.py b/gateway/run.py index 3ed81379a..9fd5ac0b7 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -178,7 +178,6 @@ class GatewayRunner: self.session_store = SessionStore( self.config.sessions_dir, self.config, has_active_processes_fn=lambda key: process_registry.has_active_for_session(key), - on_auto_reset=self._flush_memories_before_reset, ) self.delivery_router = DeliveryRouter(self.config) self._running = False @@ -209,15 +208,14 @@ class GatewayRunner: from gateway.hooks import HookRegistry self.hooks = HookRegistry() - def _flush_memories_before_reset(self, old_entry): - """Prompt the agent to save memories/skills before an auto-reset. - - Called synchronously by SessionStore before destroying an expired session. - Loads the transcript, gives the agent a real turn with memory + skills - tools, and explicitly asks it to preserve anything worth keeping. + def _flush_memories_for_session(self, old_session_id: str): + """Prompt the agent to save memories/skills before context is lost. + + Synchronous worker — meant to be called via run_in_executor from + an async context so it doesn't block the event loop. """ try: - history = self.session_store.load_transcript(old_entry.session_id) + history = self.session_store.load_transcript(old_session_id) if not history or len(history) < 4: return @@ -231,7 +229,7 @@ class GatewayRunner: max_iterations=8, quiet_mode=True, enabled_toolsets=["memory", "skills"], - session_id=old_entry.session_id, + session_id=old_session_id, ) # Build conversation history from transcript @@ -260,9 +258,14 @@ class GatewayRunner: user_message=flush_prompt, conversation_history=msgs, ) - logger.info("Pre-reset save completed for session %s", old_entry.session_id) + logger.info("Pre-reset memory flush completed for session %s", old_session_id) except Exception as e: - logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e) + logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e) + + async def _async_flush_memories(self, old_session_id: str): + """Run the sync memory flush in a thread pool so it won't block the event loop.""" + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id) @staticmethod def _load_prefill_messages() -> List[Dict[str, Any]]: @@ -464,10 +467,50 @@ class GatewayRunner: # Check if we're restarting after a /update command await self._send_update_notification() + # Start background session expiry watcher for proactive memory flushing + asyncio.create_task(self._session_expiry_watcher()) + logger.info("Press Ctrl+C to stop") return True + async def _session_expiry_watcher(self, interval: int = 300): + """Background task that proactively flushes memories for expired sessions. + + Runs every `interval` seconds (default 5 min). For each session that + has expired according to its reset policy, flushes memories in a thread + pool and marks the session so it won't be flushed again. + + This means memories are already saved by the time the user sends their + next message, so there's no blocking delay. + """ + await asyncio.sleep(60) # initial delay — let the gateway fully start + while self._running: + try: + self.session_store._ensure_loaded() + for key, entry in list(self.session_store._entries.items()): + if entry.session_id in self.session_store._pre_flushed_sessions: + continue # already flushed this session + if not self.session_store._is_session_expired(entry): + continue # session still active + # Session has expired — flush memories in the background + logger.info( + "Session %s expired (key=%s), flushing memories proactively", + entry.session_id, key, + ) + try: + await self._async_flush_memories(entry.session_id) + self.session_store._pre_flushed_sessions.add(entry.session_id) + except Exception as e: + logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e) + except Exception as e: + logger.debug("Session expiry watcher error: %s", e) + # Sleep in small increments so we can stop quickly + for _ in range(interval): + if not self._running: + break + await asyncio.sleep(1) + async def stop(self) -> None: """Stop the gateway and disconnect all adapters.""" logger.info("Stopping gateway...") @@ -1012,33 +1055,12 @@ class GatewayRunner: # Get existing session key session_key = self.session_store._generate_session_key(source) - # Memory flush before reset: load the old transcript and let a - # temporary agent save memories before the session is wiped. + # Flush memories in the background (fire-and-forget) so the user + # gets the "Session reset!" response immediately. try: old_entry = self.session_store._entries.get(session_key) if old_entry: - old_history = self.session_store.load_transcript(old_entry.session_id) - if old_history: - from run_agent import AIAgent - loop = asyncio.get_event_loop() - _flush_kwargs = _resolve_runtime_agent_kwargs() - def _do_flush(): - tmp_agent = AIAgent( - **_flush_kwargs, - max_iterations=5, - quiet_mode=True, - enabled_toolsets=["memory"], - session_id=old_entry.session_id, - ) - # Build simple message list from transcript - msgs = [] - for m in old_history: - role = m.get("role") - content = m.get("content") - if role in ("user", "assistant") and content: - msgs.append({"role": role, "content": content}) - tmp_agent.flush_memories(msgs) - await loop.run_in_executor(None, _do_flush) + asyncio.create_task(self._async_flush_memories(old_entry.session_id)) except Exception as e: logger.debug("Gateway memory flush on reset failed: %s", e) diff --git a/gateway/session.py b/gateway/session.py index 091cb46a1..4c2d9c208 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -311,7 +311,9 @@ class SessionStore: self._entries: Dict[str, SessionEntry] = {} self._loaded = False self._has_active_processes_fn = has_active_processes_fn - self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset + # on_auto_reset is deprecated — memory flush now runs proactively + # via the background session expiry watcher in GatewayRunner. + self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher # Initialize SQLite session database self._db = None @@ -353,6 +355,44 @@ class SessionStore: """Generate a session key from a source.""" return build_session_key(source) + def _is_session_expired(self, entry: SessionEntry) -> bool: + """Check if a session has expired based on its reset policy. + + Works from the entry alone — no SessionSource needed. + Used by the background expiry watcher to proactively flush memories. + Sessions with active background processes are never considered expired. + """ + if self._has_active_processes_fn: + if self._has_active_processes_fn(entry.session_key): + return False + + policy = self.config.get_reset_policy( + platform=entry.platform, + session_type=entry.chat_type, + ) + + if policy.mode == "none": + return False + + now = datetime.now() + + if policy.mode in ("idle", "both"): + idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes) + if now > idle_deadline: + return True + + if policy.mode in ("daily", "both"): + today_reset = now.replace( + hour=policy.at_hour, + minute=0, second=0, microsecond=0, + ) + if now.hour < policy.at_hour: + today_reset -= timedelta(days=1) + if entry.updated_at < today_reset: + return True + + return False + def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool: """ Check if a session should be reset based on policy. @@ -439,13 +479,11 @@ class SessionStore: self._save() return entry else: - # Session is being auto-reset — flush memories before destroying + # Session is being auto-reset. The background expiry watcher + # should have already flushed memories proactively; discard + # the marker so it doesn't accumulate. was_auto_reset = True - if self._on_auto_reset: - try: - self._on_auto_reset(entry) - except Exception as e: - logger.debug("Auto-reset callback failed: %s", e) + self._pre_flushed_sessions.discard(entry.session_id) if self._db: try: self._db.end_session(entry.session_id, "session_reset") diff --git a/tests/gateway/test_async_memory_flush.py b/tests/gateway/test_async_memory_flush.py new file mode 100644 index 000000000..675746920 --- /dev/null +++ b/tests/gateway/test_async_memory_flush.py @@ -0,0 +1,180 @@ +"""Tests for proactive memory flush on session expiry. + +Verifies that: +1. _is_session_expired() works from a SessionEntry alone (no source needed) +2. The sync callback is no longer called in get_or_create_session +3. _pre_flushed_sessions tracking works correctly +4. The background watcher can detect expired sessions +""" + +import pytest +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import patch, MagicMock + +from gateway.config import Platform, GatewayConfig, SessionResetPolicy +from gateway.session import SessionSource, SessionStore, SessionEntry + + +@pytest.fixture() +def idle_store(tmp_path): + """SessionStore with a 60-minute idle reset policy.""" + config = GatewayConfig( + default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60), + ) + with patch("gateway.session.SessionStore._ensure_loaded"): + s = SessionStore(sessions_dir=tmp_path, config=config) + s._db = None + s._loaded = True + return s + + +@pytest.fixture() +def no_reset_store(tmp_path): + """SessionStore with no reset policy (mode=none).""" + config = GatewayConfig( + default_reset_policy=SessionResetPolicy(mode="none"), + ) + with patch("gateway.session.SessionStore._ensure_loaded"): + s = SessionStore(sessions_dir=tmp_path, config=config) + s._db = None + s._loaded = True + return s + + +class TestIsSessionExpired: + """_is_session_expired should detect expiry from entry alone.""" + + def test_idle_session_expired(self, idle_store): + entry = SessionEntry( + session_key="agent:main:telegram:dm", + session_id="sid_1", + created_at=datetime.now() - timedelta(hours=3), + updated_at=datetime.now() - timedelta(minutes=120), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert idle_store._is_session_expired(entry) is True + + def test_active_session_not_expired(self, idle_store): + entry = SessionEntry( + session_key="agent:main:telegram:dm", + session_id="sid_2", + created_at=datetime.now() - timedelta(hours=1), + updated_at=datetime.now() - timedelta(minutes=10), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert idle_store._is_session_expired(entry) is False + + def test_none_mode_never_expires(self, no_reset_store): + entry = SessionEntry( + session_key="agent:main:telegram:dm", + session_id="sid_3", + created_at=datetime.now() - timedelta(days=30), + updated_at=datetime.now() - timedelta(days=30), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert no_reset_store._is_session_expired(entry) is False + + def test_active_processes_prevent_expiry(self, idle_store): + """Sessions with active background processes should never expire.""" + idle_store._has_active_processes_fn = lambda key: True + entry = SessionEntry( + session_key="agent:main:telegram:dm", + session_id="sid_4", + created_at=datetime.now() - timedelta(hours=5), + updated_at=datetime.now() - timedelta(hours=5), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert idle_store._is_session_expired(entry) is False + + def test_daily_mode_expired(self, tmp_path): + """Daily mode should expire sessions from before today's reset hour.""" + config = GatewayConfig( + default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4), + ) + with patch("gateway.session.SessionStore._ensure_loaded"): + store = SessionStore(sessions_dir=tmp_path, config=config) + store._db = None + store._loaded = True + + entry = SessionEntry( + session_key="agent:main:telegram:dm", + session_id="sid_5", + created_at=datetime.now() - timedelta(days=2), + updated_at=datetime.now() - timedelta(days=2), + platform=Platform.TELEGRAM, + chat_type="dm", + ) + assert store._is_session_expired(entry) is True + + +class TestGetOrCreateSessionNoCallback: + """get_or_create_session should NOT call a sync flush callback.""" + + def test_auto_reset_cleans_pre_flushed_marker(self, idle_store): + """When a session auto-resets, the pre_flushed marker should be discarded.""" + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + # Create initial session + entry1 = idle_store.get_or_create_session(source) + old_sid = entry1.session_id + + # Simulate the watcher having flushed it + idle_store._pre_flushed_sessions.add(old_sid) + + # Simulate the session going idle + entry1.updated_at = datetime.now() - timedelta(minutes=120) + idle_store._save() + + # Next call should auto-reset + entry2 = idle_store.get_or_create_session(source) + assert entry2.session_id != old_sid + assert entry2.was_auto_reset is True + + # The old session_id should be removed from pre_flushed + assert old_sid not in idle_store._pre_flushed_sessions + + def test_no_sync_callback_invoked(self, idle_store): + """No synchronous callback should block during auto-reset.""" + source = SessionSource( + platform=Platform.TELEGRAM, + chat_id="123", + chat_type="dm", + ) + entry1 = idle_store.get_or_create_session(source) + entry1.updated_at = datetime.now() - timedelta(minutes=120) + idle_store._save() + + # Verify no _on_auto_reset attribute + assert not hasattr(idle_store, '_on_auto_reset') + + # This should NOT block (no sync LLM call) + entry2 = idle_store.get_or_create_session(source) + assert entry2.was_auto_reset is True + + +class TestPreFlushedSessionsTracking: + """The _pre_flushed_sessions set should prevent double-flushing.""" + + def test_starts_empty(self, idle_store): + assert len(idle_store._pre_flushed_sessions) == 0 + + def test_add_and_check(self, idle_store): + idle_store._pre_flushed_sessions.add("sid_old") + assert "sid_old" in idle_store._pre_flushed_sessions + assert "sid_other" not in idle_store._pre_flushed_sessions + + def test_discard_on_reset(self, idle_store): + """discard should remove without raising if not present.""" + idle_store._pre_flushed_sessions.add("sid_a") + idle_store._pre_flushed_sessions.discard("sid_a") + assert "sid_a" not in idle_store._pre_flushed_sessions + # discard on non-existent should not raise + idle_store._pre_flushed_sessions.discard("sid_nonexistent")