feat(gateway): proactive async memory flush on session expiry

Previously, when a session expired (idle/daily reset), the memory flush
ran synchronously inside get_or_create_session — blocking the user's
message for 10-60s while an LLM call saved memories.

Now a background watcher task (_session_expiry_watcher) runs every 5 min,
detects expired sessions, and flushes memories proactively in a thread
pool.  By the time the user sends their next message, memories are
already saved and the response is immediate.

Changes:
- Add _is_session_expired(entry) to SessionStore — works from entry
  alone without needing a SessionSource
- Add _pre_flushed_sessions set to track already-flushed sessions
- Remove sync _on_auto_reset callback from get_or_create_session
- Refactor flush into _flush_memories_for_session (sync worker) +
  _async_flush_memories (thread pool wrapper)
- Add _session_expiry_watcher background task, started in start()
- Simplify /reset command to use shared fire-and-forget flush
- Add 10 tests for expiry detection, callback removal, tracking
This commit is contained in:
teknium1 2026-03-07 11:27:50 -08:00
parent e64d646bad
commit d80c30cc92
3 changed files with 282 additions and 42 deletions

View file

@ -178,7 +178,6 @@ class GatewayRunner:
self.session_store = SessionStore(
self.config.sessions_dir, self.config,
has_active_processes_fn=lambda key: process_registry.has_active_for_session(key),
on_auto_reset=self._flush_memories_before_reset,
)
self.delivery_router = DeliveryRouter(self.config)
self._running = False
@ -209,15 +208,14 @@ class GatewayRunner:
from gateway.hooks import HookRegistry
self.hooks = HookRegistry()
def _flush_memories_before_reset(self, old_entry):
"""Prompt the agent to save memories/skills before an auto-reset.
Called synchronously by SessionStore before destroying an expired session.
Loads the transcript, gives the agent a real turn with memory + skills
tools, and explicitly asks it to preserve anything worth keeping.
def _flush_memories_for_session(self, old_session_id: str):
"""Prompt the agent to save memories/skills before context is lost.
Synchronous worker meant to be called via run_in_executor from
an async context so it doesn't block the event loop.
"""
try:
history = self.session_store.load_transcript(old_entry.session_id)
history = self.session_store.load_transcript(old_session_id)
if not history or len(history) < 4:
return
@ -231,7 +229,7 @@ class GatewayRunner:
max_iterations=8,
quiet_mode=True,
enabled_toolsets=["memory", "skills"],
session_id=old_entry.session_id,
session_id=old_session_id,
)
# Build conversation history from transcript
@ -260,9 +258,14 @@ class GatewayRunner:
user_message=flush_prompt,
conversation_history=msgs,
)
logger.info("Pre-reset save completed for session %s", old_entry.session_id)
logger.info("Pre-reset memory flush completed for session %s", old_session_id)
except Exception as e:
logger.debug("Pre-reset save failed for session %s: %s", old_entry.session_id, e)
logger.debug("Pre-reset memory flush failed for session %s: %s", old_session_id, e)
async def _async_flush_memories(self, old_session_id: str):
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, self._flush_memories_for_session, old_session_id)
@staticmethod
def _load_prefill_messages() -> List[Dict[str, Any]]:
@ -464,10 +467,50 @@ class GatewayRunner:
# Check if we're restarting after a /update command
await self._send_update_notification()
# Start background session expiry watcher for proactive memory flushing
asyncio.create_task(self._session_expiry_watcher())
logger.info("Press Ctrl+C to stop")
return True
async def _session_expiry_watcher(self, interval: int = 300):
"""Background task that proactively flushes memories for expired sessions.
Runs every `interval` seconds (default 5 min). For each session that
has expired according to its reset policy, flushes memories in a thread
pool and marks the session so it won't be flushed again.
This means memories are already saved by the time the user sends their
next message, so there's no blocking delay.
"""
await asyncio.sleep(60) # initial delay — let the gateway fully start
while self._running:
try:
self.session_store._ensure_loaded()
for key, entry in list(self.session_store._entries.items()):
if entry.session_id in self.session_store._pre_flushed_sessions:
continue # already flushed this session
if not self.session_store._is_session_expired(entry):
continue # session still active
# Session has expired — flush memories in the background
logger.info(
"Session %s expired (key=%s), flushing memories proactively",
entry.session_id, key,
)
try:
await self._async_flush_memories(entry.session_id)
self.session_store._pre_flushed_sessions.add(entry.session_id)
except Exception as e:
logger.debug("Proactive memory flush failed for %s: %s", entry.session_id, e)
except Exception as e:
logger.debug("Session expiry watcher error: %s", e)
# Sleep in small increments so we can stop quickly
for _ in range(interval):
if not self._running:
break
await asyncio.sleep(1)
async def stop(self) -> None:
"""Stop the gateway and disconnect all adapters."""
logger.info("Stopping gateway...")
@ -1012,33 +1055,12 @@ class GatewayRunner:
# Get existing session key
session_key = self.session_store._generate_session_key(source)
# Memory flush before reset: load the old transcript and let a
# temporary agent save memories before the session is wiped.
# Flush memories in the background (fire-and-forget) so the user
# gets the "Session reset!" response immediately.
try:
old_entry = self.session_store._entries.get(session_key)
if old_entry:
old_history = self.session_store.load_transcript(old_entry.session_id)
if old_history:
from run_agent import AIAgent
loop = asyncio.get_event_loop()
_flush_kwargs = _resolve_runtime_agent_kwargs()
def _do_flush():
tmp_agent = AIAgent(
**_flush_kwargs,
max_iterations=5,
quiet_mode=True,
enabled_toolsets=["memory"],
session_id=old_entry.session_id,
)
# Build simple message list from transcript
msgs = []
for m in old_history:
role = m.get("role")
content = m.get("content")
if role in ("user", "assistant") and content:
msgs.append({"role": role, "content": content})
tmp_agent.flush_memories(msgs)
await loop.run_in_executor(None, _do_flush)
asyncio.create_task(self._async_flush_memories(old_entry.session_id))
except Exception as e:
logger.debug("Gateway memory flush on reset failed: %s", e)

View file

@ -311,7 +311,9 @@ class SessionStore:
self._entries: Dict[str, SessionEntry] = {}
self._loaded = False
self._has_active_processes_fn = has_active_processes_fn
self._on_auto_reset = on_auto_reset # callback(old_entry) before auto-reset
# on_auto_reset is deprecated — memory flush now runs proactively
# via the background session expiry watcher in GatewayRunner.
self._pre_flushed_sessions: set = set() # session_ids already flushed by watcher
# Initialize SQLite session database
self._db = None
@ -353,6 +355,44 @@ class SessionStore:
"""Generate a session key from a source."""
return build_session_key(source)
def _is_session_expired(self, entry: SessionEntry) -> bool:
"""Check if a session has expired based on its reset policy.
Works from the entry alone no SessionSource needed.
Used by the background expiry watcher to proactively flush memories.
Sessions with active background processes are never considered expired.
"""
if self._has_active_processes_fn:
if self._has_active_processes_fn(entry.session_key):
return False
policy = self.config.get_reset_policy(
platform=entry.platform,
session_type=entry.chat_type,
)
if policy.mode == "none":
return False
now = datetime.now()
if policy.mode in ("idle", "both"):
idle_deadline = entry.updated_at + timedelta(minutes=policy.idle_minutes)
if now > idle_deadline:
return True
if policy.mode in ("daily", "both"):
today_reset = now.replace(
hour=policy.at_hour,
minute=0, second=0, microsecond=0,
)
if now.hour < policy.at_hour:
today_reset -= timedelta(days=1)
if entry.updated_at < today_reset:
return True
return False
def _should_reset(self, entry: SessionEntry, source: SessionSource) -> bool:
"""
Check if a session should be reset based on policy.
@ -439,13 +479,11 @@ class SessionStore:
self._save()
return entry
else:
# Session is being auto-reset — flush memories before destroying
# Session is being auto-reset. The background expiry watcher
# should have already flushed memories proactively; discard
# the marker so it doesn't accumulate.
was_auto_reset = True
if self._on_auto_reset:
try:
self._on_auto_reset(entry)
except Exception as e:
logger.debug("Auto-reset callback failed: %s", e)
self._pre_flushed_sessions.discard(entry.session_id)
if self._db:
try:
self._db.end_session(entry.session_id, "session_reset")

View file

@ -0,0 +1,180 @@
"""Tests for proactive memory flush on session expiry.
Verifies that:
1. _is_session_expired() works from a SessionEntry alone (no source needed)
2. The sync callback is no longer called in get_or_create_session
3. _pre_flushed_sessions tracking works correctly
4. The background watcher can detect expired sessions
"""
import pytest
from datetime import datetime, timedelta
from pathlib import Path
from unittest.mock import patch, MagicMock
from gateway.config import Platform, GatewayConfig, SessionResetPolicy
from gateway.session import SessionSource, SessionStore, SessionEntry
@pytest.fixture()
def idle_store(tmp_path):
"""SessionStore with a 60-minute idle reset policy."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="idle", idle_minutes=60),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
@pytest.fixture()
def no_reset_store(tmp_path):
"""SessionStore with no reset policy (mode=none)."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="none"),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
s = SessionStore(sessions_dir=tmp_path, config=config)
s._db = None
s._loaded = True
return s
class TestIsSessionExpired:
"""_is_session_expired should detect expiry from entry alone."""
def test_idle_session_expired(self, idle_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_1",
created_at=datetime.now() - timedelta(hours=3),
updated_at=datetime.now() - timedelta(minutes=120),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is True
def test_active_session_not_expired(self, idle_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_2",
created_at=datetime.now() - timedelta(hours=1),
updated_at=datetime.now() - timedelta(minutes=10),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is False
def test_none_mode_never_expires(self, no_reset_store):
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_3",
created_at=datetime.now() - timedelta(days=30),
updated_at=datetime.now() - timedelta(days=30),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert no_reset_store._is_session_expired(entry) is False
def test_active_processes_prevent_expiry(self, idle_store):
"""Sessions with active background processes should never expire."""
idle_store._has_active_processes_fn = lambda key: True
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_4",
created_at=datetime.now() - timedelta(hours=5),
updated_at=datetime.now() - timedelta(hours=5),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert idle_store._is_session_expired(entry) is False
def test_daily_mode_expired(self, tmp_path):
"""Daily mode should expire sessions from before today's reset hour."""
config = GatewayConfig(
default_reset_policy=SessionResetPolicy(mode="daily", at_hour=4),
)
with patch("gateway.session.SessionStore._ensure_loaded"):
store = SessionStore(sessions_dir=tmp_path, config=config)
store._db = None
store._loaded = True
entry = SessionEntry(
session_key="agent:main:telegram:dm",
session_id="sid_5",
created_at=datetime.now() - timedelta(days=2),
updated_at=datetime.now() - timedelta(days=2),
platform=Platform.TELEGRAM,
chat_type="dm",
)
assert store._is_session_expired(entry) is True
class TestGetOrCreateSessionNoCallback:
"""get_or_create_session should NOT call a sync flush callback."""
def test_auto_reset_cleans_pre_flushed_marker(self, idle_store):
"""When a session auto-resets, the pre_flushed marker should be discarded."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
# Create initial session
entry1 = idle_store.get_or_create_session(source)
old_sid = entry1.session_id
# Simulate the watcher having flushed it
idle_store._pre_flushed_sessions.add(old_sid)
# Simulate the session going idle
entry1.updated_at = datetime.now() - timedelta(minutes=120)
idle_store._save()
# Next call should auto-reset
entry2 = idle_store.get_or_create_session(source)
assert entry2.session_id != old_sid
assert entry2.was_auto_reset is True
# The old session_id should be removed from pre_flushed
assert old_sid not in idle_store._pre_flushed_sessions
def test_no_sync_callback_invoked(self, idle_store):
"""No synchronous callback should block during auto-reset."""
source = SessionSource(
platform=Platform.TELEGRAM,
chat_id="123",
chat_type="dm",
)
entry1 = idle_store.get_or_create_session(source)
entry1.updated_at = datetime.now() - timedelta(minutes=120)
idle_store._save()
# Verify no _on_auto_reset attribute
assert not hasattr(idle_store, '_on_auto_reset')
# This should NOT block (no sync LLM call)
entry2 = idle_store.get_or_create_session(source)
assert entry2.was_auto_reset is True
class TestPreFlushedSessionsTracking:
"""The _pre_flushed_sessions set should prevent double-flushing."""
def test_starts_empty(self, idle_store):
assert len(idle_store._pre_flushed_sessions) == 0
def test_add_and_check(self, idle_store):
idle_store._pre_flushed_sessions.add("sid_old")
assert "sid_old" in idle_store._pre_flushed_sessions
assert "sid_other" not in idle_store._pre_flushed_sessions
def test_discard_on_reset(self, idle_store):
"""discard should remove without raising if not present."""
idle_store._pre_flushed_sessions.add("sid_a")
idle_store._pre_flushed_sessions.discard("sid_a")
assert "sid_a" not in idle_store._pre_flushed_sessions
# discard on non-existent should not raise
idle_store._pre_flushed_sessions.discard("sid_nonexistent")