mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-18 09:51:59 +00:00
fix(openviking): track writers per-session so commit waits for all
sync_turn's bounded join could drop a still-alive previous worker by replacing the single _sync_thread slot. The dropped worker kept POSTing under the old sid but was no longer visible to on_session_end / on_session_switch, so the commit could fire while orphaned writes were still in flight — those writes landed past the commit boundary and were never extracted. Replace the single _sync_thread slot with _inflight_writers: Dict[sid, Set[Thread]]. Writers self-register on spawn (sync_turn, on_memory_write) and self-deregister on exit. The commit path drains _drain_writers(sid, 10.0) and skips the commit if any writer for that sid is still alive after the bounded budget. Also trim inline review-rationale comments to short invariants per reviewer style ask: "commit only after session writes drain" and "drop prefetch results from older switch generations." Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3791a87dbe
commit
7537ee6f5b
2 changed files with 170 additions and 92 deletions
|
|
@ -31,10 +31,11 @@ import mimetypes
|
|||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from urllib.parse import urlparse
|
||||
from urllib.request import url2pathname
|
||||
|
||||
|
|
@ -399,14 +400,16 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._api_key = ""
|
||||
self._session_id = ""
|
||||
self._turn_count = 0
|
||||
self._sync_thread: Optional[threading.Thread] = None
|
||||
# Commit only after session writes drain. The set is keyed by the sid
|
||||
# the writer is POSTing under (snapshotted at spawn), so on_session_end
|
||||
# / on_session_switch see every still-alive writer for that sid even
|
||||
# if later writes have replaced the latest-tracked thread.
|
||||
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
|
||||
self._inflight_lock = threading.Lock()
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
# Monotonic counter incremented on every session switch. Prefetch
|
||||
# workers capture the value when spawned and refuse to write their
|
||||
# result if the generation has advanced — otherwise a slow worker
|
||||
# from session N can repopulate session N+1 with stale recall.
|
||||
# Drop prefetch results from older switch generations.
|
||||
self._prefetch_generation = 0
|
||||
|
||||
@property
|
||||
|
|
@ -520,9 +523,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not self._client or not query:
|
||||
return
|
||||
|
||||
# Snapshot the generation at spawn time. If on_session_switch bumps it
|
||||
# before this worker finishes, the worker drops its result instead of
|
||||
# repopulating the new session with stale recall from the old one.
|
||||
# Drop prefetch results from older switch generations.
|
||||
gen = self._prefetch_generation
|
||||
|
||||
def _run():
|
||||
|
|
@ -558,15 +559,59 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
)
|
||||
self._prefetch_thread.start()
|
||||
|
||||
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
|
||||
"""Spawn a daemon writer tracked in _inflight_writers[sid].
|
||||
|
||||
Tracking is keyed by sid (not by a single latest-thread slot) so that
|
||||
on_session_end / on_session_switch can drain every still-alive writer
|
||||
for the session being committed.
|
||||
"""
|
||||
holder: List[threading.Thread] = []
|
||||
|
||||
def _wrapped():
|
||||
try:
|
||||
target()
|
||||
finally:
|
||||
with self._inflight_lock:
|
||||
workers = self._inflight_writers.get(sid)
|
||||
if workers is not None:
|
||||
workers.discard(holder[0])
|
||||
if not workers:
|
||||
self._inflight_writers.pop(sid, None)
|
||||
|
||||
thread = threading.Thread(target=_wrapped, daemon=True, name=name)
|
||||
holder.append(thread)
|
||||
with self._inflight_lock:
|
||||
self._inflight_writers.setdefault(sid, set()).add(thread)
|
||||
thread.start()
|
||||
|
||||
def _drain_writers(self, sid: str, timeout: float) -> bool:
|
||||
"""Join every in-flight writer for sid within a shared timeout budget.
|
||||
|
||||
Returns True if all writers drained, False if any are still alive when
|
||||
the budget runs out. Callers use the False return to skip the commit.
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
while True:
|
||||
with self._inflight_lock:
|
||||
workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()]
|
||||
if not workers:
|
||||
return True
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return False
|
||||
for t in workers:
|
||||
slice_left = deadline - time.monotonic()
|
||||
if slice_left <= 0:
|
||||
break
|
||||
t.join(timeout=slice_left)
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Record the conversation turn in OpenViking's session (non-blocking)."""
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
# Capture the target session id NOW, not inside the worker. Otherwise
|
||||
# a delayed worker can read self._session_id after on_session_switch
|
||||
# has rotated it (the switch's join on _sync_thread is bounded), and
|
||||
# the OLD turn's content lands in the NEW session.
|
||||
# Snapshot the sid so a delayed worker can't write into a rotated session.
|
||||
sid = str(session_id or self._session_id).strip()
|
||||
if not sid:
|
||||
return
|
||||
|
|
@ -579,13 +624,10 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
|
||||
# Add user message
|
||||
client.post(f"/api/v1/sessions/{sid}/messages", {
|
||||
"role": "user",
|
||||
"content": user_content[:4000], # trim very long messages
|
||||
"content": user_content[:4000],
|
||||
})
|
||||
# Add assistant message
|
||||
client.post(f"/api/v1/sessions/{sid}/messages", {
|
||||
"role": "assistant",
|
||||
"content": assistant_content[:4000],
|
||||
|
|
@ -593,14 +635,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
except Exception as e:
|
||||
logger.debug("OpenViking sync_turn failed: %s", e)
|
||||
|
||||
# Wait for any previous sync to finish before starting a new one
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
|
||||
self._sync_thread = threading.Thread(
|
||||
target=_sync, daemon=True, name="openviking-sync"
|
||||
)
|
||||
self._sync_thread.start()
|
||||
self._spawn_writer(sid, _sync, name="openviking-sync")
|
||||
|
||||
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Commit the session to trigger memory extraction.
|
||||
|
|
@ -611,35 +646,22 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not self._client:
|
||||
return
|
||||
|
||||
# Wait for any pending sync to finish first — do this before the
|
||||
# turn_count check so the last turn's messages are flushed even if
|
||||
# the count hasn't been incremented yet.
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=10.0)
|
||||
if self._sync_thread.is_alive():
|
||||
# Worker outlived the bounded join — each POST has _TIMEOUT=30s
|
||||
# and there are two of them per turn. Committing now would
|
||||
# orphan the worker's late writes past the commit boundary
|
||||
# (they'd land in an already-committed session and never be
|
||||
# extracted). Skip the commit; leave _turn_count untouched so
|
||||
# the session stays marked dirty for any retry path.
|
||||
logger.warning(
|
||||
"OpenViking sync worker still alive after 10s join — "
|
||||
"skipping commit on session %s to avoid orphaning late writes",
|
||||
self._session_id,
|
||||
)
|
||||
return
|
||||
sid = self._session_id
|
||||
# Commit only after session writes drain.
|
||||
if not self._drain_writers(sid, timeout=10.0):
|
||||
logger.warning(
|
||||
"OpenViking writer for %s still alive after drain — skipping commit",
|
||||
sid,
|
||||
)
|
||||
return
|
||||
|
||||
if self._turn_count == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
self._client.post(f"/api/v1/sessions/{self._session_id}/commit")
|
||||
logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count)
|
||||
# Mark the session clean so a subsequent on_session_switch (fired
|
||||
# by /new and compression right after commit_memory_session) skips
|
||||
# its commit instead of double-committing. On commit failure we
|
||||
# leave the count intact so the switch hook gets a retry.
|
||||
self._client.post(f"/api/v1/sessions/{sid}/commit")
|
||||
logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count)
|
||||
# Mark clean so a follow-up on_session_switch skips its own commit.
|
||||
self._turn_count = 0
|
||||
except Exception as e:
|
||||
logger.warning("OpenViking session commit failed: %s", e)
|
||||
|
|
@ -670,33 +692,21 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not new_id or not self._client:
|
||||
return
|
||||
|
||||
# Snapshot the old session id BEFORE rotation so the join+commit
|
||||
# below always target the session whose writes we want to flush.
|
||||
old_session_id = self._session_id
|
||||
old_turn_count = self._turn_count
|
||||
sync_worker_drained = True
|
||||
|
||||
# 1. Wait for any in-flight sync_turn to finish writing under the
|
||||
# OLD session id — otherwise it races the commit below.
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=10.0)
|
||||
if self._sync_thread.is_alive():
|
||||
# Same hazard as on_session_end: worker outlived the bounded
|
||||
# join. Skip the commit so its late writes aren't orphaned
|
||||
# past a commit boundary they can't recover from.
|
||||
sync_worker_drained = False
|
||||
# Commit only after session writes drain.
|
||||
writers_drained = True
|
||||
if old_session_id:
|
||||
writers_drained = self._drain_writers(old_session_id, timeout=10.0)
|
||||
if not writers_drained:
|
||||
logger.warning(
|
||||
"OpenViking sync worker still alive after 10s join — "
|
||||
"skipping commit-on-switch for session %s; late writes "
|
||||
"will remain in the uncommitted old session",
|
||||
"OpenViking writer for %s still alive after drain — "
|
||||
"skipping commit-on-switch",
|
||||
old_session_id,
|
||||
)
|
||||
|
||||
# 2. Commit the old session if it accumulated turns — same
|
||||
# extraction semantics as on_session_end. Skip if empty (nothing
|
||||
# to extract), if the provider was never initialized, or if the
|
||||
# sync worker is still mid-flight.
|
||||
if sync_worker_drained and old_session_id and old_turn_count > 0:
|
||||
if writers_drained and old_session_id and old_turn_count > 0:
|
||||
try:
|
||||
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
|
||||
logger.info(
|
||||
|
|
@ -709,17 +719,13 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
old_session_id, e,
|
||||
)
|
||||
|
||||
# 3. Bump the prefetch generation so any in-flight prefetch worker
|
||||
# finishing AFTER this point drops its result. Then drain the
|
||||
# current worker and clear the cached result so the new session
|
||||
# doesn't see stale recall from the old one.
|
||||
# Drop prefetch results from older switch generations.
|
||||
self._prefetch_generation += 1
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = ""
|
||||
|
||||
# 4. Rotate to the new session.
|
||||
self._session_id = new_id
|
||||
self._turn_count = 0
|
||||
logger.debug(
|
||||
|
|
@ -732,10 +738,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not self._client or action != "add" or not content:
|
||||
return
|
||||
|
||||
# Snapshot the target session id at call time — see sync_turn() for
|
||||
# the rationale. A delayed worker that reads self._session_id after
|
||||
# on_session_switch has rotated it would land the memory note in the
|
||||
# NEW session.
|
||||
# Snapshot the sid so a delayed worker can't write into a rotated session.
|
||||
sid = str(self._session_id or "").strip()
|
||||
if not sid:
|
||||
return
|
||||
|
|
@ -746,8 +749,6 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._endpoint, self._api_key,
|
||||
account=self._account, user=self._user, agent=self._agent,
|
||||
)
|
||||
# Add as a user message with memory context so the commit
|
||||
# picks it up as an explicit memory during extraction
|
||||
client.post(f"/api/v1/sessions/{sid}/messages", {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
|
|
@ -757,8 +758,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
except Exception as e:
|
||||
logger.debug("OpenViking memory mirror failed: %s", e)
|
||||
|
||||
t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite")
|
||||
t.start()
|
||||
self._spawn_writer(sid, _write, name="openviking-memwrite")
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [SEARCH_SCHEMA, READ_SCHEMA, BROWSE_SCHEMA, REMEMBER_SCHEMA, ADD_RESOURCE_SCHEMA]
|
||||
|
|
@ -783,11 +783,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
return tool_error(str(e))
|
||||
|
||||
def shutdown(self) -> None:
|
||||
# Wait for background threads to finish
|
||||
for t in (self._sync_thread, self._prefetch_thread):
|
||||
if t and t.is_alive():
|
||||
# Wait for every in-flight writer across all tracked sessions.
|
||||
with self._inflight_lock:
|
||||
all_workers = [
|
||||
t for workers in self._inflight_writers.values() for t in workers
|
||||
]
|
||||
for t in all_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
# Clear atexit reference so it doesn't double-commit
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=5.0)
|
||||
# Clear atexit reference so it doesn't double-commit.
|
||||
global _last_active_provider
|
||||
if _last_active_provider is self:
|
||||
_last_active_provider = None
|
||||
|
|
|
|||
|
|
@ -483,7 +483,7 @@ def test_on_session_switch_waits_for_inflight_sync_thread():
|
|||
# Simulate a worker that finishes within the join window.
|
||||
self._alive = False
|
||||
|
||||
provider._sync_thread = FakeThread()
|
||||
provider._inflight_writers["old-sid"] = {FakeThread()}
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
|
|
@ -561,8 +561,8 @@ def test_sync_turn_captures_session_id_before_worker_runs():
|
|||
# Rotate the provider's session id while the worker is mid-flight.
|
||||
provider._session_id = "new-sid"
|
||||
release.set()
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=2.0)
|
||||
for t in list(provider._inflight_writers.get("old-sid", set())):
|
||||
t.join(timeout=2.0)
|
||||
finally:
|
||||
_mod._VikingClient = real_client_cls
|
||||
|
||||
|
|
@ -582,7 +582,7 @@ def test_sync_turn_noop_when_session_id_blank():
|
|||
|
||||
# No turn counted, no worker spawned.
|
||||
assert provider._turn_count == 0
|
||||
assert provider._sync_thread is None
|
||||
assert provider._inflight_writers == {}
|
||||
|
||||
|
||||
def test_on_session_end_marks_session_clean_after_successful_commit():
|
||||
|
|
@ -660,7 +660,7 @@ def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
|
|||
already-committed session and never be extracted. Leave _turn_count
|
||||
intact so the session stays marked dirty."""
|
||||
provider = _make_provider_with_session("old-sid", turn_count=3)
|
||||
provider._sync_thread = _HungThread()
|
||||
provider._inflight_writers["old-sid"] = {_HungThread()}
|
||||
|
||||
provider.on_session_end([])
|
||||
|
||||
|
|
@ -673,7 +673,7 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
|
|||
session needs to start) but the old-session commit is skipped to avoid
|
||||
orphaning the worker's late writes past commit."""
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
provider._sync_thread = _HungThread()
|
||||
provider._inflight_writers["old-sid"] = {_HungThread()}
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
|
|
@ -682,6 +682,78 @@ def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
|
|||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Orphaned-writer hazard: commit must wait for ALL writers for the session,
|
||||
# not just the latest tracked one. sync_turn's bounded rate-limit can drop a
|
||||
# still-alive previous worker — that dropped writer keeps POSTing under the
|
||||
# old sid and would otherwise land its writes past the commit boundary.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_on_session_end_waits_for_all_writers_not_just_latest():
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
provider._inflight_writers["old-sid"] = {_HungThread()}
|
||||
|
||||
provider.on_session_end([])
|
||||
|
||||
provider._client.post.assert_not_called()
|
||||
assert provider._turn_count == 2
|
||||
|
||||
|
||||
def test_on_session_switch_waits_for_all_writers_not_just_latest():
|
||||
provider = _make_provider_with_session("old-sid", turn_count=2)
|
||||
provider._inflight_writers["old-sid"] = {_HungThread()}
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
provider._client.post.assert_not_called()
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_sync_turn_tracks_writer_under_session_id():
|
||||
"""Every sync_turn writer must register under its captured sid so the
|
||||
drain at end/switch sees it even if a later sync_turn replaces the
|
||||
latest-tracked reference."""
|
||||
import threading
|
||||
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
provider._session_id = "sid-1"
|
||||
|
||||
release = threading.Event()
|
||||
started = threading.Event()
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
started.set()
|
||||
release.wait(timeout=2.0)
|
||||
return {}
|
||||
|
||||
import plugins.memory.openviking as _mod
|
||||
real_client_cls = _mod._VikingClient
|
||||
_mod._VikingClient = StubClient
|
||||
try:
|
||||
provider.sync_turn("u", "a")
|
||||
assert started.wait(timeout=2.0), "worker never entered post()"
|
||||
assert len(provider._inflight_writers.get("sid-1", set())) == 1
|
||||
release.set()
|
||||
for t in list(provider._inflight_writers.get("sid-1", set())):
|
||||
t.join(timeout=2.0)
|
||||
finally:
|
||||
_mod._VikingClient = real_client_cls
|
||||
|
||||
# Worker should have removed itself from the inflight set on exit.
|
||||
assert provider._inflight_writers.get("sid-1", set()) == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# on_memory_write: same late-capture hazard as sync_turn — worker must use
|
||||
# the session id snapshotted at call time, not re-read self._session_id.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue