From 2ea8d5c537bccad814c483d36ddd79904bf6b55c Mon Sep 17 00:00:00 2001 From: harshitAgr <28730481+harshitAgr@users.noreply.github.com> Date: Wed, 20 May 2026 13:19:44 +0300 Subject: [PATCH] fix(openviking): close session-boundary races on sync_turn and on_session_end Two hardening fixes prompted by review on #28296: 1. sync_turn() now snapshots the target session id before spawning the worker. The previous code read self._session_id inside the worker, so a worker delayed past on_session_switch's bounded join could read the rotated-in NEW id and write the OLD turn's messages into the wrong session. 2. on_session_end() resets _turn_count to 0 after a successful commit, making the old-session commit path idempotent with the new switch hook. /new and compression call commit_memory_session() (which fires on_session_end) immediately before on_session_switch; without this, the old session would be committed twice. On commit failure we leave _turn_count > 0 so on_session_switch retries. Co-Authored-By: Claude Opus 4.7 (1M context) --- plugins/memory/openviking/__init__.py | 14 ++- .../memory/test_openviking_provider.py | 107 ++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index ad31e7af50d..125218caebb 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -551,6 +551,14 @@ class OpenVikingMemoryProvider(MemoryProvider): if not self._client: return + # Capture the target session id NOW, not inside the worker. Otherwise + # a delayed worker can read self._session_id after on_session_switch + # has rotated it (the switch's join on _sync_thread is bounded), and + # the OLD turn's content lands in the NEW session. + sid = str(session_id or self._session_id).strip() + if not sid: + return + self._turn_count += 1 def _sync(): @@ -559,7 +567,6 @@ class OpenVikingMemoryProvider(MemoryProvider): self._endpoint, self._api_key, account=self._account, user=self._user, agent=self._agent, ) - sid = self._session_id # Add user message client.post(f"/api/v1/sessions/{sid}/messages", { @@ -604,6 +611,11 @@ class OpenVikingMemoryProvider(MemoryProvider): try: self._client.post(f"/api/v1/sessions/{self._session_id}/commit") logger.info("OpenViking session %s committed (%d turns)", self._session_id, self._turn_count) + # Mark the session clean so a subsequent on_session_switch (fired + # by /new and compression right after commit_memory_session) skips + # its commit instead of double-committing. On commit failure we + # leave the count intact so the switch hook gets a retry. + self._turn_count = 0 except Exception as e: logger.warning("OpenViking session commit failed: %s", e) diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index 2324cea0132..78c672851d7 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -511,6 +511,113 @@ def test_on_session_switch_noop_when_client_missing(): assert provider._turn_count == 4 +def test_sync_turn_captures_session_id_before_worker_runs(): + """Worker must use the session id snapshotted at sync_turn() call time, not + re-read self._session_id later — otherwise a delayed worker can write the + previous turn's messages into the rotated-in NEW session.""" + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + provider._session_id = "old-sid" + + started = threading.Event() + release = threading.Event() + captured_paths = [] + + def fake_post(path, payload=None, **kwargs): + started.set() + release.wait(timeout=2.0) + captured_paths.append(path) + return {} + + # Patch _VikingClient inside the worker by stubbing post on a client + # the constructor will produce. Easiest path: monkeypatch the class. + real_client_cls = _VikingClient + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + return fake_post(path, payload, **kwargs) + + import plugins.memory.openviking as _mod + _mod._VikingClient = StubClient + try: + provider.sync_turn("u", "a") + # Wait until the worker is parked inside the first post call. + assert started.wait(timeout=2.0), "worker never entered post()" + # Rotate the provider's session id while the worker is mid-flight. + provider._session_id = "new-sid" + release.set() + if provider._sync_thread: + provider._sync_thread.join(timeout=2.0) + finally: + _mod._VikingClient = real_client_cls + + # Both writes must target the OLD session id captured at call time. + assert captured_paths == [ + "/api/v1/sessions/old-sid/messages", + "/api/v1/sessions/old-sid/messages", + ] + + +def test_sync_turn_noop_when_session_id_blank(): + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._session_id = "" + + provider.sync_turn("u", "a") + + # No turn counted, no worker spawned. + assert provider._turn_count == 0 + assert provider._sync_thread is None + + +def test_on_session_end_marks_session_clean_after_successful_commit(): + """After a successful commit on_session_end must reset _turn_count so a + subsequent on_session_switch (fired by /new and compression right after + commit_memory_session) skips its commit instead of double-committing.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + + provider.on_session_end([]) + + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._turn_count == 0 + + +def test_on_session_end_keeps_dirty_when_commit_fails(): + """If the commit fails, leave _turn_count > 0 so on_session_switch retries + rather than silently dropping extraction for the old session.""" + provider = _make_provider_with_session("old-sid", turn_count=3) + provider._client.post.side_effect = RuntimeError("commit boom") + + provider.on_session_end([]) + + assert provider._turn_count == 3 + + +def test_end_then_switch_does_not_double_commit(): + """Mirrors the /new and compression call order: commit_memory_session + (→ on_session_end) immediately followed by on_session_switch. The switch + must NOT issue a second commit on the same session id.""" + provider = _make_provider_with_session("old-sid", turn_count=2) + + provider.on_session_end([]) + provider.on_session_switch("new-sid", parent_session_id="old-sid") + + # Exactly one commit call, on the OLD session, fired by on_session_end. + provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit") + assert provider._session_id == "new-sid" + assert provider._turn_count == 0 + + def test_on_session_switch_swallows_commit_failure(): """Commit-on-switch must not propagate exceptions: a failing commit on the old session must still allow the rotate to the new session to complete,