fix(openviking): harden session writes and switch commits

This commit is contained in:
Hao Zhe 2026-06-17 13:16:03 +08:00
parent f3b813c027
commit 00c045b43f
2 changed files with 238 additions and 37 deletions

View file

@ -47,6 +47,8 @@ logger = logging.getLogger(__name__)
_DEFAULT_ENDPOINT = "http://127.0.0.1:1933"
_TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
# Maps the viking_remember `category` enum to a viking:// subdirectory.
@ -439,6 +441,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
# if later writes have replaced the latest-tracked thread.
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
self._inflight_lock = threading.Lock()
self._deferred_commit_sids: Set[str] = set()
self._deferred_commit_threads: Set[threading.Thread] = set()
self._deferred_commit_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
@ -625,6 +630,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
Returns True if all writers drained, False if any are still alive when
the budget runs out. Callers use the False return to skip the commit.
"""
if not sid:
return True
deadline = time.monotonic() + timeout
while True:
with self._inflight_lock:
@ -640,6 +647,99 @@ class OpenVikingMemoryProvider(MemoryProvider):
break
t.join(timeout=slice_left)
def _new_client(self) -> _VikingClient:
return _VikingClient(
self._endpoint,
self._api_key,
account=self._account,
user=self._user,
agent=self._agent,
)
@staticmethod
def _text_part(content: str) -> Dict[str, str]:
return {"type": "text", "text": content}
@classmethod
def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]:
return {
"messages": [
{"role": "user", "parts": [cls._text_part(user_content)]},
{"role": "assistant", "parts": [cls._text_part(assistant_content)]},
]
}
@classmethod
def _post_session_turn(
cls,
client: _VikingClient,
sid: str,
user_content: str,
assistant_content: str,
) -> None:
client.post(
f"/api/v1/sessions/{sid}/messages/batch",
cls._turn_batch_payload(user_content, assistant_content),
)
def _session_has_pending_tokens(self, sid: str) -> bool:
try:
response = self._client.get(f"/api/v1/sessions/{sid}")
except Exception:
return False
session = self._unwrap_result(response)
if not isinstance(session, dict):
return False
try:
return int(session.get("pending_tokens") or 0) > 0
except (TypeError, ValueError):
return False
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
except Exception as e:
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None:
if not sid or turn_count <= 0:
return
with self._deferred_commit_lock:
if sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
holder: List[threading.Thread] = []
def _finalize() -> None:
try:
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after deferred drain — "
"leaving session uncommitted",
sid,
)
return
self._commit_session(sid, turn_count, context="after deferred drain")
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
if holder:
self._deferred_commit_threads.discard(holder[0])
thread = threading.Thread(
target=_finalize,
daemon=True,
name=f"openviking-finalize-{sid}",
)
holder.append(thread)
with self._deferred_commit_lock:
self._deferred_commit_threads.add(thread)
thread.start()
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
@ -658,20 +758,25 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _sync():
try:
client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "user",
"content": user_content[:4000],
})
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "assistant",
"content": assistant_content[:4000],
})
except Exception as e:
logger.debug("OpenViking sync_turn failed: %s", e)
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)
self._spawn_writer(sid, _sync, name="openviking-sync")
@ -686,23 +791,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
sid = self._session_id
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=10.0):
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — skipping commit",
sid,
)
return
if self._turn_count == 0:
if self._turn_count == 0 and not self._session_has_pending_tokens(sid):
return
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count)
if self._commit_session(sid, self._turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
self._turn_count = 0
except Exception as e:
logger.warning("OpenViking session commit failed: %s", e)
def on_session_switch(
self,
@ -736,7 +837,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
# Commit only after session writes drain.
writers_drained = True
if old_session_id:
writers_drained = self._drain_writers(old_session_id, timeout=10.0)
writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT)
if not writers_drained:
logger.warning(
"OpenViking writer for %s still alive after drain — "
@ -745,17 +846,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
if writers_drained and old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
"OpenViking session %s committed on switch (%d turns)",
old_session_id, old_turn_count,
)
except Exception as e:
logger.warning(
"OpenViking commit-on-switch failed for %s: %s",
old_session_id, e,
)
self._commit_session(old_session_id, old_turn_count, context="on switch")
elif not writers_drained:
self._schedule_deferred_commit(old_session_id, old_turn_count)
# Drop prefetch results from older switch generations.
self._prefetch_generation += 1
@ -835,9 +928,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
all_workers = [
t for workers in self._inflight_writers.values() for t in workers
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
@ -899,7 +997,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
if args.get("scope"):
payload["target_uri"] = args["scope"]
if args.get("limit"):
payload["top_k"] = args["limit"]
payload["limit"] = args["limit"]
resp = self._client.post("/api/v1/search/find", payload)
result = resp.get("result", {})

View file

@ -66,6 +66,21 @@ def test_tool_search_sorts_missing_raw_score_after_negative_scores():
assert result["total"] == 3
def test_tool_search_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {"memories": [], "resources": [], "skills": [], "total": 0}
}
provider._tool_search({"query": "session switch", "limit": 7})
provider._client.post.assert_called_once()
payload = provider._client.post.call_args.args[1]
assert payload["limit"] == 7
assert "top_k" not in payload
def test_tool_add_resource_uploads_existing_local_file(tmp_path):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
@ -534,11 +549,13 @@ def test_sync_turn_captures_session_id_before_worker_runs():
started = threading.Event()
release = threading.Event()
captured_paths = []
captured_payloads = []
def fake_post(path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
captured_paths.append(path)
captured_payloads.append(payload)
return {}
# Patch _VikingClient inside the worker by stubbing post on a client
@ -566,11 +583,59 @@ def test_sync_turn_captures_session_id_before_worker_runs():
finally:
_mod._VikingClient = real_client_cls
# Both writes must target the OLD session id captured at call time.
assert captured_paths == [
"/api/v1/sessions/old-sid/messages",
"/api/v1/sessions/old-sid/messages",
]
# The whole turn must target the OLD session id as a single ordered batch.
assert captured_paths == ["/api/v1/sessions/old-sid/messages/batch"]
assert captured_payloads == [{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
}]
def test_sync_turn_retries_batch_write_with_fresh_client():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
clients = []
captured = []
class StubClient:
def __init__(self, *a, **kw):
self.index = len(clients)
clients.append(self)
def post(self, path, payload=None, **kwargs):
if self.index == 0:
raise RuntimeError("transient")
captured.append((path, payload))
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert provider._drain_writers("sid-1", timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert len(clients) == 2
assert captured == [(
"/api/v1/sessions/sid-1/messages/batch",
{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
},
)]
def test_sync_turn_noop_when_session_id_blank():
@ -608,6 +673,16 @@ def test_on_session_end_keeps_dirty_when_commit_fails():
assert provider._turn_count == 3
def test_on_session_end_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_end_then_switch_does_not_double_commit():
"""Mirrors the /new and compression call order: commit_memory_session
( on_session_end) immediately followed by on_session_switch. The switch
@ -710,6 +785,34 @@ def test_on_session_switch_waits_for_all_writers_not_just_latest():
assert provider._turn_count == 0
def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_drain():
import threading
provider = _make_provider_with_session("old-sid", turn_count=2)
committed = threading.Event()
drain_timeouts = []
def fake_post(path):
committed.set()
return {}
def fake_drain(sid, timeout):
drain_timeouts.append(timeout)
return len(drain_timeouts) > 1
provider._client.post.side_effect = fake_post
provider._drain_writers = fake_drain
provider.on_session_switch("new-sid")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
assert committed.wait(timeout=2.0), "old session was not finalized after writers drained"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert drain_timeouts[0] == 10.0
assert drain_timeouts[1] > 10.0
def test_sync_turn_tracks_writer_under_session_id():
"""Every sync_turn writer must register under its captured sid so the
drain at end/switch sees it even if a later sync_turn replaces the