diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index 3050eb9c43a..da7a10a9f13 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -31,10 +31,11 @@ import mimetypes import os import tempfile import threading +import time import uuid import zipfile from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Set from urllib.parse import urlparse from urllib.request import url2pathname @@ -46,6 +47,8 @@ logger = logging.getLogger(__name__) _DEFAULT_ENDPOINT = "http://127.0.0.1:1933" _TIMEOUT = 30.0 +_SESSION_DRAIN_TIMEOUT = 10.0 +_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0 _REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://") # Maps the viking_remember `category` enum to a viking:// subdirectory. @@ -432,10 +435,35 @@ class OpenVikingMemoryProvider(MemoryProvider): self._api_key = "" self._session_id = "" self._turn_count = 0 - self._sync_thread: Optional[threading.Thread] = None + # Guards the (_session_id, _turn_count) pair. sync_turn runs on the + # MemoryManager's background sync executor while on_session_end / + # on_session_switch run on the caller's thread, so the snapshot+reset + # of the turn counter and the session-id rotation must be atomic + # against a concurrent increment. See hermes-agent#28296 review. + self._session_state_lock = threading.Lock() + # Commit only after session writes drain. The set is keyed by the sid + # the writer is POSTing under (snapshotted at spawn), so on_session_end + # / on_session_switch see every still-alive writer for that sid even + # if later writes have replaced the latest-tracked thread. + self._inflight_writers: Dict[str, Set[threading.Thread]] = {} + self._inflight_lock = threading.Lock() + self._deferred_commit_sids: Set[str] = set() + self._deferred_commit_threads: Set[threading.Thread] = set() + self._deferred_commit_lock = threading.Lock() + self._committed_session_ids: Set[str] = set() + self._committed_session_lock = threading.Lock() self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None + # All prefetch threads ever spawned (daemon, short-lived). Tracked so + # shutdown() can drain them and rapid re-queues don't orphan a still- + # running thread by overwriting the single _prefetch_thread slot. + self._prefetch_threads: Set[threading.Thread] = set() + # Set on shutdown so deferred-commit / writer finalizers stop issuing + # network writes against a torn-down provider. + self._shutting_down = False + # Drop prefetch results from older switch generations. + self._prefetch_generation = 0 @property def name(self) -> str: @@ -549,6 +577,12 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client or not query: return + # Drop prefetch results from older switch generations. + with self._prefetch_lock: + gen = self._prefetch_generation + + holder: List[threading.Thread] = [] + def _run(): try: client = _VikingClient( @@ -557,7 +591,7 @@ class OpenVikingMemoryProvider(MemoryProvider): ) resp = client.post("/api/v1/search/find", { "query": query, - "top_k": 5, + "limit": 5, }) result = resp.get("result", {}) parts = [] @@ -571,14 +605,237 @@ class OpenVikingMemoryProvider(MemoryProvider): parts.append(f"- [{score:.2f}] {abstract} ({uri})") if parts: with self._prefetch_lock: + if gen != self._prefetch_generation: + return self._prefetch_result = "\n".join(parts) except Exception as e: logger.debug("OpenViking prefetch failed: %s", e) + finally: + with self._prefetch_lock: + if holder: + self._prefetch_threads.discard(holder[0]) - self._prefetch_thread = threading.Thread( + thread = threading.Thread( target=_run, daemon=True, name="openviking-prefetch" ) - self._prefetch_thread.start() + holder.append(thread) + with self._prefetch_lock: + self._prefetch_thread = thread + self._prefetch_threads.add(thread) + thread.start() + + def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None: + """Spawn a daemon writer tracked in _inflight_writers[sid]. + + Tracking is keyed by sid (not by a single latest-thread slot) so that + on_session_end / on_session_switch can drain every still-alive writer + for the session being committed. + """ + holder: List[threading.Thread] = [] + + def _wrapped(): + try: + target() + finally: + with self._inflight_lock: + workers = self._inflight_writers.get(sid) + if workers is not None: + workers.discard(holder[0]) + if not workers: + self._inflight_writers.pop(sid, None) + + thread = threading.Thread(target=_wrapped, daemon=True, name=name) + holder.append(thread) + with self._inflight_lock: + self._inflight_writers.setdefault(sid, set()).add(thread) + thread.start() + + def _drain_finalizers(self, timeout: float) -> bool: + """Join every in-flight async session finalizer within a timeout. + + The switch-path commit runs on a daemon finalizer thread so it never + blocks the caller's command thread; this lets shutdown and tests wait + for those commits deterministically. Returns True if all drained. + """ + deadline = time.monotonic() + timeout + while True: + with self._deferred_commit_lock: + workers = [t for t in self._deferred_commit_threads if t.is_alive()] + if not workers: + return True + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + for t in workers: + slice_left = deadline - time.monotonic() + if slice_left <= 0: + break + # Floor the per-join wait so a thread whose join() returns + # instantly while still reporting alive can't hot-spin this loop. + t.join(timeout=min(slice_left, 0.05)) + + def _drain_writers(self, sid: str, timeout: float) -> bool: + """Join every in-flight writer for sid within a shared timeout budget. + + Returns True if all writers drained, False if any are still alive when + the budget runs out. Callers use the False return to skip the commit. + """ + if not sid: + return True + deadline = time.monotonic() + timeout + while True: + with self._inflight_lock: + workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()] + if not workers: + return True + remaining = deadline - time.monotonic() + if remaining <= 0: + return False + for t in workers: + slice_left = deadline - time.monotonic() + if slice_left <= 0: + break + t.join(timeout=slice_left) + + def _new_client(self) -> _VikingClient: + return _VikingClient( + self._endpoint, + self._api_key, + account=self._account, + user=self._user, + agent=self._agent, + ) + + @staticmethod + def _text_part(content: str) -> Dict[str, str]: + return {"type": "text", "text": content} + + @classmethod + def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]: + return { + "messages": [ + {"role": "user", "parts": [cls._text_part(user_content)]}, + {"role": "assistant", "parts": [cls._text_part(assistant_content)]}, + ] + } + + @classmethod + def _post_session_turn( + cls, + client: _VikingClient, + sid: str, + user_content: str, + assistant_content: str, + ) -> None: + client.post( + f"/api/v1/sessions/{sid}/messages/batch", + cls._turn_batch_payload(user_content, assistant_content), + ) + + def _session_has_pending_tokens(self, sid: str) -> bool: + try: + response = self._client.get(f"/api/v1/sessions/{sid}") + except Exception: + return False + session = self._unwrap_result(response) + if not isinstance(session, dict): + return False + try: + return int(session.get("pending_tokens") or 0) > 0 + except (TypeError, ValueError): + return False + + def _has_committed_session(self, sid: str) -> bool: + with self._committed_session_lock: + return sid in self._committed_session_ids + + def _mark_session_committed(self, sid: str) -> None: + with self._committed_session_lock: + self._committed_session_ids.add(sid) + + def _session_needs_commit(self, sid: str, turn_count: int) -> bool: + # Already-committed sessions never need a second commit, regardless of + # the turn counter — a racing sync_turn can re-increment _turn_count + # after a commit+reset, so the committed-guard must win over turn_count. + if self._has_committed_session(sid): + return False + if turn_count > 0: + return True + return self._session_has_pending_tokens(sid) + + def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool: + try: + self._client.post(f"/api/v1/sessions/{sid}/commit") + self._mark_session_committed(sid) + logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count) + return True + except Exception as e: + logger.warning("OpenViking session commit failed for %s: %s", sid, e) + return False + + def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None: + """Drain the old session's writers and commit it on a daemon thread. + + Used by on_session_switch (and the deferred-commit fallback) so the + potentially-multi-second drain + pending-token GET + commit POST never + runs on the caller's command thread. Deduped by sid so a rapid second + switch can't stack two finalizers for the same session, and a no-op + once shutdown has begun so we don't POST against a torn-down client. + """ + if not sid: + return + with self._deferred_commit_lock: + if self._shutting_down or sid in self._deferred_commit_sids: + return + self._deferred_commit_sids.add(sid) + + holder: List[threading.Thread] = [] + + def _finalize() -> None: + try: + if self._shutting_down: + return + if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT): + logger.warning( + "OpenViking writer for %s still alive after drain — " + "leaving session uncommitted", + sid, + ) + return + if self._shutting_down: + return + if self._session_needs_commit(sid, turn_count): + self._commit_session(sid, turn_count, context=context) + finally: + with self._deferred_commit_lock: + self._deferred_commit_sids.discard(sid) + if holder: + self._deferred_commit_threads.discard(holder[0]) + + thread = threading.Thread( + target=_finalize, + daemon=True, + name=f"openviking-finalize-{sid}", + ) + holder.append(thread) + with self._deferred_commit_lock: + self._deferred_commit_threads.add(thread) + thread.start() + + def _invalidate_prefetch_state(self) -> None: + # Bump the generation under the same lock used by prefetch workers so + # late results from an older session are discarded deterministically. + with self._prefetch_lock: + self._prefetch_generation += 1 + self._prefetch_result = "" + # Join EVERY tracked prefetch thread, not just the latest slot — a + # rapid re-queue can leave an older thread for the abandoned session + # still running (consistent with shutdown()). + workers = [t for t in self._prefetch_threads if t.is_alive()] + for t in workers: + t.join(timeout=3.0) + with self._prefetch_lock: + self._prefetch_result = "" def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Record the conversation turn in OpenViking's session (non-blocking).""" @@ -589,37 +846,39 @@ class OpenVikingMemoryProvider(MemoryProvider): if not user_content: return - self._turn_count += 1 + # Snapshot the sid and bump the turn counter atomically so a + # concurrent on_session_switch/on_session_end can't interleave its + # snapshot+reset between the read and the increment (lost turn) and so + # the turn is unambiguously attributed to the session it targets. + with self._session_state_lock: + sid = str(session_id or self._session_id).strip() + if not sid: + return + self._turn_count += 1 def _sync(): try: - client = _VikingClient( - self._endpoint, self._api_key, - account=self._account, user=self._user, agent=self._agent, + client = self._new_client() + self._post_session_turn( + client, + sid, + user_content[:4000], + assistant_content[:4000], ) - sid = self._session_id - - # Add user message - client.post(f"/api/v1/sessions/{sid}/messages", { - "role": "user", - "content": user_content[:4000], # trim very long messages - }) - # Add assistant message - client.post(f"/api/v1/sessions/{sid}/messages", { - "role": "assistant", - "content": assistant_content[:4000], - }) except Exception as e: - logger.debug("OpenViking sync_turn failed: %s", e) + logger.debug("OpenViking sync_turn failed, reconnecting: %s", e) + try: + client = self._new_client() + self._post_session_turn( + client, + sid, + user_content[:4000], + assistant_content[:4000], + ) + except Exception as retry_error: + logger.warning("OpenViking sync_turn failed: %s", retry_error) - # Wait for any previous sync to finish before starting a new one - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - - self._sync_thread = threading.Thread( - target=_sync, daemon=True, name="openviking-sync" - ) - self._sync_thread.start() + self._spawn_writer(sid, _sync, name="openviking-sync") def on_session_end(self, messages: List[Dict[str, Any]]) -> None: """Commit the session to trigger memory extraction. @@ -630,20 +889,98 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client: return - # Wait for any pending sync to finish first — do this before the - # turn_count check so the last turn's messages are flushed even if - # the count hasn't been incremented yet. - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=10.0) + # Snapshot sid + turn count atomically against a concurrent sync_turn + # increment. on_session_end runs at teardown so the drain+commit stays + # synchronous here (we want it to land before the process exits), but + # the counter read must still be consistent. + with self._session_state_lock: + sid = self._session_id + turn_count = self._turn_count - if self._turn_count == 0: + # Commit only after session writes drain. + if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT): + logger.warning( + "OpenViking writer for %s still alive after drain — skipping commit", + sid, + ) return - try: - self._client.post(f"/api/v1/sessions/{self._session_id}/commit") - logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count) - except Exception as e: - logger.warning("OpenViking session commit failed: %s", e) + if not self._session_needs_commit(sid, turn_count): + return + + if self._commit_session(sid, turn_count, context="on session end"): + # Mark clean so a follow-up on_session_switch skips its own commit. + with self._session_state_lock: + if self._session_id == sid: + self._turn_count = 0 + + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Commit the old session and rotate cached state to the new session_id. + + Fires on /resume, /branch, /reset, /new, and context compression. + Without this hook, ``_session_id`` stays stuck at the value + ``initialize()`` cached, so subsequent ``sync_turn()`` writes land in + the already-closed old session and ``on_session_end()`` tries to + commit it a second time. The new session never accumulates messages, + and memory extraction never fires for it. See hermes-agent#28296. + + Flushes any in-flight sync under the old session_id, commits the old + session if it has pending turns (same extraction semantics as + ``on_session_end``), drains and clears any stale prefetch result, + then rotates ``_session_id`` and resets ``_turn_count``. + """ + new_id = str(new_session_id or "").strip() + if not new_id or not self._client: + return + + rewound = bool(kwargs.get("rewound")) + + # Rotate cached session state synchronously (cheap, in-memory) and + # snapshot the old session under the lock so a concurrent sync_turn + # either lands fully before the rotation (counted under old) or fully + # after (counted under new) — never split. The OLD session's commit + # (drain + pending-token GET + commit POST, potentially many seconds) + # is then offloaded so /new, /branch, /resume, /undo never block the + # caller's command thread (cf. the end-of-turn-sync offload in #41945). + with self._session_state_lock: + old_session_id = self._session_id + old_turn_count = self._turn_count + rotate = not (rewound or new_id == old_session_id) + if rotate: + self._session_id = new_id + self._turn_count = 0 + + # Invalidate stale prefetch OUTSIDE the session lock — it takes its own + # _prefetch_lock and may join a prefetch thread for up to 3s, which we + # must not do while holding the session lock (would block sync_turn and + # risk lock-ordering coupling). + self._invalidate_prefetch_state() + + if not rotate: + # Same-session rewind (/undo) or no-op rotation: no commit, no + # counter reset — just the prefetch invalidation above. + logger.debug( + "OpenViking on_session_switch invalidated state without rotation: " + "session=%s rewound=%s", + old_session_id, rewound, + ) + return + + # Drain + commit the OLD session off the command thread. + if old_session_id: + self._finalize_session_async(old_session_id, old_turn_count, context="on switch") + + logger.debug( + "OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s", + old_session_id, new_id, parent_session_id, reset, + ) def _build_memory_uri(self, subdir: str) -> str: """Build a viking:// memory URI under the configured user/agent/subdir.""" @@ -704,11 +1041,28 @@ class OpenVikingMemoryProvider(MemoryProvider): return tool_error(str(e)) def shutdown(self) -> None: - # Wait for background threads to finish - for t in (self._sync_thread, self._prefetch_thread): - if t and t.is_alive(): + # Stop deferred finalizers from issuing new commits against a + # torn-down client, then drain everything still in flight. + self._shutting_down = True + # Wait for every in-flight writer across all tracked sessions. + with self._inflight_lock: + all_workers = [ + t for workers in self._inflight_writers.values() for t in workers + ] + with self._deferred_commit_lock: + deferred_workers = list(self._deferred_commit_threads) + with self._prefetch_lock: + prefetch_workers = list(self._prefetch_threads) + for t in all_workers: + if t.is_alive(): t.join(timeout=5.0) - # Clear atexit reference so it doesn't double-commit + for t in deferred_workers: + if t.is_alive(): + t.join(timeout=5.0) + for t in prefetch_workers: + if t.is_alive(): + t.join(timeout=5.0) + # Clear atexit reference so it doesn't double-commit. global _last_active_provider if _last_active_provider is self: _last_active_provider = None @@ -767,7 +1121,7 @@ class OpenVikingMemoryProvider(MemoryProvider): if args.get("scope"): payload["target_uri"] = args["scope"] if args.get("limit"): - payload["top_k"] = args["limit"] + payload["limit"] = args["limit"] resp = self._client.post("/api/v1/search/find", payload) result = resp.get("result", {}) diff --git a/tests/openviking_plugin/test_openviking.py b/tests/openviking_plugin/test_openviking.py index f374182812a..c37a15c0cda 100644 --- a/tests/openviking_plugin/test_openviking.py +++ b/tests/openviking_plugin/test_openviking.py @@ -150,7 +150,7 @@ class TestOpenVikingSkillQuerySafety: assert RecordingVikingClient.calls == [ ( "/api/v1/search/find", - {"query": "make a skill for release triage", "top_k": 5}, + {"query": "make a skill for release triage", "limit": 5}, ) ] @@ -181,7 +181,7 @@ class TestOpenVikingSkillQuerySafety: assert RecordingVikingClient.calls == [ ( "/api/v1/search/find", - {"query": "fix the failing retrieval test", "top_k": 5}, + {"query": "fix the failing retrieval test", "limit": 5}, ) ] @@ -223,17 +223,25 @@ class TestOpenVikingSkillQuerySafety: ) provider.sync_turn(skill_message, "Done.") - assert provider._sync_thread is not None - provider._sync_thread.join(timeout=5.0) + assert provider._drain_writers("session-1", timeout=5.0) assert RecordingVikingClient.calls == [ ( - "/api/v1/sessions/session-1/messages", - {"role": "user", "content": "make a skill for release triage"}, - ), - ( - "/api/v1/sessions/session-1/messages", - {"role": "assistant", "content": "Done."}, + "/api/v1/sessions/session-1/messages/batch", + { + "messages": [ + { + "role": "user", + "parts": [ + {"type": "text", "text": "make a skill for release triage"}, + ], + }, + { + "role": "assistant", + "parts": [{"type": "text", "text": "Done."}], + }, + ] + }, ), ] @@ -251,7 +259,8 @@ class TestOpenVikingSkillQuerySafety: provider.sync_turn(skill_message, "Done.") - assert provider._sync_thread is None + assert provider._turn_count == 0 + assert provider._inflight_writers == {} assert RecordingVikingClient.calls == [] diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 3f609cd1d67..92f724a39a8 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -5,7 +5,16 @@ from unittest.mock import MagicMock import pytest -from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient +from plugins.memory.openviking import ( + OpenVikingMemoryProvider, + _DEFERRED_COMMIT_TIMEOUT, + _VikingClient, +) + + +def _clear_openviking_tenant_env(monkeypatch): + for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"): + monkeypatch.delenv(name, raising=False) def test_tool_search_sorts_by_raw_score_across_buckets(): @@ -66,6 +75,21 @@ def test_tool_search_sorts_missing_raw_score_after_negative_scores(): assert result["total"] == 3 +def test_tool_search_sends_limit_not_legacy_top_k(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._client.post.return_value = { + "result": {"memories": [], "resources": [], "skills": [], "total": 0} + } + + provider._tool_search({"query": "session switch", "limit": 7}) + + provider._client.post.assert_called_once() + payload = provider._client.post.call_args.args[1] + assert payload["limit"] == 7 + assert "top_k" not in payload + + def test_tool_add_resource_uploads_existing_local_file(tmp_path): sample = tmp_path / "sample.md" sample.write_text("# Local resource\n", encoding="utf-8") @@ -371,7 +395,8 @@ def test_viking_client_headers_send_tenant_when_default(): assert headers["Authorization"] == "Bearer test-key" -def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(): +def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch): + _clear_openviking_tenant_env(monkeypatch) # Empty account/user strings fall back to "default" via the constructor. # Headers are sent even for the default value — ROOT API keys need them. client = _VikingClient( @@ -402,6 +427,7 @@ def test_viking_client_headers_sent_with_real_tenant_values(): def test_viking_client_health_sends_auth_headers(monkeypatch): + _clear_openviking_tenant_env(monkeypatch) client = _VikingClient( "https://example.com", api_key="test-key", @@ -420,3 +446,676 @@ def test_viking_client_health_sends_auth_headers(monkeypatch): assert client.health() is True assert captured["url"] == "https://example.com/health" assert captured["headers"]["Authorization"] == "Bearer test-key" + + +# --------------------------------------------------------------------------- +# on_session_switch — flush + commit + rotate behavior (hermes-agent#28296) +# --------------------------------------------------------------------------- + +def _make_provider_with_session(session_id: str, turn_count: int): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._session_id = session_id + provider._turn_count = turn_count + return provider + + +def test_on_session_switch_commits_old_session_and_rotates_id(): + provider = _make_provider_with_session("old-sid", turn_count=3) + + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_skips_commit_for_empty_old_session(): + """No turns accumulated → nothing to extract → no commit call.""" + provider = _make_provider_with_session("old-sid", turn_count=0) + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_commits_pending_tokens_without_turn_count(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_switch("new-sid") + + provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid") + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_rewound_same_session_only_invalidates_prefetch(): + provider = _make_provider_with_session("same-sid", turn_count=3) + provider._prefetch_generation = 9 + provider._prefetch_result = "stale recall" + + provider.on_session_switch("same-sid", rewound=True) + + provider._client.get.assert_not_called() + provider._client.post.assert_not_called() + assert provider._session_id == "same-sid" + assert provider._turn_count == 3 + assert provider._prefetch_generation == 10 + assert provider._prefetch_result == "" + + +def test_on_session_switch_clears_stale_prefetch_result(): + provider = _make_provider_with_session("old-sid", turn_count=1) + provider._prefetch_result = "stale recall from old session" + + provider.on_session_switch("new-sid") + + assert provider._prefetch_result == "" + + +def test_on_session_switch_waits_for_inflight_sync_thread(): + """In-flight sync_turn write must drain before the commit fires — + otherwise the commit can race the last message write.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + + join_calls = [] + + class FakeThread: + def __init__(self): + self._alive = True + + def is_alive(self): + return self._alive + + def join(self, timeout=None): + join_calls.append(timeout) + # Simulate a worker that finishes within the join window. + self._alive = False + + provider._inflight_writers["old-sid"] = {FakeThread()} + + provider.on_session_switch("new-sid") + + assert join_calls, "expected on_session_switch to join the in-flight sync thread" + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + + +def test_on_session_switch_noop_on_empty_new_id(): + provider = _make_provider_with_session("old-sid", turn_count=5) + + provider.on_session_switch("") + provider.on_session_switch(" ") + + provider._client.post.assert_not_called() + assert provider._session_id == "old-sid" + assert provider._turn_count == 5 + + +def test_on_session_switch_noop_when_client_missing(): + provider = OpenVikingMemoryProvider() + provider._client = None + provider._session_id = "old-sid" + provider._turn_count = 4 + + # Must not raise even though no client is configured. + provider.on_session_switch("new-sid") + + # State stays untouched — provider is effectively disabled. + assert provider._session_id == "old-sid" + assert provider._turn_count == 4 + + +def test_sync_turn_captures_session_id_before_worker_runs(): + """Worker must use the session id snapshotted at sync_turn() call time, not + re-read self._session_id later — otherwise a delayed worker can write the + previous turn's messages into the rotated-in NEW session.""" + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + started = threading.Event() + release = threading.Event() + captured_paths = [] + captured_payloads = [] + + def fake_post(path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + captured_paths.append(path) + captured_payloads.append(payload) + return {} + + # Patch _VikingClient inside the worker by stubbing post on a client + # the constructor will produce. Easiest path: monkeypatch the class. + real_client_cls = _VikingClient + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + return fake_post(path, payload, **kwargs) + + import plugins.memory.openviking as _mod + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + # Wait until the worker is parked inside the first post call. + assert started.wait(timeout=2.0), "worker never entered post()" + # Rotate the provider's session id while the worker is mid-flight. + provider._session_id = "new-sid" + release.set() + for t in list(provider._inflight_writers.get("old-sid", set())): + t.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # The whole turn must target the OLD session id as a single ordered batch. + assert captured_paths == ["/api/v1/sessions/old-sid/messages/batch"] + assert captured_payloads == [{ + "messages": [ + {"role": "user", "parts": [{"type": "text", "text": "u"}]}, + {"role": "assistant", "parts": [{"type": "text", "text": "a"}]}, + ] + }] + + +def test_sync_turn_retries_batch_write_with_fresh_client(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "sid-1" + + clients = [] + captured = [] + + class StubClient: + def __init__(self, *a, **kw): + self.index = len(clients) + clients.append(self) + + def post(self, path, payload=None, **kwargs): + if self.index == 0: + raise RuntimeError("transient") + captured.append((path, payload)) + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + assert provider._drain_writers("sid-1", timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + assert len(clients) == 2 + assert captured == [( + "/api/v1/sessions/sid-1/messages/batch", + { + "messages": [ + {"role": "user", "parts": [{"type": "text", "text": "u"}]}, + {"role": "assistant", "parts": [{"type": "text", "text": "a"}]}, + ] + }, + )] + + +def test_sync_turn_noop_when_session_id_blank(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._session_id = "" + + provider.sync_turn("u", "a") + + # No turn counted, no worker spawned. + assert provider._turn_count == 0 + assert provider._inflight_writers == {} + + +def test_on_session_end_marks_session_clean_after_successful_commit(): + """After a successful commit on_session_end must reset _turn_count so a + subsequent on_session_switch (fired by /new and compression right after + commit_memory_session) skips its commit instead of double-committing.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + + provider.on_session_end([]) + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._turn_count == 0 + + +def test_on_session_end_keeps_dirty_when_commit_fails(): + """If the commit fails, leave _turn_count > 0 so on_session_switch retries + rather than silently dropping extraction for the old session.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + provider._client.post.side_effect = RuntimeError("commit boom") + + provider.on_session_end([]) + + assert provider._turn_count == 3 + + +def test_on_session_end_commits_pending_tokens_without_turn_count(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_end([]) + + provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid") + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + + +def test_end_then_switch_does_not_double_commit(): + """Mirrors the /new and compression call order: commit_memory_session + (→ on_session_end) immediately followed by on_session_switch. The switch + must NOT issue a second commit on the same session id.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + + provider.on_session_end([]) + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + # Exactly one commit call, on the OLD session, fired by on_session_end. + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_end_then_switch_with_pending_tokens_does_not_double_commit(): + provider = _make_provider_with_session("old-sid", turn_count=0) + provider._client.get.return_value = {"result": {"pending_tokens": 42}} + + provider.on_session_end([]) + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_session_needs_commit_guard_wins_over_stale_turn_count(): + """Regression for hermes-agent#28296 review (M3): once a session is marked + committed, _session_needs_commit must return False even if turn_count is + still positive. A racing sync_turn can re-increment _turn_count after the + commit+reset; without the guard ordering, a follow-up finalizer would + double-commit the same session. The committed-guard must be checked BEFORE + the turn_count>0 shortcut.""" + provider = _make_provider_with_session("old-sid", turn_count=5) + provider._mark_session_committed("old-sid") + + # turn_count is a (stale) 5 but the session is already committed. + assert provider._session_needs_commit("old-sid", 5) is False + # An uncommitted session with turns still needs a commit. + assert provider._session_needs_commit("fresh-sid", 5) is True + + +def test_on_session_switch_swallows_commit_failure(): + """Commit-on-switch must not propagate exceptions: a failing commit on the + old session must still allow the rotate to the new session to complete, + otherwise subsequent sync_turn writes would land in the wrong session.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._client.post.side_effect = RuntimeError("commit boom") + + provider.on_session_switch("new-sid") + + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +# --------------------------------------------------------------------------- +# Hung-writer protection: the sync worker can outlive the bounded join +# because each OpenViking POST has _TIMEOUT=30s and there are two per turn. +# Committing while late writes are still in flight would orphan them past +# the commit boundary — they would never be extracted. +# --------------------------------------------------------------------------- + +class _HungThread: + """Thread stand-in that stays alive across joins.""" + + def is_alive(self): + return True + + def join(self, timeout=None): + # Pretend the join timed out — worker still running. + return None + + +def test_on_session_end_skips_commit_when_sync_worker_outlives_join(): + """If the sync worker is still alive after the 10s join, the commit must + be skipped — late writes from the worker would otherwise land in an + already-committed session and never be extracted. Leave _turn_count + intact so the session stays marked dirty.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_end([]) + + provider._client.post.assert_not_called() + assert provider._turn_count == 3 + + +def test_on_session_switch_skips_commit_when_sync_worker_outlives_join(): + """Same hazard on the switch path. Rotation must still proceed (the new + session needs to start) but the old-session commit is skipped to avoid + orphaning the worker's late writes past commit.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +# --------------------------------------------------------------------------- +# Orphaned-writer hazard: commit must wait for ALL writers for the session, +# not just the latest tracked one. sync_turn's bounded rate-limit can drop a +# still-alive previous worker — that dropped writer keeps POSTing under the +# old sid and would otherwise land its writes past the commit boundary. +# --------------------------------------------------------------------------- + +def test_on_session_end_waits_for_all_writers_not_just_latest(): + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_end([]) + + provider._client.post.assert_not_called() + assert provider._turn_count == 2 + + +def test_on_session_switch_waits_for_all_writers_not_just_latest(): + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._inflight_writers["old-sid"] = {_HungThread()} + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_does_not_block_caller_on_slow_drain(): + """Regression for hermes-agent#28296 review (H1): on_session_switch must + NOT run the old-session drain/commit on the caller's thread. /new, /branch, + /resume, /undo call this synchronously on the command thread, so a slow + writer drain (up to _SESSION_DRAIN_TIMEOUT/_DEFERRED_COMMIT_TIMEOUT) or a + wedged commit POST must not stall the user-facing command. The rotation is + cheap and synchronous; the commit is offloaded. Mirrors the #41945 + 'do not block the turn thread' contract.""" + import threading + import time + + provider = _make_provider_with_session("old-sid", turn_count=2) + + drain_entered = threading.Event() + release_drain = threading.Event() + + def slow_drain(sid, timeout): + drain_entered.set() + # Simulate a writer that takes a long time to drain. + release_drain.wait(timeout=10.0) + return True + + provider._drain_writers = slow_drain + + start = time.monotonic() + provider.on_session_switch("new-sid") + elapsed = time.monotonic() - start + + # The caller returned promptly with state already rotated, even though the + # drain is still parked on the finalizer thread. + assert elapsed < 1.0, f"on_session_switch blocked the caller for {elapsed:.2f}s" + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + assert drain_entered.wait(timeout=2.0), "finalizer never started draining" + # No commit yet — drain is still blocked off-thread. + provider._client.post.assert_not_called() + # Let the finalizer finish so it doesn't leak past the test. + release_drain.set() + assert provider._drain_finalizers(timeout=5.0) + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + + +def test_on_session_switch_defers_old_commit_to_finalizer_thread(): + """The switch path rotates session state synchronously (cheap, in-memory) + but offloads the old-session drain + commit onto a daemon finalizer so the + caller's command thread (/new, /branch, /resume) never blocks on the up-to + -_DEFERRED_COMMIT_TIMEOUT drain or the commit POST. See hermes-agent#28296 + review (the #41945 'do not block the turn thread' contract).""" + import threading + + provider = _make_provider_with_session("old-sid", turn_count=2) + committed = threading.Event() + drain_timeouts = [] + + def fake_post(path): + committed.set() + return {} + + def fake_drain(sid, timeout): + drain_timeouts.append(timeout) + return True + + provider._client.post.side_effect = fake_post + provider._drain_writers = fake_drain + + provider.on_session_switch("new-sid") + + # Rotation is synchronous and immediate — the new session is live at once. + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + # The old-session commit lands on the finalizer thread, not inline. + assert committed.wait(timeout=5.0), "old session was not finalized off-thread" + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + # The finalizer drains with the deferred (longer) budget, not inline 10s. + assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT] + + +def test_sync_turn_tracks_writer_under_session_id(): + """Every sync_turn writer must register under its captured sid so the + drain at end/switch sees it even if a later sync_turn replaces the + latest-tracked reference.""" + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "sid-1" + + release = threading.Event() + started = threading.Event() + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + assert started.wait(timeout=2.0), "worker never entered post()" + assert len(provider._inflight_writers.get("sid-1", set())) == 1 + release.set() + for t in list(provider._inflight_writers.get("sid-1", set())): + t.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # Worker should have removed itself from the inflight set on exit. + assert provider._inflight_writers.get("sid-1", set()) == set() + + +# --------------------------------------------------------------------------- +# on_memory_write: explicit memory writes use content/write and stay outside +# the session transcript/commit boundary. +# --------------------------------------------------------------------------- + +def test_on_memory_write_uses_content_write_independent_of_session_rotation(): + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + in_ctor = threading.Event() + release = threading.Event() + done = threading.Event() + captured_paths = [] + captured_payloads = [] + + class StubClient: + def __init__(self, *a, **kw): + in_ctor.set() + release.wait(timeout=2.0) + + def post(self, path, payload=None, **kwargs): + captured_paths.append(path) + captured_payloads.append(payload) + done.set() + return {} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.on_memory_write("add", "user", "remember this") + assert in_ctor.wait(timeout=2.0), "worker never entered ctor" + # Rotate provider's session id while the worker is parked. Memory writes + # must not become session messages in either the old or new session. + provider._session_id = "new-sid" + release.set() + assert done.wait(timeout=2.0), "worker never reached post()" + finally: + _mod._VikingClient = real_client_cls + + assert captured_paths == ["/api/v1/content/write"] + assert captured_payloads[0]["content"] == "remember this" + assert captured_payloads[0]["mode"] == "create" + assert captured_payloads[0]["uri"].startswith( + "viking://user/usr/agent/hermes/memories/preferences/mem_" + ) + + +# --------------------------------------------------------------------------- +# Prefetch staleness: a prefetch worker that finishes AFTER a session switch +# must drop its result instead of repopulating the new session with stale +# recall from the old generation. Bump the generation directly (rather than +# calling on_session_switch, whose own join blocks on the test worker) so +# the test isolates the generation-gating behavior. +# --------------------------------------------------------------------------- + +def test_queue_prefetch_drops_result_when_generation_changed_mid_flight(): + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + started = threading.Event() + release = threading.Event() + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + return { + "result": { + "memories": [ + {"uri": "viking://memories/old", "score": 0.9, + "abstract": "stale from old session"}, + ], + "resources": [], + } + } + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.queue_prefetch("anything") + assert started.wait(timeout=2.0), "prefetch worker never entered post()" + # Simulate a session switch by bumping the generation directly. + # The worker captured the pre-bump generation when it was spawned. + provider._prefetch_generation += 1 + release.set() + if provider._prefetch_thread: + provider._prefetch_thread.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # The stale result from the pre-bump generation must NOT have been written + # into the new generation's prefetch slot. + assert provider._prefetch_result == "" + + +def test_queue_prefetch_sends_limit_not_legacy_top_k(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + + captured_payloads = [] + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + captured_payloads.append(payload) + return {"result": {"memories": [], "resources": []}} + + import plugins.memory.openviking as _mod + real_client_cls = _mod._VikingClient + _mod._VikingClient = StubClient + try: + provider.queue_prefetch("anything") + if provider._prefetch_thread: + provider._prefetch_thread.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + assert captured_payloads == [{"query": "anything", "limit": 5}] + assert "top_k" not in captured_payloads[0]