mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-19 10:02:16 +00:00
fix(openviking): don't block the command thread on session switch; lock turn state
Follow-up hardening on @ehz0ah / @harshitAgr's session-switch work (#28296): - on_session_switch no longer runs the old-session writer-drain + pending-token GET + commit POST inline on the caller's command thread. /new, /branch, /resume, /undo call it synchronously, so a slow drain (up to 10s) or wedged commit blocked the user-facing command — the same hazard #41945 fixed for end-of-turn sync. State now rotates synchronously (cheap) and the old-session commit is offloaded to a daemon finalizer (generalized _finalize_session_async). - Guard the (_session_id, _turn_count) pair with _session_state_lock: sync_turn runs on the memory-manager executor thread while the session hooks run on the command thread, so the snapshot+reset vs increment was a cross-thread race. - _session_needs_commit checks the committed-session guard BEFORE the turn_count>0 shortcut, closing a double-commit window when a racing sync_turn re-increments after commit+reset. - Add a _shutting_down flag so deferred finalizers stop POSTing against a torn-down client; track all prefetch threads in a set so invalidate/shutdown join every one, not just the latest slot. Tests: regression for the non-blocking switch (asserts the caller returns while a slow drain is parked off-thread) and the committed-guard ordering; updated the deferred-commit test to the unified finalizer contract.
This commit is contained in:
parent
0c1e8d0ba9
commit
c835448908
2 changed files with 210 additions and 55 deletions
|
|
@ -435,6 +435,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._api_key = ""
|
||||
self._session_id = ""
|
||||
self._turn_count = 0
|
||||
# Guards the (_session_id, _turn_count) pair. sync_turn runs on the
|
||||
# MemoryManager's background sync executor while on_session_end /
|
||||
# on_session_switch run on the caller's thread, so the snapshot+reset
|
||||
# of the turn counter and the session-id rotation must be atomic
|
||||
# against a concurrent increment. See hermes-agent#28296 review.
|
||||
self._session_state_lock = threading.Lock()
|
||||
# Commit only after session writes drain. The set is keyed by the sid
|
||||
# the writer is POSTing under (snapshotted at spawn), so on_session_end
|
||||
# / on_session_switch see every still-alive writer for that sid even
|
||||
|
|
@ -449,6 +455,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
|
||||
# shutdown() can drain them and rapid re-queues don't orphan a still-
|
||||
# running thread by overwriting the single _prefetch_thread slot.
|
||||
self._prefetch_threads: Set[threading.Thread] = set()
|
||||
# Set on shutdown so deferred-commit / writer finalizers stop issuing
|
||||
# network writes against a torn-down provider.
|
||||
self._shutting_down = False
|
||||
# Drop prefetch results from older switch generations.
|
||||
self._prefetch_generation = 0
|
||||
|
||||
|
|
@ -568,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
with self._prefetch_lock:
|
||||
gen = self._prefetch_generation
|
||||
|
||||
holder: List[threading.Thread] = []
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = _VikingClient(
|
||||
|
|
@ -595,11 +610,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._prefetch_result = "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.debug("OpenViking prefetch failed: %s", e)
|
||||
finally:
|
||||
with self._prefetch_lock:
|
||||
if holder:
|
||||
self._prefetch_threads.discard(holder[0])
|
||||
|
||||
self._prefetch_thread = threading.Thread(
|
||||
thread = threading.Thread(
|
||||
target=_run, daemon=True, name="openviking-prefetch"
|
||||
)
|
||||
self._prefetch_thread.start()
|
||||
holder.append(thread)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_thread = thread
|
||||
self._prefetch_threads.add(thread)
|
||||
thread.start()
|
||||
|
||||
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
|
||||
"""Spawn a daemon writer tracked in _inflight_writers[sid].
|
||||
|
|
@ -627,6 +650,30 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._inflight_writers.setdefault(sid, set()).add(thread)
|
||||
thread.start()
|
||||
|
||||
def _drain_finalizers(self, timeout: float) -> bool:
|
||||
"""Join every in-flight async session finalizer within a timeout.
|
||||
|
||||
The switch-path commit runs on a daemon finalizer thread so it never
|
||||
blocks the caller's command thread; this lets shutdown and tests wait
|
||||
for those commits deterministically. Returns True if all drained.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
with self._deferred_commit_lock:
|
||||
workers = [t for t in self._deferred_commit_threads if t.is_alive()]
|
||||
if not workers:
|
||||
return True
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
for t in workers:
|
||||
slice_left = deadline - time.monotonic()
|
||||
if slice_left <= 0:
|
||||
break
|
||||
# Floor the per-join wait so a thread whose join() returns
|
||||
# instantly while still reporting alive can't hot-spin this loop.
|
||||
t.join(timeout=min(slice_left, 0.05))
|
||||
|
||||
def _drain_writers(self, sid: str, timeout: float) -> bool:
|
||||
"""Join every in-flight writer for sid within a shared timeout budget.
|
||||
|
||||
|
|
@ -707,10 +754,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._committed_session_ids.add(sid)
|
||||
|
||||
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
|
||||
if turn_count > 0:
|
||||
return True
|
||||
# Already-committed sessions never need a second commit, regardless of
|
||||
# the turn counter — a racing sync_turn can re-increment _turn_count
|
||||
# after a commit+reset, so the committed-guard must win over turn_count.
|
||||
if self._has_committed_session(sid):
|
||||
return False
|
||||
if turn_count > 0:
|
||||
return True
|
||||
return self._session_has_pending_tokens(sid)
|
||||
|
||||
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
|
||||
|
|
@ -723,11 +773,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
|
||||
return False
|
||||
|
||||
def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None:
|
||||
if not sid or turn_count <= 0:
|
||||
def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None:
|
||||
"""Drain the old session's writers and commit it on a daemon thread.
|
||||
|
||||
Used by on_session_switch (and the deferred-commit fallback) so the
|
||||
potentially-multi-second drain + pending-token GET + commit POST never
|
||||
runs on the caller's command thread. Deduped by sid so a rapid second
|
||||
switch can't stack two finalizers for the same session, and a no-op
|
||||
once shutdown has begun so we don't POST against a torn-down client.
|
||||
"""
|
||||
if not sid:
|
||||
return
|
||||
with self._deferred_commit_lock:
|
||||
if sid in self._deferred_commit_sids:
|
||||
if self._shutting_down or sid in self._deferred_commit_sids:
|
||||
return
|
||||
self._deferred_commit_sids.add(sid)
|
||||
|
||||
|
|
@ -735,14 +793,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
|
||||
def _finalize() -> None:
|
||||
try:
|
||||
if self._shutting_down:
|
||||
return
|
||||
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
|
||||
logger.warning(
|
||||
"OpenViking writer for %s still alive after deferred drain — "
|
||||
"OpenViking writer for %s still alive after drain — "
|
||||
"leaving session uncommitted",
|
||||
sid,
|
||||
)
|
||||
return
|
||||
self._commit_session(sid, turn_count, context="after deferred drain")
|
||||
if self._shutting_down:
|
||||
return
|
||||
if self._session_needs_commit(sid, turn_count):
|
||||
self._commit_session(sid, turn_count, context=context)
|
||||
finally:
|
||||
with self._deferred_commit_lock:
|
||||
self._deferred_commit_sids.discard(sid)
|
||||
|
|
@ -765,8 +828,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
with self._prefetch_lock:
|
||||
self._prefetch_generation += 1
|
||||
self._prefetch_result = ""
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
# Join EVERY tracked prefetch thread, not just the latest slot — a
|
||||
# rapid re-queue can leave an older thread for the abandoned session
|
||||
# still running (consistent with shutdown()).
|
||||
workers = [t for t in self._prefetch_threads if t.is_alive()]
|
||||
for t in workers:
|
||||
t.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = ""
|
||||
|
||||
|
|
@ -779,12 +846,15 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not user_content:
|
||||
return
|
||||
|
||||
# Snapshot the sid so a delayed worker can't write into a rotated session.
|
||||
sid = str(session_id or self._session_id).strip()
|
||||
if not sid:
|
||||
return
|
||||
|
||||
self._turn_count += 1
|
||||
# Snapshot the sid and bump the turn counter atomically so a
|
||||
# concurrent on_session_switch/on_session_end can't interleave its
|
||||
# snapshot+reset between the read and the increment (lost turn) and so
|
||||
# the turn is unambiguously attributed to the session it targets.
|
||||
with self._session_state_lock:
|
||||
sid = str(session_id or self._session_id).strip()
|
||||
if not sid:
|
||||
return
|
||||
self._turn_count += 1
|
||||
|
||||
def _sync():
|
||||
try:
|
||||
|
|
@ -819,7 +889,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not self._client:
|
||||
return
|
||||
|
||||
sid = self._session_id
|
||||
# Snapshot sid + turn count atomically against a concurrent sync_turn
|
||||
# increment. on_session_end runs at teardown so the drain+commit stays
|
||||
# synchronous here (we want it to land before the process exits), but
|
||||
# the counter read must still be consistent.
|
||||
with self._session_state_lock:
|
||||
sid = self._session_id
|
||||
turn_count = self._turn_count
|
||||
|
||||
# Commit only after session writes drain.
|
||||
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
|
||||
logger.warning(
|
||||
|
|
@ -828,12 +905,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
)
|
||||
return
|
||||
|
||||
if not self._session_needs_commit(sid, self._turn_count):
|
||||
if not self._session_needs_commit(sid, turn_count):
|
||||
return
|
||||
|
||||
if self._commit_session(sid, self._turn_count, context="on session end"):
|
||||
if self._commit_session(sid, turn_count, context="on session end"):
|
||||
# Mark clean so a follow-up on_session_switch skips its own commit.
|
||||
self._turn_count = 0
|
||||
with self._session_state_lock:
|
||||
if self._session_id == sid:
|
||||
self._turn_count = 0
|
||||
|
||||
def on_session_switch(
|
||||
self,
|
||||
|
|
@ -862,10 +941,31 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
return
|
||||
|
||||
rewound = bool(kwargs.get("rewound"))
|
||||
old_session_id = self._session_id
|
||||
old_turn_count = self._turn_count
|
||||
if rewound or new_id == old_session_id:
|
||||
self._invalidate_prefetch_state()
|
||||
|
||||
# Rotate cached session state synchronously (cheap, in-memory) and
|
||||
# snapshot the old session under the lock so a concurrent sync_turn
|
||||
# either lands fully before the rotation (counted under old) or fully
|
||||
# after (counted under new) — never split. The OLD session's commit
|
||||
# (drain + pending-token GET + commit POST, potentially many seconds)
|
||||
# is then offloaded so /new, /branch, /resume, /undo never block the
|
||||
# caller's command thread (cf. the end-of-turn-sync offload in #41945).
|
||||
with self._session_state_lock:
|
||||
old_session_id = self._session_id
|
||||
old_turn_count = self._turn_count
|
||||
rotate = not (rewound or new_id == old_session_id)
|
||||
if rotate:
|
||||
self._session_id = new_id
|
||||
self._turn_count = 0
|
||||
|
||||
# Invalidate stale prefetch OUTSIDE the session lock — it takes its own
|
||||
# _prefetch_lock and may join a prefetch thread for up to 3s, which we
|
||||
# must not do while holding the session lock (would block sync_turn and
|
||||
# risk lock-ordering coupling).
|
||||
self._invalidate_prefetch_state()
|
||||
|
||||
if not rotate:
|
||||
# Same-session rewind (/undo) or no-op rotation: no commit, no
|
||||
# counter reset — just the prefetch invalidation above.
|
||||
logger.debug(
|
||||
"OpenViking on_session_switch invalidated state without rotation: "
|
||||
"session=%s rewound=%s",
|
||||
|
|
@ -873,30 +973,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
)
|
||||
return
|
||||
|
||||
# Commit only after session writes drain.
|
||||
writers_drained = True
|
||||
# Drain + commit the OLD session off the command thread.
|
||||
if old_session_id:
|
||||
writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT)
|
||||
if not writers_drained:
|
||||
logger.warning(
|
||||
"OpenViking writer for %s still alive after drain — "
|
||||
"skipping commit-on-switch",
|
||||
old_session_id,
|
||||
)
|
||||
self._finalize_session_async(old_session_id, old_turn_count, context="on switch")
|
||||
|
||||
if (
|
||||
writers_drained
|
||||
and old_session_id
|
||||
and self._session_needs_commit(old_session_id, old_turn_count)
|
||||
):
|
||||
self._commit_session(old_session_id, old_turn_count, context="on switch")
|
||||
elif not writers_drained:
|
||||
self._schedule_deferred_commit(old_session_id, old_turn_count)
|
||||
|
||||
self._invalidate_prefetch_state()
|
||||
|
||||
self._session_id = new_id
|
||||
self._turn_count = 0
|
||||
logger.debug(
|
||||
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
|
||||
old_session_id, new_id, parent_session_id, reset,
|
||||
|
|
@ -961,6 +1041,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
return tool_error(str(e))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# Stop deferred finalizers from issuing new commits against a
|
||||
# torn-down client, then drain everything still in flight.
|
||||
self._shutting_down = True
|
||||
# Wait for every in-flight writer across all tracked sessions.
|
||||
with self._inflight_lock:
|
||||
all_workers = [
|
||||
|
|
@ -968,14 +1051,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
]
|
||||
with self._deferred_commit_lock:
|
||||
deferred_workers = list(self._deferred_commit_threads)
|
||||
with self._prefetch_lock:
|
||||
prefetch_workers = list(self._prefetch_threads)
|
||||
for t in all_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
for t in deferred_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=5.0)
|
||||
for t in prefetch_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
# Clear atexit reference so it doesn't double-commit.
|
||||
global _last_active_provider
|
||||
if _last_active_provider is self:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@ from unittest.mock import MagicMock
|
|||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient
|
||||
from plugins.memory.openviking import (
|
||||
OpenVikingMemoryProvider,
|
||||
_DEFERRED_COMMIT_TIMEOUT,
|
||||
_VikingClient,
|
||||
)
|
||||
|
||||
|
||||
def _clear_openviking_tenant_env(monkeypatch):
|
||||
|
|
@ -744,6 +748,22 @@ def test_end_then_switch_with_pending_tokens_does_not_double_commit():
|
|||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_session_needs_commit_guard_wins_over_stale_turn_count():
|
||||
"""Regression for hermes-agent#28296 review (M3): once a session is marked
|
||||
committed, _session_needs_commit must return False even if turn_count is
|
||||
still positive. A racing sync_turn can re-increment _turn_count after the
|
||||
commit+reset; without the guard ordering, a follow-up finalizer would
|
||||
double-commit the same session. The committed-guard must be checked BEFORE
|
||||
the turn_count>0 shortcut."""
|
||||
provider = _make_provider_with_session("old-sid", turn_count=5)
|
||||
provider._mark_session_committed("old-sid")
|
||||
|
||||
# turn_count is a (stale) 5 but the session is already committed.
|
||||
assert provider._session_needs_commit("old-sid", 5) is False
|
||||
# An uncommitted session with turns still needs a commit.
|
||||
assert provider._session_needs_commit("fresh-sid", 5) is True
|
||||
|
||||
|
||||
def test_on_session_switch_swallows_commit_failure():
|
||||
"""Commit-on-switch must not propagate exceptions: a failing commit on the
|
||||
old session must still allow the rotate to the new session to complete,
|
||||
|
|
@ -831,7 +851,54 @@ def test_on_session_switch_waits_for_all_writers_not_just_latest():
|
|||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_drain():
|
||||
def test_on_session_switch_does_not_block_caller_on_slow_drain():
|
||||
"""Regression for hermes-agent#28296 review (H1): on_session_switch must
|
||||
NOT run the old-session drain/commit on the caller's thread. /new, /branch,
|
||||
/resume, /undo call this synchronously on the command thread, so a slow
|
||||
writer drain (up to _SESSION_DRAIN_TIMEOUT/_DEFERRED_COMMIT_TIMEOUT) or a
|
||||
wedged commit POST must not stall the user-facing command. The rotation is
|
||||
cheap and synchronous; the commit is offloaded. Mirrors the #41945
|
||||
'do not block the turn thread' contract."""
|
||||
import threading
|
||||
import time
|
||||
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
|
||||
drain_entered = threading.Event()
|
||||
release_drain = threading.Event()
|
||||
|
||||
def slow_drain(sid, timeout):
|
||||
drain_entered.set()
|
||||
# Simulate a writer that takes a long time to drain.
|
||||
release_drain.wait(timeout=10.0)
|
||||
return True
|
||||
|
||||
provider._drain_writers = slow_drain
|
||||
|
||||
start = time.monotonic()
|
||||
provider.on_session_switch("new-sid")
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# The caller returned promptly with state already rotated, even though the
|
||||
# drain is still parked on the finalizer thread.
|
||||
assert elapsed < 1.0, f"on_session_switch blocked the caller for {elapsed:.2f}s"
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
assert drain_entered.wait(timeout=2.0), "finalizer never started draining"
|
||||
# No commit yet — drain is still blocked off-thread.
|
||||
provider._client.post.assert_not_called()
|
||||
# Let the finalizer finish so it doesn't leak past the test.
|
||||
release_drain.set()
|
||||
assert provider._drain_finalizers(timeout=5.0)
|
||||
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
|
||||
|
||||
|
||||
def test_on_session_switch_defers_old_commit_to_finalizer_thread():
|
||||
"""The switch path rotates session state synchronously (cheap, in-memory)
|
||||
but offloads the old-session drain + commit onto a daemon finalizer so the
|
||||
caller's command thread (/new, /branch, /resume) never blocks on the up-to
|
||||
-_DEFERRED_COMMIT_TIMEOUT drain or the commit POST. See hermes-agent#28296
|
||||
review (the #41945 'do not block the turn thread' contract)."""
|
||||
import threading
|
||||
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
|
|
@ -844,19 +911,21 @@ def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_d
|
|||
|
||||
def fake_drain(sid, timeout):
|
||||
drain_timeouts.append(timeout)
|
||||
return len(drain_timeouts) > 1
|
||||
return True
|
||||
|
||||
provider._client.post.side_effect = fake_post
|
||||
provider._drain_writers = fake_drain
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
# Rotation is synchronous and immediate — the new session is live at once.
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
assert committed.wait(timeout=2.0), "old session was not finalized after writers drained"
|
||||
# The old-session commit lands on the finalizer thread, not inline.
|
||||
assert committed.wait(timeout=5.0), "old session was not finalized off-thread"
|
||||
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
|
||||
assert drain_timeouts[0] == 10.0
|
||||
assert drain_timeouts[1] > 10.0
|
||||
# The finalizer drains with the deferred (longer) budget, not inline 10s.
|
||||
assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT]
|
||||
|
||||
|
||||
def test_sync_turn_tracks_writer_under_session_id():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue