fix(openviking): implement on_session_switch hook (#28296)

OpenVikingMemoryProvider only overrides on_session_end and inherits the
base-class no-op for on_session_switch. When the agent rotates session_id
(via /new, /branch, /reset, /resume, or context compression), the
provider's cached _session_id stays at the value initialize() captured.
All subsequent sync_turn writes then land in the already-closed old
session, and on_session_end tries to commit it a second time — the new
session never accumulates messages and never triggers memory extraction.

The fix mirrors the pattern Hindsight uses (#17508):

  1. Wait for any in-flight sync thread to drain under the OLD _session_id
     before we mutate it, otherwise the commit below races the last
     message write.
  2. Commit the old session if it accumulated turns — same extraction
     semantics as on_session_end. Skip if empty (nothing to extract).
  3. Drain in-flight prefetch from the old session and clear its cached
     result so the new session doesn't see stale recall.
  4. Rotate _session_id to the new value and reset _turn_count.

Commit failures are swallowed (logged at WARN) so a flaky server can't
strand the provider on the old session forever — same posture as the
existing on_session_end commit.
This commit is contained in:
harshitAgr 2026-05-19 07:07:27 +03:00
parent 206f595f66
commit a1e7185e8a
2 changed files with 169 additions and 0 deletions

View file

@ -607,6 +607,73 @@ class OpenVikingMemoryProvider(MemoryProvider):
except Exception as e:
logger.warning("OpenViking session commit failed: %s", e)
def on_session_switch(
self,
new_session_id: str,
*,
parent_session_id: str = "",
reset: bool = False,
**kwargs,
) -> None:
"""Commit the old session and rotate cached state to the new session_id.
Fires on /resume, /branch, /reset, /new, and context compression.
Without this hook, ``_session_id`` stays stuck at the value
``initialize()`` cached, so subsequent ``sync_turn()`` writes land in
the already-closed old session and ``on_session_end()`` tries to
commit it a second time. The new session never accumulates messages,
and memory extraction never fires for it. See hermes-agent#28296.
Flushes any in-flight sync under the old session_id, commits the old
session if it has pending turns (same extraction semantics as
``on_session_end``), drains and clears any stale prefetch result,
then rotates ``_session_id`` and resets ``_turn_count``.
"""
new_id = str(new_session_id or "").strip()
if not new_id or not self._client:
return
# Snapshot the old session id BEFORE rotation so the join+commit
# below always target the session whose writes we want to flush.
old_session_id = self._session_id
old_turn_count = self._turn_count
# 1. Wait for any in-flight sync_turn to finish writing under the
# OLD session id — otherwise it races the commit below.
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=10.0)
# 2. Commit the old session if it accumulated turns — same
# extraction semantics as on_session_end. Skip if empty (nothing
# to extract) or if the provider was never initialized.
if old_session_id and old_turn_count > 0:
try:
self._client.post(f"/api/v1/sessions/{old_session_id}/commit")
logger.info(
"OpenViking session %s committed on switch (%d turns)",
old_session_id, old_turn_count,
)
except Exception as e:
logger.warning(
"OpenViking commit-on-switch failed for %s: %s",
old_session_id, e,
)
# 3. Drain in-flight prefetch from the old session and drop its
# cached result so the new session doesn't see stale recall.
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
# 4. Rotate to the new session.
self._session_id = new_id
self._turn_count = 0
logger.debug(
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
old_session_id, new_id, parent_session_id, reset,
)
def on_memory_write(self, action: str, target: str, content: str) -> None:
"""Mirror built-in memory writes to OpenViking as explicit memories."""
if not self._client or action != "add" or not content:

View file

@ -420,3 +420,105 @@ def test_viking_client_health_sends_auth_headers(monkeypatch):
assert client.health() is True
assert captured["url"] == "https://example.com/health"
assert captured["headers"]["Authorization"] == "Bearer test-key"
# ---------------------------------------------------------------------------
# on_session_switch — flush + commit + rotate behavior (hermes-agent#28296)
# ---------------------------------------------------------------------------
def _make_provider_with_session(session_id: str, turn_count: int):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._session_id = session_id
provider._turn_count = turn_count
return provider
def test_on_session_switch_commits_old_session_and_rotates_id():
provider = _make_provider_with_session("old-sid", turn_count=3)
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_skips_commit_for_empty_old_session():
"""No turns accumulated → nothing to extract → no commit call."""
provider = _make_provider_with_session("old-sid", turn_count=0)
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_clears_stale_prefetch_result():
provider = _make_provider_with_session("old-sid", turn_count=1)
provider._prefetch_result = "stale recall from old session"
provider.on_session_switch("new-sid")
assert provider._prefetch_result == ""
def test_on_session_switch_waits_for_inflight_sync_thread():
"""In-flight sync_turn write must drain before the commit fires —
otherwise the commit can race the last message write."""
provider = _make_provider_with_session("old-sid", turn_count=2)
join_calls = []
class FakeThread:
def is_alive(self):
return True
def join(self, timeout=None):
join_calls.append(timeout)
provider._sync_thread = FakeThread()
provider.on_session_switch("new-sid")
assert join_calls, "expected on_session_switch to join the in-flight sync thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_noop_on_empty_new_id():
provider = _make_provider_with_session("old-sid", turn_count=5)
provider.on_session_switch("")
provider.on_session_switch(" ")
provider._client.post.assert_not_called()
assert provider._session_id == "old-sid"
assert provider._turn_count == 5
def test_on_session_switch_noop_when_client_missing():
provider = OpenVikingMemoryProvider()
provider._client = None
provider._session_id = "old-sid"
provider._turn_count = 4
# Must not raise even though no client is configured.
provider.on_session_switch("new-sid")
# State stays untouched — provider is effectively disabled.
assert provider._session_id == "old-sid"
assert provider._turn_count == 4
def test_on_session_switch_swallows_commit_failure():
"""Commit-on-switch must not propagate exceptions: a failing commit on the
old session must still allow the rotate to the new session to complete,
otherwise subsequent sync_turn writes would land in the wrong session."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._client.post.side_effect = RuntimeError("commit boom")
provider.on_session_switch("new-sid")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0