fix(openviking): harden session writes and switch commits

This commit is contained in:
Hao Zhe 2026-06-17 13:16:03 +08:00
parent f3b813c027
commit 00c045b43f
2 changed files with 238 additions and 37 deletions

View file

@ -47,6 +47,8 @@ logger = logging.getLogger(__name__)
_DEFAULT_ENDPOINT = "http://127.0.0.1:1933"
_TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
# Maps the viking_remember `category` enum to a viking:// subdirectory.
@ -439,6 +441,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
# if later writes have replaced the latest-tracked thread.
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
self._inflight_lock = threading.Lock()
self._deferred_commit_sids: Set[str] = set()
self._deferred_commit_threads: Set[threading.Thread] = set()
self._deferred_commit_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
@ -625,6 +630,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
Returns True if all writers drained, False if any are still alive when
the budget runs out. Callers use the False return to skip the commit.
"""
if not sid:
return True
deadline = time.monotonic() + timeout
while True:
with self._inflight_lock:
@ -640,6 +647,99 @@ class OpenVikingMemoryProvider(MemoryProvider):
break
t.join(timeout=slice_left)
def _new_client(self) -> _VikingClient:
return _VikingClient(
self._endpoint,
self._api_key,
account=self._account,
user=self._user,
agent=self._agent,
)
@staticmethod
def _text_part(content: str) -> Dict[str, str]:
return {"type": "text", "text": content}
@classmethod
def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]:
return {
"messages": [
{"role": "user", "parts": [cls._text_part(user_content)]},
{"role": "assistant", "parts": [cls._text_part(assistant_content)]},
]
}
@classmethod
def _post_session_turn(
cls,
client: _VikingClient,
sid: str,
user_content: str,
assistant_content: str,
) -> None:
client.post(
f"/api/v1/sessions/{sid}/messages/batch",
cls._turn_batch_payload(user_content, assistant_content),
)
def _session_has_pending_tokens(self, sid: str) -> bool:
try:
response = self._client.get(f"/api/v1/sessions/{sid}")
except Exception:
return False
session = self._unwrap_result(response)
if not isinstance(session, dict):
return False
try:
return int(session.get("pending_tokens") or 0) > 0
except (TypeError, ValueError):
return False
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
except Exception as e:
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _schedule_deferred_commit(self, sid: str, turn_count: int) -> None:
if not sid or turn_count <= 0:
return
with self._deferred_commit_lock:
if sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
holder: List[threading.Thread] = []
def _finalize() -> None:
try:
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after deferred drain — "
"leaving session uncommitted",
sid,
)
return
self._commit_session(sid, turn_count, context="after deferred drain")
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
if holder:
self._deferred_commit_threads.discard(holder[0])
thread = threading.Thread(
target=_finalize,
daemon=True,
name=f"openviking-finalize-{sid}",
)
holder.append(thread)
with self._deferred_commit_lock:
self._deferred_commit_threads.add(thread)
thread.start()
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
@ -658,20 +758,25 @@ class OpenVikingMemoryProvider(MemoryProvider):
def _sync():
try:
client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "user",
"content": user_content[:4000],
})
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "assistant",
"content": assistant_content[:4000],
})
except Exception as e:
logger.debug("OpenViking sync_turn failed: %s", e)
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)
self._spawn_writer(sid, _sync, name="openviking-sync")
@ -686,23 +791,19 @@ class OpenVikingMemoryProvider(MemoryProvider):
sid = self._session_id
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=10.0):
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — skipping commit",
sid,
)
return
if self._turn_count == 0:
if self._turn_count == 0 and not self._session_has_pending_tokens(sid):
return
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
logger.info("OpenViking session %s committed (%d turns)", sid, self._turn_count)
if self._commit_session(sid, self._turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
self._turn_count = 0
except Exception as e:
logger.warning("OpenViking session commit failed: %s", e)
def on_session_switch(
self,
@ -736,7 +837,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
# Commit only after session writes drain.
writers_drained = True
if old_session_id:
writers_drained = self._drain_writers(old_session_id, timeout=10.0)
writers_drained = self._drain_writers(old_session_id, timeout=_SESSION_DRAIN_TIMEOUT)
if not writers_drained:
logger.warning(
"OpenViking writer for %s still alive after drain — "
@ -745,17 +846,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
if writers_drained and old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
"OpenViking session %s committed on switch (%d turns)",
old_session_id, old_turn_count,
)
except Exception as e:
logger.warning(
"OpenViking commit-on-switch failed for %s: %s",
old_session_id, e,
)
self._commit_session(old_session_id, old_turn_count, context="on switch")
elif not writers_drained:
self._schedule_deferred_commit(old_session_id, old_turn_count)
# Drop prefetch results from older switch generations.
self._prefetch_generation += 1
@ -835,9 +928,14 @@ class OpenVikingMemoryProvider(MemoryProvider):
all_workers = [
t for workers in self._inflight_writers.values() for t in workers
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
@ -899,7 +997,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
if args.get("scope"):
payload["target_uri"] = args["scope"]
if args.get("limit"):
payload["top_k"] = args["limit"]
payload["limit"] = args["limit"]
resp = self._client.post("/api/v1/search/find", payload)
result = resp.get("result", {})