From 91e9459e10062b3e68e58dc84ac379993ce51c2b Mon Sep 17 00:00:00 2001 From: harshitAgr <28730481+harshitAgr@users.noreply.github.com> Date: Thu, 21 May 2026 11:51:54 +0300 Subject: [PATCH] fix(openviking): track writers per-session so commit waits for all MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit sync_turn's bounded join could drop a still-alive previous worker by replacing the single _sync_thread slot. The dropped worker kept POSTing under the old sid but was no longer visible to on_session_end / on_session_switch, so the commit could fire while orphaned writes were still in flight — those writes landed past the commit boundary and were never extracted. Replace the single _sync_thread slot with _inflight_writers: Dict[sid, Set[Thread]]. Writers self-register on spawn (sync_turn, on_memory_write) and self-deregister on exit. The commit path drains _drain_writers(sid, 10.0) and skips the commit if any writer for that sid is still alive after the bounded budget. Also trim inline review-rationale comments to short invariants per reviewer style ask: "commit only after session writes drain" and "drop prefetch results from older switch generations." Co-Authored-By: Claude Opus 4.7 (1M context) (cherry picked from commit 7537ee6f5b9ffa4c0f7b79af053aab449caf5af5) --- plugins/memory/openviking/__init__.py | 163 ++++++++++-------- .../memory/test_openviking_provider.py | 84 ++++++++- 2 files changed, 168 insertions(+), 79 deletions(-) diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index b076f411ebe..df0a0cd9a42 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -31,10 +31,11 @@ import mimetypes import os import tempfile import threading +import time import uuid import zipfile from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set from urllib.parse import urlparse from urllib.request import url2pathname @@ -432,14 +433,16 @@ class OpenVikingMemoryProvider(MemoryProvider): self._api_key = "" self._session_id = "" self._turn_count = 0 - self._sync_thread: Optional[threading.Thread] = None + # Commit only after session writes drain. The set is keyed by the sid + # the writer is POSTing under (snapshotted at spawn), so on_session_end + # / on_session_switch see every still-alive writer for that sid even + # if later writes have replaced the latest-tracked thread. + self._inflight_writers: Dict[str, Set[threading.Thread]] = {} + self._inflight_lock = threading.Lock() self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None - # Monotonic counter incremented on every session switch. Prefetch - # workers capture the value when spawned and refuse to write their - # result if the generation has advanced — otherwise a slow worker - # from session N can repopulate session N+1 with stale recall. + # Drop prefetch results from older switch generations. self._prefetch_generation = 0 @property @@ -554,9 +557,7 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client or not query: return - # Snapshot the generation at spawn time. If on_session_switch bumps it - # before this worker finishes, the worker drops its result instead of - # repopulating the new session with stale recall from the old one. + # Drop prefetch results from older switch generations. gen = self._prefetch_generation def _run(): @@ -592,6 +593,53 @@ class OpenVikingMemoryProvider(MemoryProvider): ) self._prefetch_thread.start() + def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None: + """Spawn a daemon writer tracked in _inflight_writers[sid]. + + Tracking is keyed by sid (not by a single latest-thread slot) so that + on_session_end / on_session_switch can drain every still-alive writer + for the session being committed. + """ + holder: List[threading.Thread] = [] + + def _wrapped(): + try: + target() + finally: + with self._inflight_lock: + workers = self._inflight_writers.get(sid) + if workers is not None: + workers.discard(holder[0]) + if not workers: + self._inflight_writers.pop(sid, None) + + thread = threading.Thread(target=_wrapped, daemon=True, name=name) + holder.append(thread) + with self._inflight_lock: + self._inflight_writers.setdefault(sid, set()).add(thread) + thread.start() + + def _drain_writers(self, sid: str, timeout: float) -> bool: + """Join every in-flight writer for sid within a shared timeout budget. + + Returns True if all writers drained, False if any are still alive when + the budget runs out. Callers use the False return to skip the commit. + """ + deadline = time.monotonic() + timeout + while True: + with self._inflight_lock: + workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()] + if not workers: + return True + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + for t in workers: + slice_left = deadline - time.monotonic() + if slice_left <= 0: + break + t.join(timeout=slice_left) + def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Record the conversation turn in OpenViking's session (non-blocking).""" if not self._client: @@ -601,10 +649,7 @@ class OpenVikingMemoryProvider(MemoryProvider): if not user_content: return - # Capture the target session id NOW, not inside the worker. Otherwise - # a delayed worker can read self._session_id after on_session_switch - # has rotated it (the switch's join on _sync_thread is bounded), and - # the OLD turn's content lands in the NEW session. + # Snapshot the sid so a delayed worker can't write into a rotated session. sid = str(session_id or self._session_id).strip() if not sid: return @@ -617,13 +662,10 @@ class OpenVikingMemoryProvider(MemoryProvider): self._endpoint, self._api_key, account=self._account, user=self._user, agent=self._agent, ) - - # Add user message client.post(f"/api/v1/sessions/{sid}/messages", { "role": "user", - "content": user_content[:4000], # trim very long messages + "content": user_content[:4000], }) - # Add assistant message client.post(f"/api/v1/sessions/{sid}/messages", { "role": "assistant", "content": assistant_content[:4000], @@ -631,14 +673,7 @@ class OpenVikingMemoryProvider(MemoryProvider): except Exception as e: logger.debug("OpenViking sync_turn failed: %s", e) - # Wait for any previous sync to finish before starting a new one - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - - self._sync_thread = threading.Thread( - target=_sync, daemon=True, name="openviking-sync" - ) - self._sync_thread.start() + self._spawn_writer(sid, _sync, name="openviking-sync") def on_session_end(self, messages: List[Dict[str, Any]]) -> None: """Commit the session to trigger memory extraction. @@ -649,35 +684,22 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client: return - # Wait for any pending sync to finish first — do this before the - # turn_count check so the last turn's messages are flushed even if - # the count hasn't been incremented yet. - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=10.0) - if self._sync_thread.is_alive(): - # Worker outlived the bounded join — each POST has _TIMEOUT=30s - # and there are two of them per turn. Committing now would - # orphan the worker's late writes past the commit boundary - # (they'd land in an already-committed session and never be - # extracted). Skip the commit; leave _turn_count untouched so - # the session stays marked dirty for any retry path. - logger.warning( - "OpenViking sync worker still alive after 10s join — " - "skipping commit on session %s to avoid orphaning late writes", - self._session_id, - ) - return + sid = self._session_id + # Commit only after session writes drain. + if not self._drain_writers(sid, timeout=10.0): + logger.warning( + "OpenViking writer for %s still alive after drain — skipping commit", + sid, + ) + return if self._turn_count == 0: return try: - self._client.post(f"/api/v1/sessions/{self._session_id}/commit") - logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count) - # Mark the session clean so a subsequent on_session_switch (fired - # by /new and compression right after commit_memory_session) skips - # its commit instead of double-committing. On commit failure we - # leave the count intact so the switch hook gets a retry. + self._client.post(f"/api/v1/sessions/{sid}/commit") + logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count) + # Mark clean so a follow-up on_session_switch skips its own commit. self._turn_count = 0 except Exception as e: logger.warning("OpenViking session commit failed: %s", e) @@ -710,27 +732,19 @@ class OpenVikingMemoryProvider(MemoryProvider): old_session_id = self._session_id old_turn_count = self._turn_count - sync_worker_drained = True - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=10.0) - if self._sync_thread.is_alive(): - # Same hazard as on_session_end: worker outlived the bounded - # join. Skip the commit so its late writes aren't orphaned - # past a commit boundary they can't recover from. - sync_worker_drained = False + # Commit only after session writes drain. + writers_drained = True + if old_session_id: + writers_drained = self._drain_writers(old_session_id, timeout=10.0) + if not writers_drained: logger.warning( - "OpenViking sync worker still alive after 10s join — " - "skipping commit-on-switch for session %s; late writes " - "will remain in the uncommitted old session", + "OpenViking writer for %s still alive after drain — " + "skipping commit-on-switch", old_session_id, ) - # 2. Commit the old session if it accumulated turns — same - # extraction semantics as on_session_end. Skip if empty (nothing - # to extract), if the provider was never initialized, or if the - # sync worker is still mid-flight. - if sync_worker_drained and old_session_id and old_turn_count > 0: + if writers_drained and old_session_id and old_turn_count > 0: try: self._client.post(f"/api/v1/sessions/{old_session_id}/commit") logger.info( @@ -743,10 +757,7 @@ class OpenVikingMemoryProvider(MemoryProvider): old_session_id, e, ) - # 3. Bump the prefetch generation so any in-flight prefetch worker - # finishing AFTER this point drops its result. Then drain the - # current worker and clear the cached result so the new session - # doesn't see stale recall from the old one. + # Drop prefetch results from older switch generations. self._prefetch_generation += 1 if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=3.0) @@ -819,11 +830,17 @@ class OpenVikingMemoryProvider(MemoryProvider): return tool_error(str(e)) def shutdown(self) -> None: - # Wait for background threads to finish - for t in (self._sync_thread, self._prefetch_thread): - if t and t.is_alive(): + # Wait for every in-flight writer across all tracked sessions. + with self._inflight_lock: + all_workers = [ + t for workers in self._inflight_writers.values() for t in workers + ] + for t in all_workers: + if t.is_alive(): t.join(timeout=5.0) - # Clear atexit reference so it doesn't double-commit + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=5.0) + # Clear atexit reference so it doesn't double-commit. global _last_active_provider if _last_active_provider is self: _last_active_provider = None diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 7bbecee1e83..0353a1b6f00 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -483,7 +483,7 @@ def test_on_session_switch_waits_for_inflight_sync_thread(): # Simulate a worker that finishes within the join window. self._alive = False - provider._sync_thread = FakeThread() + provider._inflight_writers["old-sid"] = {FakeThread()} provider.on_session_switch("new-sid") @@ -561,8 +561,8 @@ def test_sync_turn_captures_session_id_before_worker_runs(): # Rotate the provider's session id while the worker is mid-flight. provider._session_id = "new-sid" release.set() - if provider._sync_thread: - provider._sync_thread.join(timeout=2.0) + for t in list(provider._inflight_writers.get("old-sid", set())): + t.join(timeout=2.0) finally: _mod._VikingClient = real_client_cls @@ -582,7 +582,7 @@ def test_sync_turn_noop_when_session_id_blank(): # No turn counted, no worker spawned. assert provider._turn_count == 0 - assert provider._sync_thread is None + assert provider._inflight_writers == {} def test_on_session_end_marks_session_clean_after_successful_commit(): @@ -660,7 +660,7 @@ def test_on_session_end_skips_commit_when_sync_worker_outlives_join(): already-committed session and never be extracted. Leave _turn_count intact so the session stays marked dirty.""" provider = _make_provider_with_session("old-sid", turn_count=3) - provider._sync_thread = _HungThread() + provider._inflight_writers["old-sid"] = {_HungThread()} provider.on_session_end([]) @@ -673,7 +673,7 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join(): session needs to start) but the old-session commit is skipped to avoid orphaning the worker's late writes past commit.""" provider = _make_provider_with_session("old-sid", turn_count=2) - provider._sync_thread = _HungThread() + provider._inflight_writers["old-sid"] = {_HungThread()} provider.on_session_switch("new-sid") @@ -682,6 +682,78 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join(): assert provider._turn_count == 0 +# --------------------------------------------------------------------------- +# Orphaned-writer hazard: commit must wait for ALL writers for the session, +# not just the latest tracked one. sync_turn's bounded rate-limit can drop a +# still-alive previous worker — that dropped writer keeps POSTing under the +# old sid and would otherwise land its writes past the commit boundary. +# --------------------------------------------------------------------------- + +def test_on_session_end_waits_for_all_writers_not_just_latest(): + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_end([]) + + provider._client.post.assert_not_called() + assert provider._turn_count == 2 + + +def test_on_session_switch_waits_for_all_writers_not_just_latest(): + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_sync_turn_tracks_writer_under_session_id(): + """Every sync_turn writer must register under its captured sid so the + drain at end/switch sees it even if a later sync_turn replaces the + latest-tracked reference.""" + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "sid-1" + + release = threading.Event() + started = threading.Event() + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + assert started.wait(timeout=2.0), "worker never entered post()" + assert len(provider._inflight_writers.get("sid-1", set())) == 1 + release.set() + for t in list(provider._inflight_writers.get("sid-1", set())): + t.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # Worker should have removed itself from the inflight set on exit. + assert provider._inflight_writers.get("sid-1", set()) == set() + + # --------------------------------------------------------------------------- # on_memory_write: same late-capture hazard as sync_turn — worker must use # the session id snapshotted at call time, not re-read self._session_id.