fix(openviking): close remaining session-boundary races on switch

Three follow-ups from review on #28296:

1. Sync worker outliving the bounded join. Each sync_turn POST has
   _TIMEOUT=30s and there are two per turn, but on_session_end and
   on_session_switch only join for 10s. If the worker is still alive
   after the join, committing the old session orphans the worker's
   late writes past the commit boundary — they land in an already-
   committed session and never get extracted. Both hooks now re-check
   is_alive() after the join and skip the commit when the worker
   hasn't drained.

2. on_memory_write late session_id capture. Same shape as the
   pre-fix sync_turn: f-string for the post path read self._session_id
   inside the worker, so a switch between thread spawn and post call
   landed the memory note in the new session. Snapshot sid at call
   time, same pattern as sync_turn.

3. Stale prefetch repopulating the new session. The pre-switch
   drain+clear only protects against workers that finish before the
   join completes; one finishing after the clear would write its
   result into the new generation's slot. Added a monotonic
   _prefetch_generation; workers capture it at spawn and refuse to
   write if it has advanced.

Tests: existing in-flight-sync test updated to drain (it tested the
join-before-commit happy path); four new tests cover hung-writer skip
on end + switch, on_memory_write sid capture, and prefetch generation
gating. 177/177 memory tests pass.

(cherry picked from commit 3791a87dbe)
This commit is contained in:
harshitAgr 2026-05-21 07:21:18 +03:00 committed by Hao Zhe
parent a30b40c73a
commit eddbf291a4
2 changed files with 212 additions and 2 deletions

View file

@ -436,6 +436,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# Monotonic counter incremented on every session switch. Prefetch
# workers capture the value when spawned and refuse to write their
# result if the generation has advanced — otherwise a slow worker
# from session N can repopulate session N+1 with stale recall.
self._prefetch_generation = 0
@property
def name(self) -> str:
@ -549,6 +554,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client or not query:
return
# Snapshot the generation at spawn time. If on_session_switch bumps it
# before this worker finishes, the worker drops its result instead of
# repopulating the new session with stale recall from the old one.
gen = self._prefetch_generation
def _run():
try:
client = _VikingClient(
@ -571,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
parts.append(f"- [{score:.2f}] {abstract} ({uri})")
if parts:
with self._prefetch_lock:
if gen != self._prefetch_generation:
return
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
@ -642,6 +654,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
# the count hasn't been incremented yet.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Worker outlived the bounded join — each POST has _TIMEOUT=30s
# and there are two of them per turn. Committing now would
# orphan the worker's late writes past the commit boundary
# (they'd land in an already-committed session and never be
# extracted). Skip the commit; leave _turn_count untouched so
# the session stays marked dirty for any retry path.
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit on session %s to avoid orphaning late writes",
self._session_id,
)
return
if self._turn_count == 0:
return
@ -685,11 +710,27 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id = self._session_id
old_turn_count = self._turn_count
sync_worker_drained = True
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Same hazard as on_session_end: worker outlived the bounded
# join. Skip the commit so its late writes aren't orphaned
# past a commit boundary they can't recover from.
sync_worker_drained = False
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit-on-switch for session %s; late writes "
"will remain in the uncommitted old session",
old_session_id,
)
if old_session_id and old_turn_count > 0:
# 2. Commit the old session if it accumulated turns — same
# extraction semantics as on_session_end. Skip if empty (nothing
# to extract), if the provider was never initialized, or if the
# sync worker is still mid-flight.
if sync_worker_drained and old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
@ -702,6 +743,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id, e,
)
# 3. Bump the prefetch generation so any in-flight prefetch worker
# finishing AFTER this point drops its result. Then drain the
# current worker and clear the cached result so the new session
# doesn't see stale recall from the old one.
self._prefetch_generation += 1
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock: