fix(openviking): track writers per-session so commit waits for all

sync_turn's bounded join could drop a still-alive previous worker by
replacing the single _sync_thread slot. The dropped worker kept POSTing
under the old sid but was no longer visible to on_session_end /
on_session_switch, so the commit could fire while orphaned writes were
still in flight — those writes landed past the commit boundary and were
never extracted.

Replace the single _sync_thread slot with _inflight_writers:
Dict[sid, Set[Thread]]. Writers self-register on spawn (sync_turn,
on_memory_write) and self-deregister on exit. The commit path drains
_drain_writers(sid, 10.0) and skips the commit if any writer for that
sid is still alive after the bounded budget.

Also trim inline review-rationale comments to short invariants per
reviewer style ask: "commit only after session writes drain" and
"drop prefetch results from older switch generations."

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
harshitAgr 2026-05-21 11:51:54 +03:00
parent 3791a87dbe
commit 7537ee6f5b
2 changed files with 170 additions and 92 deletions

View file

@ -31,10 +31,11 @@ import mimetypes
import os
import tempfile
import threading
import time
import uuid
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set
from urllib.parse import urlparse
from urllib.request import url2pathname
@ -399,14 +400,16 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._api_key = ""
self._session_id = ""
self._turn_count = 0
self._sync_thread: Optional[threading.Thread] = None
# Commit only after session writes drain. The set is keyed by the sid
# the writer is POSTing under (snapshotted at spawn), so on_session_end
# / on_session_switch see every still-alive writer for that sid even
# if later writes have replaced the latest-tracked thread.
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
self._inflight_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# Monotonic counter incremented on every session switch. Prefetch
# workers capture the value when spawned and refuse to write their
# result if the generation has advanced — otherwise a slow worker
# from session N can repopulate session N+1 with stale recall.
# Drop prefetch results from older switch generations.
self._prefetch_generation = 0
@property
@ -520,9 +523,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client or not query:
return
# Snapshot the generation at spawn time. If on_session_switch bumps it
# before this worker finishes, the worker drops its result instead of
# repopulating the new session with stale recall from the old one.
# Drop prefetch results from older switch generations.
gen = self._prefetch_generation
def _run():
@ -558,15 +559,59 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
self._prefetch_thread.start()
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
"""Spawn a daemon writer tracked in _inflight_writers[sid].
Tracking is keyed by sid (not by a single latest-thread slot) so that
on_session_end / on_session_switch can drain every still-alive writer
for the session being committed.
"""
holder: List[threading.Thread] = []
def _wrapped():
try:
target()
finally:
with self._inflight_lock:
workers = self._inflight_writers.get(sid)
if workers is not None:
workers.discard(holder[0])
if not workers:
self._inflight_writers.pop(sid, None)
thread = threading.Thread(target=_wrapped, daemon=True, name=name)
holder.append(thread)
with self._inflight_lock:
self._inflight_writers.setdefault(sid, set()).add(thread)
thread.start()
def _drain_writers(self, sid: str, timeout: float) -> bool:
"""Join every in-flight writer for sid within a shared timeout budget.
Returns True if all writers drained, False if any are still alive when
the budget runs out. Callers use the False return to skip the commit.
"""
deadline = time.monotonic() + timeout
while True:
with self._inflight_lock:
workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
t.join(timeout=slice_left)
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
return
# Capture the target session id NOW, not inside the worker. Otherwise
# a delayed worker can read self._session_id after on_session_switch
# has rotated it (the switch's join on _sync_thread is bounded), and
# the OLD turn's content lands in the NEW session.
# Snapshot the sid so a delayed worker can't write into a rotated session.
sid = str(session_id or self._session_id).strip()
if not sid:
return
@ -579,13 +624,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
)
# Add user message
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "user",
"content": user_content[:4000], # trim very long messages
"content": user_content[:4000],
})
# Add assistant message
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "assistant",
"content": assistant_content[:4000],
@ -593,14 +635,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
except Exception as e:
logger.debug("OpenViking sync_turn failed: %s", e)
# Wait for any previous sync to finish before starting a new one
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(
target=_sync, daemon=True, name="openviking-sync"
)
self._sync_thread.start()
self._spawn_writer(sid, _sync, name="openviking-sync")
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
"""Commit the session to trigger memory extraction.
@ -611,35 +646,22 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client:
return
# Wait for any pending sync to finish first — do this before the
# turn_count check so the last turn's messages are flushed even if
# the count hasn't been incremented yet.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Worker outlived the bounded join — each POST has _TIMEOUT=30s
# and there are two of them per turn. Committing now would
# orphan the worker's late writes past the commit boundary
# (they'd land in an already-committed session and never be
# extracted). Skip the commit; leave _turn_count untouched so
# the session stays marked dirty for any retry path.
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit on session %s to avoid orphaning late writes",
self._session_id,
)
return
sid = self._session_id
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=10.0):
logger.warning(
"OpenViking writer for %s still alive after drain — skipping commit",
sid,
)
return
if self._turn_count == 0:
return
try:
self._client.post(f"/api/v1/sessions/{self._session_id}/commit")
logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count)
# Mark the session clean so a subsequent on_session_switch (fired
# by /new and compression right after commit_memory_session) skips
# its commit instead of double-committing. On commit failure we
# leave the count intact so the switch hook gets a retry.
self._client.post(f"/api/v1/sessions/{sid}/commit")
logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count)
# Mark clean so a follow-up on_session_switch skips its own commit.
self._turn_count = 0
except Exception as e:
logger.warning("OpenViking session commit failed: %s", e)
@ -670,33 +692,21 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not new_id or not self._client:
return
# Snapshot the old session id BEFORE rotation so the join+commit
# below always target the session whose writes we want to flush.
old_session_id = self._session_id
old_turn_count = self._turn_count
sync_worker_drained = True
# 1. Wait for any in-flight sync_turn to finish writing under the
# OLD session id — otherwise it races the commit below.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
if self._sync_thread.is_alive():
# Same hazard as on_session_end: worker outlived the bounded
# join. Skip the commit so its late writes aren't orphaned
# past a commit boundary they can't recover from.
sync_worker_drained = False
# Commit only after session writes drain.
writers_drained = True
if old_session_id:
writers_drained = self._drain_writers(old_session_id, timeout=10.0)
if not writers_drained:
logger.warning(
"OpenViking sync worker still alive after 10s join — "
"skipping commit-on-switch for session %s; late writes "
"will remain in the uncommitted old session",
"OpenViking writer for %s still alive after drain — "
"skipping commit-on-switch",
old_session_id,
)
# 2. Commit the old session if it accumulated turns — same
# extraction semantics as on_session_end. Skip if empty (nothing
# to extract), if the provider was never initialized, or if the
# sync worker is still mid-flight.
if sync_worker_drained and old_session_id and old_turn_count > 0:
if writers_drained and old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
@ -709,17 +719,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id, e,
)
# 3. Bump the prefetch generation so any in-flight prefetch worker
# finishing AFTER this point drops its result. Then drain the
# current worker and clear the cached result so the new session
# doesn't see stale recall from the old one.
# Drop prefetch results from older switch generations.
self._prefetch_generation += 1
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
# 4. Rotate to the new session.
self._session_id = new_id
self._turn_count = 0
logger.debug(
@ -732,10 +738,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client or action != "add" or not content:
return
# Snapshot the target session id at call time — see sync_turn() for
# the rationale. A delayed worker that reads self._session_id after
# on_session_switch has rotated it would land the memory note in the
# NEW session.
# Snapshot the sid so a delayed worker can't write into a rotated session.
sid = str(self._session_id or "").strip()
if not sid:
return
@ -746,8 +749,6 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
)
# Add as a user message with memory context so the commit
# picks it up as an explicit memory during extraction
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "user",
"parts": [
@ -757,8 +758,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
except Exception as e:
logger.debug("OpenViking memory mirror failed: %s", e)
t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite")
t.start()
self._spawn_writer(sid, _write, name="openviking-memwrite")
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [SEARCH_SCHEMA, READ_SCHEMA, BROWSE_SCHEMA, REMEMBER_SCHEMA, ADD_RESOURCE_SCHEMA]
@ -783,11 +783,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
return tool_error(str(e))
def shutdown(self) -> None:
# Wait for background threads to finish
for t in (self._sync_thread, self._prefetch_thread):
if t and t.is_alive():
# Wait for every in-flight writer across all tracked sessions.
with self._inflight_lock:
all_workers = [
t for workers in self._inflight_writers.values() for t in workers
]
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self:
_last_active_provider = None

View file

@ -483,7 +483,7 @@ def test_on_session_switch_waits_for_inflight_sync_thread():
# Simulate a worker that finishes within the join window.
self._alive = False
provider._sync_thread = FakeThread()
provider._inflight_writers["old-sid"] = {FakeThread()}
provider.on_session_switch("new-sid")
@ -561,8 +561,8 @@ def test_sync_turn_captures_session_id_before_worker_runs():
# Rotate the provider's session id while the worker is mid-flight.
provider._session_id = "new-sid"
release.set()
if provider._sync_thread:
provider._sync_thread.join(timeout=2.0)
for t in list(provider._inflight_writers.get("old-sid", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
@ -582,7 +582,7 @@ def test_sync_turn_noop_when_session_id_blank():
# No turn counted, no worker spawned.
assert provider._turn_count == 0
assert provider._sync_thread is None
assert provider._inflight_writers == {}
def test_on_session_end_marks_session_clean_after_successful_commit():
@ -660,7 +660,7 @@ def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
already-committed session and never be extracted. Leave _turn_count
intact so the session stays marked dirty."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._sync_thread = _HungThread()
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
@ -673,7 +673,7 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
session needs to start) but the old-session commit is skipped to avoid
orphaning the worker's late writes past commit."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._sync_thread = _HungThread()
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
@ -682,6 +682,78 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Orphaned-writer hazard: commit must wait for ALL writers for the session,
# not just the latest tracked one. sync_turn's bounded rate-limit can drop a
# still-alive previous worker — that dropped writer keeps POSTing under the
# old sid and would otherwise land its writes past the commit boundary.
# ---------------------------------------------------------------------------
def test_on_session_end_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 2
def test_on_session_switch_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_sync_turn_tracks_writer_under_session_id():
"""Every sync_turn writer must register under its captured sid so the
drain at end/switch sees it even if a later sync_turn replaces the
latest-tracked reference."""
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
release = threading.Event()
started = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert started.wait(timeout=2.0), "worker never entered post()"
assert len(provider._inflight_writers.get("sid-1", set())) == 1
release.set()
for t in list(provider._inflight_writers.get("sid-1", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# Worker should have removed itself from the inflight set on exit.
assert provider._inflight_writers.get("sid-1", set()) == set()
# ---------------------------------------------------------------------------
# on_memory_write: same late-capture hazard as sync_turn — worker must use
# the session id snapshotted at call time, not re-read self._session_id.