diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index 125218caebb..9b140b6da8c 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -403,6 +403,11 @@ class OpenVikingMemoryProvider(MemoryProvider): self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None + # Monotonic counter incremented on every session switch. Prefetch + # workers capture the value when spawned and refuse to write their + # result if the generation has advanced — otherwise a slow worker + # from session N can repopulate session N+1 with stale recall. + self._prefetch_generation = 0 @property def name(self) -> str: @@ -515,6 +520,11 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client or not query: return + # Snapshot the generation at spawn time. If on_session_switch bumps it + # before this worker finishes, the worker drops its result instead of + # repopulating the new session with stale recall from the old one. + gen = self._prefetch_generation + def _run(): try: client = _VikingClient( @@ -537,6 +547,8 @@ class OpenVikingMemoryProvider(MemoryProvider): parts.append(f"- [{score:.2f}] {abstract} ({uri})") if parts: with self._prefetch_lock: + if gen != self._prefetch_generation: + return self._prefetch_result = "\n".join(parts) except Exception as e: logger.debug("OpenViking prefetch failed: %s", e) @@ -604,6 +616,19 @@ class OpenVikingMemoryProvider(MemoryProvider): # the count hasn't been incremented yet. if self._sync_thread and self._sync_thread.is_alive(): self._sync_thread.join(timeout=10.0) + if self._sync_thread.is_alive(): + # Worker outlived the bounded join — each POST has _TIMEOUT=30s + # and there are two of them per turn. Committing now would + # orphan the worker's late writes past the commit boundary + # (they'd land in an already-committed session and never be + # extracted). Skip the commit; leave _turn_count untouched so + # the session stays marked dirty for any retry path. + logger.warning( + "OpenViking sync worker still alive after 10s join — " + "skipping commit on session %s to avoid orphaning late writes", + self._session_id, + ) + return if self._turn_count == 0: return @@ -649,16 +674,29 @@ class OpenVikingMemoryProvider(MemoryProvider): # below always target the session whose writes we want to flush. old_session_id = self._session_id old_turn_count = self._turn_count + sync_worker_drained = True # 1. Wait for any in-flight sync_turn to finish writing under the # OLD session id — otherwise it races the commit below. if self._sync_thread and self._sync_thread.is_alive(): self._sync_thread.join(timeout=10.0) + if self._sync_thread.is_alive(): + # Same hazard as on_session_end: worker outlived the bounded + # join. Skip the commit so its late writes aren't orphaned + # past a commit boundary they can't recover from. + sync_worker_drained = False + logger.warning( + "OpenViking sync worker still alive after 10s join — " + "skipping commit-on-switch for session %s; late writes " + "will remain in the uncommitted old session", + old_session_id, + ) # 2. Commit the old session if it accumulated turns — same # extraction semantics as on_session_end. Skip if empty (nothing - # to extract) or if the provider was never initialized. - if old_session_id and old_turn_count > 0: + # to extract), if the provider was never initialized, or if the + # sync worker is still mid-flight. + if sync_worker_drained and old_session_id and old_turn_count > 0: try: self._client.post(f"/api/v1/sessions/{old_session_id}/commit") logger.info( @@ -671,8 +709,11 @@ class OpenVikingMemoryProvider(MemoryProvider): old_session_id, e, ) - # 3. Drain in-flight prefetch from the old session and drop its - # cached result so the new session doesn't see stale recall. + # 3. Bump the prefetch generation so any in-flight prefetch worker + # finishing AFTER this point drops its result. Then drain the + # current worker and clear the cached result so the new session + # doesn't see stale recall from the old one. + self._prefetch_generation += 1 if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=3.0) with self._prefetch_lock: @@ -691,6 +732,14 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client or action != "add" or not content: return + # Snapshot the target session id at call time — see sync_turn() for + # the rationale. A delayed worker that reads self._session_id after + # on_session_switch has rotated it would land the memory note in the + # NEW session. + sid = str(self._session_id or "").strip() + if not sid: + return + def _write(): try: client = _VikingClient( @@ -699,7 +748,7 @@ class OpenVikingMemoryProvider(MemoryProvider): ) # Add as a user message with memory context so the commit # picks it up as an explicit memory during extraction - client.post(f"/api/v1/sessions/{self._session_id}/messages", { + client.post(f"/api/v1/sessions/{sid}/messages", { "role": "user", "parts": [ {"type": "text", "text": f"[Memory note — {target}] {content}"}, diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 78c672851d7..7bbecee1e83 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -472,11 +472,16 @@ def test_on_session_switch_waits_for_inflight_sync_thread(): join_calls = [] class FakeThread: + def __init__(self): + self._alive = True + def is_alive(self): - return True + return self._alive def join(self, timeout=None): join_calls.append(timeout) + # Simulate a worker that finishes within the join window. + self._alive = False provider._sync_thread = FakeThread() @@ -629,3 +634,162 @@ def test_on_session_switch_swallows_commit_failure(): assert provider._session_id == "new-sid" assert provider._turn_count == 0 + + +# --------------------------------------------------------------------------- +# Hung-writer protection: the sync worker can outlive the bounded join +# because each OpenViking POST has _TIMEOUT=30s and there are two per turn. +# Committing while late writes are still in flight would orphan them past +# the commit boundary — they would never be extracted. +# --------------------------------------------------------------------------- + +class _HungThread: + """Thread stand-in that stays alive across joins.""" + + def is_alive(self): + return True + + def join(self, timeout=None): + # Pretend the join timed out — worker still running. + return None + + +def test_on_session_end_skips_commit_when_sync_worker_outlives_join(): + """If the sync worker is still alive after the 10s join, the commit must + be skipped — late writes from the worker would otherwise land in an + already-committed session and never be extracted. Leave _turn_count + intact so the session stays marked dirty.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + provider._sync_thread = _HungThread() + + provider.on_session_end([]) + + provider._client.post.assert_not_called() + assert provider._turn_count == 3 + + +def test_on_session_switch_skips_commit_when_sync_worker_outlives_join(): + """Same hazard on the switch path. Rotation must still proceed (the new + session needs to start) but the old-session commit is skipped to avoid + orphaning the worker's late writes past commit.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._sync_thread = _HungThread() + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +# --------------------------------------------------------------------------- +# on_memory_write: same late-capture hazard as sync_turn — worker must use +# the session id snapshotted at call time, not re-read self._session_id. +# Block inside the stub ctor (BEFORE the f-string for the post path is +# evaluated) so the rotation deterministically beats the f-string. +# --------------------------------------------------------------------------- + +def test_on_memory_write_captures_session_id_at_call_time(): + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + in_ctor = threading.Event() + release = threading.Event() + done = threading.Event() + captured_paths = [] + + class StubClient: + def __init__(self, *a, **kw): + in_ctor.set() + release.wait(timeout=2.0) + + def post(self, path, payload=None, **kwargs): + captured_paths.append(path) + done.set() + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.on_memory_write("add", "viking://memories/x", "remember this") + assert in_ctor.wait(timeout=2.0), "worker never entered ctor" + # Rotate provider's session id while the worker is parked in the ctor, + # BEFORE it evaluates the f-string for the post path. If the worker + # reads self._session_id inside the closure, it will now see "new-sid". + provider._session_id = "new-sid" + release.set() + assert done.wait(timeout=2.0), "worker never reached post()" + finally: + _mod._VikingClient = real_client_cls + + # The write must target the OLD session id captured at call time. + assert captured_paths == ["/api/v1/sessions/old-sid/messages"] + + +# --------------------------------------------------------------------------- +# Prefetch staleness: a prefetch worker that finishes AFTER a session switch +# must drop its result instead of repopulating the new session with stale +# recall from the old generation. Bump the generation directly (rather than +# calling on_session_switch, whose own join blocks on the test worker) so +# the test isolates the generation-gating behavior. +# --------------------------------------------------------------------------- + +def test_queue_prefetch_drops_result_when_generation_changed_mid_flight(): + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + started = threading.Event() + release = threading.Event() + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + return { + "result": { + "memories": [ + {"uri": "viking://memories/old", "score": 0.9, + "abstract": "stale from old session"}, + ], + "resources": [], + } + } + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.queue_prefetch("anything") + assert started.wait(timeout=2.0), "prefetch worker never entered post()" + # Simulate a session switch by bumping the generation directly. + # The worker captured the pre-bump generation when it was spawned. + provider._prefetch_generation += 1 + release.set() + if provider._prefetch_thread: + provider._prefetch_thread.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # The stale result from the pre-bump generation must NOT have been written + # into the new generation's prefetch slot. + assert provider._prefetch_result == ""