mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-18 09:51:59 +00:00
fix(openviking): close remaining session-boundary races on switch
Three follow-ups from review on #28296:
1. Sync worker outliving the bounded join. Each sync_turn POST has
_TIMEOUT=30s and there are two per turn, but on_session_end and
on_session_switch only join for 10s. If the worker is still alive
after the join, committing the old session orphans the worker's
late writes past the commit boundary — they land in an already-
committed session and never get extracted. Both hooks now re-check
is_alive() after the join and skip the commit when the worker
hasn't drained.
2. on_memory_write late session_id capture. Same shape as the
pre-fix sync_turn: f-string for the post path read self._session_id
inside the worker, so a switch between thread spawn and post call
landed the memory note in the new session. Snapshot sid at call
time, same pattern as sync_turn.
3. Stale prefetch repopulating the new session. The pre-switch
drain+clear only protects against workers that finish before the
join completes; one finishing after the clear would write its
result into the new generation's slot. Added a monotonic
_prefetch_generation; workers capture it at spawn and refuse to
write if it has advanced.
Tests: existing in-flight-sync test updated to drain (it tested the
join-before-commit happy path); four new tests cover hung-writer skip
on end + switch, on_memory_write sid capture, and prefetch generation
gating. 177/177 memory tests pass.
(cherry picked from commit 3791a87dbe)
This commit is contained in:
parent
a30b40c73a
commit
eddbf291a4
2 changed files with 212 additions and 2 deletions
|
|
@ -436,6 +436,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
# Monotonic counter incremented on every session switch. Prefetch
|
||||
# workers capture the value when spawned and refuse to write their
|
||||
# result if the generation has advanced — otherwise a slow worker
|
||||
# from session N can repopulate session N+1 with stale recall.
|
||||
self._prefetch_generation = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -549,6 +554,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not self._client or not query:
|
||||
return
|
||||
|
||||
# Snapshot the generation at spawn time. If on_session_switch bumps it
|
||||
# before this worker finishes, the worker drops its result instead of
|
||||
# repopulating the new session with stale recall from the old one.
|
||||
gen = self._prefetch_generation
|
||||
|
||||
def _run():
|
||||
try:
|
||||
client = _VikingClient(
|
||||
|
|
@ -571,6 +581,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
parts.append(f"- [{score:.2f}] {abstract} ({uri})")
|
||||
if parts:
|
||||
with self._prefetch_lock:
|
||||
if gen != self._prefetch_generation:
|
||||
return
|
||||
self._prefetch_result = "\n".join(parts)
|
||||
except Exception as e:
|
||||
logger.debug("OpenViking prefetch failed: %s", e)
|
||||
|
|
@ -642,6 +654,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
# the count hasn't been incremented yet.
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=10.0)
|
||||
if self._sync_thread.is_alive():
|
||||
# Worker outlived the bounded join — each POST has _TIMEOUT=30s
|
||||
# and there are two of them per turn. Committing now would
|
||||
# orphan the worker's late writes past the commit boundary
|
||||
# (they'd land in an already-committed session and never be
|
||||
# extracted). Skip the commit; leave _turn_count untouched so
|
||||
# the session stays marked dirty for any retry path.
|
||||
logger.warning(
|
||||
"OpenViking sync worker still alive after 10s join — "
|
||||
"skipping commit on session %s to avoid orphaning late writes",
|
||||
self._session_id,
|
||||
)
|
||||
return
|
||||
|
||||
if self._turn_count == 0:
|
||||
return
|
||||
|
|
@ -685,11 +710,27 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
|
||||
old_session_id = self._session_id
|
||||
old_turn_count = self._turn_count
|
||||
sync_worker_drained = True
|
||||
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=10.0)
|
||||
if self._sync_thread.is_alive():
|
||||
# Same hazard as on_session_end: worker outlived the bounded
|
||||
# join. Skip the commit so its late writes aren't orphaned
|
||||
# past a commit boundary they can't recover from.
|
||||
sync_worker_drained = False
|
||||
logger.warning(
|
||||
"OpenViking sync worker still alive after 10s join — "
|
||||
"skipping commit-on-switch for session %s; late writes "
|
||||
"will remain in the uncommitted old session",
|
||||
old_session_id,
|
||||
)
|
||||
|
||||
if old_session_id and old_turn_count > 0:
|
||||
# 2. Commit the old session if it accumulated turns — same
|
||||
# extraction semantics as on_session_end. Skip if empty (nothing
|
||||
# to extract), if the provider was never initialized, or if the
|
||||
# sync worker is still mid-flight.
|
||||
if sync_worker_drained and old_session_id and old_turn_count > 0:
|
||||
try:
|
||||
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
|
||||
logger.info(
|
||||
|
|
@ -702,6 +743,11 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
old_session_id, e,
|
||||
)
|
||||
|
||||
# 3. Bump the prefetch generation so any in-flight prefetch worker
|
||||
# finishing AFTER this point drops its result. Then drain the
|
||||
# current worker and clear the cached result so the new session
|
||||
# doesn't see stale recall from the old one.
|
||||
self._prefetch_generation += 1
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
|
|
|
|||
|
|
@ -472,11 +472,16 @@ def test_on_session_switch_waits_for_inflight_sync_thread():
|
|||
join_calls = []
|
||||
|
||||
class FakeThread:
|
||||
def __init__(self):
|
||||
self._alive = True
|
||||
|
||||
def is_alive(self):
|
||||
return True
|
||||
return self._alive
|
||||
|
||||
def join(self, timeout=None):
|
||||
join_calls.append(timeout)
|
||||
# Simulate a worker that finishes within the join window.
|
||||
self._alive = False
|
||||
|
||||
provider._sync_thread = FakeThread()
|
||||
|
||||
|
|
@ -629,3 +634,162 @@ def test_on_session_switch_swallows_commit_failure():
|
|||
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Hung-writer protection: the sync worker can outlive the bounded join
|
||||
# because each OpenViking POST has _TIMEOUT=30s and there are two per turn.
|
||||
# Committing while late writes are still in flight would orphan them past
|
||||
# the commit boundary — they would never be extracted.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class _HungThread:
|
||||
"""Thread stand-in that stays alive across joins."""
|
||||
|
||||
def is_alive(self):
|
||||
return True
|
||||
|
||||
def join(self, timeout=None):
|
||||
# Pretend the join timed out — worker still running.
|
||||
return None
|
||||
|
||||
|
||||
def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
|
||||
"""If the sync worker is still alive after the 10s join, the commit must
|
||||
be skipped — late writes from the worker would otherwise land in an
|
||||
already-committed session and never be extracted. Leave _turn_count
|
||||
intact so the session stays marked dirty."""
|
||||
provider = _make_provider_with_session("old-sid", turn_count=3)
|
||||
provider._sync_thread = _HungThread()
|
||||
|
||||
provider.on_session_end([])
|
||||
|
||||
provider._client.post.assert_not_called()
|
||||
assert provider._turn_count == 3
|
||||
|
||||
|
||||
def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
|
||||
"""Same hazard on the switch path. Rotation must still proceed (the new
|
||||
session needs to start) but the old-session commit is skipped to avoid
|
||||
orphaning the worker's late writes past commit."""
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
provider._sync_thread = _HungThread()
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
provider._client.post.assert_not_called()
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_memory_write: same late-capture hazard as sync_turn — worker must use
|
||||
# the session id snapshotted at call time, not re-read self._session_id.
|
||||
# Block inside the stub ctor (BEFORE the f-string for the post path is
|
||||
# evaluated) so the rotation deterministically beats the f-string.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_on_memory_write_captures_session_id_at_call_time():
|
||||
import threading
|
||||
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
provider._session_id = "old-sid"
|
||||
|
||||
in_ctor = threading.Event()
|
||||
release = threading.Event()
|
||||
done = threading.Event()
|
||||
captured_paths = []
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
in_ctor.set()
|
||||
release.wait(timeout=2.0)
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
captured_paths.append(path)
|
||||
done.set()
|
||||
return {}
|
||||
|
||||
import plugins.memory.openviking as _mod
|
||||
real_client_cls = _mod._VikingClient
|
||||
_mod._VikingClient = StubClient
|
||||
try:
|
||||
provider.on_memory_write("add", "viking://memories/x", "remember this")
|
||||
assert in_ctor.wait(timeout=2.0), "worker never entered ctor"
|
||||
# Rotate provider's session id while the worker is parked in the ctor,
|
||||
# BEFORE it evaluates the f-string for the post path. If the worker
|
||||
# reads self._session_id inside the closure, it will now see "new-sid".
|
||||
provider._session_id = "new-sid"
|
||||
release.set()
|
||||
assert done.wait(timeout=2.0), "worker never reached post()"
|
||||
finally:
|
||||
_mod._VikingClient = real_client_cls
|
||||
|
||||
# The write must target the OLD session id captured at call time.
|
||||
assert captured_paths == ["/api/v1/sessions/old-sid/messages"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
|
||||
# must drop its result instead of repopulating the new session with stale
|
||||
# recall from the old generation. Bump the generation directly (rather than
|
||||
# calling on_session_switch, whose own join blocks on the test worker) so
|
||||
# the test isolates the generation-gating behavior.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
|
||||
import threading
|
||||
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
provider._session_id = "old-sid"
|
||||
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
started.set()
|
||||
release.wait(timeout=2.0)
|
||||
return {
|
||||
"result": {
|
||||
"memories": [
|
||||
{"uri": "viking://memories/old", "score": 0.9,
|
||||
"abstract": "stale from old session"},
|
||||
],
|
||||
"resources": [],
|
||||
}
|
||||
}
|
||||
|
||||
import plugins.memory.openviking as _mod
|
||||
real_client_cls = _mod._VikingClient
|
||||
_mod._VikingClient = StubClient
|
||||
try:
|
||||
provider.queue_prefetch("anything")
|
||||
assert started.wait(timeout=2.0), "prefetch worker never entered post()"
|
||||
# Simulate a session switch by bumping the generation directly.
|
||||
# The worker captured the pre-bump generation when it was spawned.
|
||||
provider._prefetch_generation += 1
|
||||
release.set()
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=2.0)
|
||||
finally:
|
||||
_mod._VikingClient = real_client_cls
|
||||
|
||||
# The stale result from the pre-bump generation must NOT have been written
|
||||
# into the new generation's prefetch slot.
|
||||
assert provider._prefetch_result == ""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue