Merge pull request #48042 from kshitijk4poor/salvage-47662

fix(openviking): implement on_session_switch hook + harden session writes (salvage #47662)
This commit is contained in:
kshitij 2026-06-18 02:34:27 +05:30 committed by GitHub
commit 7fbb8c9df5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 1121 additions and 59 deletions

View file

@ -31,10 +31,11 @@ import mimetypes
import os
import tempfile
import threading
import time
import uuid
import zipfile
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Set
from urllib.parse import urlparse
from urllib.request import url2pathname
@ -46,6 +47,8 @@ logger = logging.getLogger(__name__)
_DEFAULT_ENDPOINT = "http://127.0.0.1:1933"
_TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
# Maps the viking_remember `category` enum to a viking:// subdirectory.
@ -432,10 +435,35 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._api_key = ""
self._session_id = ""
self._turn_count = 0
self._sync_thread: Optional[threading.Thread] = None
# Guards the (_session_id, _turn_count) pair. sync_turn runs on the
# MemoryManager's background sync executor while on_session_end /
# on_session_switch run on the caller's thread, so the snapshot+reset
# of the turn counter and the session-id rotation must be atomic
# against a concurrent increment. See hermes-agent#28296 review.
self._session_state_lock = threading.Lock()
# Commit only after session writes drain. The set is keyed by the sid
# the writer is POSTing under (snapshotted at spawn), so on_session_end
# / on_session_switch see every still-alive writer for that sid even
# if later writes have replaced the latest-tracked thread.
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
self._inflight_lock = threading.Lock()
self._deferred_commit_sids: Set[str] = set()
self._deferred_commit_threads: Set[threading.Thread] = set()
self._deferred_commit_lock = threading.Lock()
self._committed_session_ids: Set[str] = set()
self._committed_session_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
# shutdown() can drain them and rapid re-queues don't orphan a still-
# running thread by overwriting the single _prefetch_thread slot.
self._prefetch_threads: Set[threading.Thread] = set()
# Set on shutdown so deferred-commit / writer finalizers stop issuing
# network writes against a torn-down provider.
self._shutting_down = False
# Drop prefetch results from older switch generations.
self._prefetch_generation = 0
@property
def name(self) -> str:
@ -549,6 +577,12 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client or not query:
return
# Drop prefetch results from older switch generations.
with self._prefetch_lock:
gen = self._prefetch_generation
holder: List[threading.Thread] = []
def _run():
try:
client = _VikingClient(
@ -557,7 +591,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
)
resp = client.post("/api/v1/search/find", {
"query": query,
"top_k": 5,
"limit": 5,
})
result = resp.get("result", {})
parts = []
@ -571,14 +605,237 @@ class OpenVikingMemoryProvider(MemoryProvider):
parts.append(f"- [{score:.2f}] {abstract} ({uri})")
if parts:
with self._prefetch_lock:
if gen != self._prefetch_generation:
return
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
finally:
with self._prefetch_lock:
if holder:
self._prefetch_threads.discard(holder[0])
self._prefetch_thread = threading.Thread(
thread = threading.Thread(
target=_run, daemon=True, name="openviking-prefetch"
)
self._prefetch_thread.start()
holder.append(thread)
with self._prefetch_lock:
self._prefetch_thread = thread
self._prefetch_threads.add(thread)
thread.start()
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
"""Spawn a daemon writer tracked in _inflight_writers[sid].
Tracking is keyed by sid (not by a single latest-thread slot) so that
on_session_end / on_session_switch can drain every still-alive writer
for the session being committed.
"""
holder: List[threading.Thread] = []
def _wrapped():
try:
target()
finally:
with self._inflight_lock:
workers = self._inflight_writers.get(sid)
if workers is not None:
workers.discard(holder[0])
if not workers:
self._inflight_writers.pop(sid, None)
thread = threading.Thread(target=_wrapped, daemon=True, name=name)
holder.append(thread)
with self._inflight_lock:
self._inflight_writers.setdefault(sid, set()).add(thread)
thread.start()
def _drain_finalizers(self, timeout: float) -> bool:
"""Join every in-flight async session finalizer within a timeout.
The switch-path commit runs on a daemon finalizer thread so it never
blocks the caller's command thread; this lets shutdown and tests wait
for those commits deterministically. Returns True if all drained.
"""
deadline = time.monotonic() + timeout
while True:
with self._deferred_commit_lock:
workers = [t for t in self._deferred_commit_threads if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
# Floor the per-join wait so a thread whose join() returns
# instantly while still reporting alive can't hot-spin this loop.
t.join(timeout=min(slice_left, 0.05))
def _drain_writers(self, sid: str, timeout: float) -> bool:
"""Join every in-flight writer for sid within a shared timeout budget.
Returns True if all writers drained, False if any are still alive when
the budget runs out. Callers use the False return to skip the commit.
"""
if not sid:
return True
deadline = time.monotonic() + timeout
while True:
with self._inflight_lock:
workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
t.join(timeout=slice_left)
def _new_client(self) -> _VikingClient:
return _VikingClient(
self._endpoint,
self._api_key,
account=self._account,
user=self._user,
agent=self._agent,
)
@staticmethod
def _text_part(content: str) -> Dict[str, str]:
return {"type": "text", "text": content}
@classmethod
def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]:
return {
"messages": [
{"role": "user", "parts": [cls._text_part(user_content)]},
{"role": "assistant", "parts": [cls._text_part(assistant_content)]},
]
}
@classmethod
def _post_session_turn(
cls,
client: _VikingClient,
sid: str,
user_content: str,
assistant_content: str,
) -> None:
client.post(
f"/api/v1/sessions/{sid}/messages/batch",
cls._turn_batch_payload(user_content, assistant_content),
)
def _session_has_pending_tokens(self, sid: str) -> bool:
try:
response = self._client.get(f"/api/v1/sessions/{sid}")
except Exception:
return False
session = self._unwrap_result(response)
if not isinstance(session, dict):
return False
try:
return int(session.get("pending_tokens") or 0) > 0
except (TypeError, ValueError):
return False
def _has_committed_session(self, sid: str) -> bool:
with self._committed_session_lock:
return sid in self._committed_session_ids
def _mark_session_committed(self, sid: str) -> None:
with self._committed_session_lock:
self._committed_session_ids.add(sid)
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
# Already-committed sessions never need a second commit, regardless of
# the turn counter — a racing sync_turn can re-increment _turn_count
# after a commit+reset, so the committed-guard must win over turn_count.
if self._has_committed_session(sid):
return False
if turn_count > 0:
return True
return self._session_has_pending_tokens(sid)
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
self._mark_session_committed(sid)
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
except Exception as e:
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None:
"""Drain the old session's writers and commit it on a daemon thread.
Used by on_session_switch (and the deferred-commit fallback) so the
potentially-multi-second drain + pending-token GET + commit POST never
runs on the caller's command thread. Deduped by sid so a rapid second
switch can't stack two finalizers for the same session, and a no-op
once shutdown has begun so we don't POST against a torn-down client.
"""
if not sid:
return
with self._deferred_commit_lock:
if self._shutting_down or sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
holder: List[threading.Thread] = []
def _finalize() -> None:
try:
if self._shutting_down:
return
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — "
"leaving session uncommitted",
sid,
)
return
if self._shutting_down:
return
if self._session_needs_commit(sid, turn_count):
self._commit_session(sid, turn_count, context=context)
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
if holder:
self._deferred_commit_threads.discard(holder[0])
thread = threading.Thread(
target=_finalize,
daemon=True,
name=f"openviking-finalize-{sid}",
)
holder.append(thread)
with self._deferred_commit_lock:
self._deferred_commit_threads.add(thread)
thread.start()
def _invalidate_prefetch_state(self) -> None:
# Bump the generation under the same lock used by prefetch workers so
# late results from an older session are discarded deterministically.
with self._prefetch_lock:
self._prefetch_generation += 1
self._prefetch_result = ""
# Join EVERY tracked prefetch thread, not just the latest slot — a
# rapid re-queue can leave an older thread for the abandoned session
# still running (consistent with shutdown()).
workers = [t for t in self._prefetch_threads if t.is_alive()]
for t in workers:
t.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
@ -589,37 +846,39 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not user_content:
return
self._turn_count += 1
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
# the turn is unambiguously attributed to the session it targets.
with self._session_state_lock:
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
def _sync():
try:
client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
sid = self._session_id
# Add user message
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "user",
"content": user_content[:4000], # trim very long messages
})
# Add assistant message
client.post(f"/api/v1/sessions/{sid}/messages", {
"role": "assistant",
"content": assistant_content[:4000],
})
except Exception as e:
logger.debug("OpenViking sync_turn failed: %s", e)
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)
# Wait for any previous sync to finish before starting a new one
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
self._sync_thread = threading.Thread(
target=_sync, daemon=True, name="openviking-sync"
)
self._sync_thread.start()
self._spawn_writer(sid, _sync, name="openviking-sync")
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
"""Commit the session to trigger memory extraction.
@ -630,20 +889,98 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client:
return
# Wait for any pending sync to finish first — do this before the
# turn_count check so the last turn's messages are flushed even if
# the count hasn't been incremented yet.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
# Snapshot sid + turn count atomically against a concurrent sync_turn
# increment. on_session_end runs at teardown so the drain+commit stays
# synchronous here (we want it to land before the process exits), but
# the counter read must still be consistent.
with self._session_state_lock:
sid = self._session_id
turn_count = self._turn_count
if self._turn_count == 0:
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — skipping commit",
sid,
)
return
try:
self._client.post(f"/api/v1/sessions/{self._session_id}/commit")
logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count)
except Exception as e:
logger.warning("OpenViking session commit failed: %s", e)
if not self._session_needs_commit(sid, turn_count):
return
if self._commit_session(sid, turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
with self._session_state_lock:
if self._session_id == sid:
self._turn_count = 0
def on_session_switch(
self,
new_session_id: str,
*,
parent_session_id: str = "",
reset: bool = False,
**kwargs,
) -> None:
"""Commit the old session and rotate cached state to the new session_id.
Fires on /resume, /branch, /reset, /new, and context compression.
Without this hook, ``_session_id`` stays stuck at the value
``initialize()`` cached, so subsequent ``sync_turn()`` writes land in
the already-closed old session and ``on_session_end()`` tries to
commit it a second time. The new session never accumulates messages,
and memory extraction never fires for it. See hermes-agent#28296.
Flushes any in-flight sync under the old session_id, commits the old
session if it has pending turns (same extraction semantics as
``on_session_end``), drains and clears any stale prefetch result,
then rotates ``_session_id`` and resets ``_turn_count``.
"""
new_id = str(new_session_id or "").strip()
if not new_id or not self._client:
return
rewound = bool(kwargs.get("rewound"))
# Rotate cached session state synchronously (cheap, in-memory) and
# snapshot the old session under the lock so a concurrent sync_turn
# either lands fully before the rotation (counted under old) or fully
# after (counted under new) — never split. The OLD session's commit
# (drain + pending-token GET + commit POST, potentially many seconds)
# is then offloaded so /new, /branch, /resume, /undo never block the
# caller's command thread (cf. the end-of-turn-sync offload in #41945).
with self._session_state_lock:
old_session_id = self._session_id
old_turn_count = self._turn_count
rotate = not (rewound or new_id == old_session_id)
if rotate:
self._session_id = new_id
self._turn_count = 0
# Invalidate stale prefetch OUTSIDE the session lock — it takes its own
# _prefetch_lock and may join a prefetch thread for up to 3s, which we
# must not do while holding the session lock (would block sync_turn and
# risk lock-ordering coupling).
self._invalidate_prefetch_state()
if not rotate:
# Same-session rewind (/undo) or no-op rotation: no commit, no
# counter reset — just the prefetch invalidation above.
logger.debug(
"OpenViking on_session_switch invalidated state without rotation: "
"session=%s rewound=%s",
old_session_id, rewound,
)
return
# Drain + commit the OLD session off the command thread.
if old_session_id:
self._finalize_session_async(old_session_id, old_turn_count, context="on switch")
logger.debug(
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
old_session_id, new_id, parent_session_id, reset,
)
def _build_memory_uri(self, subdir: str) -> str:
"""Build a viking:// memory URI under the configured user/agent/subdir."""
@ -704,11 +1041,28 @@ class OpenVikingMemoryProvider(MemoryProvider):
return tool_error(str(e))
def shutdown(self) -> None:
# Wait for background threads to finish
for t in (self._sync_thread, self._prefetch_thread):
if t and t.is_alive():
# Stop deferred finalizers from issuing new commits against a
# torn-down client, then drain everything still in flight.
self._shutting_down = True
# Wait for every in-flight writer across all tracked sessions.
with self._inflight_lock:
all_workers = [
t for workers in self._inflight_writers.values() for t in workers
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
with self._prefetch_lock:
prefetch_workers = list(self._prefetch_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in prefetch_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self:
_last_active_provider = None
@ -767,7 +1121,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
if args.get("scope"):
payload["target_uri"] = args["scope"]
if args.get("limit"):
payload["top_k"] = args["limit"]
payload["limit"] = args["limit"]
resp = self._client.post("/api/v1/search/find", payload)
result = resp.get("result", {})

View file

@ -150,7 +150,7 @@ class TestOpenVikingSkillQuerySafety:
assert RecordingVikingClient.calls == [
(
"/api/v1/search/find",
{"query": "make a skill for release triage", "top_k": 5},
{"query": "make a skill for release triage", "limit": 5},
)
]
@ -181,7 +181,7 @@ class TestOpenVikingSkillQuerySafety:
assert RecordingVikingClient.calls == [
(
"/api/v1/search/find",
{"query": "fix the failing retrieval test", "top_k": 5},
{"query": "fix the failing retrieval test", "limit": 5},
)
]
@ -223,17 +223,25 @@ class TestOpenVikingSkillQuerySafety:
)
provider.sync_turn(skill_message, "Done.")
assert provider._sync_thread is not None
provider._sync_thread.join(timeout=5.0)
assert provider._drain_writers("session-1", timeout=5.0)
assert RecordingVikingClient.calls == [
(
"/api/v1/sessions/session-1/messages",
{"role": "user", "content": "make a skill for release triage"},
),
(
"/api/v1/sessions/session-1/messages",
{"role": "assistant", "content": "Done."},
"/api/v1/sessions/session-1/messages/batch",
{
"messages": [
{
"role": "user",
"parts": [
{"type": "text", "text": "make a skill for release triage"},
],
},
{
"role": "assistant",
"parts": [{"type": "text", "text": "Done."}],
},
]
},
),
]
@ -251,7 +259,8 @@ class TestOpenVikingSkillQuerySafety:
provider.sync_turn(skill_message, "Done.")
assert provider._sync_thread is None
assert provider._turn_count == 0
assert provider._inflight_writers == {}
assert RecordingVikingClient.calls == []

View file

@ -5,7 +5,16 @@ from unittest.mock import MagicMock
import pytest
from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient
from plugins.memory.openviking import (
OpenVikingMemoryProvider,
_DEFERRED_COMMIT_TIMEOUT,
_VikingClient,
)
def _clear_openviking_tenant_env(monkeypatch):
for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"):
monkeypatch.delenv(name, raising=False)
def test_tool_search_sorts_by_raw_score_across_buckets():
@ -66,6 +75,21 @@ def test_tool_search_sorts_missing_raw_score_after_negative_scores():
assert result["total"] == 3
def test_tool_search_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {"memories": [], "resources": [], "skills": [], "total": 0}
}
provider._tool_search({"query": "session switch", "limit": 7})
provider._client.post.assert_called_once()
payload = provider._client.post.call_args.args[1]
assert payload["limit"] == 7
assert "top_k" not in payload
def test_tool_add_resource_uploads_existing_local_file(tmp_path):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
@ -371,7 +395,8 @@ def test_viking_client_headers_send_tenant_when_default():
assert headers["Authorization"] == "Bearer test-key"
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default():
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
# Empty account/user strings fall back to "default" via the constructor.
# Headers are sent even for the default value — ROOT API keys need them.
client = _VikingClient(
@ -402,6 +427,7 @@ def test_viking_client_headers_sent_with_real_tenant_values():
def test_viking_client_health_sends_auth_headers(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
client = _VikingClient(
"https://example.com",
api_key="test-key",
@ -420,3 +446,676 @@ def test_viking_client_health_sends_auth_headers(monkeypatch):
assert client.health() is True
assert captured["url"] == "https://example.com/health"
assert captured["headers"]["Authorization"] == "Bearer test-key"
# ---------------------------------------------------------------------------
# on_session_switch — flush + commit + rotate behavior (hermes-agent#28296)
# ---------------------------------------------------------------------------
def _make_provider_with_session(session_id: str, turn_count: int):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._session_id = session_id
provider._turn_count = turn_count
return provider
def test_on_session_switch_commits_old_session_and_rotates_id():
provider = _make_provider_with_session("old-sid", turn_count=3)
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_skips_commit_for_empty_old_session():
"""No turns accumulated → nothing to extract → no commit call."""
provider = _make_provider_with_session("old-sid", turn_count=0)
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_switch("new-sid")
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_rewound_same_session_only_invalidates_prefetch():
provider = _make_provider_with_session("same-sid", turn_count=3)
provider._prefetch_generation = 9
provider._prefetch_result = "stale recall"
provider.on_session_switch("same-sid", rewound=True)
provider._client.get.assert_not_called()
provider._client.post.assert_not_called()
assert provider._session_id == "same-sid"
assert provider._turn_count == 3
assert provider._prefetch_generation == 10
assert provider._prefetch_result == ""
def test_on_session_switch_clears_stale_prefetch_result():
provider = _make_provider_with_session("old-sid", turn_count=1)
provider._prefetch_result = "stale recall from old session"
provider.on_session_switch("new-sid")
assert provider._prefetch_result == ""
def test_on_session_switch_waits_for_inflight_sync_thread():
"""In-flight sync_turn write must drain before the commit fires —
otherwise the commit can race the last message write."""
provider = _make_provider_with_session("old-sid", turn_count=2)
join_calls = []
class FakeThread:
def __init__(self):
self._alive = True
def is_alive(self):
return self._alive
def join(self, timeout=None):
join_calls.append(timeout)
# Simulate a worker that finishes within the join window.
self._alive = False
provider._inflight_writers["old-sid"] = {FakeThread()}
provider.on_session_switch("new-sid")
assert join_calls, "expected on_session_switch to join the in-flight sync thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_noop_on_empty_new_id():
provider = _make_provider_with_session("old-sid", turn_count=5)
provider.on_session_switch("")
provider.on_session_switch(" ")
provider._client.post.assert_not_called()
assert provider._session_id == "old-sid"
assert provider._turn_count == 5
def test_on_session_switch_noop_when_client_missing():
provider = OpenVikingMemoryProvider()
provider._client = None
provider._session_id = "old-sid"
provider._turn_count = 4
# Must not raise even though no client is configured.
provider.on_session_switch("new-sid")
# State stays untouched — provider is effectively disabled.
assert provider._session_id == "old-sid"
assert provider._turn_count == 4
def test_sync_turn_captures_session_id_before_worker_runs():
"""Worker must use the session id snapshotted at sync_turn() call time, not
re-read self._session_id later otherwise a delayed worker can write the
previous turn's messages into the rotated-in NEW session."""
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
started = threading.Event()
release = threading.Event()
captured_paths = []
captured_payloads = []
def fake_post(path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
captured_paths.append(path)
captured_payloads.append(payload)
return {}
# Patch _VikingClient inside the worker by stubbing post on a client
# the constructor will produce. Easiest path: monkeypatch the class.
real_client_cls = _VikingClient
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
return fake_post(path, payload, **kwargs)
import plugins.memory.openviking as _mod
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
# Wait until the worker is parked inside the first post call.
assert started.wait(timeout=2.0), "worker never entered post()"
# Rotate the provider's session id while the worker is mid-flight.
provider._session_id = "new-sid"
release.set()
for t in list(provider._inflight_writers.get("old-sid", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# The whole turn must target the OLD session id as a single ordered batch.
assert captured_paths == ["/api/v1/sessions/old-sid/messages/batch"]
assert captured_payloads == [{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
}]
def test_sync_turn_retries_batch_write_with_fresh_client():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
clients = []
captured = []
class StubClient:
def __init__(self, *a, **kw):
self.index = len(clients)
clients.append(self)
def post(self, path, payload=None, **kwargs):
if self.index == 0:
raise RuntimeError("transient")
captured.append((path, payload))
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert provider._drain_writers("sid-1", timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert len(clients) == 2
assert captured == [(
"/api/v1/sessions/sid-1/messages/batch",
{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
},
)]
def test_sync_turn_noop_when_session_id_blank():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._session_id = ""
provider.sync_turn("u", "a")
# No turn counted, no worker spawned.
assert provider._turn_count == 0
assert provider._inflight_writers == {}
def test_on_session_end_marks_session_clean_after_successful_commit():
"""After a successful commit on_session_end must reset _turn_count so a
subsequent on_session_switch (fired by /new and compression right after
commit_memory_session) skips its commit instead of double-committing."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider.on_session_end([])
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._turn_count == 0
def test_on_session_end_keeps_dirty_when_commit_fails():
"""If the commit fails, leave _turn_count > 0 so on_session_switch retries
rather than silently dropping extraction for the old session."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._client.post.side_effect = RuntimeError("commit boom")
provider.on_session_end([])
assert provider._turn_count == 3
def test_on_session_end_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_end_then_switch_does_not_double_commit():
"""Mirrors the /new and compression call order: commit_memory_session
( on_session_end) immediately followed by on_session_switch. The switch
must NOT issue a second commit on the same session id."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
# Exactly one commit call, on the OLD session, fired by on_session_end.
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_end_then_switch_with_pending_tokens_does_not_double_commit():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_session_needs_commit_guard_wins_over_stale_turn_count():
"""Regression for hermes-agent#28296 review (M3): once a session is marked
committed, _session_needs_commit must return False even if turn_count is
still positive. A racing sync_turn can re-increment _turn_count after the
commit+reset; without the guard ordering, a follow-up finalizer would
double-commit the same session. The committed-guard must be checked BEFORE
the turn_count>0 shortcut."""
provider = _make_provider_with_session("old-sid", turn_count=5)
provider._mark_session_committed("old-sid")
# turn_count is a (stale) 5 but the session is already committed.
assert provider._session_needs_commit("old-sid", 5) is False
# An uncommitted session with turns still needs a commit.
assert provider._session_needs_commit("fresh-sid", 5) is True
def test_on_session_switch_swallows_commit_failure():
"""Commit-on-switch must not propagate exceptions: a failing commit on the
old session must still allow the rotate to the new session to complete,
otherwise subsequent sync_turn writes would land in the wrong session."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._client.post.side_effect = RuntimeError("commit boom")
provider.on_session_switch("new-sid")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Hung-writer protection: the sync worker can outlive the bounded join
# because each OpenViking POST has _TIMEOUT=30s and there are two per turn.
# Committing while late writes are still in flight would orphan them past
# the commit boundary — they would never be extracted.
# ---------------------------------------------------------------------------
class _HungThread:
"""Thread stand-in that stays alive across joins."""
def is_alive(self):
return True
def join(self, timeout=None):
# Pretend the join timed out — worker still running.
return None
def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
"""If the sync worker is still alive after the 10s join, the commit must
be skipped late writes from the worker would otherwise land in an
already-committed session and never be extracted. Leave _turn_count
intact so the session stays marked dirty."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 3
def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
"""Same hazard on the switch path. Rotation must still proceed (the new
session needs to start) but the old-session commit is skipped to avoid
orphaning the worker's late writes past commit."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Orphaned-writer hazard: commit must wait for ALL writers for the session,
# not just the latest tracked one. sync_turn's bounded rate-limit can drop a
# still-alive previous worker — that dropped writer keeps POSTing under the
# old sid and would otherwise land its writes past the commit boundary.
# ---------------------------------------------------------------------------
def test_on_session_end_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 2
def test_on_session_switch_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_does_not_block_caller_on_slow_drain():
"""Regression for hermes-agent#28296 review (H1): on_session_switch must
NOT run the old-session drain/commit on the caller's thread. /new, /branch,
/resume, /undo call this synchronously on the command thread, so a slow
writer drain (up to _SESSION_DRAIN_TIMEOUT/_DEFERRED_COMMIT_TIMEOUT) or a
wedged commit POST must not stall the user-facing command. The rotation is
cheap and synchronous; the commit is offloaded. Mirrors the #41945
'do not block the turn thread' contract."""
import threading
import time
provider = _make_provider_with_session("old-sid", turn_count=2)
drain_entered = threading.Event()
release_drain = threading.Event()
def slow_drain(sid, timeout):
drain_entered.set()
# Simulate a writer that takes a long time to drain.
release_drain.wait(timeout=10.0)
return True
provider._drain_writers = slow_drain
start = time.monotonic()
provider.on_session_switch("new-sid")
elapsed = time.monotonic() - start
# The caller returned promptly with state already rotated, even though the
# drain is still parked on the finalizer thread.
assert elapsed < 1.0, f"on_session_switch blocked the caller for {elapsed:.2f}s"
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
assert drain_entered.wait(timeout=2.0), "finalizer never started draining"
# No commit yet — drain is still blocked off-thread.
provider._client.post.assert_not_called()
# Let the finalizer finish so it doesn't leak past the test.
release_drain.set()
assert provider._drain_finalizers(timeout=5.0)
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_defers_old_commit_to_finalizer_thread():
"""The switch path rotates session state synchronously (cheap, in-memory)
but offloads the old-session drain + commit onto a daemon finalizer so the
caller's command thread (/new, /branch, /resume) never blocks on the up-to
-_DEFERRED_COMMIT_TIMEOUT drain or the commit POST. See hermes-agent#28296
review (the #41945 'do not block the turn thread' contract)."""
import threading
provider = _make_provider_with_session("old-sid", turn_count=2)
committed = threading.Event()
drain_timeouts = []
def fake_post(path):
committed.set()
return {}
def fake_drain(sid, timeout):
drain_timeouts.append(timeout)
return True
provider._client.post.side_effect = fake_post
provider._drain_writers = fake_drain
provider.on_session_switch("new-sid")
# Rotation is synchronous and immediate — the new session is live at once.
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# The old-session commit lands on the finalizer thread, not inline.
assert committed.wait(timeout=5.0), "old session was not finalized off-thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
# The finalizer drains with the deferred (longer) budget, not inline 10s.
assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT]
def test_sync_turn_tracks_writer_under_session_id():
"""Every sync_turn writer must register under its captured sid so the
drain at end/switch sees it even if a later sync_turn replaces the
latest-tracked reference."""
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
release = threading.Event()
started = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert started.wait(timeout=2.0), "worker never entered post()"
assert len(provider._inflight_writers.get("sid-1", set())) == 1
release.set()
for t in list(provider._inflight_writers.get("sid-1", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# Worker should have removed itself from the inflight set on exit.
assert provider._inflight_writers.get("sid-1", set()) == set()
# ---------------------------------------------------------------------------
# on_memory_write: explicit memory writes use content/write and stay outside
# the session transcript/commit boundary.
# ---------------------------------------------------------------------------
def test_on_memory_write_uses_content_write_independent_of_session_rotation():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
in_ctor = threading.Event()
release = threading.Event()
done = threading.Event()
captured_paths = []
captured_payloads = []
class StubClient:
def __init__(self, *a, **kw):
in_ctor.set()
release.wait(timeout=2.0)
def post(self, path, payload=None, **kwargs):
captured_paths.append(path)
captured_payloads.append(payload)
done.set()
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.on_memory_write("add", "user", "remember this")
assert in_ctor.wait(timeout=2.0), "worker never entered ctor"
# Rotate provider's session id while the worker is parked. Memory writes
# must not become session messages in either the old or new session.
provider._session_id = "new-sid"
release.set()
assert done.wait(timeout=2.0), "worker never reached post()"
finally:
_mod._VikingClient = real_client_cls
assert captured_paths == ["/api/v1/content/write"]
assert captured_payloads[0]["content"] == "remember this"
assert captured_payloads[0]["mode"] == "create"
assert captured_payloads[0]["uri"].startswith(
"viking://user/usr/agent/hermes/memories/preferences/mem_"
)
# ---------------------------------------------------------------------------
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
# must drop its result instead of repopulating the new session with stale
# recall from the old generation. Bump the generation directly (rather than
# calling on_session_switch, whose own join blocks on the test worker) so
# the test isolates the generation-gating behavior.
# ---------------------------------------------------------------------------
def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
started = threading.Event()
release = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {
"result": {
"memories": [
{"uri": "viking://memories/old", "score": 0.9,
"abstract": "stale from old session"},
],
"resources": [],
}
}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
assert started.wait(timeout=2.0), "prefetch worker never entered post()"
# Simulate a session switch by bumping the generation directly.
# The worker captured the pre-bump generation when it was spawned.
provider._prefetch_generation += 1
release.set()
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# The stale result from the pre-bump generation must NOT have been written
# into the new generation's prefetch slot.
assert provider._prefetch_result == ""
def test_queue_prefetch_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
captured_payloads = []
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
captured_payloads.append(payload)
return {"result": {"memories": [], "resources": []}}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert captured_payloads == [{"query": "anything", "limit": 5}]
assert "top_k" not in captured_payloads[0]