fix(hindsight): route flush-on-switch through writer queue, not raw thread

Follow-up to the cherry-picked PR #17447. The original flush spawned a
bare threading.Thread for the buffer-flush path, overwriting
self._sync_thread — which is aliased to the long-lived writer thread.
Two consequences:

1. No serialization with the writer queue. If old-session retains were
   still queued in _retain_queue, the flush ran concurrently with the
   writer and both threads could call aretain_batch against the same
   document_id.
2. The pre-spawn 'self._sync_thread.join(timeout=5.0)' tried to join the
   long-lived writer, which never exits, so the join was a no-op that
   just timed out — never actually serialized anything.

Fix: enqueue the flush closure on _retain_queue via _ensure_writer +
put(). Natural FIFO ordering behind any pending retains, no new thread,
no broken join. Shutdown-aware so it doesn't enqueue after teardown.

Tests updated to drain via _retain_queue.join() instead of the stale
_sync_thread.join(). Added regression guard
test_flush_serializes_behind_pending_retains_via_writer_queue that
blocks the writer mid-retain to prove the flush waits in FIFO behind
the old retain.

Also seeds _retain_queue / _shutting_down / stubbed _ensure_writer on
the bare-object test helper in test_memory_session_switch.py so that
path doesn't blow up under the new queue-enqueue.

tests/plugins/memory/test_hindsight_provider.py + tests/agent/test_memory_session_switch.py: 103/103 passing.
This commit is contained in:
teknium1 2026-04-29 08:08:02 -07:00 committed by Teknium
parent c38dac742b
commit 0a5ee01e48
3 changed files with 86 additions and 19 deletions

View file

@ -235,11 +235,21 @@ def _make_hindsight_provider():
provider._prefetch_thread = None
provider._prefetch_lock = threading.Lock()
provider._prefetch_result = ""
# Sync thread tracking — flush spawn target.
# Sync thread tracking (legacy alias at the writer).
provider._sync_thread = None
# Stub the network-touching helper so the spawned flush thread is a
# no-op in unit tests. Real plugin behavior is covered by the
# mock-client tests in tests/plugins/memory/test_hindsight_provider.py.
# Writer queue infra the flush-on-switch path enqueues onto. We stub
# _ensure_writer / _register_atexit so no real thread is spawned;
# tests exercising flush delivery live in
# tests/plugins/memory/test_hindsight_provider.py where the full
# writer-queue wiring is in place.
import queue as _queue
provider._retain_queue = _queue.Queue()
provider._shutting_down = threading.Event()
provider._atexit_registered = True
provider._ensure_writer = lambda: None
provider._register_atexit = lambda: None
# Stub the network-touching helper so any enqueued flush closure is
# a no-op if ever drained in a unit test.
provider._run_hindsight_operation = lambda _op: None
return provider

View file

@ -940,16 +940,17 @@ class TestSessionSwitchBufferFlush:
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
old_doc = p._document_id
# Two turns buffered, no retain yet (boundary is at turn 3).
# Two turns buffered, no retain yet (boundary is at turn 3). The
# writer hasn't been started either — sync_turn's early return
# skips _ensure_writer when no retain is due.
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None
p._client.aretain_batch.assert_not_called()
# Switch — flush should fire under OLD document_id.
# Switch — flush should fire under OLD document_id via the writer queue.
p.on_session_switch("new-sid", parent_session_id="test-session", reset=True)
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._retain_queue.join()
p._client.aretain_batch.assert_called_once()
kw = p._client.aretain_batch.call_args.kwargs
@ -974,8 +975,8 @@ class TestSessionSwitchBufferFlush:
def test_no_flush_when_buffer_empty(self, provider):
"""Switch with no buffered turns must not fire a spurious retain."""
provider.on_session_switch("new-sid")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
# Nothing enqueued — join is immediate.
provider._retain_queue.join()
provider._client.aretain_batch.assert_not_called()
assert provider._session_id == "new-sid"
@ -1015,6 +1016,61 @@ class TestSessionSwitchBufferFlush:
assert finished.is_set(), "switch returned before prefetch thread settled"
assert provider._prefetch_result == ""
def test_flush_serializes_behind_pending_retains_via_writer_queue(
self, provider_with_config
):
"""The flush closure must ride the same _retain_queue sync_turn
uses, so it lands FIFO behind any still-queued old-session
retains rather than racing them on a separate thread.
Regression guard: an earlier draft spawned a raw threading.Thread
for flush, overwriting _sync_thread and racing the writer against
the same document_id.
"""
import threading as _threading
p = provider_with_config(retain_every_n_turns=2, retain_async=False)
# Block the first writer job until we've enqueued the flush
# behind it. This proves ordering — the flush MUST wait.
gate = _threading.Event()
call_order: list[str] = []
def _aretain_batch_tracking(**kw):
idx = kw["items"][0]["metadata"].get("turn_index", "")
call_order.append(str(idx))
if idx == "2":
# First retain blocks until we've enqueued the flush.
gate.wait(timeout=5.0)
p._client.aretain_batch = AsyncMock(side_effect=_aretain_batch_tracking)
# Turn 1+2 → boundary hit → retain enqueued (will block).
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
# One more buffered turn so flush has something to land.
p.sync_turn("turn3-user", "turn3-asst")
# Switch while the first retain is still blocked on `gate`.
p.on_session_switch("new-sid", parent_session_id="test-session")
# Release the first retain. Flush must have been enqueued
# BEHIND it, and run second.
gate.set()
p._retain_queue.join()
# The flush carries all buffered turns; sync_turn's retain #2
# carried the batch at boundary time. Two distinct calls.
assert p._client.aretain_batch.call_count == 2
# First call landed while buffer was [t1, t2]; flush landed
# after we added t3. So the second call must be strictly after.
assert call_order[0] == "2"
# Flush retain has turn_index matching the buffered count at
# switch time (3 turns accumulated, _turn_index was set to 3
# by the last sync_turn).
assert call_order[1] == "3"
# ---------------------------------------------------------------------------
# System prompt tests