From 3ac6551ba3d3c35e67cd808a0b9363de1b13d29f Mon Sep 17 00:00:00 2001 From: Hao Zhe Date: Wed, 17 Jun 2026 14:46:06 +0800 Subject: [PATCH] fix(openviking): handle rewound session switches --- plugins/memory/openviking/__init__.py | 58 ++++++++++--- .../memory/test_openviking_provider.py | 81 ++++++++++++++++++- 2 files changed, 128 insertions(+), 11 deletions(-) diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index 9d0e0cb199c..9bd10209d61 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -444,6 +444,8 @@ class OpenVikingMemoryProvider(MemoryProvider): self._deferred_commit_sids: Set[str] = set() self._deferred_commit_threads: Set[threading.Thread] = set() self._deferred_commit_lock = threading.Lock() + self._committed_session_ids: Set[str] = set() + self._committed_session_lock = threading.Lock() self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None @@ -563,7 +565,8 @@ class OpenVikingMemoryProvider(MemoryProvider): return # Drop prefetch results from older switch generations. - gen = self._prefetch_generation + with self._prefetch_lock: + gen = self._prefetch_generation def _run(): try: @@ -573,7 +576,7 @@ class OpenVikingMemoryProvider(MemoryProvider): ) resp = client.post("/api/v1/search/find", { "query": query, - "top_k": 5, + "limit": 5, }) result = resp.get("result", {}) parts = [] @@ -695,9 +698,25 @@ class OpenVikingMemoryProvider(MemoryProvider): except (TypeError, ValueError): return False + def _has_committed_session(self, sid: str) -> bool: + with self._committed_session_lock: + return sid in self._committed_session_ids + + def _mark_session_committed(self, sid: str) -> None: + with self._committed_session_lock: + self._committed_session_ids.add(sid) + + def _session_needs_commit(self, sid: str, turn_count: int) -> bool: + if turn_count > 0: + return True + if self._has_committed_session(sid): + return False + return self._session_has_pending_tokens(sid) + def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool: try: self._client.post(f"/api/v1/sessions/{sid}/commit") + self._mark_session_committed(sid) logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count) return True except Exception as e: @@ -740,6 +759,17 @@ class OpenVikingMemoryProvider(MemoryProvider): self._deferred_commit_threads.add(thread) thread.start() + def _invalidate_prefetch_state(self) -> None: + # Bump the generation under the same lock used by prefetch workers so + # late results from an older session are discarded deterministically. + with self._prefetch_lock: + self._prefetch_generation += 1 + self._prefetch_result = "" + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=3.0) + with self._prefetch_lock: + self._prefetch_result = "" + def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Record the conversation turn in OpenViking's session (non-blocking).""" if not self._client: @@ -798,7 +828,7 @@ class OpenVikingMemoryProvider(MemoryProvider): ) return - if self._turn_count == 0 and not self._session_has_pending_tokens(sid): + if not self._session_needs_commit(sid, self._turn_count): return if self._commit_session(sid, self._turn_count, context="on session end"): @@ -831,8 +861,17 @@ class OpenVikingMemoryProvider(MemoryProvider): if not new_id or not self._client: return + rewound = bool(kwargs.get("rewound")) old_session_id = self._session_id old_turn_count = self._turn_count + if rewound or new_id == old_session_id: + self._invalidate_prefetch_state() + logger.debug( + "OpenViking on_session_switch invalidated state without rotation: " + "session=%s rewound=%s", + old_session_id, rewound, + ) + return # Commit only after session writes drain. writers_drained = True @@ -845,17 +884,16 @@ class OpenVikingMemoryProvider(MemoryProvider): old_session_id, ) - if writers_drained and old_session_id and old_turn_count > 0: + if ( + writers_drained + and old_session_id + and self._session_needs_commit(old_session_id, old_turn_count) + ): self._commit_session(old_session_id, old_turn_count, context="on switch") elif not writers_drained: self._schedule_deferred_commit(old_session_id, old_turn_count) - # Drop prefetch results from older switch generations. - self._prefetch_generation += 1 - if self._prefetch_thread and self._prefetch_thread.is_alive(): - self._prefetch_thread.join(timeout=3.0) - with self._prefetch_lock: - self._prefetch_result = "" + self._invalidate_prefetch_state() self._session_id = new_id self._turn_count = 0 diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index c76b08fadfa..2a58cc43cba 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -8,6 +8,11 @@ import pytest from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient +def _clear_openviking_tenant_env(monkeypatch): + for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"): + monkeypatch.delenv(name, raising=False) + + def test_tool_search_sorts_by_raw_score_across_buckets(): provider = OpenVikingMemoryProvider() provider._client = MagicMock() @@ -386,7 +391,8 @@ def test_viking_client_headers_send_tenant_when_default(): assert headers["Authorization"] == "Bearer test-key" -def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(): +def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch): + _clear_openviking_tenant_env(monkeypatch) # Empty account/user strings fall back to "default" via the constructor. # Headers are sent even for the default value — ROOT API keys need them. client = _VikingClient( @@ -417,6 +423,7 @@ def test_viking_client_headers_sent_with_real_tenant_values(): def test_viking_client_health_sends_auth_headers(monkeypatch): + _clear_openviking_tenant_env(monkeypatch) client = _VikingClient( "https://example.com", api_key="test-key", @@ -470,6 +477,33 @@ def test_on_session_switch_skips_commit_for_empty_old_session(): assert provider._turn_count == 0 +def test_on_session_switch_commits_pending_tokens_without_turn_count(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_switch("new-sid") + + provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid") + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_rewound_same_session_only_invalidates_prefetch(): + provider = _make_provider_with_session("same-sid", turn_count=3) + provider._prefetch_generation = 9 + provider._prefetch_result = "stale recall" + + provider.on_session_switch("same-sid", rewound=True) + + provider._client.get.assert_not_called() + provider._client.post.assert_not_called() + assert provider._session_id == "same-sid" + assert provider._turn_count == 3 + assert provider._prefetch_generation == 10 + assert provider._prefetch_result == "" + + def test_on_session_switch_clears_stale_prefetch_result(): provider = _make_provider_with_session("old-sid", turn_count=1) provider._prefetch_result = "stale recall from old session" @@ -698,6 +732,18 @@ def test_end_then_switch_does_not_double_commit(): assert provider._turn_count == 0 +def test_end_then_switch_with_pending_tokens_does_not_double_commit(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_end([]) + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + def test_on_session_switch_swallows_commit_failure(): """Commit-on-switch must not propagate exceptions: a failing commit on the old session must still allow the rotate to the new session to complete, @@ -971,3 +1017,36 @@ def test_queue_prefetch_drops_result_when_generation_changed_mid_flight(): # The stale result from the pre-bump generation must NOT have been written # into the new generation's prefetch slot. assert provider._prefetch_result == "" + + +def test_queue_prefetch_sends_limit_not_legacy_top_k(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + + captured_payloads = [] + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + captured_payloads.append(payload) + return {"result": {"memories": [], "resources": []}} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.queue_prefetch("anything") + if provider._prefetch_thread: + provider._prefetch_thread.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + assert captured_payloads == [{"query": "anything", "limit": 5}] + assert "top_k" not in captured_payloads[0]