fix(openviking): close remaining session-boundary races on switch

Three follow-ups from review on #28296:

1. Sync worker outliving the bounded join. Each sync_turn POST has
   _TIMEOUT=30s and there are two per turn, but on_session_end and
   on_session_switch only join for 10s. If the worker is still alive
   after the join, committing the old session orphans the worker's
   late writes past the commit boundary — they land in an already-
   committed session and never get extracted. Both hooks now re-check
   is_alive() after the join and skip the commit when the worker
   hasn't drained.

2. on_memory_write late session_id capture. Same shape as the
   pre-fix sync_turn: f-string for the post path read self._session_id
   inside the worker, so a switch between thread spawn and post call
   landed the memory note in the new session. Snapshot sid at call
   time, same pattern as sync_turn.

3. Stale prefetch repopulating the new session. The pre-switch
   drain+clear only protects against workers that finish before the
   join completes; one finishing after the clear would write its
   result into the new generation's slot. Added a monotonic
   _prefetch_generation; workers capture it at spawn and refuse to
   write if it has advanced.

Tests: existing in-flight-sync test updated to drain (it tested the
join-before-commit happy path); four new tests cover hung-writer skip
on end + switch, on_memory_write sid capture, and prefetch generation
gating. 177/177 memory tests pass.

(cherry picked from commit 3791a87dbe)
This commit is contained in:
harshitAgr 2026-05-21 07:21:18 +03:00 committed by Hao Zhe
parent a30b40c73a
commit eddbf291a4
2 changed files with 212 additions and 2 deletions

View file

@ -436,6 +436,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# Monotonic counter incremented on every session switch. Prefetch
# workers capture the value when spawned and refuse to write their
# result if the generation has advanced — otherwise a slow worker
# from session N can repopulate session N+1 with stale recall.
self._prefetch_generation = 0
@property
def name(self) -> str:
@ -549,6 +554,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client or not query:
return
# Snapshot the generation at spawn time. If on_session_switch bumps it
# before this worker finishes, the worker drops its result instead of
# repopulating the new session with stale recall from the old one.
gen = self._prefetch_generation
def _run():
try:
client = _VikingClient(
@ -571,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
parts.append(f"- [{score:.2f}] {abstract} ({uri})")
if parts:
with self._prefetch_lock:
if gen != self._prefetch_generation:
return
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
@ -642,6 +654,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
# the count hasn't been incremented yet.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Worker outlived the bounded join — each POST has _TIMEOUT=30s
# and there are two of them per turn. Committing now would
# orphan the worker's late writes past the commit boundary
# (they'd land in an already-committed session and never be
# extracted). Skip the commit; leave _turn_count untouched so
# the session stays marked dirty for any retry path.
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit on session %s to avoid orphaning late writes",
self._session_id,
)
return
if self._turn_count == 0:
return
@ -685,11 +710,27 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id = self._session_id
old_turn_count = self._turn_count
sync_worker_drained = True
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Same hazard as on_session_end: worker outlived the bounded
# join. Skip the commit so its late writes aren't orphaned
# past a commit boundary they can't recover from.
sync_worker_drained = False
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit-on-switch for session %s; late writes "
"will remain in the uncommitted old session",
old_session_id,
)
if old_session_id and old_turn_count > 0:
# 2. Commit the old session if it accumulated turns — same
# extraction semantics as on_session_end. Skip if empty (nothing
# to extract), if the provider was never initialized, or if the
# sync worker is still mid-flight.
if sync_worker_drained and old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
@ -702,6 +743,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id, e,
)
# 3. Bump the prefetch generation so any in-flight prefetch worker
# finishing AFTER this point drops its result. Then drain the
# current worker and clear the cached result so the new session
# doesn't see stale recall from the old one.
self._prefetch_generation += 1
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:

View file

@ -472,11 +472,16 @@ def test_on_session_switch_waits_for_inflight_sync_thread():
join_calls = []
class FakeThread:
def __init__(self):
self._alive = True
def is_alive(self):
return True
return self._alive
def join(self, timeout=None):
join_calls.append(timeout)
# Simulate a worker that finishes within the join window.
self._alive = False
provider._sync_thread = FakeThread()
@ -629,3 +634,162 @@ def test_on_session_switch_swallows_commit_failure():
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Hung-writer protection: the sync worker can outlive the bounded join
# because each OpenViking POST has _TIMEOUT=30s and there are two per turn.
# Committing while late writes are still in flight would orphan them past
# the commit boundary — they would never be extracted.
# ---------------------------------------------------------------------------
class _HungThread:
"""Thread stand-in that stays alive across joins."""
def is_alive(self):
return True
def join(self, timeout=None):
# Pretend the join timed out — worker still running.
return None
def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
"""If the sync worker is still alive after the 10s join, the commit must
be skipped late writes from the worker would otherwise land in an
already-committed session and never be extracted. Leave _turn_count
intact so the session stays marked dirty."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._sync_thread = _HungThread()
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 3
def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
"""Same hazard on the switch path. Rotation must still proceed (the new
session needs to start) but the old-session commit is skipped to avoid
orphaning the worker's late writes past commit."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._sync_thread = _HungThread()
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# on_memory_write: same late-capture hazard as sync_turn — worker must use
# the session id snapshotted at call time, not re-read self._session_id.
# Block inside the stub ctor (BEFORE the f-string for the post path is
# evaluated) so the rotation deterministically beats the f-string.
# ---------------------------------------------------------------------------
def test_on_memory_write_captures_session_id_at_call_time():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
in_ctor = threading.Event()
release = threading.Event()
done = threading.Event()
captured_paths = []
class StubClient:
def __init__(self, *a, **kw):
in_ctor.set()
release.wait(timeout=2.0)
def post(self, path, payload=None, **kwargs):
captured_paths.append(path)
done.set()
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.on_memory_write("add", "viking://memories/x", "remember this")
assert in_ctor.wait(timeout=2.0), "worker never entered ctor"
# Rotate provider's session id while the worker is parked in the ctor,
# BEFORE it evaluates the f-string for the post path. If the worker
# reads self._session_id inside the closure, it will now see "new-sid".
provider._session_id = "new-sid"
release.set()
assert done.wait(timeout=2.0), "worker never reached post()"
finally:
_mod._VikingClient = real_client_cls
# The write must target the OLD session id captured at call time.
assert captured_paths == ["/api/v1/sessions/old-sid/messages"]
# ---------------------------------------------------------------------------
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
# must drop its result instead of repopulating the new session with stale
# recall from the old generation. Bump the generation directly (rather than
# calling on_session_switch, whose own join blocks on the test worker) so
# the test isolates the generation-gating behavior.
# ---------------------------------------------------------------------------
def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
started = threading.Event()
release = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {
"result": {
"memories": [
{"uri": "viking://memories/old", "score": 0.9,
"abstract": "stale from old session"},
],
"resources": [],
}
}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
assert started.wait(timeout=2.0), "prefetch worker never entered post()"
# Simulate a session switch by bumping the generation directly.
# The worker captured the pre-bump generation when it was spawned.
provider._prefetch_generation += 1
release.set()
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# The stale result from the pre-bump generation must NOT have been written
# into the new generation's prefetch slot.
assert provider._prefetch_result == ""