fix(openviking): don't block the command thread on session switch; lock turn state

Follow-up hardening on @ehz0ah / @harshitAgr's session-switch work (#28296):

- on_session_switch no longer runs the old-session writer-drain + pending-token
  GET + commit POST inline on the caller's command thread. /new, /branch,
  /resume, /undo call it synchronously, so a slow drain (up to 10s) or wedged
  commit blocked the user-facing command — the same hazard #41945 fixed for
  end-of-turn sync. State now rotates synchronously (cheap) and the old-session
  commit is offloaded to a daemon finalizer (generalized _finalize_session_async).
- Guard the (_session_id, _turn_count) pair with _session_state_lock: sync_turn
  runs on the memory-manager executor thread while the session hooks run on the
  command thread, so the snapshot+reset vs increment was a cross-thread race.
- _session_needs_commit checks the committed-session guard BEFORE the
  turn_count>0 shortcut, closing a double-commit window when a racing sync_turn
  re-increments after commit+reset.
- Add a _shutting_down flag so deferred finalizers stop POSTing against a
  torn-down client; track all prefetch threads in a set so invalidate/shutdown
  join every one, not just the latest slot.

Tests: regression for the non-blocking switch (asserts the caller returns while
a slow drain is parked off-thread) and the committed-guard ordering; updated the
deferred-commit test to the unified finalizer contract.
This commit is contained in:
kshitijk4poor 2026-06-18 00:21:21 +05:30
parent 0c1e8d0ba9
commit c835448908
2 changed files with 210 additions and 55 deletions

View file

@ -435,6 +435,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._api_key = ""
self._session_id = ""
self._turn_count = 0
# Guards the (_session_id, _turn_count) pair. sync_turn runs on the
# MemoryManager's background sync executor while on_session_end /
# on_session_switch run on the caller's thread, so the snapshot+reset
# of the turn counter and the session-id rotation must be atomic
# against a concurrent increment. See hermes-agent#28296 review.
self._session_state_lock = threading.Lock()
# Commit only after session writes drain. The set is keyed by the sid
# the writer is POSTing under (snapshotted at spawn), so on_session_end
# / on_session_switch see every still-alive writer for that sid even
@ -449,6 +455,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
# shutdown() can drain them and rapid re-queues don't orphan a still-
# running thread by overwriting the single _prefetch_thread slot.
self._prefetch_threads: Set[threading.Thread] = set()
# Set on shutdown so deferred-commit / writer finalizers stop issuing
# network writes against a torn-down provider.
self._shutting_down = False
# Drop prefetch results from older switch generations.
self._prefetch_generation = 0
@ -568,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
gen = self._prefetch_generation
holder: List[threading.Thread] = []
def _run():
try:
client = _VikingClient(
@ -595,11 +610,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
finally:
with self._prefetch_lock:
if holder:
self._prefetch_threads.discard(holder[0])
self._prefetch_thread = threading.Thread(
thread = threading.Thread(
target=_run, daemon=True, name="openviking-prefetch"
)
self._prefetch_thread.start()
holder.append(thread)
with self._prefetch_lock:
self._prefetch_thread = thread
self._prefetch_threads.add(thread)
thread.start()
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
"""Spawn a daemon writer tracked in _inflight_writers[sid].
@ -627,6 +650,30 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._inflight_writers.setdefault(sid, set()).add(thread)
thread.start()
def _drain_finalizers(self, timeout: float) -> bool:
"""Join every in-flight async session finalizer within a timeout.
The switch-path commit runs on a daemon finalizer thread so it never
blocks the caller's command thread; this lets shutdown and tests wait
for those commits deterministically. Returns True if all drained.
"""
deadline = time.monotonic() + timeout
while True:
with self._deferred_commit_lock:
workers = [t for t in self._deferred_commit_threads if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
# Floor the per-join wait so a thread whose join() returns
# instantly while still reporting alive can't hot-spin this loop.
t.join(timeout=min(slice_left, 0.05))
def _drain_writers(self, sid: str, timeout: float) -> bool:
"""Join every in-flight writer for sid within a shared timeout budget.
@ -707,10 +754,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._committed_session_ids.add(sid)
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
if turn_count > 0:
return True
# Already-committed sessions never need a second commit, regardless of
# the turn counter — a racing sync_turn can re-increment _turn_count
# after a commit+reset, so the committed-guard must win over turn_count.
if self._has_committed_session(sid):
return False
if turn_count > 0:
return True
return self._session_has_pending_tokens(sid)
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
@ -723,11 +773,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None:
if not sid or turn_count <= 0:
def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None:
"""Drain the old session's writers and commit it on a daemon thread.
Used by on_session_switch (and the deferred-commit fallback) so the
potentially-multi-second drain + pending-token GET + commit POST never
runs on the caller's command thread. Deduped by sid so a rapid second
switch can't stack two finalizers for the same session, and a no-op
once shutdown has begun so we don't POST against a torn-down client.
"""
if not sid:
return
with self._deferred_commit_lock:
if sid in self._deferred_commit_sids:
if self._shutting_down or sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
@ -735,14 +793,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _finalize() -> None:
try:
if self._shutting_down:
return
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after deferred drain — "
"OpenViking writer for %s still alive after drain — "
"leaving session uncommitted",
sid,
)
return
self._commit_session(sid, turn_count, context="after deferred drain")
if self._shutting_down:
return
if self._session_needs_commit(sid, turn_count):
self._commit_session(sid, turn_count, context=context)
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
@ -765,8 +828,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
self._prefetch_generation += 1
self._prefetch_result = ""
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
# Join EVERY tracked prefetch thread, not just the latest slot — a
# rapid re-queue can leave an older thread for the abandoned session
# still running (consistent with shutdown()).
workers = [t for t in self._prefetch_threads if t.is_alive()]
for t in workers:
t.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
@ -779,12 +846,15 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not user_content:
return
# Snapshot the sid so a delayed worker can't write into a rotated session.
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
# the turn is unambiguously attributed to the session it targets.
with self._session_state_lock:
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
def _sync():
try:
@ -819,7 +889,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client:
return
sid = self._session_id
# Snapshot sid + turn count atomically against a concurrent sync_turn
# increment. on_session_end runs at teardown so the drain+commit stays
# synchronous here (we want it to land before the process exits), but
# the counter read must still be consistent.
with self._session_state_lock:
sid = self._session_id
turn_count = self._turn_count
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
@ -828,12 +905,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
return
if not self._session_needs_commit(sid, self._turn_count):
if not self._session_needs_commit(sid, turn_count):
return
if self._commit_session(sid, self._turn_count, context="on session end"):
if self._commit_session(sid, turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
self._turn_count = 0
with self._session_state_lock:
if self._session_id == sid:
self._turn_count = 0
def on_session_switch(
self,
@ -862,10 +941,31 @@ class OpenVikingMemoryProvider(MemoryProvider):
return
rewound = bool(kwargs.get("rewound"))
old_session_id = self._session_id
old_turn_count = self._turn_count
if rewound or new_id == old_session_id:
self._invalidate_prefetch_state()
# Rotate cached session state synchronously (cheap, in-memory) and
# snapshot the old session under the lock so a concurrent sync_turn
# either lands fully before the rotation (counted under old) or fully
# after (counted under new) — never split. The OLD session's commit
# (drain + pending-token GET + commit POST, potentially many seconds)
# is then offloaded so /new, /branch, /resume, /undo never block the
# caller's command thread (cf. the end-of-turn-sync offload in #41945).
with self._session_state_lock:
old_session_id = self._session_id
old_turn_count = self._turn_count
rotate = not (rewound or new_id == old_session_id)
if rotate:
self._session_id = new_id
self._turn_count = 0
# Invalidate stale prefetch OUTSIDE the session lock — it takes its own
# _prefetch_lock and may join a prefetch thread for up to 3s, which we
# must not do while holding the session lock (would block sync_turn and
# risk lock-ordering coupling).
self._invalidate_prefetch_state()
if not rotate:
# Same-session rewind (/undo) or no-op rotation: no commit, no
# counter reset — just the prefetch invalidation above.
logger.debug(
"OpenViking on_session_switch invalidated state without rotation: "
"session=%s rewound=%s",
@ -873,30 +973,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
return
# Commit only after session writes drain.
writers_drained = True
# Drain + commit the OLD session off the command thread.
if old_session_id:
writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT)
if not writers_drained:
logger.warning(
"OpenViking writer for %s still alive after drain — "
"skipping commit-on-switch",
old_session_id,
)
self._finalize_session_async(old_session_id, old_turn_count, context="on switch")
if (
writers_drained
and old_session_id
and self._session_needs_commit(old_session_id, old_turn_count)
):
self._commit_session(old_session_id, old_turn_count, context="on switch")
elif not writers_drained:
self._schedule_deferred_commit(old_session_id, old_turn_count)
self._invalidate_prefetch_state()
self._session_id = new_id
self._turn_count = 0
logger.debug(
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
old_session_id, new_id, parent_session_id, reset,
@ -961,6 +1041,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
return tool_error(str(e))
def shutdown(self) -> None:
# Stop deferred finalizers from issuing new commits against a
# torn-down client, then drain everything still in flight.
self._shutting_down = True
# Wait for every in-flight writer across all tracked sessions.
with self._inflight_lock:
all_workers = [
@ -968,14 +1051,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
with self._prefetch_lock:
prefetch_workers = list(self._prefetch_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
for t in prefetch_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self: