hermes-agent/tests/plugins/memory/test_openviking_provider.py
kshitijk4poor c835448908 fix(openviking): don't block the command thread on session switch; lock turn state
Follow-up hardening on @ehz0ah / @harshitAgr's session-switch work (#28296):

- on_session_switch no longer runs the old-session writer-drain + pending-token
  GET + commit POST inline on the caller's command thread. /new, /branch,
  /resume, /undo call it synchronously, so a slow drain (up to 10s) or wedged
  commit blocked the user-facing command — the same hazard #41945 fixed for
  end-of-turn sync. State now rotates synchronously (cheap) and the old-session
  commit is offloaded to a daemon finalizer (generalized _finalize_session_async).
- Guard the (_session_id, _turn_count) pair with _session_state_lock: sync_turn
  runs on the memory-manager executor thread while the session hooks run on the
  command thread, so the snapshot+reset vs increment was a cross-thread race.
- _session_needs_commit checks the committed-session guard BEFORE the
  turn_count>0 shortcut, closing a double-commit window when a racing sync_turn
  re-increments after commit+reset.
- Add a _shutting_down flag so deferred finalizers stop POSTing against a
  torn-down client; track all prefetch threads in a set so invalidate/shutdown
  join every one, not just the latest slot.

Tests: regression for the non-blocking switch (asserts the caller returns while
a slow drain is parked off-thread) and the committed-guard ordering; updated the
deferred-commit test to the unified finalizer contract.
2026-06-18 00:21:21 +05:30

1121 lines
39 KiB
Python

import json
import zipfile
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from plugins.memory.openviking import (
OpenVikingMemoryProvider,
_DEFERRED_COMMIT_TIMEOUT,
_VikingClient,
)
def _clear_openviking_tenant_env(monkeypatch):
for name in ("OPENVIKING_ACCOUNT", "OPENVIKING_USER", "OPENVIKING_AGENT"):
monkeypatch.delenv(name, raising=False)
def test_tool_search_sorts_by_raw_score_across_buckets():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {
"memories": [
{"uri": "viking://memories/1", "score": 0.9003, "abstract": "memory result"},
],
"resources": [
{"uri": "viking://resources/1", "score": 0.9004, "abstract": "resource result"},
],
"skills": [
{"uri": "viking://skills/1", "score": 0.8999, "abstract": "skill result"},
],
"total": 3,
}
}
result = json.loads(provider._tool_search({"query": "ranking"}))
assert [entry["uri"] for entry in result["results"]] == [
"viking://resources/1",
"viking://memories/1",
"viking://skills/1",
]
assert [entry["score"] for entry in result["results"]] == [0.9, 0.9, 0.9]
assert result["total"] == 3
def test_tool_search_sorts_missing_raw_score_after_negative_scores():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {
"memories": [
{"uri": "viking://memories/missing", "abstract": "missing score"},
],
"resources": [
{"uri": "viking://resources/negative", "score": -0.25, "abstract": "negative score"},
],
"skills": [
{"uri": "viking://skills/positive", "score": 0.1, "abstract": "positive score"},
],
"total": 3,
}
}
result = json.loads(provider._tool_search({"query": "ranking"}))
assert [entry["uri"] for entry in result["results"]] == [
"viking://skills/positive",
"viking://memories/missing",
"viking://resources/negative",
]
assert [entry["score"] for entry in result["results"]] == [0.1, 0.0, -0.25]
assert result["total"] == 3
def test_tool_search_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"result": {"memories": [], "resources": [], "skills": [], "total": 0}
}
provider._tool_search({"query": "session switch", "limit": 7})
provider._client.post.assert_called_once()
payload = provider._client.post.call_args.args[1]
assert payload["limit"] == 7
assert "top_k" not in payload
def test_tool_add_resource_uploads_existing_local_file(tmp_path):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.upload_temp_file.return_value = "upload_sample.md"
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/sample"},
}
result = json.loads(provider._tool_add_resource({
"url": str(sample),
"reason": "local test",
"wait": True,
}))
provider._client.upload_temp_file.assert_called_once_with(sample)
provider._client.post.assert_called_once_with("/api/v1/resources", {
"reason": "local test",
"wait": True,
"source_name": "sample.md",
"temp_file_id": "upload_sample.md",
})
assert result["status"] == "added"
assert result["root_uri"] == "viking://resources/sample"
def test_tool_add_resource_uploads_file_uri(tmp_path):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.upload_temp_file.return_value = "upload_sample.md"
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/sample"},
}
result = json.loads(provider._tool_add_resource({
"url": sample.as_uri(),
"reason": "file uri test",
}))
provider._client.upload_temp_file.assert_called_once_with(sample)
provider._client.post.assert_called_once_with("/api/v1/resources", {
"reason": "file uri test",
"source_name": "sample.md",
"temp_file_id": "upload_sample.md",
})
assert result["status"] == "added"
assert result["root_uri"] == "viking://resources/sample"
def test_tool_add_resource_uploads_existing_local_directory_and_cleans_zip(tmp_path):
docs = tmp_path / "docs"
docs.mkdir()
(docs / "guide.md").write_text("# Guide\n", encoding="utf-8")
nested = docs / "nested"
nested.mkdir()
(nested / "api.md").write_text("# API\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
uploaded_paths = []
provider._client.upload_temp_file.side_effect = (
lambda path: uploaded_paths.append(path) or "upload_docs.zip"
)
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/docs"},
}
result = json.loads(provider._tool_add_resource({
"url": str(docs),
"reason": "directory test",
"wait": True,
}))
assert uploaded_paths
assert uploaded_paths[0].suffix == ".zip"
assert not uploaded_paths[0].exists()
provider._client.post.assert_called_once_with("/api/v1/resources", {
"reason": "directory test",
"wait": True,
"source_name": "docs",
"temp_file_id": "upload_docs.zip",
})
assert result["status"] == "added"
assert result["root_uri"] == "viking://resources/docs"
def test_tool_add_resource_directory_zip_skips_symlink_escape(tmp_path):
secret = tmp_path / "outside-secret.txt"
secret.write_text("do not upload\n", encoding="utf-8")
docs = tmp_path / "docs"
docs.mkdir()
(docs / "guide.md").write_text("# Guide\n", encoding="utf-8")
link = docs / "leak.txt"
try:
link.symlink_to(secret)
except OSError as exc:
pytest.skip(f"symlinks unavailable in test environment: {exc}")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
archive_entries = {}
def inspect_upload(path):
with zipfile.ZipFile(path) as archive:
archive_entries["names"] = archive.namelist()
archive_entries["payloads"] = {
name: archive.read(name)
for name in archive.namelist()
}
return "upload_docs.zip"
provider._client.upload_temp_file.side_effect = inspect_upload
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/docs"},
}
json.loads(provider._tool_add_resource({"url": str(docs)}))
assert archive_entries["names"] == ["guide.md"]
assert b"do not upload" not in b"".join(archive_entries["payloads"].values())
def test_tool_add_resource_cleans_local_directory_zip_when_add_fails(tmp_path):
docs = tmp_path / "docs"
docs.mkdir()
(docs / "guide.md").write_text("# Guide\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
uploaded_paths = []
provider._client.upload_temp_file.side_effect = (
lambda path: uploaded_paths.append(path) or "upload_docs.zip"
)
provider._client.post.side_effect = RuntimeError("add failed")
with pytest.raises(RuntimeError, match="add failed"):
provider._tool_add_resource({"url": str(docs)})
assert uploaded_paths
assert not uploaded_paths[0].exists()
def test_tool_add_resource_cleans_local_directory_zip_when_upload_fails(tmp_path):
docs = tmp_path / "docs"
docs.mkdir()
(docs / "guide.md").write_text("# Guide\n", encoding="utf-8")
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
uploaded_paths = []
def fail_upload(path):
uploaded_paths.append(path)
raise RuntimeError("upload failed")
provider._client.upload_temp_file.side_effect = fail_upload
with pytest.raises(RuntimeError, match="upload failed"):
provider._tool_add_resource({"url": str(docs)})
assert uploaded_paths
assert not uploaded_paths[0].exists()
provider._client.post.assert_not_called()
def test_tool_add_resource_rejects_missing_local_path(tmp_path):
missing = tmp_path / "missing.md"
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
result = json.loads(provider._tool_add_resource({"url": str(missing)}))
assert result["error"] == f"Local resource path does not exist: {missing}"
provider._client.upload_temp_file.assert_not_called()
provider._client.post.assert_not_called()
def test_tool_add_resource_sends_remote_url_as_path():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/remote"},
}
provider._tool_add_resource({"url": "https://example.com/doc.md"})
provider._client.upload_temp_file.assert_not_called()
provider._client.post.assert_called_once_with("/api/v1/resources", {
"path": "https://example.com/doc.md",
})
@pytest.mark.parametrize("url", [
"git@github.com:org/repo.git",
"git@ssh.dev.azure.com:v3/org/project/repo",
"ssh://git@github.com/org/repo.git",
"git://github.com/org/repo.git",
])
def test_tool_add_resource_sends_git_remote_sources_as_path(url):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._client.post.return_value = {
"status": "ok",
"result": {"root_uri": "viking://resources/repo"},
}
provider._tool_add_resource({"url": url})
provider._client.upload_temp_file.assert_not_called()
provider._client.post.assert_called_once_with("/api/v1/resources", {
"path": url,
})
def test_viking_client_upload_temp_file_uses_multipart_identity_headers(tmp_path, monkeypatch):
sample = tmp_path / "sample.md"
sample.write_text("# Local resource\n", encoding="utf-8")
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="test-account",
user="test-user",
agent="test-agent",
)
captured_kwargs = {}
def capture_httpx_post(url, **kwargs):
captured_kwargs.update(kwargs)
return SimpleNamespace(
status_code=200,
text="",
json=lambda: {"status": "ok", "result": {"temp_file_id": "upload_sample.md"}},
raise_for_status=lambda: None,
)
monkeypatch.setattr(client._httpx, "post", capture_httpx_post)
assert client.upload_temp_file(sample) == "upload_sample.md"
assert "files" in captured_kwargs
assert "json" not in captured_kwargs
headers = captured_kwargs["headers"]
assert headers["X-OpenViking-Account"] == "test-account"
assert headers["X-OpenViking-User"] == "test-user"
assert headers["X-OpenViking-Agent"] == "test-agent"
assert headers["X-API-Key"] == "test-key"
assert "Content-Type" not in headers
def test_viking_client_raises_structured_server_error():
client = _VikingClient.__new__(_VikingClient)
response = SimpleNamespace(
status_code=403,
text='{"status":"error"}',
json=lambda: {
"status": "error",
"error": {
"code": "PERMISSION_DENIED",
"message": "direct host filesystem paths are not allowed",
},
},
raise_for_status=lambda: None,
)
with pytest.raises(RuntimeError, match="PERMISSION_DENIED"):
client._parse_response(response)
def test_viking_client_headers_include_bearer_when_api_key_set():
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="acct",
user="usr",
agent="hermes",
)
headers = client._headers()
assert headers["X-API-Key"] == "test-key"
assert headers["Authorization"] == "Bearer test-key"
def test_viking_client_headers_send_tenant_when_default():
# account/user set to the literal string "default". OpenViking 0.3.x
# requires X-OpenViking-Account and X-OpenViking-User for ROOT API key
# requests to tenant-scoped APIs — omitting them causes
# INVALID_ARGUMENT errors even when account="default".
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="default",
user="default",
agent="hermes",
)
headers = client._headers()
assert headers["X-OpenViking-Account"] == "default"
assert headers["X-OpenViking-User"] == "default"
assert headers["X-OpenViking-Agent"] == "hermes"
assert headers["Authorization"] == "Bearer test-key"
def test_viking_client_headers_send_tenant_when_empty_falls_back_to_default(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
# Empty account/user strings fall back to "default" via the constructor.
# Headers are sent even for the default value — ROOT API keys need them.
client = _VikingClient(
"https://example.com",
api_key="",
account="",
user="",
agent="hermes",
)
headers = client._headers()
assert headers["X-OpenViking-Account"] == "default"
assert headers["X-OpenViking-User"] == "default"
assert "Authorization" not in headers
assert "X-API-Key" not in headers
def test_viking_client_headers_sent_with_real_tenant_values():
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="real-account",
user="real-user",
agent="hermes",
)
headers = client._headers()
assert headers["X-OpenViking-Account"] == "real-account"
assert headers["X-OpenViking-User"] == "real-user"
def test_viking_client_health_sends_auth_headers(monkeypatch):
_clear_openviking_tenant_env(monkeypatch)
client = _VikingClient(
"https://example.com",
api_key="test-key",
account="",
user="",
agent="hermes",
)
captured = {}
def capture_get(url, **kwargs):
captured["url"] = url
captured["headers"] = kwargs.get("headers") or {}
return SimpleNamespace(status_code=200)
monkeypatch.setattr(client._httpx, "get", capture_get)
assert client.health() is True
assert captured["url"] == "https://example.com/health"
assert captured["headers"]["Authorization"] == "Bearer test-key"
# ---------------------------------------------------------------------------
# on_session_switch — flush + commit + rotate behavior (hermes-agent#28296)
# ---------------------------------------------------------------------------
def _make_provider_with_session(session_id: str, turn_count: int):
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._session_id = session_id
provider._turn_count = turn_count
return provider
def test_on_session_switch_commits_old_session_and_rotates_id():
provider = _make_provider_with_session("old-sid", turn_count=3)
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_skips_commit_for_empty_old_session():
"""No turns accumulated → nothing to extract → no commit call."""
provider = _make_provider_with_session("old-sid", turn_count=0)
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_switch("new-sid")
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_rewound_same_session_only_invalidates_prefetch():
provider = _make_provider_with_session("same-sid", turn_count=3)
provider._prefetch_generation = 9
provider._prefetch_result = "stale recall"
provider.on_session_switch("same-sid", rewound=True)
provider._client.get.assert_not_called()
provider._client.post.assert_not_called()
assert provider._session_id == "same-sid"
assert provider._turn_count == 3
assert provider._prefetch_generation == 10
assert provider._prefetch_result == ""
def test_on_session_switch_clears_stale_prefetch_result():
provider = _make_provider_with_session("old-sid", turn_count=1)
provider._prefetch_result = "stale recall from old session"
provider.on_session_switch("new-sid")
assert provider._prefetch_result == ""
def test_on_session_switch_waits_for_inflight_sync_thread():
"""In-flight sync_turn write must drain before the commit fires —
otherwise the commit can race the last message write."""
provider = _make_provider_with_session("old-sid", turn_count=2)
join_calls = []
class FakeThread:
def __init__(self):
self._alive = True
def is_alive(self):
return self._alive
def join(self, timeout=None):
join_calls.append(timeout)
# Simulate a worker that finishes within the join window.
self._alive = False
provider._inflight_writers["old-sid"] = {FakeThread()}
provider.on_session_switch("new-sid")
assert join_calls, "expected on_session_switch to join the in-flight sync thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_noop_on_empty_new_id():
provider = _make_provider_with_session("old-sid", turn_count=5)
provider.on_session_switch("")
provider.on_session_switch(" ")
provider._client.post.assert_not_called()
assert provider._session_id == "old-sid"
assert provider._turn_count == 5
def test_on_session_switch_noop_when_client_missing():
provider = OpenVikingMemoryProvider()
provider._client = None
provider._session_id = "old-sid"
provider._turn_count = 4
# Must not raise even though no client is configured.
provider.on_session_switch("new-sid")
# State stays untouched — provider is effectively disabled.
assert provider._session_id == "old-sid"
assert provider._turn_count == 4
def test_sync_turn_captures_session_id_before_worker_runs():
"""Worker must use the session id snapshotted at sync_turn() call time, not
re-read self._session_id later — otherwise a delayed worker can write the
previous turn's messages into the rotated-in NEW session."""
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
started = threading.Event()
release = threading.Event()
captured_paths = []
captured_payloads = []
def fake_post(path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
captured_paths.append(path)
captured_payloads.append(payload)
return {}
# Patch _VikingClient inside the worker by stubbing post on a client
# the constructor will produce. Easiest path: monkeypatch the class.
real_client_cls = _VikingClient
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
return fake_post(path, payload, **kwargs)
import plugins.memory.openviking as _mod
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
# Wait until the worker is parked inside the first post call.
assert started.wait(timeout=2.0), "worker never entered post()"
# Rotate the provider's session id while the worker is mid-flight.
provider._session_id = "new-sid"
release.set()
for t in list(provider._inflight_writers.get("old-sid", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# The whole turn must target the OLD session id as a single ordered batch.
assert captured_paths == ["/api/v1/sessions/old-sid/messages/batch"]
assert captured_payloads == [{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
}]
def test_sync_turn_retries_batch_write_with_fresh_client():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
clients = []
captured = []
class StubClient:
def __init__(self, *a, **kw):
self.index = len(clients)
clients.append(self)
def post(self, path, payload=None, **kwargs):
if self.index == 0:
raise RuntimeError("transient")
captured.append((path, payload))
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert provider._drain_writers("sid-1", timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert len(clients) == 2
assert captured == [(
"/api/v1/sessions/sid-1/messages/batch",
{
"messages": [
{"role": "user", "parts": [{"type": "text", "text": "u"}]},
{"role": "assistant", "parts": [{"type": "text", "text": "a"}]},
]
},
)]
def test_sync_turn_noop_when_session_id_blank():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._session_id = ""
provider.sync_turn("u", "a")
# No turn counted, no worker spawned.
assert provider._turn_count == 0
assert provider._inflight_writers == {}
def test_on_session_end_marks_session_clean_after_successful_commit():
"""After a successful commit on_session_end must reset _turn_count so a
subsequent on_session_switch (fired by /new and compression right after
commit_memory_session) skips its commit instead of double-committing."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider.on_session_end([])
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._turn_count == 0
def test_on_session_end_keeps_dirty_when_commit_fails():
"""If the commit fails, leave _turn_count > 0 so on_session_switch retries
rather than silently dropping extraction for the old session."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._client.post.side_effect = RuntimeError("commit boom")
provider.on_session_end([])
assert provider._turn_count == 3
def test_on_session_end_commits_pending_tokens_without_turn_count():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider._client.get.assert_called_once_with("/api/v1/sessions/old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_end_then_switch_does_not_double_commit():
"""Mirrors the /new and compression call order: commit_memory_session
(→ on_session_end) immediately followed by on_session_switch. The switch
must NOT issue a second commit on the same session id."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
# Exactly one commit call, on the OLD session, fired by on_session_end.
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_end_then_switch_with_pending_tokens_does_not_double_commit():
provider = _make_provider_with_session("old-sid", turn_count=0)
provider._client.get.return_value = {"result": {"pending_tokens": 42}}
provider.on_session_end([])
provider.on_session_switch("new-sid", parent_session_id="old-sid")
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_session_needs_commit_guard_wins_over_stale_turn_count():
"""Regression for hermes-agent#28296 review (M3): once a session is marked
committed, _session_needs_commit must return False even if turn_count is
still positive. A racing sync_turn can re-increment _turn_count after the
commit+reset; without the guard ordering, a follow-up finalizer would
double-commit the same session. The committed-guard must be checked BEFORE
the turn_count>0 shortcut."""
provider = _make_provider_with_session("old-sid", turn_count=5)
provider._mark_session_committed("old-sid")
# turn_count is a (stale) 5 but the session is already committed.
assert provider._session_needs_commit("old-sid", 5) is False
# An uncommitted session with turns still needs a commit.
assert provider._session_needs_commit("fresh-sid", 5) is True
def test_on_session_switch_swallows_commit_failure():
"""Commit-on-switch must not propagate exceptions: a failing commit on the
old session must still allow the rotate to the new session to complete,
otherwise subsequent sync_turn writes would land in the wrong session."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._client.post.side_effect = RuntimeError("commit boom")
provider.on_session_switch("new-sid")
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Hung-writer protection: the sync worker can outlive the bounded join
# because each OpenViking POST has _TIMEOUT=30s and there are two per turn.
# Committing while late writes are still in flight would orphan them past
# the commit boundary — they would never be extracted.
# ---------------------------------------------------------------------------
class _HungThread:
"""Thread stand-in that stays alive across joins."""
def is_alive(self):
return True
def join(self, timeout=None):
# Pretend the join timed out — worker still running.
return None
def test_on_session_end_skips_commit_when_sync_worker_outlives_join():
"""If the sync worker is still alive after the 10s join, the commit must
be skipped — late writes from the worker would otherwise land in an
already-committed session and never be extracted. Leave _turn_count
intact so the session stays marked dirty."""
provider = _make_provider_with_session("old-sid", turn_count=3)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 3
def test_on_session_switch_skips_commit_when_sync_worker_outlives_join():
"""Same hazard on the switch path. Rotation must still proceed (the new
session needs to start) but the old-session commit is skipped to avoid
orphaning the worker's late writes past commit."""
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# ---------------------------------------------------------------------------
# Orphaned-writer hazard: commit must wait for ALL writers for the session,
# not just the latest tracked one. sync_turn's bounded rate-limit can drop a
# still-alive previous worker — that dropped writer keeps POSTing under the
# old sid and would otherwise land its writes past the commit boundary.
# ---------------------------------------------------------------------------
def test_on_session_end_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_end([])
provider._client.post.assert_not_called()
assert provider._turn_count == 2
def test_on_session_switch_waits_for_all_writers_not_just_latest():
provider = _make_provider_with_session("old-sid", turn_count=2)
provider._inflight_writers["old-sid"] = {_HungThread()}
provider.on_session_switch("new-sid")
provider._client.post.assert_not_called()
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
def test_on_session_switch_does_not_block_caller_on_slow_drain():
"""Regression for hermes-agent#28296 review (H1): on_session_switch must
NOT run the old-session drain/commit on the caller's thread. /new, /branch,
/resume, /undo call this synchronously on the command thread, so a slow
writer drain (up to _SESSION_DRAIN_TIMEOUT/_DEFERRED_COMMIT_TIMEOUT) or a
wedged commit POST must not stall the user-facing command. The rotation is
cheap and synchronous; the commit is offloaded. Mirrors the #41945
'do not block the turn thread' contract."""
import threading
import time
provider = _make_provider_with_session("old-sid", turn_count=2)
drain_entered = threading.Event()
release_drain = threading.Event()
def slow_drain(sid, timeout):
drain_entered.set()
# Simulate a writer that takes a long time to drain.
release_drain.wait(timeout=10.0)
return True
provider._drain_writers = slow_drain
start = time.monotonic()
provider.on_session_switch("new-sid")
elapsed = time.monotonic() - start
# The caller returned promptly with state already rotated, even though the
# drain is still parked on the finalizer thread.
assert elapsed < 1.0, f"on_session_switch blocked the caller for {elapsed:.2f}s"
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
assert drain_entered.wait(timeout=2.0), "finalizer never started draining"
# No commit yet — drain is still blocked off-thread.
provider._client.post.assert_not_called()
# Let the finalizer finish so it doesn't leak past the test.
release_drain.set()
assert provider._drain_finalizers(timeout=5.0)
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
def test_on_session_switch_defers_old_commit_to_finalizer_thread():
"""The switch path rotates session state synchronously (cheap, in-memory)
but offloads the old-session drain + commit onto a daemon finalizer so the
caller's command thread (/new, /branch, /resume) never blocks on the up-to
-_DEFERRED_COMMIT_TIMEOUT drain or the commit POST. See hermes-agent#28296
review (the #41945 'do not block the turn thread' contract)."""
import threading
provider = _make_provider_with_session("old-sid", turn_count=2)
committed = threading.Event()
drain_timeouts = []
def fake_post(path):
committed.set()
return {}
def fake_drain(sid, timeout):
drain_timeouts.append(timeout)
return True
provider._client.post.side_effect = fake_post
provider._drain_writers = fake_drain
provider.on_session_switch("new-sid")
# Rotation is synchronous and immediate — the new session is live at once.
assert provider._session_id == "new-sid"
assert provider._turn_count == 0
# The old-session commit lands on the finalizer thread, not inline.
assert committed.wait(timeout=5.0), "old session was not finalized off-thread"
provider._client.post.assert_called_once_with("/api/v1/sessions/old-sid/commit")
# The finalizer drains with the deferred (longer) budget, not inline 10s.
assert drain_timeouts == [_DEFERRED_COMMIT_TIMEOUT]
def test_sync_turn_tracks_writer_under_session_id():
"""Every sync_turn writer must register under its captured sid so the
drain at end/switch sees it even if a later sync_turn replaces the
latest-tracked reference."""
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "sid-1"
release = threading.Event()
started = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.sync_turn("u", "a")
assert started.wait(timeout=2.0), "worker never entered post()"
assert len(provider._inflight_writers.get("sid-1", set())) == 1
release.set()
for t in list(provider._inflight_writers.get("sid-1", set())):
t.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# Worker should have removed itself from the inflight set on exit.
assert provider._inflight_writers.get("sid-1", set()) == set()
# ---------------------------------------------------------------------------
# on_memory_write: explicit memory writes use content/write and stay outside
# the session transcript/commit boundary.
# ---------------------------------------------------------------------------
def test_on_memory_write_uses_content_write_independent_of_session_rotation():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
in_ctor = threading.Event()
release = threading.Event()
done = threading.Event()
captured_paths = []
captured_payloads = []
class StubClient:
def __init__(self, *a, **kw):
in_ctor.set()
release.wait(timeout=2.0)
def post(self, path, payload=None, **kwargs):
captured_paths.append(path)
captured_payloads.append(payload)
done.set()
return {}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.on_memory_write("add", "user", "remember this")
assert in_ctor.wait(timeout=2.0), "worker never entered ctor"
# Rotate provider's session id while the worker is parked. Memory writes
# must not become session messages in either the old or new session.
provider._session_id = "new-sid"
release.set()
assert done.wait(timeout=2.0), "worker never reached post()"
finally:
_mod._VikingClient = real_client_cls
assert captured_paths == ["/api/v1/content/write"]
assert captured_payloads[0]["content"] == "remember this"
assert captured_payloads[0]["mode"] == "create"
assert captured_payloads[0]["uri"].startswith(
"viking://user/usr/agent/hermes/memories/preferences/mem_"
)
# ---------------------------------------------------------------------------
# Prefetch staleness: a prefetch worker that finishes AFTER a session switch
# must drop its result instead of repopulating the new session with stale
# recall from the old generation. Bump the generation directly (rather than
# calling on_session_switch, whose own join blocks on the test worker) so
# the test isolates the generation-gating behavior.
# ---------------------------------------------------------------------------
def test_queue_prefetch_drops_result_when_generation_changed_mid_flight():
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
provider._session_id = "old-sid"
started = threading.Event()
release = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
started.set()
release.wait(timeout=2.0)
return {
"result": {
"memories": [
{"uri": "viking://memories/old", "score": 0.9,
"abstract": "stale from old session"},
],
"resources": [],
}
}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
assert started.wait(timeout=2.0), "prefetch worker never entered post()"
# Simulate a session switch by bumping the generation directly.
# The worker captured the pre-bump generation when it was spawned.
provider._prefetch_generation += 1
release.set()
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
# The stale result from the pre-bump generation must NOT have been written
# into the new generation's prefetch slot.
assert provider._prefetch_result == ""
def test_queue_prefetch_sends_limit_not_legacy_top_k():
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
captured_payloads = []
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
captured_payloads.append(payload)
return {"result": {"memories": [], "resources": []}}
import plugins.memory.openviking as _mod
real_client_cls = _mod._VikingClient
_mod._VikingClient = StubClient
try:
provider.queue_prefetch("anything")
if provider._prefetch_thread:
provider._prefetch_thread.join(timeout=2.0)
finally:
_mod._VikingClient = real_client_cls
assert captured_payloads == [{"query": "anything", "limit": 5}]
assert "top_k" not in captured_payloads[0]