mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-21 10:22:18 +00:00
fix(openviking): handle rewound session switches
This commit is contained in:
parent
00c045b43f
commit
3ac6551ba3
2 changed files with 128 additions and 11 deletions
|
|
@ -444,6 +444,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._deferred_commit_sids: Set[str] = set()
|
||||
self._deferred_commit_threads: Set[threading.Thread] = set()
|
||||
self._deferred_commit_lock = threading.Lock()
|
||||
self._committed_session_ids: Set[str] = set()
|
||||
self._committed_session_lock = threading.Lock()
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
|
|
@ -563,7 +565,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
return
|
||||
|
||||
# Drop prefetch results from older switch generations.
|
||||
gen = self._prefetch_generation
|
||||
with self._prefetch_lock:
|
||||
gen = self._prefetch_generation
|
||||
|
||||
def _run():
|
||||
try:
|
||||
|
|
@ -573,7 +576,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
)
|
||||
resp = client.post("/api/v1/search/find", {
|
||||
"query": query,
|
||||
"top_k": 5,
|
||||
"limit": 5,
|
||||
})
|
||||
result = resp.get("result", {})
|
||||
parts = []
|
||||
|
|
@ -695,9 +698,25 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def _has_committed_session(self, sid: str) -> bool:
|
||||
with self._committed_session_lock:
|
||||
return sid in self._committed_session_ids
|
||||
|
||||
def _mark_session_committed(self, sid: str) -> None:
|
||||
with self._committed_session_lock:
|
||||
self._committed_session_ids.add(sid)
|
||||
|
||||
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
|
||||
if turn_count > 0:
|
||||
return True
|
||||
if self._has_committed_session(sid):
|
||||
return False
|
||||
return self._session_has_pending_tokens(sid)
|
||||
|
||||
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
|
||||
try:
|
||||
self._client.post(f"/api/v1/sessions/{sid}/commit")
|
||||
self._mark_session_committed(sid)
|
||||
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
|
||||
return True
|
||||
except Exception as e:
|
||||
|
|
@ -740,6 +759,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._deferred_commit_threads.add(thread)
|
||||
thread.start()
|
||||
|
||||
def _invalidate_prefetch_state(self) -> None:
|
||||
# Bump the generation under the same lock used by prefetch workers so
|
||||
# late results from an older session are discarded deterministically.
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_generation += 1
|
||||
self._prefetch_result = ""
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Record the conversation turn in OpenViking's session (non-blocking)."""
|
||||
if not self._client:
|
||||
|
|
@ -798,7 +828,7 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
)
|
||||
return
|
||||
|
||||
if self._turn_count == 0 and not self._session_has_pending_tokens(sid):
|
||||
if not self._session_needs_commit(sid, self._turn_count):
|
||||
return
|
||||
|
||||
if self._commit_session(sid, self._turn_count, context="on session end"):
|
||||
|
|
@ -831,8 +861,17 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
if not new_id or not self._client:
|
||||
return
|
||||
|
||||
rewound = bool(kwargs.get("rewound"))
|
||||
old_session_id = self._session_id
|
||||
old_turn_count = self._turn_count
|
||||
if rewound or new_id == old_session_id:
|
||||
self._invalidate_prefetch_state()
|
||||
logger.debug(
|
||||
"OpenViking on_session_switch invalidated state without rotation: "
|
||||
"session=%s rewound=%s",
|
||||
old_session_id, rewound,
|
||||
)
|
||||
return
|
||||
|
||||
# Commit only after session writes drain.
|
||||
writers_drained = True
|
||||
|
|
@ -845,17 +884,16 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
old_session_id,
|
||||
)
|
||||
|
||||
if writers_drained and old_session_id and old_turn_count > 0:
|
||||
if (
|
||||
writers_drained
|
||||
and old_session_id
|
||||
and self._session_needs_commit(old_session_id, old_turn_count)
|
||||
):
|
||||
self._commit_session(old_session_id, old_turn_count, context="on switch")
|
||||
elif not writers_drained:
|
||||
self._schedule_deferred_commit(old_session_id, old_turn_count)
|
||||
|
||||
# Drop prefetch results from older switch generations.
|
||||
self._prefetch_generation += 1
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = ""
|
||||
self._invalidate_prefetch_state()
|
||||
|
||||
self._session_id = new_id
|
||||
self._turn_count = 0
|
||||
|
|
|
|||
|
|
@ -8,6 +8,11 @@ import pytest
|
|||
from plugins.memory.openviking import OpenVikingMemoryProvider, _VikingClient
|
||||
|
||||
|
||||
def _clear_openviking_tenant_env(monkeypatch):
|
||||
for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"):
|
||||
monkeypatch.delenv(name, raising=False)
|
||||
|
||||
|
||||
def test_tool_search_sorts_by_raw_score_across_buckets():
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
|
|
@ -386,7 +391,8 @@ def test_viking_client_headers_send_tenant_when_default():
|
|||
assert headers["Authorization"] == "Bearer test-key"
|
||||
|
||||
|
||||
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default():
|
||||
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch):
|
||||
_clear_openviking_tenant_env(monkeypatch)
|
||||
# Empty account/user strings fall back to "default" via the constructor.
|
||||
# Headers are sent even for the default value — ROOT API keys need them.
|
||||
client = _VikingClient(
|
||||
|
|
@ -417,6 +423,7 @@ def test_viking_client_headers_sent_with_real_tenant_values():
|
|||
|
||||
|
||||
def test_viking_client_health_sends_auth_headers(monkeypatch):
|
||||
_clear_openviking_tenant_env(monkeypatch)
|
||||
client = _VikingClient(
|
||||
"https://example.com",
|
||||
api_key="test-key",
|
||||
|
|
@ -470,6 +477,33 @@ def test_on_session_switch_skips_commit_for_empty_old_session():
|
|||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_on_session_switch_commits_pending_tokens_without_turn_count():
|
||||
provider = _make_provider_with_session("old-sid", turn_count=0)
|
||||
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
|
||||
|
||||
provider.on_session_switch("new-sid")
|
||||
|
||||
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
|
||||
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_on_session_switch_rewound_same_session_only_invalidates_prefetch():
|
||||
provider = _make_provider_with_session("same-sid", turn_count=3)
|
||||
provider._prefetch_generation = 9
|
||||
provider._prefetch_result = "stale recall"
|
||||
|
||||
provider.on_session_switch("same-sid", rewound=True)
|
||||
|
||||
provider._client.get.assert_not_called()
|
||||
provider._client.post.assert_not_called()
|
||||
assert provider._session_id == "same-sid"
|
||||
assert provider._turn_count == 3
|
||||
assert provider._prefetch_generation == 10
|
||||
assert provider._prefetch_result == ""
|
||||
|
||||
|
||||
def test_on_session_switch_clears_stale_prefetch_result():
|
||||
provider = _make_provider_with_session("old-sid", turn_count=1)
|
||||
provider._prefetch_result = "stale recall from old session"
|
||||
|
|
@ -698,6 +732,18 @@ def test_end_then_switch_does_not_double_commit():
|
|||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_end_then_switch_with_pending_tokens_does_not_double_commit():
|
||||
provider = _make_provider_with_session("old-sid", turn_count=0)
|
||||
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
|
||||
|
||||
provider.on_session_end([])
|
||||
provider.on_session_switch("new-sid", parent_session_id="old-sid")
|
||||
|
||||
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
|
||||
assert provider._session_id == "new-sid"
|
||||
assert provider._turn_count == 0
|
||||
|
||||
|
||||
def test_on_session_switch_swallows_commit_failure():
|
||||
"""Commit-on-switch must not propagate exceptions: a failing commit on the
|
||||
old session must still allow the rotate to the new session to complete,
|
||||
|
|
@ -971,3 +1017,36 @@ def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
|
|||
# The stale result from the pre-bump generation must NOT have been written
|
||||
# into the new generation's prefetch slot.
|
||||
assert provider._prefetch_result == ""
|
||||
|
||||
|
||||
def test_queue_prefetch_sends_limit_not_legacy_top_k():
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
|
||||
captured_payloads = []
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
captured_payloads.append(payload)
|
||||
return {"result": {"memories": [], "resources": []}}
|
||||
|
||||
import plugins.memory.openviking as _mod
|
||||
real_client_cls = _mod._VikingClient
|
||||
_mod._VikingClient = StubClient
|
||||
try:
|
||||
provider.queue_prefetch("anything")
|
||||
if provider._prefetch_thread:
|
||||
provider._prefetch_thread.join(timeout=2.0)
|
||||
finally:
|
||||
_mod._VikingClient = real_client_cls
|
||||
|
||||
assert captured_payloads == [{"query": "anything", "limit": 5}]
|
||||
assert "top_k" not in captured_payloads[0]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue