diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index 3050eb9c43a..391bcabe794 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -645,6 +645,63 @@ class OpenVikingMemoryProvider(MemoryProvider): except Exception as e: logger.warning("OpenViking session commit failed: %s", e) + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Commit the old session and rotate cached state to the new session_id. + + Fires on /resume, /branch, /reset, /new, and context compression. + Without this hook, ``_session_id`` stays stuck at the value + ``initialize()`` cached, so subsequent ``sync_turn()`` writes land in + the already-closed old session and ``on_session_end()`` tries to + commit it a second time. The new session never accumulates messages, + and memory extraction never fires for it. See hermes-agent#28296. + + Flushes any in-flight sync under the old session_id, commits the old + session if it has pending turns (same extraction semantics as + ``on_session_end``), drains and clears any stale prefetch result, + then rotates ``_session_id`` and resets ``_turn_count``. + """ + new_id = str(new_session_id or "").strip() + if not new_id or not self._client: + return + + old_session_id = self._session_id + old_turn_count = self._turn_count + + if self._sync_thread and self._sync_thread.is_alive(): + self._sync_thread.join(timeout=10.0) + + if old_session_id and old_turn_count > 0: + try: + self._client.post(f"/api/v1/sessions/{old_session_id}/commit") + logger.info( + "OpenViking session %s committed on switch (%d turns)", + old_session_id, old_turn_count, + ) + except Exception as e: + logger.warning( + "OpenViking commit-on-switch failed for %s: %s", + old_session_id, e, + ) + + if self._prefetch_thread and self._prefetch_thread.is_alive(): + self._prefetch_thread.join(timeout=3.0) + with self._prefetch_lock: + self._prefetch_result = "" + + self._session_id = new_id + self._turn_count = 0 + logger.debug( + "OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s", + old_session_id, new_id, parent_session_id, reset, + ) + def _build_memory_uri(self, subdir: str) -> str: """Build a viking:// memory URI under the configured user/agent/subdir.""" slug = uuid.uuid4().hex[:12] diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 3f609cd1d67..2324cea0132 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -420,3 +420,105 @@ def test_viking_client_health_sends_auth_headers(monkeypatch): assert client.health() is True assert captured["url"] == "https://example.com/health" assert captured["headers"]["Authorization"] == "Bearer test-key" + + +# --------------------------------------------------------------------------- +# on_session_switch — flush + commit + rotate behavior (hermes-agent#28296) +# --------------------------------------------------------------------------- + +def _make_provider_with_session(session_id: str, turn_count: int): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._session_id = session_id + provider._turn_count = turn_count + return provider + + +def test_on_session_switch_commits_old_session_and_rotates_id(): + provider = _make_provider_with_session("old-sid", turn_count=3) + + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_skips_commit_for_empty_old_session(): + """No turns accumulated → nothing to extract → no commit call.""" + provider = _make_provider_with_session("old-sid", turn_count=0) + + provider.on_session_switch("new-sid") + + provider._client.post.assert_not_called() + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + +def test_on_session_switch_clears_stale_prefetch_result(): + provider = _make_provider_with_session("old-sid", turn_count=1) + provider._prefetch_result = "stale recall from old session" + + provider.on_session_switch("new-sid") + + assert provider._prefetch_result == "" + + +def test_on_session_switch_waits_for_inflight_sync_thread(): + """In-flight sync_turn write must drain before the commit fires — + otherwise the commit can race the last message write.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + + join_calls = [] + + class FakeThread: + def is_alive(self): + return True + + def join(self, timeout=None): + join_calls.append(timeout) + + provider._sync_thread = FakeThread() + + provider.on_session_switch("new-sid") + + assert join_calls, "expected on_session_switch to join the in-flight sync thread" + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + + +def test_on_session_switch_noop_on_empty_new_id(): + provider = _make_provider_with_session("old-sid", turn_count=5) + + provider.on_session_switch("") + provider.on_session_switch(" ") + + provider._client.post.assert_not_called() + assert provider._session_id == "old-sid" + assert provider._turn_count == 5 + + +def test_on_session_switch_noop_when_client_missing(): + provider = OpenVikingMemoryProvider() + provider._client = None + provider._session_id = "old-sid" + provider._turn_count = 4 + + # Must not raise even though no client is configured. + provider.on_session_switch("new-sid") + + # State stays untouched — provider is effectively disabled. + assert provider._session_id == "old-sid" + assert provider._turn_count == 4 + + +def test_on_session_switch_swallows_commit_failure(): + """Commit-on-switch must not propagate exceptions: a failing commit on the + old session must still allow the rotate to the new session to complete, + otherwise subsequent sync_turn writes would land in the wrong session.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + provider._client.post.side_effect = RuntimeError("commit boom") + + provider.on_session_switch("new-sid") + + assert provider._session_id == "new-sid" + assert provider._turn_count == 0