fix(hindsight): drain retain queue cleanly on shutdown

The plugin used to spawn one daemon thread per sync_turn() to do the
aretain_batch network write. On CLI exit, that pattern raced interpreter
shutdown — the last retain could reach aiohttp after asyncio's
"cannot schedule new futures" guard had fired, producing noisy logs and
silently losing the final unsaved turn:

    WARNING ... Hindsight sync failed: cannot schedule new futures after
            interpreter shutdown
    ERROR asyncio: Unclosed client session
            client_session: <aiohttp.client.ClientSession object at 0x...>

Switch to a single-writer model: each provider owns one long-lived
writer thread plus a queue. sync_turn() snapshots state and enqueues a
job; the writer drains sequentially. Once shutdown() is called:

  - new sync_turn() / queue_prefetch() calls are dropped, not enqueued
  - a sentinel wakes the writer so it finishes in-flight work
  - shutdown joins the writer (10s) before nulling the client

Also register an idempotent atexit hook from the first sync_turn(), so
exit paths that don't go through MemoryManager.shutdown_all() (Ctrl-C,
abrupt exit) still get a chance to drain.

Tests: keep _sync_thread as a legacy alias to the writer, swap join()
calls to _retain_queue.join() (canonical wait-for-drain), add a new
TestShutdownRace suite covering single-writer reuse, post-shutdown drop,
queue draining, and shutdown idempotency.
This commit is contained in:
Nicolò Boschi 2026-04-28 14:49:14 +02:00 committed by Teknium
parent 5662ac2afc
commit 0565497dcc
2 changed files with 228 additions and 57 deletions

View file

@ -29,10 +29,12 @@ Or via $HERMES_HOME/hindsight/config.json (profile-scoped), falling back to
from __future__ import annotations
import asyncio
import atexit
import importlib
import json
import logging
import os
import queue
import threading
from datetime import datetime, timezone
@ -100,6 +102,10 @@ _loop: asyncio.AbstractEventLoop | None = None
_loop_thread: threading.Thread | None = None
_loop_lock = threading.Lock()
# Sentinel pushed to the per-provider retain queue to wake the writer for a
# clean exit. A unique object so it can never collide with a real job.
_WRITER_SENTINEL = object()
def _get_loop() -> asyncio.AbstractEventLoop:
"""Return a long-lived event loop running on a background thread."""
@ -444,6 +450,16 @@ class HindsightMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
# Single-writer model for retain. sync_turn() enqueues; the writer
# thread drains sequentially. Avoids spawning ad-hoc threads that
# can race the interpreter shutdown and emit "cannot schedule new
# futures after interpreter shutdown" / "Unclosed client session".
self._retain_queue: queue.Queue = queue.Queue()
self._writer_thread: threading.Thread | None = None
self._shutting_down = threading.Event()
self._atexit_registered = False
# Legacy alias — older tests/callers reference _sync_thread directly.
# Points at _writer_thread once the writer is running.
self._sync_thread = None
self._session_id = ""
self._parent_session_id = ""
@ -818,6 +834,73 @@ class HindsightMemoryProvider(MemoryProvider):
)
)
def _ensure_writer(self) -> None:
"""Lazy-start the single retain-writer thread.
We don't start the writer in initialize() so providers that never
retain (e.g. tools-only mode) don't pay for an idle thread.
"""
thread = self._writer_thread
if thread is not None and thread.is_alive():
return
# If the previous writer exited (e.g. after a prior shutdown), reset
# the flag so this fresh writer is allowed to drain new jobs.
self._shutting_down.clear()
thread = threading.Thread(
target=self._writer_loop,
daemon=True,
name="hindsight-writer",
)
self._writer_thread = thread
# Keep the legacy _sync_thread alias pointing at the writer so any
# external code that joins _sync_thread keeps working.
self._sync_thread = thread
thread.start()
def _writer_loop(self) -> None:
"""Drain the retain queue serially. Exits on sentinel.
Each job() is wrapped so a single failure can't kill the writer.
task_done() always fires so queue.join() works in tests.
"""
while True:
try:
job = self._retain_queue.get(timeout=1.0)
except queue.Empty:
if self._shutting_down.is_set():
return
continue
try:
if job is _WRITER_SENTINEL:
return
try:
job()
except Exception as exc:
logger.warning("Hindsight retain failed: %s", exc, exc_info=True)
finally:
self._retain_queue.task_done()
def _register_atexit(self) -> None:
"""Register an idempotent atexit hook to drain the writer.
Without this, a CLI exit that doesn't go through MemoryManager.
shutdown_all() would leave in-flight retain jobs racing interpreter
teardown, producing "cannot schedule new futures" warnings and
unclosed aiohttp sessions.
"""
if self._atexit_registered:
return
self._atexit_registered = True
atexit.register(self._atexit_shutdown)
def _atexit_shutdown(self) -> None:
if self._shutting_down.is_set():
return
try:
self.shutdown()
except Exception as exc:
logger.debug("Hindsight atexit shutdown failed: %s", exc)
def _run_hindsight_operation(self, operation):
"""Run an async Hindsight client operation, retrying once after idle shutdown."""
client = self._get_client()
@ -1081,6 +1164,9 @@ class HindsightMemoryProvider(MemoryProvider):
if not self._auto_recall:
logger.debug("Prefetch: skipped (auto_recall disabled)")
return
if self._shutting_down.is_set():
logger.debug("Prefetch: skipped (shutting down)")
return
# Truncate query to max chars
if self._recall_max_input_chars and len(query) > self._recall_max_input_chars:
query = query[:self._recall_max_input_chars]
@ -1189,13 +1275,19 @@ class HindsightMemoryProvider(MemoryProvider):
return kwargs
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Retain conversation turn in background (non-blocking).
"""Enqueue a retain for the current turn. Non-blocking.
Respects retain_every_n_turns for batching.
The actual aretain_batch runs on a single long-lived writer thread
that drains an in-memory queue. Once shutdown() has been called,
further sync_turn() calls are dropped this prevents post-exit
retains from reaching aiohttp after interpreter shutdown begins.
"""
if not self._auto_retain:
logger.debug("sync_turn: skipped (auto_retain disabled)")
return
if self._shutting_down.is_set():
logger.debug("sync_turn: skipped (shutting down)")
return
if session_id:
self._session_id = str(session_id).strip()
@ -1220,37 +1312,42 @@ class HindsightMemoryProvider(MemoryProvider):
if self._parent_session_id:
lineage_tags.append(f"parent:{self._parent_session_id}")
def _sync():
try:
item = self._build_retain_kwargs(
content,
context=self._retain_context,
metadata=self._build_metadata(
message_count=len(self._session_turns) * 2,
turn_index=self._turn_index,
),
tags=lineage_tags or None,
)
item.pop("bank_id", None)
item.pop("retain_async", None)
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns))
self._run_hindsight_operation(
lambda client: client.aretain_batch(
bank_id=self._bank_id,
items=[item],
document_id=self._document_id,
retain_async=self._retain_async,
)
)
logger.debug("Hindsight retain succeeded")
except Exception as e:
logger.warning("Hindsight sync failed: %s", e, exc_info=True)
# Snapshot the state needed for the retain. The writer may run after
# _session_turns / _turn_index are mutated by a later sync_turn().
metadata_snapshot = self._build_metadata(
message_count=len(self._session_turns) * 2,
turn_index=self._turn_index,
)
num_turns = len(self._session_turns)
document_id = self._document_id
bank_id = self._bank_id
retain_async_flag = self._retain_async
retain_context = self._retain_context
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="hindsight-sync")
self._sync_thread.start()
def _do_retain() -> None:
item = self._build_retain_kwargs(
content,
context=retain_context,
metadata=metadata_snapshot,
tags=lineage_tags or None,
)
item.pop("bank_id", None)
item.pop("retain_async", None)
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
bank_id, document_id, retain_async_flag, len(content), num_turns)
self._run_hindsight_operation(
lambda client: client.aretain_batch(
bank_id=bank_id,
items=[item],
document_id=document_id,
retain_async=retain_async_flag,
)
)
logger.debug("Hindsight retain succeeded")
self._ensure_writer()
self._register_atexit()
self._retain_queue.put(_do_retain)
def get_tool_schemas(self) -> List[Dict[str, Any]]:
if self._memory_mode == "context":
@ -1371,10 +1468,28 @@ class HindsightMemoryProvider(MemoryProvider):
)
def shutdown(self) -> None:
logger.debug("Hindsight shutdown: waiting for background threads")
for t in (self._prefetch_thread, self._sync_thread):
if t and t.is_alive():
t.join(timeout=5.0)
logger.debug("Hindsight shutdown: stopping writer + waiting for background threads")
# Stop accepting new retain jobs first so anyone still calling
# sync_turn() during teardown is dropped, not enqueued.
self._shutting_down.set()
# Drain the writer: it will finish in-flight work, then exit on
# the sentinel. Bounded join keeps shutdown predictable even if
# the daemon is wedged.
writer = self._writer_thread
if writer is not None and writer.is_alive():
try:
self._retain_queue.put(_WRITER_SENTINEL)
except Exception:
pass
writer.join(timeout=10.0)
if writer.is_alive():
logger.warning(
"Hindsight writer did not stop within 10s; "
"abandoning %d pending retain(s)",
self._retain_queue.qsize(),
)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
if self._client is not None:
try:
if self._mode == "local_embedded":