hermes-agent/plugins/memory/openviking/__init__.py
kshitijk4poor c835448908 fix(openviking): don't block the command thread on session switch; lock turn state
Follow-up hardening on @ehz0ah / @harshitAgr's session-switch work (#28296):

- on_session_switch no longer runs the old-session writer-drain + pending-token
  GET + commit POST inline on the caller's command thread. /new, /branch,
  /resume, /undo call it synchronously, so a slow drain (up to 10s) or wedged
  commit blocked the user-facing command — the same hazard #41945 fixed for
  end-of-turn sync. State now rotates synchronously (cheap) and the old-session
  commit is offloaded to a daemon finalizer (generalized _finalize_session_async).
- Guard the (_session_id, _turn_count) pair with _session_state_lock: sync_turn
  runs on the memory-manager executor thread while the session hooks run on the
  command thread, so the snapshot+reset vs increment was a cross-thread race.
- _session_needs_commit checks the committed-session guard BEFORE the
  turn_count>0 shortcut, closing a double-commit window when a racing sync_turn
  re-increments after commit+reset.
- Add a _shutting_down flag so deferred finalizers stop POSTing against a
  torn-down client; track all prefetch threads in a set so invalidate/shutdown
  join every one, not just the latest slot.

Tests: regression for the non-blocking switch (asserts the caller returns while
a slow drain is parked off-thread) and the committed-guard ordering; updated the
deferred-commit test to the unified finalizer contract.
2026-06-18 00:21:21 +05:30

1351 lines
53 KiB
Python

"""OpenViking memory plugin — full bidirectional MemoryProvider interface.
Context database by Volcengine (ByteDance) that organizes agent knowledge
into a filesystem hierarchy (viking:// URIs) with tiered context loading,
automatic memory extraction, and session management.
Original PR #3369 by Mibayy, rewritten to use the full OpenViking session
lifecycle instead of read-only search endpoints.
Config via environment variables (profile-scoped via each profile's .env):
OPENVIKING_ENDPOINT — Server URL (default: http://127.0.0.1:1933)
OPENVIKING_API_KEY — API key (required for authenticated servers)
OPENVIKING_ACCOUNT — Tenant account (default: default)
OPENVIKING_USER — Tenant user (default: default)
OPENVIKING_AGENT — Tenant agent (default: hermes)
Capabilities:
- Automatic memory extraction on session commit (6 categories)
- Tiered context: L0 (~100 tokens), L1 (~2k), L2 (full)
- Semantic search with hierarchical directory retrieval
- Filesystem-style browsing via viking:// URIs
- Resource ingestion (URLs, docs, code)
"""
from __future__ import annotations
import atexit
import json
import logging
import mimetypes
import os
import tempfile
import threading
import time
import uuid
import zipfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set
from urllib.parse import urlparse
from urllib.request import url2pathname
from agent.memory_provider import MemoryProvider
from agent.skill_commands import extract_user_instruction_from_skill_message
from tools.registry import tool_error
logger = logging.getLogger(__name__)
_DEFAULT_ENDPOINT = "http://127.0.0.1:1933"
_TIMEOUT = 30.0
_SESSION_DRAIN_TIMEOUT = 10.0
_DEFERRED_COMMIT_TIMEOUT = (_TIMEOUT * 2) + 5.0
_REMOTE_RESOURCE_PREFIXES = ("http://", "https://", "git@", "ssh://", "git://")
# Maps the viking_remember `category` enum to a viking:// subdirectory.
# Keep in sync with REMEMBER_SCHEMA.parameters.properties.category.enum.
_CATEGORY_SUBDIR_MAP = {
"preference": "preferences",
"entity": "entities",
"event": "events",
"case": "cases",
"pattern": "patterns",
}
_DEFAULT_MEMORY_SUBDIR = "preferences"
# Maps the built-in memory tool's `target` ("user" vs "memory") to a subdir
# for on_memory_write mirroring. User profile facts → preferences; agent
# notes / observations → patterns. Anything unknown falls back to the default.
_MEMORY_WRITE_TARGET_SUBDIR_MAP = {
"user": "preferences",
"memory": "patterns",
}
def _derive_openviking_user_text(content: Any) -> str:
"""Strip Hermes slash-skill scaffolding before sending content to OpenViking.
Defense-in-depth: MemoryManager already strips skill scaffolding for the
whole provider fan-out (see ``MemoryManager._strip_skill_scaffolding``), so
in normal operation this receives already-clean text and passes it through
unchanged. It stays here so OpenViking is correct if its hooks are ever
invoked outside the manager. Delegates to the canonical extractor in
``agent.skill_commands`` — no duplicated marker literals, no drift risk.
"""
return extract_user_instruction_from_skill_message(content) or ""
# ---------------------------------------------------------------------------
# Process-level atexit safety net — ensures pending sessions are committed
# even if shutdown_memory_provider is never called (e.g. gateway crash,
# SIGKILL, or exception in the session expiry watcher preventing shutdown).
# ---------------------------------------------------------------------------
_last_active_provider: Optional["OpenVikingMemoryProvider"] = None
def _atexit_commit_sessions():
"""Fire on_session_end for the last active provider on process exit."""
global _last_active_provider
provider = _last_active_provider
if provider is None:
return
_last_active_provider = None
try:
provider.on_session_end([])
except Exception:
pass # best-effort at shutdown time
atexit.register(_atexit_commit_sessions)
# ---------------------------------------------------------------------------
# HTTP helper — uses httpx to avoid requiring the openviking SDK
# ---------------------------------------------------------------------------
def _get_httpx():
"""Lazy import httpx."""
try:
import httpx
return httpx
except ImportError:
return None
class _VikingClient:
"""Thin HTTP client for the OpenViking REST API."""
def __init__(self, endpoint: str, api_key: str = "",
account: str = "", user: str = "", agent: str = ""):
self._endpoint = endpoint.rstrip("/")
self._api_key = api_key
self._account = account or os.environ.get("OPENVIKING_ACCOUNT", "default")
self._user = user or os.environ.get("OPENVIKING_USER", "default")
self._agent = agent or os.environ.get("OPENVIKING_AGENT", "hermes")
self._httpx = _get_httpx()
if self._httpx is None:
raise ImportError("httpx is required for OpenViking: pip install httpx")
def _headers(self) -> dict:
# Always send tenant headers when account/user are configured.
# OpenViking 0.3.x requires X-OpenViking-Account and X-OpenViking-User
# for ROOT API key requests to tenant-scoped APIs — omitting them
# causes INVALID_ARGUMENT errors even when account="default".
# User-level keys can omit them (server derives tenancy from the key),
# but ROOT keys must always include them explicitly.
h = {
"Content-Type": "application/json",
"X-OpenViking-Agent": self._agent,
}
if self._account:
h["X-OpenViking-Account"] = self._account
if self._user:
h["X-OpenViking-User"] = self._user
if self._api_key:
h["X-API-Key"] = self._api_key
h["Authorization"] = "Bearer " + self._api_key
return h
def _url(self, path: str) -> str:
return f"{self._endpoint}{path}"
def _multipart_headers(self) -> dict:
headers = self._headers()
headers.pop("Content-Type", None)
return headers
def _parse_response(self, resp) -> dict:
try:
data = resp.json()
except Exception:
data = None
if resp.status_code >= 400:
if isinstance(data, dict):
error = data.get("error")
if isinstance(error, dict):
code = error.get("code", "HTTP_ERROR")
message = error.get("message", resp.text)
raise RuntimeError(f"{code}: {message}")
if data.get("status") == "error":
raise RuntimeError(str(data))
resp.raise_for_status()
if isinstance(data, dict) and data.get("status") == "error":
error = data.get("error")
if isinstance(error, dict):
code = error.get("code", "OPENVIKING_ERROR")
message = error.get("message", "")
raise RuntimeError(f"{code}: {message}")
raise RuntimeError(str(data))
if data is None:
return {}
return data
def get(self, path: str, **kwargs) -> dict:
resp = self._httpx.get(
self._url(path), headers=self._headers(), timeout=_TIMEOUT, **kwargs
)
return self._parse_response(resp)
def post(self, path: str, payload: dict = None, **kwargs) -> dict:
resp = self._httpx.post(
self._url(path), json=payload or {}, headers=self._headers(),
timeout=_TIMEOUT, **kwargs
)
return self._parse_response(resp)
def upload_temp_file(self, file_path: Path) -> str:
mime_type = mimetypes.guess_type(file_path.name)[0] or "application/octet-stream"
with file_path.open("rb") as f:
resp = self._httpx.post(
self._url("/api/v1/resources/temp_upload"),
files={"file": (file_path.name, f, mime_type)},
headers=self._multipart_headers(),
timeout=_TIMEOUT,
)
data = self._parse_response(resp)
result = data.get("result", {})
temp_file_id = result.get("temp_file_id", "")
if not temp_file_id:
raise RuntimeError("OpenViking temp upload did not return temp_file_id")
return temp_file_id
def health(self) -> bool:
try:
resp = self._httpx.get(
self._url("/health"), headers=self._headers(), timeout=3.0
)
return resp.status_code == 200
except Exception:
return False
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
SEARCH_SCHEMA = {
"name": "viking_search",
"description": (
"Semantic search over the OpenViking knowledge base. "
"Returns ranked results with viking:// URIs for deeper reading. "
"Use mode='deep' for complex queries that need reasoning across "
"multiple sources, 'fast' for simple lookups."
),
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query."},
"mode": {
"type": "string", "enum": ["auto", "fast", "deep"],
"description": "Search depth (default: auto).",
},
"scope": {
"type": "string",
"description": "Viking URI prefix to scope search (e.g. 'viking://resources/docs/').",
},
"limit": {"type": "integer", "description": "Max results (default: 10)."},
},
"required": ["query"],
},
}
READ_SCHEMA = {
"name": "viking_read",
"description": (
"Read content at a viking:// URI. Three detail levels:\n"
" abstract — ~100 token summary (L0)\n"
" overview — ~2k token key points (L1)\n"
" full — complete content (L2)\n"
"Start with abstract/overview, only use full when you need details."
),
"parameters": {
"type": "object",
"properties": {
"uri": {"type": "string", "description": "viking:// URI to read."},
"level": {
"type": "string", "enum": ["abstract", "overview", "full"],
"description": "Detail level (default: overview).",
},
},
"required": ["uri"],
},
}
BROWSE_SCHEMA = {
"name": "viking_browse",
"description": (
"Browse the OpenViking knowledge store like a filesystem.\n"
" list — show directory contents\n"
" tree — show hierarchy\n"
" stat — show metadata for a URI"
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string", "enum": ["tree", "list", "stat"],
"description": "Browse action.",
},
"path": {
"type": "string",
"description": "Viking URI path (default: viking://). Examples: 'viking://resources/', 'viking://user/memories/'.",
},
},
"required": ["action"],
},
}
REMEMBER_SCHEMA = {
"name": "viking_remember",
"description": (
"Explicitly store a fact or memory in the OpenViking knowledge base. "
"Use for important information the agent should remember long-term. "
"The system automatically categorizes and indexes the memory."
),
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "The information to remember."},
"category": {
"type": "string",
"enum": ["preference", "entity", "event", "case", "pattern"],
"description": "Memory category (default: auto-detected).",
},
},
"required": ["content"],
},
}
ADD_RESOURCE_SCHEMA = {
"name": "viking_add_resource",
"description": (
"Add a remote URL or local file/directory to the OpenViking knowledge base. "
"Remote resources must be public http(s), git, or ssh URLs. "
"Local files are uploaded first using OpenViking temp_upload. "
"The system automatically parses, indexes, and generates summaries."
),
"parameters": {
"type": "object",
"properties": {
"url": {"type": "string", "description": "Remote URL or local file/directory path to add."},
"reason": {
"type": "string",
"description": "Why this resource is relevant (improves search).",
},
"to": {
"type": "string",
"description": "Optional target viking:// URI for the resource.",
},
"parent": {
"type": "string",
"description": "Optional parent viking:// URI. Cannot be used with to.",
},
"instruction": {
"type": "string",
"description": "Optional processing instruction for semantic extraction.",
},
"wait": {
"type": "boolean",
"description": "Whether to wait for processing to complete.",
},
"timeout": {
"type": "number",
"description": "Timeout in seconds when wait is true.",
},
},
"required": ["url"],
},
}
def _zip_directory(dir_path: Path) -> Path:
"""Create a temporary zip file containing a directory tree."""
root = dir_path.resolve()
zip_path = Path(tempfile.gettempdir()) / f"openviking_upload_{uuid.uuid4().hex}.zip"
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for file_path in dir_path.rglob("*"):
if file_path.is_symlink():
continue
if file_path.is_file():
try:
file_path.resolve().relative_to(root)
except ValueError:
continue
arcname = str(file_path.relative_to(dir_path)).replace("\\", "/")
zipf.write(file_path, arcname=arcname)
return zip_path
def _is_windows_absolute_path(value: str) -> bool:
return (
len(value) >= 3
and value[0].isalpha()
and value[1] == ":"
and value[2] in {"/", "\\"}
)
def _is_remote_resource_source(value: str) -> bool:
return value.startswith(_REMOTE_RESOURCE_PREFIXES)
def _is_local_path_reference(value: str) -> bool:
if not value or "\n" in value or "\r" in value:
return False
if _is_remote_resource_source(value):
return False
if _is_windows_absolute_path(value):
return True
return (
value.startswith(("/", "./", "../", "~/", ".\\", "..\\", "~\\"))
or "/" in value
or "\\" in value
)
def _path_from_file_uri(uri: str) -> Path | str:
parsed = urlparse(uri)
if parsed.netloc not in {"", "localhost"}:
return f"Unsupported non-local file URI: {uri}"
return Path(url2pathname(parsed.path)).expanduser()
# ---------------------------------------------------------------------------
# MemoryProvider implementation
# ---------------------------------------------------------------------------
class OpenVikingMemoryProvider(MemoryProvider):
"""Full bidirectional memory via OpenViking context database."""
def __init__(self):
self._client: Optional[_VikingClient] = None
self._endpoint = ""
self._api_key = ""
self._session_id = ""
self._turn_count = 0
# Guards the (_session_id, _turn_count) pair. sync_turn runs on the
# MemoryManager's background sync executor while on_session_end /
# on_session_switch run on the caller's thread, so the snapshot+reset
# of the turn counter and the session-id rotation must be atomic
# against a concurrent increment. See hermes-agent#28296 review.
self._session_state_lock = threading.Lock()
# Commit only after session writes drain. The set is keyed by the sid
# the writer is POSTing under (snapshotted at spawn), so on_session_end
# / on_session_switch see every still-alive writer for that sid even
# if later writes have replaced the latest-tracked thread.
self._inflight_writers: Dict[str, Set[threading.Thread]] = {}
self._inflight_lock = threading.Lock()
self._deferred_commit_sids: Set[str] = set()
self._deferred_commit_threads: Set[threading.Thread] = set()
self._deferred_commit_lock = threading.Lock()
self._committed_session_ids: Set[str] = set()
self._committed_session_lock = threading.Lock()
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread: Optional[threading.Thread] = None
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
# shutdown() can drain them and rapid re-queues don't orphan a still-
# running thread by overwriting the single _prefetch_thread slot.
self._prefetch_threads: Set[threading.Thread] = set()
# Set on shutdown so deferred-commit / writer finalizers stop issuing
# network writes against a torn-down provider.
self._shutting_down = False
# Drop prefetch results from older switch generations.
self._prefetch_generation = 0
@property
def name(self) -> str:
return "openviking"
def is_available(self) -> bool:
"""Check if OpenViking endpoint is configured. No network calls."""
return bool(os.environ.get("OPENVIKING_ENDPOINT"))
def get_config_schema(self):
return [
{
"key": "endpoint",
"description": "OpenViking server URL",
"required": True,
"default": _DEFAULT_ENDPOINT,
"env_var": "OPENVIKING_ENDPOINT",
},
{
"key": "api_key",
"description": "OpenViking API key (leave blank for local dev mode)",
"secret": True,
"env_var": "OPENVIKING_API_KEY",
},
{
"key": "account",
"description": "OpenViking tenant account ID ([default], used when local mode, OPENVIKING_API_KEY is empty)",
"default": "default",
"env_var": "OPENVIKING_ACCOUNT",
},
{
"key": "user",
"description": "OpenViking user ID within the account ([default], used when local mode, OPENVIKING_API_KEY is empty)",
"default": "default",
"env_var": "OPENVIKING_USER",
},
{
"key": "agent",
"description": "OpenViking agent ID within the account ([hermes], useful in multi-agent mode)",
"default": "hermes",
"env_var": "OPENVIKING_AGENT",
},
]
def initialize(self, session_id: str, **kwargs) -> None:
self._endpoint = os.environ.get("OPENVIKING_ENDPOINT", _DEFAULT_ENDPOINT)
self._api_key = os.environ.get("OPENVIKING_API_KEY", "")
self._account = os.environ.get("OPENVIKING_ACCOUNT", "default")
self._user = os.environ.get("OPENVIKING_USER", "default")
self._agent = os.environ.get("OPENVIKING_AGENT", "hermes")
self._session_id = session_id
self._turn_count = 0
try:
self._client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
)
if not self._client.health():
logger.warning("OpenViking server at %s is not reachable", self._endpoint)
self._client = None
except ImportError:
logger.warning("httpx not installed — OpenViking plugin disabled")
self._client = None
# Register as the last active provider for atexit safety net
global _last_active_provider
_last_active_provider = self
def system_prompt_block(self) -> str:
if not self._client:
return ""
# Provide brief info about the knowledge base
try:
# Check what's in the knowledge base via a root listing
resp = self._client.get("/api/v1/fs/ls", params={"uri": "viking://"})
result = resp.get("result", [])
children = len(result) if isinstance(result, list) else 0
if children == 0:
return ""
return (
"# OpenViking Knowledge Base\n"
f"Active. Endpoint: {self._endpoint}\n"
"Use viking_search to find information, viking_read for details "
"(abstract/overview/full), viking_browse to explore.\n"
"Use viking_remember to store facts, viking_add_resource to index URLs/docs."
)
except Exception as e:
logger.warning("OpenViking system_prompt_block failed: %s", e)
return (
"# OpenViking Knowledge Base\n"
f"Active. Endpoint: {self._endpoint}\n"
"Use viking_search, viking_read, viking_browse, "
"viking_remember, viking_add_resource."
)
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Return prefetched results from the background thread."""
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
with self._prefetch_lock:
result = self._prefetch_result
self._prefetch_result = ""
if not result:
return ""
return f"## OpenViking Context\n{result}"
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
"""Fire a background search to pre-load relevant context."""
query = _derive_openviking_user_text(query)
if not self._client or not query:
return
# Drop prefetch results from older switch generations.
with self._prefetch_lock:
gen = self._prefetch_generation
holder: List[threading.Thread] = []
def _run():
try:
client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
)
resp = client.post("/api/v1/search/find", {
"query": query,
"limit": 5,
})
result = resp.get("result", {})
parts = []
for ctx_type in ("memories", "resources"):
items = result.get(ctx_type, [])
for item in items[:3]:
uri = item.get("uri", "")
abstract = item.get("abstract", "")
score = item.get("score", 0)
if abstract:
parts.append(f"- [{score:.2f}] {abstract} ({uri})")
if parts:
with self._prefetch_lock:
if gen != self._prefetch_generation:
return
self._prefetch_result = "\n".join(parts)
except Exception as e:
logger.debug("OpenViking prefetch failed: %s", e)
finally:
with self._prefetch_lock:
if holder:
self._prefetch_threads.discard(holder[0])
thread = threading.Thread(
target=_run, daemon=True, name="openviking-prefetch"
)
holder.append(thread)
with self._prefetch_lock:
self._prefetch_thread = thread
self._prefetch_threads.add(thread)
thread.start()
def _spawn_writer(self, sid: str, target: Callable[[], None], name: str) -> None:
"""Spawn a daemon writer tracked in _inflight_writers[sid].
Tracking is keyed by sid (not by a single latest-thread slot) so that
on_session_end / on_session_switch can drain every still-alive writer
for the session being committed.
"""
holder: List[threading.Thread] = []
def _wrapped():
try:
target()
finally:
with self._inflight_lock:
workers = self._inflight_writers.get(sid)
if workers is not None:
workers.discard(holder[0])
if not workers:
self._inflight_writers.pop(sid, None)
thread = threading.Thread(target=_wrapped, daemon=True, name=name)
holder.append(thread)
with self._inflight_lock:
self._inflight_writers.setdefault(sid, set()).add(thread)
thread.start()
def _drain_finalizers(self, timeout: float) -> bool:
"""Join every in-flight async session finalizer within a timeout.
The switch-path commit runs on a daemon finalizer thread so it never
blocks the caller's command thread; this lets shutdown and tests wait
for those commits deterministically. Returns True if all drained.
"""
deadline = time.monotonic() + timeout
while True:
with self._deferred_commit_lock:
workers = [t for t in self._deferred_commit_threads if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
# Floor the per-join wait so a thread whose join() returns
# instantly while still reporting alive can't hot-spin this loop.
t.join(timeout=min(slice_left, 0.05))
def _drain_writers(self, sid: str, timeout: float) -> bool:
"""Join every in-flight writer for sid within a shared timeout budget.
Returns True if all writers drained, False if any are still alive when
the budget runs out. Callers use the False return to skip the commit.
"""
if not sid:
return True
deadline = time.monotonic() + timeout
while True:
with self._inflight_lock:
workers = [t for t in self._inflight_writers.get(sid, ()) if t.is_alive()]
if not workers:
return True
remaining = deadline - time.monotonic()
if remaining <= 0:
return False
for t in workers:
slice_left = deadline - time.monotonic()
if slice_left <= 0:
break
t.join(timeout=slice_left)
def _new_client(self) -> _VikingClient:
return _VikingClient(
self._endpoint,
self._api_key,
account=self._account,
user=self._user,
agent=self._agent,
)
@staticmethod
def _text_part(content: str) -> Dict[str, str]:
return {"type": "text", "text": content}
@classmethod
def _turn_batch_payload(cls, user_content: str, assistant_content: str) -> Dict[str, Any]:
return {
"messages": [
{"role": "user", "parts": [cls._text_part(user_content)]},
{"role": "assistant", "parts": [cls._text_part(assistant_content)]},
]
}
@classmethod
def _post_session_turn(
cls,
client: _VikingClient,
sid: str,
user_content: str,
assistant_content: str,
) -> None:
client.post(
f"/api/v1/sessions/{sid}/messages/batch",
cls._turn_batch_payload(user_content, assistant_content),
)
def _session_has_pending_tokens(self, sid: str) -> bool:
try:
response = self._client.get(f"/api/v1/sessions/{sid}")
except Exception:
return False
session = self._unwrap_result(response)
if not isinstance(session, dict):
return False
try:
return int(session.get("pending_tokens") or 0) > 0
except (TypeError, ValueError):
return False
def _has_committed_session(self, sid: str) -> bool:
with self._committed_session_lock:
return sid in self._committed_session_ids
def _mark_session_committed(self, sid: str) -> None:
with self._committed_session_lock:
self._committed_session_ids.add(sid)
def _session_needs_commit(self, sid: str, turn_count: int) -> bool:
# Already-committed sessions never need a second commit, regardless of
# the turn counter — a racing sync_turn can re-increment _turn_count
# after a commit+reset, so the committed-guard must win over turn_count.
if self._has_committed_session(sid):
return False
if turn_count > 0:
return True
return self._session_has_pending_tokens(sid)
def _commit_session(self, sid: str, turn_count: int, *, context: str) -> bool:
try:
self._client.post(f"/api/v1/sessions/{sid}/commit")
self._mark_session_committed(sid)
logger.info("OpenViking session %s committed %s (%d turns)", sid, context, turn_count)
return True
except Exception as e:
logger.warning("OpenViking session commit failed for %s: %s", sid, e)
return False
def _finalize_session_async(self, sid: str, turn_count: int, *, context: str) -> None:
"""Drain the old session's writers and commit it on a daemon thread.
Used by on_session_switch (and the deferred-commit fallback) so the
potentially-multi-second drain + pending-token GET + commit POST never
runs on the caller's command thread. Deduped by sid so a rapid second
switch can't stack two finalizers for the same session, and a no-op
once shutdown has begun so we don't POST against a torn-down client.
"""
if not sid:
return
with self._deferred_commit_lock:
if self._shutting_down or sid in self._deferred_commit_sids:
return
self._deferred_commit_sids.add(sid)
holder: List[threading.Thread] = []
def _finalize() -> None:
try:
if self._shutting_down:
return
if not self._drain_writers(sid, timeout=_DEFERRED_COMMIT_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — "
"leaving session uncommitted",
sid,
)
return
if self._shutting_down:
return
if self._session_needs_commit(sid, turn_count):
self._commit_session(sid, turn_count, context=context)
finally:
with self._deferred_commit_lock:
self._deferred_commit_sids.discard(sid)
if holder:
self._deferred_commit_threads.discard(holder[0])
thread = threading.Thread(
target=_finalize,
daemon=True,
name=f"openviking-finalize-{sid}",
)
holder.append(thread)
with self._deferred_commit_lock:
self._deferred_commit_threads.add(thread)
thread.start()
def _invalidate_prefetch_state(self) -> None:
# Bump the generation under the same lock used by prefetch workers so
# late results from an older session are discarded deterministically.
with self._prefetch_lock:
self._prefetch_generation += 1
self._prefetch_result = ""
# Join EVERY tracked prefetch thread, not just the latest slot — a
# rapid re-queue can leave an older thread for the abandoned session
# still running (consistent with shutdown()).
workers = [t for t in self._prefetch_threads if t.is_alive()]
for t in workers:
t.join(timeout=3.0)
with self._prefetch_lock:
self._prefetch_result = ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Record the conversation turn in OpenViking's session (non-blocking)."""
if not self._client:
return
user_content = _derive_openviking_user_text(user_content)
if not user_content:
return
# Snapshot the sid and bump the turn counter atomically so a
# concurrent on_session_switch/on_session_end can't interleave its
# snapshot+reset between the read and the increment (lost turn) and so
# the turn is unambiguously attributed to the session it targets.
with self._session_state_lock:
sid = str(session_id or self._session_id).strip()
if not sid:
return
self._turn_count += 1
def _sync():
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
except Exception as e:
logger.debug("OpenViking sync_turn failed, reconnecting: %s", e)
try:
client = self._new_client()
self._post_session_turn(
client,
sid,
user_content[:4000],
assistant_content[:4000],
)
except Exception as retry_error:
logger.warning("OpenViking sync_turn failed: %s", retry_error)
self._spawn_writer(sid, _sync, name="openviking-sync")
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
"""Commit the session to trigger memory extraction.
OpenViking automatically extracts 6 categories of memories:
profile, preferences, entities, events, cases, and patterns.
"""
if not self._client:
return
# Snapshot sid + turn count atomically against a concurrent sync_turn
# increment. on_session_end runs at teardown so the drain+commit stays
# synchronous here (we want it to land before the process exits), but
# the counter read must still be consistent.
with self._session_state_lock:
sid = self._session_id
turn_count = self._turn_count
# Commit only after session writes drain.
if not self._drain_writers(sid, timeout=_SESSION_DRAIN_TIMEOUT):
logger.warning(
"OpenViking writer for %s still alive after drain — skipping commit",
sid,
)
return
if not self._session_needs_commit(sid, turn_count):
return
if self._commit_session(sid, turn_count, context="on session end"):
# Mark clean so a follow-up on_session_switch skips its own commit.
with self._session_state_lock:
if self._session_id == sid:
self._turn_count = 0
def on_session_switch(
self,
new_session_id: str,
*,
parent_session_id: str = "",
reset: bool = False,
**kwargs,
) -> None:
"""Commit the old session and rotate cached state to the new session_id.
Fires on /resume, /branch, /reset, /new, and context compression.
Without this hook, ``_session_id`` stays stuck at the value
``initialize()`` cached, so subsequent ``sync_turn()`` writes land in
the already-closed old session and ``on_session_end()`` tries to
commit it a second time. The new session never accumulates messages,
and memory extraction never fires for it. See hermes-agent#28296.
Flushes any in-flight sync under the old session_id, commits the old
session if it has pending turns (same extraction semantics as
``on_session_end``), drains and clears any stale prefetch result,
then rotates ``_session_id`` and resets ``_turn_count``.
"""
new_id = str(new_session_id or "").strip()
if not new_id or not self._client:
return
rewound = bool(kwargs.get("rewound"))
# Rotate cached session state synchronously (cheap, in-memory) and
# snapshot the old session under the lock so a concurrent sync_turn
# either lands fully before the rotation (counted under old) or fully
# after (counted under new) — never split. The OLD session's commit
# (drain + pending-token GET + commit POST, potentially many seconds)
# is then offloaded so /new, /branch, /resume, /undo never block the
# caller's command thread (cf. the end-of-turn-sync offload in #41945).
with self._session_state_lock:
old_session_id = self._session_id
old_turn_count = self._turn_count
rotate = not (rewound or new_id == old_session_id)
if rotate:
self._session_id = new_id
self._turn_count = 0
# Invalidate stale prefetch OUTSIDE the session lock — it takes its own
# _prefetch_lock and may join a prefetch thread for up to 3s, which we
# must not do while holding the session lock (would block sync_turn and
# risk lock-ordering coupling).
self._invalidate_prefetch_state()
if not rotate:
# Same-session rewind (/undo) or no-op rotation: no commit, no
# counter reset — just the prefetch invalidation above.
logger.debug(
"OpenViking on_session_switch invalidated state without rotation: "
"session=%s rewound=%s",
old_session_id, rewound,
)
return
# Drain + commit the OLD session off the command thread.
if old_session_id:
self._finalize_session_async(old_session_id, old_turn_count, context="on switch")
logger.debug(
"OpenViking on_session_switch: old=%s new=%s parent=%s reset=%s",
old_session_id, new_id, parent_session_id, reset,
)
def _build_memory_uri(self, subdir: str) -> str:
"""Build a viking:// memory URI under the configured user/agent/subdir."""
slug = uuid.uuid4().hex[:12]
return f"viking://user/{self._user}/agent/{self._agent}/memories/{subdir}/mem_{slug}.md"
def on_memory_write(
self,
action: str,
target: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Mirror built-in memory writes to OpenViking via content/write."""
if not self._client or action != "add" or not content:
return
subdir = _MEMORY_WRITE_TARGET_SUBDIR_MAP.get(target, _DEFAULT_MEMORY_SUBDIR)
uri = self._build_memory_uri(subdir)
def _write():
try:
client = _VikingClient(
self._endpoint, self._api_key,
account=self._account, user=self._user, agent=self._agent,
)
client.post("/api/v1/content/write", {
"uri": uri,
"content": content,
"mode": "create",
})
except Exception as e:
logger.debug("OpenViking memory mirror failed: %s", e)
t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite")
t.start()
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [SEARCH_SCHEMA, READ_SCHEMA, BROWSE_SCHEMA, REMEMBER_SCHEMA, ADD_RESOURCE_SCHEMA]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if not self._client:
return tool_error("OpenViking server not connected")
try:
if tool_name == "viking_search":
return self._tool_search(args)
elif tool_name == "viking_read":
return self._tool_read(args)
elif tool_name == "viking_browse":
return self._tool_browse(args)
elif tool_name == "viking_remember":
return self._tool_remember(args)
elif tool_name == "viking_add_resource":
return self._tool_add_resource(args)
return tool_error(f"Unknown tool: {tool_name}")
except Exception as e:
return tool_error(str(e))
def shutdown(self) -> None:
# Stop deferred finalizers from issuing new commits against a
# torn-down client, then drain everything still in flight.
self._shutting_down = True
# Wait for every in-flight writer across all tracked sessions.
with self._inflight_lock:
all_workers = [
t for workers in self._inflight_writers.values() for t in workers
]
with self._deferred_commit_lock:
deferred_workers = list(self._deferred_commit_threads)
with self._prefetch_lock:
prefetch_workers = list(self._prefetch_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in deferred_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in prefetch_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self:
_last_active_provider = None
# -- Tool implementations ------------------------------------------------
@staticmethod
def _unwrap_result(resp: Any) -> Any:
"""Return OpenViking payload body regardless of wrapped/unwrapped shape."""
if isinstance(resp, dict) and "result" in resp:
return resp.get("result")
return resp
@staticmethod
def _normalize_summary_uri(uri: str) -> str:
"""Map pseudo summary files to their parent directory URI for L0/L1 reads."""
if not uri:
return uri
for suffix in ("/.abstract.md", "/.overview.md", "/.read.md", "/.full.md"):
if uri.endswith(suffix):
return uri[: -len(suffix)] or "viking://"
return uri
def _is_directory_uri(self, uri: str) -> bool | None:
"""Probe fs/stat to decide if a URI is a directory.
Returns True/False when the server answers cleanly, and None when the
probe itself fails (network error, unexpected shape). Callers should
treat None as "unknown" and fall back to the exception-based path.
"""
try:
resp = self._client.get("/api/v1/fs/stat", params={"uri": uri})
except Exception:
return None
result = self._unwrap_result(resp)
if isinstance(result, dict):
if "isDir" in result:
return bool(result.get("isDir"))
if "is_dir" in result:
return bool(result.get("is_dir"))
if result.get("type") == "dir":
return True
if result.get("type") == "file":
return False
return None
def _tool_search(self, args: dict) -> str:
query = args.get("query", "")
if not query:
return tool_error("query is required")
payload: Dict[str, Any] = {"query": query}
mode = args.get("mode", "auto")
if mode != "auto":
payload["mode"] = mode
if args.get("scope"):
payload["target_uri"] = args["scope"]
if args.get("limit"):
payload["limit"] = args["limit"]
resp = self._client.post("/api/v1/search/find", payload)
result = resp.get("result", {})
# Format results for the model — keep it concise
scored_entries = []
for ctx_type in ("memories", "resources", "skills"):
items = result.get(ctx_type, [])
for item in items:
raw_score = item.get("score")
sort_score = raw_score if raw_score is not None else 0.0
entry = {
"uri": item.get("uri", ""),
"type": ctx_type.rstrip("s"),
"score": round(raw_score, 3) if raw_score is not None else 0.0,
"abstract": item.get("abstract", ""),
}
if item.get("relations"):
entry["related"] = [r.get("uri") for r in item["relations"][:3]]
scored_entries.append((sort_score, entry))
scored_entries.sort(key=lambda x: x[0], reverse=True)
formatted = [entry for _, entry in scored_entries]
return json.dumps({
"results": formatted,
"total": result.get("total", len(formatted)),
}, ensure_ascii=False)
def _tool_read(self, args: dict) -> str:
uri = args.get("uri", "")
if not uri:
return tool_error("uri is required")
level = args.get("level", "overview")
summary_level = level in {"abstract", "overview"}
# OpenViking expects directory URIs for pseudo summary files
# (e.g. viking://user/hermes/.overview.md).
resolved_uri = self._normalize_summary_uri(uri) if summary_level else uri
used_fallback = False
# abstract/overview endpoints are directory-only on OpenViking
# (v0.3.x returns 500/412 for file URIs). When the caller asks for a
# summary level on a non-pseudo URI, probe fs/stat first and route
# file URIs straight to /content/read instead of eating a failing
# round-trip. The pseudo-URI path already points at a directory, so
# skip the probe there.
if summary_level and resolved_uri == uri:
is_dir = self._is_directory_uri(uri)
if is_dir is False:
resolved_uri = uri
used_fallback = True
# Map our level names to OpenViking GET endpoints.
endpoint = "/api/v1/content/read"
if not used_fallback:
if level == "abstract":
endpoint = "/api/v1/content/abstract"
elif level == "overview":
endpoint = "/api/v1/content/overview"
try:
resp = self._client.get(endpoint, params={"uri": resolved_uri})
except Exception:
# OpenViking may return HTTP 500 for abstract/overview reads on normal
# file URIs (mem_*.md). For those, gracefully fallback to full read.
if not summary_level or resolved_uri != uri or used_fallback:
raise
resp = self._client.get("/api/v1/content/read", params={"uri": uri})
used_fallback = True
result = self._unwrap_result(resp)
# Content endpoints may return either plain strings or objects.
if isinstance(result, str):
content = result
elif isinstance(result, dict):
content = result.get("content", "") or result.get("text", "")
else:
content = ""
# Truncate long content to avoid flooding context.
max_len = 8000
if level == "overview":
max_len = 4000
elif level == "abstract":
max_len = 1200
if len(content) > max_len:
content = content[:max_len] + "\n\n[... truncated, use a more specific URI or full level]"
payload = {
"uri": uri,
"resolved_uri": resolved_uri,
"level": level,
"content": content,
}
if used_fallback:
payload["fallback"] = "content/read"
return json.dumps(payload, ensure_ascii=False)
def _tool_browse(self, args: dict) -> str:
action = args.get("action", "list")
path = args.get("path", "viking://")
# Map action to the correct fs endpoint (all GET with uri= param)
endpoint_map = {"tree": "/api/v1/fs/tree", "list": "/api/v1/fs/ls", "stat": "/api/v1/fs/stat"}
endpoint = endpoint_map.get(action, "/api/v1/fs/ls")
resp = self._client.get(endpoint, params={"uri": path})
result = self._unwrap_result(resp)
# Format list/tree results for readability
if action in {"list", "tree"}:
raw_entries = result
if isinstance(result, dict):
raw_entries = result.get("entries") or result.get("items") or result.get("children") or []
if isinstance(raw_entries, list):
entries = []
for e in raw_entries[:50]: # cap at 50 entries
uri = e.get("uri", "")
name = e.get("rel_path") or e.get("name") or (uri.rsplit("/", 1)[-1] if uri else "")
is_dir = bool(e.get("isDir") or e.get("is_dir") or e.get("type") == "dir")
entries.append({
"name": name,
"uri": uri,
"type": "dir" if is_dir else "file",
"abstract": e.get("abstract", ""),
})
return json.dumps({"path": path, "entries": entries}, ensure_ascii=False)
return json.dumps(result, ensure_ascii=False)
def _tool_remember(self, args: dict) -> str:
content = args.get("content", "")
if not content:
return tool_error("content is required")
category = args.get("category", "")
subdir = _CATEGORY_SUBDIR_MAP.get(category, _DEFAULT_MEMORY_SUBDIR)
uri = self._build_memory_uri(subdir)
# Write directly via content/write API.
# This creates the file, stores the content, and queues vector indexing
# in a single call — no dependency on session commit / VLM extraction.
try:
result = self._client.post("/api/v1/content/write", {
"uri": uri,
"content": content,
"mode": "create",
})
written = result.get("result", {}).get("written_bytes", 0)
return json.dumps({
"status": "stored",
"message": f"Memory stored ({written}b) and queued for vector indexing.",
})
except Exception as e:
logger.error("OpenViking content/write failed: %s", e)
return tool_error(f"Failed to store memory: {e}")
def _tool_add_resource(self, args: dict) -> str:
url = args.get("url", "")
if not url:
return tool_error("url is required")
if args.get("to") and args.get("parent"):
return tool_error("Cannot specify both 'to' and 'parent'")
payload: Dict[str, Any] = {}
for key in ("reason", "to", "parent", "instruction", "wait", "timeout"):
if key in args and args[key] not in {None, ""}:
payload[key] = args[key]
parsed_url = urlparse(url)
if _is_remote_resource_source(url):
source_path = None
elif parsed_url.scheme == "file":
source_path = _path_from_file_uri(url)
if isinstance(source_path, str):
return tool_error(source_path)
elif parsed_url.scheme and not _is_windows_absolute_path(url):
source_path = None
else:
source_path = Path(url).expanduser()
cleanup_path: Optional[Path] = None
try:
if source_path is not None:
if source_path.exists():
if source_path.is_dir():
payload["source_name"] = source_path.name
cleanup_path = _zip_directory(source_path)
upload_path = cleanup_path
elif source_path.is_file():
payload["source_name"] = source_path.name
upload_path = source_path
else:
return tool_error(f"Unsupported local resource path: {url}")
payload["temp_file_id"] = self._client.upload_temp_file(upload_path)
elif _is_local_path_reference(url):
return tool_error(f"Local resource path does not exist: {url}")
else:
payload["path"] = url
else:
payload["path"] = url
resp = self._client.post("/api/v1/resources", payload)
result = resp.get("result", {})
finally:
if cleanup_path:
cleanup_path.unlink(missing_ok=True)
return json.dumps({
"status": "added",
"root_uri": result.get("root_uri", ""),
"message": "Resource queued for processing. Use viking_search after a moment to find it.",
}, ensure_ascii=False)
# ---------------------------------------------------------------------------
# Plugin entry point
# ---------------------------------------------------------------------------
def register(ctx) -> None:
"""Register OpenViking as a memory provider plugin."""
ctx.register_memory_provider(OpenVikingMemoryProvider())