diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index 31a04d5d4a..eb200d3328 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -29,10 +29,12 @@ Or via $HERMES_HOME/hindsight/config.json (profile-scoped), falling back to from __future__ import annotations import asyncio +import atexit import importlib import json import logging import os +import queue import threading from datetime import datetime, timezone @@ -100,6 +102,10 @@ _loop: asyncio.AbstractEventLoop | None = None _loop_thread: threading.Thread | None = None _loop_lock = threading.Lock() +# Sentinel pushed to the per-provider retain queue to wake the writer for a +# clean exit. A unique object so it can never collide with a real job. +_WRITER_SENTINEL = object() + def _get_loop() -> asyncio.AbstractEventLoop: """Return a long-lived event loop running on a background thread.""" @@ -444,6 +450,16 @@ class HindsightMemoryProvider(MemoryProvider): self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread = None + # Single-writer model for retain. sync_turn() enqueues; the writer + # thread drains sequentially. Avoids spawning ad-hoc threads that + # can race the interpreter shutdown and emit "cannot schedule new + # futures after interpreter shutdown" / "Unclosed client session". + self._retain_queue: queue.Queue = queue.Queue() + self._writer_thread: threading.Thread | None = None + self._shutting_down = threading.Event() + self._atexit_registered = False + # Legacy alias — older tests/callers reference _sync_thread directly. + # Points at _writer_thread once the writer is running. self._sync_thread = None self._session_id = "" self._parent_session_id = "" @@ -818,6 +834,73 @@ class HindsightMemoryProvider(MemoryProvider): ) ) + def _ensure_writer(self) -> None: + """Lazy-start the single retain-writer thread. + + We don't start the writer in initialize() so providers that never + retain (e.g. tools-only mode) don't pay for an idle thread. + """ + thread = self._writer_thread + if thread is not None and thread.is_alive(): + return + # If the previous writer exited (e.g. after a prior shutdown), reset + # the flag so this fresh writer is allowed to drain new jobs. + self._shutting_down.clear() + thread = threading.Thread( + target=self._writer_loop, + daemon=True, + name="hindsight-writer", + ) + self._writer_thread = thread + # Keep the legacy _sync_thread alias pointing at the writer so any + # external code that joins _sync_thread keeps working. + self._sync_thread = thread + thread.start() + + def _writer_loop(self) -> None: + """Drain the retain queue serially. Exits on sentinel. + + Each job() is wrapped so a single failure can't kill the writer. + task_done() always fires so queue.join() works in tests. + """ + while True: + try: + job = self._retain_queue.get(timeout=1.0) + except queue.Empty: + if self._shutting_down.is_set(): + return + continue + try: + if job is _WRITER_SENTINEL: + return + try: + job() + except Exception as exc: + logger.warning("Hindsight retain failed: %s", exc, exc_info=True) + finally: + self._retain_queue.task_done() + + def _register_atexit(self) -> None: + """Register an idempotent atexit hook to drain the writer. + + Without this, a CLI exit that doesn't go through MemoryManager. + shutdown_all() would leave in-flight retain jobs racing interpreter + teardown, producing "cannot schedule new futures" warnings and + unclosed aiohttp sessions. + """ + if self._atexit_registered: + return + self._atexit_registered = True + atexit.register(self._atexit_shutdown) + + def _atexit_shutdown(self) -> None: + if self._shutting_down.is_set(): + return + try: + self.shutdown() + except Exception as exc: + logger.debug("Hindsight atexit shutdown failed: %s", exc) + def _run_hindsight_operation(self, operation): """Run an async Hindsight client operation, retrying once after idle shutdown.""" client = self._get_client() @@ -1081,6 +1164,9 @@ class HindsightMemoryProvider(MemoryProvider): if not self._auto_recall: logger.debug("Prefetch: skipped (auto_recall disabled)") return + if self._shutting_down.is_set(): + logger.debug("Prefetch: skipped (shutting down)") + return # Truncate query to max chars if self._recall_max_input_chars and len(query) > self._recall_max_input_chars: query = query[:self._recall_max_input_chars] @@ -1189,13 +1275,19 @@ class HindsightMemoryProvider(MemoryProvider): return kwargs def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: - """Retain conversation turn in background (non-blocking). + """Enqueue a retain for the current turn. Non-blocking. - Respects retain_every_n_turns for batching. + The actual aretain_batch runs on a single long-lived writer thread + that drains an in-memory queue. Once shutdown() has been called, + further sync_turn() calls are dropped — this prevents post-exit + retains from reaching aiohttp after interpreter shutdown begins. """ if not self._auto_retain: logger.debug("sync_turn: skipped (auto_retain disabled)") return + if self._shutting_down.is_set(): + logger.debug("sync_turn: skipped (shutting down)") + return if session_id: self._session_id = str(session_id).strip() @@ -1220,37 +1312,42 @@ class HindsightMemoryProvider(MemoryProvider): if self._parent_session_id: lineage_tags.append(f"parent:{self._parent_session_id}") - def _sync(): - try: - item = self._build_retain_kwargs( - content, - context=self._retain_context, - metadata=self._build_metadata( - message_count=len(self._session_turns) * 2, - turn_index=self._turn_index, - ), - tags=lineage_tags or None, - ) - item.pop("bank_id", None) - item.pop("retain_async", None) - logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", - self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns)) - self._run_hindsight_operation( - lambda client: client.aretain_batch( - bank_id=self._bank_id, - items=[item], - document_id=self._document_id, - retain_async=self._retain_async, - ) - ) - logger.debug("Hindsight retain succeeded") - except Exception as e: - logger.warning("Hindsight sync failed: %s", e, exc_info=True) + # Snapshot the state needed for the retain. The writer may run after + # _session_turns / _turn_index are mutated by a later sync_turn(). + metadata_snapshot = self._build_metadata( + message_count=len(self._session_turns) * 2, + turn_index=self._turn_index, + ) + num_turns = len(self._session_turns) + document_id = self._document_id + bank_id = self._bank_id + retain_async_flag = self._retain_async + retain_context = self._retain_context - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - self._sync_thread = threading.Thread(target=_sync, daemon=True, name="hindsight-sync") - self._sync_thread.start() + def _do_retain() -> None: + item = self._build_retain_kwargs( + content, + context=retain_context, + metadata=metadata_snapshot, + tags=lineage_tags or None, + ) + item.pop("bank_id", None) + item.pop("retain_async", None) + logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", + bank_id, document_id, retain_async_flag, len(content), num_turns) + self._run_hindsight_operation( + lambda client: client.aretain_batch( + bank_id=bank_id, + items=[item], + document_id=document_id, + retain_async=retain_async_flag, + ) + ) + logger.debug("Hindsight retain succeeded") + + self._ensure_writer() + self._register_atexit() + self._retain_queue.put(_do_retain) def get_tool_schemas(self) -> List[Dict[str, Any]]: if self._memory_mode == "context": @@ -1371,10 +1468,28 @@ class HindsightMemoryProvider(MemoryProvider): ) def shutdown(self) -> None: - logger.debug("Hindsight shutdown: waiting for background threads") - for t in (self._prefetch_thread, self._sync_thread): - if t and t.is_alive(): - t.join(timeout=5.0) + logger.debug("Hindsight shutdown: stopping writer + waiting for background threads") + # Stop accepting new retain jobs first so anyone still calling + # sync_turn() during teardown is dropped, not enqueued. + self._shutting_down.set() + # Drain the writer: it will finish in-flight work, then exit on + # the sentinel. Bounded join keeps shutdown predictable even if + # the daemon is wedged. + writer = self._writer_thread + if writer is not None and writer.is_alive(): + try: + self._retain_queue.put(_WRITER_SENTINEL) + except Exception: + pass + writer.join(timeout=10.0) + if writer.is_alive(): + logger.warning( + "Hindsight writer did not stop within 10s; " + "abandoning %d pending retain(s)", + self._retain_queue.qsize(), + ) + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=5.0) if self._client is not None: try: if self._mode == "local_embedded": diff --git a/tests/plugins/memory/test_hindsight_provider.py b/tests/plugins/memory/test_hindsight_provider.py index 4d363db326..1d6238475b 100644 --- a/tests/plugins/memory/test_hindsight_provider.py +++ b/tests/plugins/memory/test_hindsight_provider.py @@ -669,7 +669,7 @@ class TestSyncTurn: p._client = _make_mock_client() p.sync_turn("hello", "hi there") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() call_kwargs = p._client.aretain_batch.call_args.kwargs @@ -710,8 +710,7 @@ class TestSyncTurn: def test_sync_turn_with_tags(self, provider_with_config): p = provider_with_config(retain_tags=["conv", "session1"]) p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert "conv" in item["tags"] assert "session1" in item["tags"] @@ -720,8 +719,7 @@ class TestSyncTurn: def test_sync_turn_uses_aretain_batch(self, provider): """sync_turn should use aretain_batch with retain_async.""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() provider._client.aretain_batch.assert_called_once() call_kwargs = provider._client.aretain_batch.call_args.kwargs assert call_kwargs["document_id"].startswith("test-session-") @@ -732,8 +730,7 @@ class TestSyncTurn: def test_sync_turn_custom_context(self, provider_with_config): p = provider_with_config(retain_context="my-agent") p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert item["context"] == "my-agent" @@ -744,7 +741,7 @@ class TestSyncTurn: p.sync_turn("turn2-user", "turn2-asst") assert p._sync_thread is None p.sync_turn("turn3-user", "turn3-asst") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() call_kwargs = p._client.aretain_batch.call_args.kwargs assert call_kwargs["document_id"].startswith("test-session-") @@ -765,15 +762,13 @@ class TestSyncTurn: p.sync_turn("turn1-user", "turn1-asst") p.sync_turn("turn2-user", "turn2-asst") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.reset_mock() p.sync_turn("turn3-user", "turn3-asst") p.sync_turn("turn4-user", "turn4-asst") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"] # Should contain ALL turns from the session @@ -785,8 +780,7 @@ class TestSyncTurn: def test_sync_turn_passes_document_id(self, provider): """sync_turn should pass document_id (session_id + per-startup ts).""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() call_kwargs = provider._client.aretain_batch.call_args.kwargs # Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds} assert call_kwargs["document_id"].startswith("test-session-") @@ -819,8 +813,7 @@ class TestSyncTurn: def test_sync_turn_session_tag(self, provider): """Each retain should be tagged with session: for filtering.""" provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() item = provider._client.aretain_batch.call_args.kwargs["items"][0] assert "session:test-session" in item["tags"] @@ -841,8 +834,7 @@ class TestSyncTurn: ) p._client = _make_mock_client() p.sync_turn("hello", "hi") - if p._sync_thread: - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() item = p._client.aretain_batch.call_args.kwargs["items"][0] assert "session:child-session" in item["tags"] @@ -851,15 +843,14 @@ class TestSyncTurn: def test_sync_turn_error_does_not_raise(self, provider): provider._client.aretain_batch.side_effect = RuntimeError("network error") provider.sync_turn("hello", "hi") - if provider._sync_thread: - provider._sync_thread.join(timeout=5.0) + provider._retain_queue.join() def test_sync_turn_preserves_unicode(self, provider_with_config): """Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact.""" p = provider_with_config() p._client = _make_mock_client() p.sync_turn("안녕 こんにちは 你好", "👨‍👩‍👧‍👦 family") - p._sync_thread.join(timeout=5.0) + p._retain_queue.join() p._client.aretain_batch.assert_called_once() item = p._client.aretain_batch.call_args.kwargs["items"][0] # ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON, @@ -871,6 +862,71 @@ class TestSyncTurn: assert "👨‍👩‍👧‍👦" in raw_json +# --------------------------------------------------------------------------- +# Shutdown / writer tests +# --------------------------------------------------------------------------- + + +class TestShutdownRace: + def test_sync_turn_uses_single_writer_thread(self, provider): + """All retains run through one long-lived writer thread.""" + provider.sync_turn("a", "b") + provider._retain_queue.join() + first_writer = provider._writer_thread + assert first_writer is not None + assert first_writer.is_alive() + + provider.sync_turn("c", "d") + provider._retain_queue.join() + # Same thread reused — no ad-hoc thread per call. + assert provider._writer_thread is first_writer + assert provider._client.aretain_batch.call_count == 2 + + def test_sync_turn_after_shutdown_is_dropped(self, provider): + """Once shutdown has fired, new sync_turn() calls are no-ops. + + This is the core of the fix: the plugin must not enqueue a retain + during interpreter teardown — that's what causes the + 'cannot schedule new futures' RuntimeError + unclosed aiohttp + sessions on CLI exit. + """ + client = provider._client + provider.shutdown() + before_calls = client.aretain_batch.call_count + provider.sync_turn("late", "turn") + # No new enqueue — the retain queue stays empty. + assert provider._retain_queue.empty() + # And no new client call (would be impossible anyway since shutdown + # nulled self._client; we assert via the captured handle). + assert client.aretain_batch.call_count == before_calls + + def test_queue_prefetch_after_shutdown_is_dropped(self, provider): + provider.shutdown() + provider.queue_prefetch("late query") + assert provider._prefetch_thread is None + + def test_shutdown_drains_pending_retains(self, provider): + """Shutdown must wait for queued retains to complete, not abandon them. + + Otherwise the LAST in-flight turn — typically the most important — + is silently lost. + """ + client = provider._client + provider.sync_turn("a", "b") + provider.sync_turn("c", "d") + provider.shutdown() + # Both retains drained before shutdown returned. + assert client.aretain_batch.call_count == 2 + assert provider._retain_queue.empty() + + def test_shutdown_is_idempotent(self, provider): + provider.sync_turn("a", "b") + provider.shutdown() + # Second shutdown shouldn't blow up or re-close the client. + provider.shutdown() + assert provider._shutting_down.is_set() + + # --------------------------------------------------------------------------- # System prompt tests # ---------------------------------------------------------------------------