From 00c045b43f309360724e6b35be5a6526831795f8 Mon Sep 17 00:00:00 2001 From: Hao Zhe Date: Wed, 17 Jun 2026 13:16:03 +0800 Subject: [PATCH] fix(openviking): harden session writes and switch commits --- plugins/memory/openviking/__init__.py | 162 ++++++++++++++---- .../memory/test_openviking_provider.py | 113 +++++++++++- 2 files changed, 238 insertions(+), 37 deletions(-) diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index df0a0cd9a42..9d0e0cb199c 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -47,6 +47,8 @@ logger = logging.getLogger(__name__) _DEFAULT_ENDPOINT = "http://127.0.0.1:1933" _TIMEOUT = 30.0 +_SESSION_DRAIN_TIMEOUT = 10.0 +_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0 _REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://") # Maps the viking_remember `category` enum to a viking:// subdirectory. @@ -439,6 +441,9 @@ class OpenVikingMemoryProvider(MemoryProvider): # if later writes have replaced the latest-tracked thread. self._inflight_writers: Dict[str, Set[threading.Thread]] = {} self._inflight_lock = threading.Lock() + self._deferred_commit_sids: Set[str] = set() + self._deferred_commit_threads: Set[threading.Thread] = set() + self._deferred_commit_lock = threading.Lock() self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None @@ -625,6 +630,8 @@ class OpenVikingMemoryProvider(MemoryProvider): Returns True if all writers drained, False if any are still alive when the budget runs out. Callers use the False return to skip the commit. """ + if not sid: + return True deadline = time.monotonic() + timeout while True: with self._inflight_lock: @@ -640,6 +647,99 @@ class OpenVikingMemoryProvider(MemoryProvider): break t.join(timeout=slice_left) + def _new_client(self) -> _VikingClient: + return _VikingClient( + self._endpoint, + self._api_key, + account=self._account, + user=self._user, + agent=self._agent, + ) + + @staticmethod + def _text_part(content: str) -> Dict[str, str]: + return {"type": "text", "text": content} + + @classmethod + def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]: + return { + "messages": [ + {"role": "user", "parts": [cls._text_part(user_content)]}, + {"role": "assistant", "parts": [cls._text_part(assistant_content)]}, + ] + } + + @classmethod + def _post_session_turn( + cls, + client: _VikingClient, + sid: str, + user_content: str, + assistant_content: str, + ) -> None: + client.post( + f"/api/v1/sessions/{sid}/messages/batch", + cls._turn_batch_payload(user_content, assistant_content), + ) + + def _session_has_pending_tokens(self, sid: str) -> bool: + try: + response = self._client.get(f"/api/v1/sessions/{sid}") + except Exception: + return False + session = self._unwrap_result(response) + if not isinstance(session, dict): + return False + try: + return int(session.get("pending_tokens") or 0) > 0 + except (TypeError, ValueError): + return False + + def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool: + try: + self._client.post(f"/api/v1/sessions/{sid}/commit") + logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count) + return True + except Exception as e: + logger.warning("OpenViking session commit failed for %s: %s", sid, e) + return False + + def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None: + if not sid or turn_count <= 0: + return + with self._deferred_commit_lock: + if sid in self._deferred_commit_sids: + return + self._deferred_commit_sids.add(sid) + + holder: List[threading.Thread] = [] + + def _finalize() -> None: + try: + if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT): + logger.warning( + "OpenViking writer for %s still alive after deferred drain — " + "leaving session uncommitted", + sid, + ) + return + self._commit_session(sid, turn_count, context="after deferred drain") + finally: + with self._deferred_commit_lock: + self._deferred_commit_sids.discard(sid) + if holder: + self._deferred_commit_threads.discard(holder[0]) + + thread = threading.Thread( + target=_finalize, + daemon=True, + name=f"openviking-finalize-{sid}", + ) + holder.append(thread) + with self._deferred_commit_lock: + self._deferred_commit_threads.add(thread) + thread.start() + def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Record the conversation turn in OpenViking's session (non-blocking).""" if not self._client: @@ -658,20 +758,25 @@ class OpenVikingMemoryProvider(MemoryProvider): def _sync(): try: - client = _VikingClient( - self._endpoint, self._api_key, - account=self._account, user=self._user, agent=self._agent, + client = self._new_client() + self._post_session_turn( + client, + sid, + user_content[:4000], + assistant_content[:4000], ) - client.post(f"/api/v1/sessions/{sid}/messages", { - "role": "user", - "content": user_content[:4000], - }) - client.post(f"/api/v1/sessions/{sid}/messages", { - "role": "assistant", - "content": assistant_content[:4000], - }) except Exception as e: - logger.debug("OpenViking sync_turn failed: %s", e) + logger.debug("OpenViking sync_turn failed, reconnecting: %s", e) + try: + client = self._new_client() + self._post_session_turn( + client, + sid, + user_content[:4000], + assistant_content[:4000], + ) + except Exception as retry_error: + logger.warning("OpenViking sync_turn failed: %s", retry_error) self._spawn_writer(sid, _sync, name="openviking-sync") @@ -686,23 +791,19 @@ class OpenVikingMemoryProvider(MemoryProvider): sid = self._session_id # Commit only after session writes drain. - if not self._drain_writers(sid, timeout=10.0): + if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT): logger.warning( "OpenViking writer for %s still alive after drain — skipping commit", sid, ) return - if self._turn_count == 0: + if self._turn_count == 0 and not self._session_has_pending_tokens(sid): return - try: - self._client.post(f"/api/v1/sessions/{sid}/commit") - logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count) + if self._commit_session(sid, self._turn_count, context="on session end"): # Mark clean so a follow-up on_session_switch skips its own commit. self._turn_count = 0 - except Exception as e: - logger.warning("OpenViking session commit failed: %s", e) def on_session_switch( self, @@ -736,7 +837,7 @@ class OpenVikingMemoryProvider(MemoryProvider): # Commit only after session writes drain. writers_drained = True if old_session_id: - writers_drained = self._drain_writers(old_session_id, timeout=10.0) + writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT) if not writers_drained: logger.warning( "OpenViking writer for %s still alive after drain — " @@ -745,17 +846,9 @@ class OpenVikingMemoryProvider(MemoryProvider): ) if writers_drained and old_session_id and old_turn_count > 0: - try: - self._client.post(f"/api/v1/sessions/{old_session_id}/commit") - logger.info( - "OpenViking session %s committed on switch (%d turns)", - old_session_id, old_turn_count, - ) - except Exception as e: - logger.warning( - "OpenViking commit-on-switch failed for %s: %s", - old_session_id, e, - ) + self._commit_session(old_session_id, old_turn_count, context="on switch") + elif not writers_drained: + self._schedule_deferred_commit(old_session_id, old_turn_count) # Drop prefetch results from older switch generations. self._prefetch_generation += 1 @@ -835,9 +928,14 @@ class OpenVikingMemoryProvider(MemoryProvider): all_workers = [ t for workers in self._inflight_writers.values() for t in workers ] + with self._deferred_commit_lock: + deferred_workers = list(self._deferred_commit_threads) for t in all_workers: if t.is_alive(): t.join(timeout=5.0) + for t in deferred_workers: + if t.is_alive(): + t.join(timeout=5.0) if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=5.0) # Clear atexit reference so it doesn't double-commit. @@ -899,7 +997,7 @@ class OpenVikingMemoryProvider(MemoryProvider): if args.get("scope"): payload["target_uri"] = args["scope"] if args.get("limit"): - payload["top_k"] = args["limit"] + payload["limit"] = args["limit"] resp = self._client.post("/api/v1/search/find", payload) result = resp.get("result", {}) diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 0479d45aae5..c76b08fadfa 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -66,6 +66,21 @@ def test_tool_search_sorts_missing_raw_score_after_negative_scores(): assert result["total"] == 3 +def test_tool_search_sends_limit_not_legacy_top_k(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._client.post.return_value = { + "result": {"memories": [], "resources": [], "skills": [], "total": 0} + } + + provider._tool_search({"query": "session switch", "limit": 7}) + + provider._client.post.assert_called_once() + payload = provider._client.post.call_args.args[1] + assert payload["limit"] == 7 + assert "top_k" not in payload + + def test_tool_add_resource_uploads_existing_local_file(tmp_path): sample = tmp_path / "sample.md" sample.write_text("# Local resource\n", encoding="utf-8") @@ -534,11 +549,13 @@ def test_sync_turn_captures_session_id_before_worker_runs(): started = threading.Event() release = threading.Event() captured_paths = [] + captured_payloads = [] def fake_post(path, payload=None, **kwargs): started.set() release.wait(timeout=2.0) captured_paths.append(path) + captured_payloads.append(payload) return {} # Patch _VikingClient inside the worker by stubbing post on a client @@ -566,11 +583,59 @@ def test_sync_turn_captures_session_id_before_worker_runs(): finally: _mod._VikingClient = real_client_cls - # Both writes must target the OLD session id captured at call time. - assert captured_paths == [ - "/api/v1/sessions/old-sid/messages", - "/api/v1/sessions/old-sid/messages", - ] + # The whole turn must target the OLD session id as a single ordered batch. + assert captured_paths == ["/api/v1/sessions/old-sid/messages/batch"] + assert captured_payloads == [{ + "messages": [ + {"role": "user", "parts": [{"type": "text", "text": "u"}]}, + {"role": "assistant", "parts": [{"type": "text", "text": "a"}]}, + ] + }] + + +def test_sync_turn_retries_batch_write_with_fresh_client(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "sid-1" + + clients = [] + captured = [] + + class StubClient: + def __init__(self, *a, **kw): + self.index = len(clients) + clients.append(self) + + def post(self, path, payload=None, **kwargs): + if self.index == 0: + raise RuntimeError("transient") + captured.append((path, payload)) + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + assert provider._drain_writers("sid-1", timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + assert len(clients) == 2 + assert captured == [( + "/api/v1/sessions/sid-1/messages/batch", + { + "messages": [ + {"role": "user", "parts": [{"type": "text", "text": "u"}]}, + {"role": "assistant", "parts": [{"type": "text", "text": "a"}]}, + ] + }, + )] def test_sync_turn_noop_when_session_id_blank(): @@ -608,6 +673,16 @@ def test_on_session_end_keeps_dirty_when_commit_fails(): assert provider._turn_count == 3 +def test_on_session_end_commits_pending_tokens_without_turn_count(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_end([]) + + provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid") + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + + def test_end_then_switch_does_not_double_commit(): """Mirrors the /new and compression call order: commit_memory_session (→ on_session_end) immediately followed by on_session_switch. The switch @@ -710,6 +785,34 @@ def test_on_session_switch_waits_for_all_writers_not_just_latest(): assert provider._turn_count == 0 +def test_on_session_switch_defers_old_commit_when_writers_finish_after_initial_drain(): + import threading + + provider = _make_provider_with_session("old-sid", turn_count=2) + committed = threading.Event() + drain_timeouts = [] + + def fake_post(path): + committed.set() + return {} + + def fake_drain(sid, timeout): + drain_timeouts.append(timeout) + return len(drain_timeouts) > 1 + + provider._client.post.side_effect = fake_post + provider._drain_writers = fake_drain + + provider.on_session_switch("new-sid") + + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + assert committed.wait(timeout=2.0), "old session was not finalized after writers drained" + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert drain_timeouts[0] == 10.0 + assert drain_timeouts[1] > 10.0 + + def test_sync_turn_tracks_writer_under_session_id(): """Every sync_turn writer must register under its captured sid so the drain at end/switch sees it even if a later sync_turn replaces the