fix(openviking): don't block the command thread on session switch; lock turn state

Follow-up hardening on @ehz0ah / @harshitAgr's session-switch work (#28296):

- on_session_switch no longer runs the old-session writer-drain + pending-token
  GET + commit POST inline on the caller's command thread. /new, /branch,
  /resume, /undo call it synchronously, so a slow drain (up to 10s) or wedged
  commit blocked the user-facing command — the same hazard #41945 fixed for
  end-of-turn sync. State now rotates synchronously (cheap) and the old-session
  commit is offloaded to a daemon finalizer (generalized _finalize_session_async).
- Guard the (_session_id, _turn_count) pair with _session_state_lock: sync_turn
  runs on the memory-manager executor thread while the session hooks run on the
  command thread, so the snapshot+reset vs increment was a cross-thread race.
- _session_needs_commit checks the committed-session guard BEFORE the
  turn_count>0 shortcut, closing a double-commit window when a racing sync_turn
  re-increments after commit+reset.
- Add a _shutting_down flag so deferred finalizers stop POSTing against a
  torn-down client; track all prefetch threads in a set so invalidate/shutdown
  join every one, not just the latest slot.

Tests: regression for the non-blocking switch (asserts the caller returns while
a slow drain is parked off-thread) and the committed-guard ordering; updated the
deferred-commit test to the unified finalizer contract.
This commit is contained in:
kshitijk4poor 2026-06-18 00:21:21 +05:30
parent 0c1e8d0ba9
commit c835448908
2 changed files with 210 additions and 55 deletions

View file

@ -435,6 +435,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._api_key = ""
self._session_id = ""
self._turn_count = 0
# Guards the (_session_id, _turn_count) pair. sync_turn runs on the
# MemoryManager's background sync executor while on_session_end /
# on_session_switch run on the caller's thread, so the snapshot+reset
# of the turn counter and the session-id rotation must be atomic
# against a concurrent increment. See hermes-agent#28296 review.
self._session_state_lock = threading.Lock()
# Commit only after session writes drain. The set is keyed by the sid
# the writer is POSTing under (snapshotted at spawn), so on_session_end
# / on_session_switch see every still-alive writer for that sid even
@ -449,6 +455,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
# shutdown() can drain them and rapid re-queues don't orphan a still-
# running thread by overwriting the single _prefetch_thread slot.
self._prefetch_threads: Set[threading.Thread] = set()
# Set on shutdown so deferred-commit / writer finalizers stop issuing
# network writes against a torn-down provider.
self._shutting_down = False
# Drop prefetch results from older switch generations.
self._prefetch_generation = 0
@ -568,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
gen = self._prefetch_generation
holder: List[threading.Thread] = []
def _run():
try:
client = _VikingClient(
@ -595,11 +610,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
finally:
with self._prefetch_lock:
if holder:
self._prefetch_threads.discard(holder[0])
self._prefetch_thread = threading.Thread(
thread = threading.Thread(
target=_run, daemon=True, name="openviking-prefetch"
)
self._prefetch_thread.start()
holder.append(thread)
with self._prefetch_lock:
self._prefetch_thread = thread
self._prefetch_threads.add(thread)
thread.start()
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
"""Spawn a daemon writer tracked in _inflight_writers[sid].
@ -627,6 +650,30 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._inflight_writers.setdefault(sid, set()).add(thread)
thread.start()
def _drain_finalizers(self, timeout: float) -> bool:
"""Join every in-flight async session finalizer within a timeout.
The switch-path commit runs on a daemon finalizer thread so it never
blocks the caller's command thread; this lets shutdown and tests wait
for those commits deterministically. Returns True if all drained.
"""
deadline = time.monotonic() + timeout
while True:
with self._deferred_commit_lock:
workers = [t for t in self._deferred_commit_threads if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
# Floor the per-join wait so a thread whose join() returns
# instantly while still reporting alive can't hot-spin this loop.
t.join(timeout=min(slice_left, 0.05))
def _drain_writers(self, sid: str, timeout: float) -> bool:
"""Join every in-flight writer for sid within a shared timeout budget.
@ -707,10 +754,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._committed_session_ids.add(sid)
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
if turn_count > 0:
return True
# Already-committed sessions never need a second commit, regardless of
# the turn counter — a racing sync_turn can re-increment _turn_count
# after a commit+reset, so the committed-guard must win over turn_count.
if self._has_committed_session(sid):
return False
if turn_count > 0:
return True
return self._session_has_pending_tokens(sid)
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
@ -723,11 +773,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None:
if not sid or turn_count <= 0:
def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None:
"""Drain the old session's writers and commit it on a daemon thread.
Used by on_session_switch (and the deferred-commit fallback) so the
potentially-multi-second drain + pending-token GET + commit POST never
runs on the caller's command thread. Deduped by sid so a rapid second
switch can't stack two finalizers for the same session, and a no-op
once shutdown has begun so we don't POST against a torn-down client.
"""
if not sid:
return
with self._deferred_commit_lock:
if sid in self._deferred_commit_sids:
if self._shutting_down or sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
@ -735,14 +793,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _finalize() -> None:
try:
if self._shutting_down:
return
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after deferred drain — "
"OpenViking writer for %s still alive after drain — "
"leaving session uncommitted",
sid,
)
return
self._commit_session(sid, turn_count, context="after deferred drain")
if self._shutting_down:
return
if self._session_needs_commit(sid, turn_count):
self._commit_session(sid, turn_count, context=context)
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
@ -765,8 +828,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
with self._prefetch_lock:
self._prefetch_generation += 1
self._prefetch_result = ""
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
# Join EVERY tracked prefetch thread, not just the latest slot — a
# rapid re-queue can leave an older thread for the abandoned session
# still running (consistent with shutdown()).
workers = [t for t in self._prefetch_threads if t.is_alive()]
for t in workers:
t.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
@ -779,12 +846,15 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not user_content:
return
# Snapshot the sid so a delayed worker can't write into a rotated session.
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
# the turn is unambiguously attributed to the session it targets.
with self._session_state_lock:
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
def _sync():
try:
@ -819,7 +889,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client:
return
sid = self._session_id
# Snapshot sid + turn count atomically against a concurrent sync_turn
# increment. on_session_end runs at teardown so the drain+commit stays
# synchronous here (we want it to land before the process exits), but
# the counter read must still be consistent.
with self._session_state_lock:
sid = self._session_id
turn_count = self._turn_count
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
@ -828,12 +905,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
return
if not self._session_needs_commit(sid, self._turn_count):
if not self._session_needs_commit(sid, turn_count):
return
if self._commit_session(sid, self._turn_count, context="on session end"):
if self._commit_session(sid, turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
self._turn_count = 0
with self._session_state_lock:
if self._session_id == sid:
self._turn_count = 0
def on_session_switch(
self,
@ -862,10 +941,31 @@ class OpenVikingMemoryProvider(MemoryProvider):
return
rewound = bool(kwargs.get("rewound"))
old_session_id = self._session_id
old_turn_count = self._turn_count
if rewound or new_id == old_session_id:
self._invalidate_prefetch_state()
# Rotate cached session state synchronously (cheap, in-memory) and
# snapshot the old session under the lock so a concurrent sync_turn
# either lands fully before the rotation (counted under old) or fully
# after (counted under new) — never split. The OLD session's commit
# (drain + pending-token GET + commit POST, potentially many seconds)
# is then offloaded so /new, /branch, /resume, /undo never block the
# caller's command thread (cf. the end-of-turn-sync offload in #41945).
with self._session_state_lock:
old_session_id = self._session_id
old_turn_count = self._turn_count
rotate = not (rewound or new_id == old_session_id)
if rotate:
self._session_id = new_id
self._turn_count = 0
# Invalidate stale prefetch OUTSIDE the session lock — it takes its own
# _prefetch_lock and may join a prefetch thread for up to 3s, which we
# must not do while holding the session lock (would block sync_turn and
# risk lock-ordering coupling).
self._invalidate_prefetch_state()
if not rotate:
# Same-session rewind (/undo) or no-op rotation: no commit, no
# counter reset — just the prefetch invalidation above.
logger.debug(
"OpenViking on_session_switch invalidated state without rotation: "
"session=%s rewound=%s",
@ -873,30 +973,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
return
# Commit only after session writes drain.
writers_drained = True
# Drain + commit the OLD session off the command thread.
if old_session_id:
writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT)
if not writers_drained:
logger.warning(
"OpenViking writer for %s still alive after drain — "
"skipping commit-on-switch",
old_session_id,
)
self._finalize_session_async(old_session_id, old_turn_count, context="on switch")
if (
writers_drained
and old_session_id
and self._session_needs_commit(old_session_id, old_turn_count)
):
self._commit_session(old_session_id, old_turn_count, context="on switch")
elif not writers_drained:
self._schedule_deferred_commit(old_session_id, old_turn_count)
self._invalidate_prefetch_state()
self._session_id = new_id
self._turn_count = 0
logger.debug(
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
old_session_id, new_id, parent_session_id, reset,
@ -961,6 +1041,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
return tool_error(str(e))
def shutdown(self) -> None:
# Stop deferred finalizers from issuing new commits against a
# torn-down client, then drain everything still in flight.
self._shutting_down = True
# Wait for every in-flight writer across all tracked sessions.
with self._inflight_lock:
all_workers = [
@ -968,14 +1051,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
with self._prefetch_lock:
prefetch_workers = list(self._prefetch_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
for t in prefetch_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self:

View file

@ -5,7 +5,11 @@ from unittest.mock import MagicMock
import pytest
from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient
from plugins.memory.openviking import (
OpenVikingMemoryProvider,
_DEFERRED_COMMIT_TIMEOUT,
_VikingClient,
)
def _clear_openviking_tenant_env(monkeypatch):
@ -744,6 +748,22 @@ def test_end_then_switch_with_pending_tokens_does_not_double_commit():
assert provider._turn_count == 0
def test_session_needs_commit_guard_wins_over_stale_turn_count():
"""Regression for hermes-agent#28296 review (M3): once a session is marked
committed, _session_needs_commit must return False even if turn_count is
still positive. A racing sync_turn can re-increment _turn_count after the
commit+reset; without the guard ordering, a follow-up finalizer would
double-commit the same session. The committed-guard must be checked BEFORE
the turn_count>0 shortcut."""
provider = _make_provider_with_session("old-sid", turn_count=5)
provider._mark_session_committed("old-sid")
# turn_count is a (stale) 5 but the session is already committed.
assert provider._session_needs_commit("old-sid", 5) is False
# An uncommitted session with turns still needs a commit.
assert provider._session_needs_commit("fresh-sid", 5) is True
def test_on_session_switch_swallows_commit_failure():
"""Commit-on-switch must not propagate exceptions: a failing commit on the
old session must still allow the rotate to the new session to complete,
@ -831,7 +851,54 @@ def test_on_session_switch_waits_for_all_writers_not_just_latest():
assert provider._turn_count == 0
def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_drain():
def test_on_session_switch_does_not_block_caller_on_slow_drain():
"""Regression for hermes-agent#28296 review (H1): on_session_switch must
NOT run the old-session drain/commit on the caller's thread. /new, /branch,
/resume, /undo call this synchronously on the command thread, so a slow
writer drain (up to _SESSION_DRAIN_TIMEOUT/_DEFERRED_COMMIT_TIMEOUT) or a
wedged commit POST must not stall the user-facing command. The rotation is
cheap and synchronous; the commit is offloaded. Mirrors the #41945
'do not block the turn thread' contract."""
import threading
import time
provider = _make_provider_with_session("old-sid", turn_count=2)
drain_entered = threading.Event()
release_drain = threading.Event()
def slow_drain(sid, timeout):
drain_entered.set()
# Simulate a writer that takes a long time to drain.
release_drain.wait(timeout=10.0)
return True
provider._drain_writers = slow_drain
start = time.monotonic()
provider.on_session_switch("new-sid")
elapsed = time.monotonic() - start
# The caller returned promptly with state already rotated, even though the
# drain is still parked on the finalizer thread.
assert elapsed < 1.0, f"on_session_switch blocked the caller for {elapsed:.2f}s"
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
assert drain_entered.wait(timeout=2.0), "finalizer never started draining"
# No commit yet — drain is still blocked off-thread.
provider._client.post.assert_not_called()
# Let the finalizer finish so it doesn't leak past the test.
release_drain.set()
assert provider._drain_finalizers(timeout=5.0)
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_defers_old_commit_to_finalizer_thread():
"""The switch path rotates session state synchronously (cheap, in-memory)
but offloads the old-session drain + commit onto a daemon finalizer so the
caller's command thread (/new, /branch, /resume) never blocks on the up-to
-_DEFERRED_COMMIT_TIMEOUT drain or the commit POST. See hermes-agent#28296
review (the #41945 'do not block the turn thread' contract)."""
import threading
provider = _make_provider_with_session("old-sid", turn_count=2)
@ -844,19 +911,21 @@ def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_d
def fake_drain(sid, timeout):
drain_timeouts.append(timeout)
return len(drain_timeouts) > 1
return True
provider._client.post.side_effect = fake_post
provider._drain_writers = fake_drain
provider.on_session_switch("new-sid")
# Rotation is synchronous and immediate — the new session is live at once.
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
assert committed.wait(timeout=2.0), "old session was not finalized after writers drained"
# The old-session commit lands on the finalizer thread, not inline.
assert committed.wait(timeout=5.0), "old session was not finalized off-thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert drain_timeouts[0] == 10.0
assert drain_timeouts[1] > 10.0
# The finalizer drains with the deferred (longer) budget, not inline 10s.
assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT]
def test_sync_turn_tracks_writer_under_session_id():