From c38dac742b22c55581d4105a9727e55ba620a984 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Wed, 29 Apr 2026 14:58:34 +0200 Subject: [PATCH] fix(hindsight): flush buffered turns and drop stale prefetch on session switch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two data-loss / leak gaps in HindsightMemoryProvider.on_session_switch introduced by #17409. 1. Buffered turns silently lost when retain_every_n_turns > 1. on_session_switch unconditionally cleared _session_turns without flushing. Users who batched every N>1 turns and switched mid-batch (/reset, /new, /resume, /branch, or context compression) had those buffered turns disappear. Same data-loss class as the shutdown race, different lifecycle event. Note commit_memory_session() -> on_session_end() runs *before* on_session_switch on /reset, but Hindsight doesn't implement on_session_end so the buffer survives that step and dies at clear time. /resume, /branch, and compression skip commit_memory_session entirely so an on_session_end impl wouldn't help them anyway. Fix: snapshot the old _session_id, _document_id, _parent_session_id, _turn_index, and _session_turns; spawn one final retain that lands under the OLD document_id; then rotate state. Metadata is built synchronously against the old self._* so session_id / lineage tags on the flushed item all reference the prior session consistently. 2. Stale _prefetch_result leaks across switch. If queue_prefetch ran in the old session and the result hadn't been consumed by prefetch() yet, on_session_switch left the cached recall text in place. The next session's first prefetch() call would return text mined from the prior session's bank/query. Fix: join any in-flight _prefetch_thread (3s bounded — matches shutdown()), then clear _prefetch_result under _prefetch_lock before rotating session_id. Tests ----- - tests/plugins/memory/test_hindsight_provider.py (TestSessionSwitchBufferFlush): - buffered turns flushed under OLD document_id with OLD lineage tags - empty buffer => no spurious retain - _prefetch_result cleared on switch - in-flight prefetch thread is awaited before clear (no race) - tests/agent/test_memory_session_switch.py: factory extended to seed the attrs the new flush path reads (_retain_source, _platform, _bank_id, prefetch state, etc.) and stub _run_hindsight_operation so existing switch-state assertions keep passing without network setup. --- plugins/memory/hindsight/__init__.py | 74 +++++++++++++++ tests/agent/test_memory_session_switch.py | 28 ++++++ .../plugins/memory/test_hindsight_provider.py | 89 +++++++++++++++++++ 3 files changed, 191 insertions(+) diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index eb200d3328..a658fb9887 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -1447,6 +1447,16 @@ class HindsightMemoryProvider(MemoryProvider): batching must start from zero so an in-flight retain doesn't flush under the wrong ``_document_id``. + Before clearing, flush any buffered turns under the *old* + ``_document_id``. Users who set ``retain_every_n_turns > 1`` would + otherwise silently lose whatever's in ``_session_turns`` at the + moment of switch — the same data-loss class as the shutdown race, + just at a different lifecycle event. + + Also wait for any in-flight prefetch from the old session and drop + its cached result; otherwise the new session's first ``prefetch()`` + could read stale recall text from before the switch. + ``parent_session_id`` is recorded for lineage tags on future retains. ``reset`` is accepted but not needed for Hindsight's state model — buffer clearing is correct for every session switch, not only /reset. @@ -1454,6 +1464,70 @@ class HindsightMemoryProvider(MemoryProvider): new_id = str(new_session_id or "").strip() if not new_id: return + + # 1. Flush any buffered turns under the OLD identifiers. Snapshot + # everything before mutating self._* so metadata + tags + doc_id + # all reference the old session consistently. + if self._session_turns: + old_turns = list(self._session_turns) + old_session_id = self._session_id + old_document_id = self._document_id + old_parent_session_id = self._parent_session_id + old_turn_index = self._turn_index + old_metadata = self._build_metadata( + message_count=len(old_turns) * 2, + turn_index=old_turn_index, + ) + old_lineage_tags: list[str] = [] + if old_session_id: + old_lineage_tags.append(f"session:{old_session_id}") + if old_parent_session_id: + old_lineage_tags.append(f"parent:{old_parent_session_id}") + old_content = "[" + ",".join(old_turns) + "]" + + def _flush(): + try: + item = self._build_retain_kwargs( + old_content, + context=self._retain_context, + metadata=old_metadata, + tags=old_lineage_tags or None, + ) + item.pop("bank_id", None) + item.pop("retain_async", None) + logger.debug( + "Hindsight flush-on-switch: bank=%s, doc=%s, num_turns=%d", + self._bank_id, old_document_id, len(old_turns), + ) + self._run_hindsight_operation( + lambda client: client.aretain_batch( + bank_id=self._bank_id, + items=[item], + document_id=old_document_id, + retain_async=self._retain_async, + ) + ) + except Exception as e: + logger.warning("Hindsight flush-on-switch failed: %s", e, exc_info=True) + + # Match sync_turn's serialization — wait for any prior retain + # thread to finish before spawning the flush, so writes + # against the old document arrive in order. + if self._sync_thread and self._sync_thread.is_alive(): + self._sync_thread.join(timeout=5.0) + self._sync_thread = threading.Thread( + target=_flush, daemon=True, name="hindsight-flush-on-switch" + ) + self._sync_thread.start() + + # 2. Drain any in-flight prefetch from the old session and drop + # its cached result so the new session doesn't see stale recall. + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=3.0) + with self._prefetch_lock: + self._prefetch_result = "" + + # 3. Now rotate to the new session. if parent_session_id: self._parent_session_id = str(parent_session_id).strip() self._session_id = new_id diff --git a/tests/agent/test_memory_session_switch.py b/tests/agent/test_memory_session_switch.py index 1cf945e738..12b2aa0a77 100644 --- a/tests/agent/test_memory_session_switch.py +++ b/tests/agent/test_memory_session_switch.py @@ -205,6 +205,7 @@ def _make_hindsight_provider(): bypassing __init__ and seeding the attributes on_session_switch reads/writes. This keeps the test hermetic. """ + import threading hindsight_mod = pytest.importorskip("plugins.memory.hindsight") provider = object.__new__(hindsight_mod.HindsightMemoryProvider) provider._session_id = "old-sid" @@ -213,6 +214,33 @@ def _make_hindsight_provider(): provider._session_turns = ["turn-1", "turn-2"] provider._turn_counter = 2 provider._turn_index = 2 + # Attrs read by _build_metadata / _build_retain_kwargs when the + # buffer-flush path on session switch fires. Empty strings keep the + # metadata minimal but well-formed. + provider._retain_source = "" + provider._platform = "" + provider._user_id = "" + provider._user_name = "" + provider._chat_id = "" + provider._chat_name = "" + provider._chat_type = "" + provider._thread_id = "" + provider._agent_identity = "" + provider._agent_workspace = "" + provider._retain_tags = [] + provider._retain_context = "test-context" + provider._retain_async = False + provider._bank_id = "test-bank" + # Prefetch state the switch path drains/clears. + provider._prefetch_thread = None + provider._prefetch_lock = threading.Lock() + provider._prefetch_result = "" + # Sync thread tracking — flush spawn target. + provider._sync_thread = None + # Stub the network-touching helper so the spawned flush thread is a + # no-op in unit tests. Real plugin behavior is covered by the + # mock-client tests in tests/plugins/memory/test_hindsight_provider.py. + provider._run_hindsight_operation = lambda _op: None return provider diff --git a/tests/plugins/memory/test_hindsight_provider.py b/tests/plugins/memory/test_hindsight_provider.py index 1d6238475b..056e249351 100644 --- a/tests/plugins/memory/test_hindsight_provider.py +++ b/tests/plugins/memory/test_hindsight_provider.py @@ -927,6 +927,95 @@ class TestShutdownRace: assert provider._shutting_down.is_set() +# --------------------------------------------------------------------------- +# on_session_switch — flush + prefetch reset behavior +# --------------------------------------------------------------------------- + + +class TestSessionSwitchBufferFlush: + def test_buffered_turns_flushed_before_clear(self, provider_with_config): + """retain_every_n_turns > 1 must not silently drop partial buffers + on session switch. Whatever's in _session_turns at switch time + should land in the OLD document under the OLD session id.""" + p = provider_with_config(retain_every_n_turns=3, retain_async=False) + old_doc = p._document_id + + # Two turns buffered, no retain yet (boundary is at turn 3). + p.sync_turn("turn1-user", "turn1-asst") + p.sync_turn("turn2-user", "turn2-asst") + assert p._sync_thread is None + p._client.aretain_batch.assert_not_called() + + # Switch — flush should fire under OLD document_id. + p.on_session_switch("new-sid", parent_session_id="test-session", reset=True) + if p._sync_thread: + p._sync_thread.join(timeout=5.0) + + p._client.aretain_batch.assert_called_once() + kw = p._client.aretain_batch.call_args.kwargs + assert kw["document_id"] == old_doc + item = kw["items"][0] + # Both buffered turns must be present in the flushed payload. + content = json.loads(item["content"]) + flat = json.dumps(content) + assert "turn1-user" in flat + assert "turn2-user" in flat + # Old session id must appear in lineage tags / metadata. + assert "session:test-session" in item["tags"] + assert item["metadata"]["session_id"] == "test-session" + + # And the new session must start with a clean slate. + assert p._session_id == "new-sid" + assert p._session_turns == [] + assert p._turn_counter == 0 + assert p._document_id != old_doc + assert p._document_id.startswith("new-sid-") + + def test_no_flush_when_buffer_empty(self, provider): + """Switch with no buffered turns must not fire a spurious retain.""" + provider.on_session_switch("new-sid") + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + provider._client.aretain_batch.assert_not_called() + assert provider._session_id == "new-sid" + + def test_prefetch_result_cleared_on_switch(self, provider): + """Stale recall text from the old session must not leak into the + next session's first prefetch read.""" + provider._prefetch_result = "old-session recall: User likes Rust" + provider.on_session_switch("new-sid") + assert provider._prefetch_result == "" + # And subsequent prefetch() should now report empty, not the leftover. + assert provider.prefetch("anything") == "" + + def test_in_flight_prefetch_thread_drained_on_switch(self, provider, monkeypatch): + """on_session_switch must wait for an in-flight prefetch from the + old session to settle before clearing _prefetch_result, otherwise + the thread can race and re-populate the field after the clear.""" + import threading + import time as _time + + gate = threading.Event() + finished = threading.Event() + + def _slow_prefetch(): + gate.wait(timeout=5.0) + with provider._prefetch_lock: + provider._prefetch_result = "old-session recall" + finished.set() + + provider._prefetch_thread = threading.Thread(target=_slow_prefetch, daemon=True) + provider._prefetch_thread.start() + + # Release the prefetch worker so it writes _prefetch_result, then + # call on_session_switch — it must join the thread before clearing. + gate.set() + provider.on_session_switch("new-sid") + + assert finished.is_set(), "switch returned before prefetch thread settled" + assert provider._prefetch_result == "" + + # --------------------------------------------------------------------------- # System prompt tests # ---------------------------------------------------------------------------