fix(openviking): handle rewound session switches

This commit is contained in:
Hao Zhe 2026-06-17 14:46:06 +08:00
parent 00c045b43f
commit 3ac6551ba3
2 changed files with 128 additions and 11 deletions

View file

@ -444,6 +444,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._deferred_commit_sids: Set[str] = set()
self._deferred_commit_threads: Set[threading.Thread] = set()
self._deferred_commit_lock = threading.Lock()
self._committed_session_ids: Set[str] = set()
self._committed_session_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
@ -563,7 +565,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
return
# Drop prefetch results from older switch generations.
gen = self._prefetch_generation
with self._prefetch_lock:
gen = self._prefetch_generation
def _run():
try:
@ -573,7 +576,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
resp = client.post("/api/v1/search/find", {
"query": query,
"top_k": 5,
"limit": 5,
})
result = resp.get("result", {})
parts = []
@ -695,9 +698,25 @@ class OpenVikingMemoryProvider(MemoryProvider):
except (TypeError, ValueError):
return False
def _has_committed_session(self, sid: str) -> bool:
with self._committed_session_lock:
return sid in self._committed_session_ids
def _mark_session_committed(self, sid: str) -> None:
with self._committed_session_lock:
self._committed_session_ids.add(sid)
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
if turn_count > 0:
return True
if self._has_committed_session(sid):
return False
return self._session_has_pending_tokens(sid)
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
self._mark_session_committed(sid)
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
except Exception as e:
@ -740,6 +759,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._deferred_commit_threads.add(thread)
thread.start()
def _invalidate_prefetch_state(self) -> None:
# Bump the generation under the same lock used by prefetch workers so
# late results from an older session are discarded deterministically.
with self._prefetch_lock:
self._prefetch_generation += 1
self._prefetch_result = ""
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
@ -798,7 +828,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
return
if self._turn_count == 0 and not self._session_has_pending_tokens(sid):
if not self._session_needs_commit(sid, self._turn_count):
return
if self._commit_session(sid, self._turn_count, context="on session end"):
@ -831,8 +861,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not new_id or not self._client:
return
rewound = bool(kwargs.get("rewound"))
old_session_id = self._session_id
old_turn_count = self._turn_count
if rewound or new_id == old_session_id:
self._invalidate_prefetch_state()
logger.debug(
"OpenViking on_session_switch invalidated state without rotation: "
"session=%s rewound=%s",
old_session_id, rewound,
)
return
# Commit only after session writes drain.
writers_drained = True
@ -845,17 +884,16 @@ class OpenVikingMemoryProvider(MemoryProvider):
old_session_id,
)
if writers_drained and old_session_id and old_turn_count > 0:
if (
writers_drained
and old_session_id
and self._session_needs_commit(old_session_id, old_turn_count)
):
self._commit_session(old_session_id, old_turn_count, context="on switch")
elif not writers_drained:
self._schedule_deferred_commit(old_session_id, old_turn_count)
# Drop prefetch results from older switch generations.
self._prefetch_generation += 1
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
self._invalidate_prefetch_state()
self._session_id = new_id
self._turn_count = 0

View file

@ -8,6 +8,11 @@ import pytest
from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient
def _clear_openviking_tenant_env(monkeypatch):
for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"):
monkeypatch.delenv(name, raising=False)
def test_tool_search_sorts_by_raw_score_across_buckets():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
@ -386,7 +391,8 @@ def test_viking_client_headers_send_tenant_when_default():
assert headers["Authorization"] == "Bearer test-key"
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default():
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
# Empty account/user strings fall back to "default" via the constructor.
# Headers are sent even for the default value — ROOT API keys need them.
client = _VikingClient(
@ -417,6 +423,7 @@ def test_viking_client_headers_sent_with_real_tenant_values():
def test_viking_client_health_sends_auth_headers(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
client = _VikingClient(
"https://example.com",
api_key="test-key",
@ -470,6 +477,33 @@ def test_on_session_switch_skips_commit_for_empty_old_session():
assert provider._turn_count == 0
def test_on_session_switch_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_switch("new-sid")
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_rewound_same_session_only_invalidates_prefetch():
provider = _make_provider_with_session("same-sid", turn_count=3)
provider._prefetch_generation = 9
provider._prefetch_result = "stale recall"
provider.on_session_switch("same-sid", rewound=True)
provider._client.get.assert_not_called()
provider._client.post.assert_not_called()
assert provider._session_id == "same-sid"
assert provider._turn_count == 3
assert provider._prefetch_generation == 10
assert provider._prefetch_result == ""
def test_on_session_switch_clears_stale_prefetch_result():
provider = _make_provider_with_session("old-sid", turn_count=1)
provider._prefetch_result = "stale recall from old session"
@ -698,6 +732,18 @@ def test_end_then_switch_does_not_double_commit():
assert provider._turn_count == 0
def test_end_then_switch_with_pending_tokens_does_not_double_commit():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_swallows_commit_failure():
"""Commit-on-switch must not propagate exceptions: a failing commit on the
old session must still allow the rotate to the new session to complete,
@ -971,3 +1017,36 @@ def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
# The stale result from the pre-bump generation must NOT have been written
# into the new generation's prefetch slot.
assert provider._prefetch_result == ""
def test_queue_prefetch_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
captured_payloads = []
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
captured_payloads.append(payload)
return {"result": {"memories": [], "resources": []}}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert captured_payloads == [{"query": "anything", "limit": 5}]
assert "top_k" not in captured_payloads[0]