fix(hindsight): drain retain queue cleanly on shutdown

The plugin used to spawn one daemon thread per sync_turn() to do the
aretain_batch network write. On CLI exit, that pattern raced interpreter
shutdown — the last retain could reach aiohttp after asyncio's
"cannot schedule new futures" guard had fired, producing noisy logs and
silently losing the final unsaved turn:

    WARNING ... Hindsight sync failed: cannot schedule new futures after
            interpreter shutdown
    ERROR asyncio: Unclosed client session
            client_session: <aiohttp.client.ClientSession object at 0x...>

Switch to a single-writer model: each provider owns one long-lived
writer thread plus a queue. sync_turn() snapshots state and enqueues a
job; the writer drains sequentially. Once shutdown() is called:

  - new sync_turn() / queue_prefetch() calls are dropped, not enqueued
  - a sentinel wakes the writer so it finishes in-flight work
  - shutdown joins the writer (10s) before nulling the client

Also register an idempotent atexit hook from the first sync_turn(), so
exit paths that don't go through MemoryManager.shutdown_all() (Ctrl-C,
abrupt exit) still get a chance to drain.

Tests: keep _sync_thread as a legacy alias to the writer, swap join()
calls to _retain_queue.join() (canonical wait-for-drain), add a new
TestShutdownRace suite covering single-writer reuse, post-shutdown drop,
queue draining, and shutdown idempotency.
This commit is contained in:
Nicolò Boschi 2026-04-28 14:49:14 +02:00 committed by Teknium
parent 5662ac2afc
commit 0565497dcc
2 changed files with 228 additions and 57 deletions

View file

@ -669,7 +669,7 @@ class TestSyncTurn:
p._client = _make_mock_client()
p.sync_turn("hello", "hi there")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs
@ -710,8 +710,7 @@ class TestSyncTurn:
def test_sync_turn_with_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["conv", "session1"])
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "conv" in item["tags"]
assert "session1" in item["tags"]
@ -720,8 +719,7 @@ class TestSyncTurn:
def test_sync_turn_uses_aretain_batch(self, provider):
"""sync_turn should use aretain_batch with retain_async."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"].startswith("test-session-")
@ -732,8 +730,7 @@ class TestSyncTurn:
def test_sync_turn_custom_context(self, provider_with_config):
p = provider_with_config(retain_context="my-agent")
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["context"] == "my-agent"
@ -744,7 +741,7 @@ class TestSyncTurn:
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None
p.sync_turn("turn3-user", "turn3-asst")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"].startswith("test-session-")
@ -765,15 +762,13 @@ class TestSyncTurn:
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.reset_mock()
p.sync_turn("turn3-user", "turn3-asst")
p.sync_turn("turn4-user", "turn4-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain ALL turns from the session
@ -785,8 +780,7 @@ class TestSyncTurn:
def test_sync_turn_passes_document_id(self, provider):
"""sync_turn should pass document_id (session_id + per-startup ts)."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
assert call_kwargs["document_id"].startswith("test-session-")
@ -819,8 +813,7 @@ class TestSyncTurn:
def test_sync_turn_session_tag(self, provider):
"""Each retain should be tagged with session:<id> for filtering."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:test-session" in item["tags"]
@ -841,8 +834,7 @@ class TestSyncTurn:
)
p._client = _make_mock_client()
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:child-session" in item["tags"]
@ -851,15 +843,14 @@ class TestSyncTurn:
def test_sync_turn_error_does_not_raise(self, provider):
provider._client.aretain_batch.side_effect = RuntimeError("network error")
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._retain_queue.join()
def test_sync_turn_preserves_unicode(self, provider_with_config):
"""Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact."""
p = provider_with_config()
p._client = _make_mock_client()
p.sync_turn("안녕 こんにちは 你好", "👨‍👩‍👧‍👦 family")
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
item = p._client.aretain_batch.call_args.kwargs["items"][0]
# ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON,
@ -871,6 +862,71 @@ class TestSyncTurn:
assert "👨‍👩‍👧‍👦" in raw_json
# ---------------------------------------------------------------------------
# Shutdown / writer tests
# ---------------------------------------------------------------------------
class TestShutdownRace:
def test_sync_turn_uses_single_writer_thread(self, provider):
"""All retains run through one long-lived writer thread."""
provider.sync_turn("a", "b")
provider._retain_queue.join()
first_writer = provider._writer_thread
assert first_writer is not None
assert first_writer.is_alive()
provider.sync_turn("c", "d")
provider._retain_queue.join()
# Same thread reused — no ad-hoc thread per call.
assert provider._writer_thread is first_writer
assert provider._client.aretain_batch.call_count == 2
def test_sync_turn_after_shutdown_is_dropped(self, provider):
"""Once shutdown has fired, new sync_turn() calls are no-ops.
This is the core of the fix: the plugin must not enqueue a retain
during interpreter teardown that's what causes the
'cannot schedule new futures' RuntimeError + unclosed aiohttp
sessions on CLI exit.
"""
client = provider._client
provider.shutdown()
before_calls = client.aretain_batch.call_count
provider.sync_turn("late", "turn")
# No new enqueue — the retain queue stays empty.
assert provider._retain_queue.empty()
# And no new client call (would be impossible anyway since shutdown
# nulled self._client; we assert via the captured handle).
assert client.aretain_batch.call_count == before_calls
def test_queue_prefetch_after_shutdown_is_dropped(self, provider):
provider.shutdown()
provider.queue_prefetch("late query")
assert provider._prefetch_thread is None
def test_shutdown_drains_pending_retains(self, provider):
"""Shutdown must wait for queued retains to complete, not abandon them.
Otherwise the LAST in-flight turn typically the most important
is silently lost.
"""
client = provider._client
provider.sync_turn("a", "b")
provider.sync_turn("c", "d")
provider.shutdown()
# Both retains drained before shutdown returned.
assert client.aretain_batch.call_count == 2
assert provider._retain_queue.empty()
def test_shutdown_is_idempotent(self, provider):
provider.sync_turn("a", "b")
provider.shutdown()
# Second shutdown shouldn't blow up or re-close the client.
provider.shutdown()
assert provider._shutting_down.is_set()
# ---------------------------------------------------------------------------
# System prompt tests
# ---------------------------------------------------------------------------