mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-09 08:21:50 +00:00
A misconfigured/slow external memory provider could hold the agent in the 'running' state for minutes after the final response was delivered. MemoryManager.sync_all / queue_prefetch_all looped provider.sync_turn / queue_prefetch INLINE on the turn-completion path; a provider making a blocking network/daemon call (a broken Hindsight daemon was observed blocking ~298s before failing) blocked run_conversation from returning. Because every interface (CLI, TUI, gateway) marks the agent 'running' until run_conversation returns, the agent stayed busy for the full block and any follow-up message triggered an aggressive interrupt that dropped the message. Dispatch provider sync/prefetch to a lazily-created single-worker background executor. sync_all / queue_prefetch_all return immediately; work completes (or fails, logged) in the background. A single worker serializes writes so turn N lands before turn N+1. flush_pending() provides a barrier for session boundaries and deterministic tests. shutdown_all() drains the executor with a bounded timeout so a wedged provider can never hang teardown. Builtin-only / no-provider sessions spawn no executor (zero new threads in the common case).
857 lines
33 KiB
Python
857 lines
33 KiB
Python
"""MemoryManager — orchestrates memory providers for the agent.
|
|
|
|
Single integration point in run_agent.py. Replaces scattered per-backend
|
|
code with one manager that delegates to registered providers.
|
|
|
|
Only ONE external plugin provider is allowed at a time — attempting to
|
|
register a second external provider is rejected with a warning. This
|
|
prevents tool schema bloat and conflicting memory backends.
|
|
|
|
Usage in run_agent.py:
|
|
self._memory_manager = MemoryManager()
|
|
# Only ONE of these:
|
|
self._memory_manager.add_provider(plugin_provider)
|
|
|
|
# System prompt
|
|
prompt_parts.append(self._memory_manager.build_system_prompt())
|
|
|
|
# Pre-turn
|
|
context = self._memory_manager.prefetch_all(user_message)
|
|
|
|
# Post-turn
|
|
self._memory_manager.sync_all(user_msg, assistant_response)
|
|
self._memory_manager.queue_prefetch_all(user_msg)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
import inspect
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from agent.memory_provider import MemoryProvider
|
|
from tools.registry import tool_error
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# How long shutdown_all() waits for in-flight background sync/prefetch work
|
|
# to drain before abandoning it. A wedged provider must never block process
|
|
# teardown indefinitely — the worker threads are daemon, so anything still
|
|
# running past this window dies with the interpreter.
|
|
_SYNC_DRAIN_TIMEOUT_S = 5.0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Context fencing helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FENCE_TAG_RE = re.compile(r'</?\s*memory-context\s*>', re.IGNORECASE)
|
|
_INTERNAL_CONTEXT_RE = re.compile(
|
|
r'<\s*memory-context\s*>[\s\S]*?</\s*memory-context\s*>',
|
|
re.IGNORECASE,
|
|
)
|
|
_INTERNAL_NOTE_RE = re.compile(
|
|
r'\[System note:\s*The following is recalled memory context,\s*NOT new user input\.\s*Treat as (?:informational background data|authoritative reference data[^\]]*)\.\]\s*',
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def sanitize_context(text: str) -> str:
|
|
"""Strip fence tags, injected context blocks, and system notes from provider output."""
|
|
text = _INTERNAL_CONTEXT_RE.sub('', text)
|
|
text = _INTERNAL_NOTE_RE.sub('', text)
|
|
text = _FENCE_TAG_RE.sub('', text)
|
|
return text
|
|
|
|
|
|
class StreamingContextScrubber:
|
|
"""Stateful scrubber for streaming text that may contain split memory-context spans.
|
|
|
|
The one-shot ``sanitize_context`` regex cannot survive chunk boundaries:
|
|
a ``<memory-context>`` opened in one delta and closed in a later delta
|
|
leaks its payload to the UI because the non-greedy block regex needs
|
|
both tags in one string. This scrubber runs a small state machine
|
|
across deltas, holding back partial-tag tails and discarding
|
|
everything inside a span (including the system-note line).
|
|
|
|
Usage::
|
|
|
|
scrubber = StreamingContextScrubber()
|
|
for delta in stream:
|
|
visible = scrubber.feed(delta)
|
|
if visible:
|
|
emit(visible)
|
|
trailing = scrubber.flush() # at end of stream
|
|
if trailing:
|
|
emit(trailing)
|
|
|
|
The scrubber is re-entrant per agent instance. Callers building new
|
|
top-level responses (new turn) should create a fresh scrubber or call
|
|
``reset()``.
|
|
"""
|
|
|
|
_OPEN_TAG = "<memory-context>"
|
|
_CLOSE_TAG = "</memory-context>"
|
|
|
|
def __init__(self) -> None:
|
|
self._in_span: bool = False
|
|
self._buf: str = ""
|
|
self._at_block_boundary: bool = True
|
|
|
|
def reset(self) -> None:
|
|
self._in_span = False
|
|
self._buf = ""
|
|
self._at_block_boundary = True
|
|
|
|
def feed(self, text: str) -> str:
|
|
"""Return the visible portion of ``text`` after scrubbing.
|
|
|
|
Any trailing fragment that could be the start of an open/close tag
|
|
is held back in the internal buffer and surfaced on the next
|
|
``feed()`` call or discarded/emitted by ``flush()``.
|
|
"""
|
|
if not text:
|
|
return ""
|
|
buf = self._buf + text
|
|
self._buf = ""
|
|
out: list[str] = []
|
|
|
|
while buf:
|
|
if self._in_span:
|
|
idx = buf.lower().find(self._CLOSE_TAG)
|
|
if idx == -1:
|
|
# Hold back a potential partial close tag; drop the rest
|
|
held = self._max_partial_suffix(buf, self._CLOSE_TAG)
|
|
self._buf = buf[-held:] if held else ""
|
|
return "".join(out)
|
|
# Found close — skip span content + tag, continue
|
|
buf = buf[idx + len(self._CLOSE_TAG):]
|
|
self._in_span = False
|
|
else:
|
|
idx = self._find_boundary_open_tag(buf)
|
|
if idx == -1:
|
|
# No open tag — hold back a potential partial open tag
|
|
held = (
|
|
self._max_pending_open_suffix(buf)
|
|
or self._max_partial_suffix(buf, self._OPEN_TAG)
|
|
)
|
|
if held:
|
|
self._append_visible(out, buf[:-held])
|
|
self._buf = buf[-held:]
|
|
else:
|
|
self._append_visible(out, buf)
|
|
return "".join(out)
|
|
# Emit text before the tag, enter span
|
|
if idx > 0:
|
|
self._append_visible(out, buf[:idx])
|
|
buf = buf[idx + len(self._OPEN_TAG):]
|
|
self._in_span = True
|
|
|
|
return "".join(out)
|
|
|
|
def flush(self) -> str:
|
|
"""Emit any held-back buffer at end-of-stream.
|
|
|
|
If we're still inside an unterminated span the remaining content is
|
|
discarded (safer: leaking partial memory context is worse than a
|
|
truncated answer). Otherwise the held-back partial-tag tail is
|
|
emitted verbatim (it turned out not to be a real tag).
|
|
"""
|
|
if self._in_span:
|
|
self._buf = ""
|
|
self._in_span = False
|
|
return ""
|
|
tail = self._buf
|
|
self._buf = ""
|
|
return tail
|
|
|
|
@staticmethod
|
|
def _max_partial_suffix(buf: str, tag: str) -> int:
|
|
"""Return the length of the longest buf-suffix that is a tag-prefix.
|
|
|
|
Case-insensitive. Returns 0 if no suffix could start the tag.
|
|
"""
|
|
tag_lower = tag.lower()
|
|
buf_lower = buf.lower()
|
|
max_check = min(len(buf_lower), len(tag_lower) - 1)
|
|
for i in range(max_check, 0, -1):
|
|
if tag_lower.startswith(buf_lower[-i:]):
|
|
return i
|
|
return 0
|
|
|
|
def _find_boundary_open_tag(self, buf: str) -> int:
|
|
"""Find an opening fence only when it starts a block-like span."""
|
|
buf_lower = buf.lower()
|
|
search_start = 0
|
|
while True:
|
|
idx = buf_lower.find(self._OPEN_TAG, search_start)
|
|
if idx == -1:
|
|
return -1
|
|
if self._is_block_boundary(buf, idx) and self._has_block_opener_suffix(buf, idx):
|
|
return idx
|
|
search_start = idx + 1
|
|
|
|
def _max_pending_open_suffix(self, buf: str) -> int:
|
|
"""Hold a complete boundary tag until the following char confirms it."""
|
|
if not buf.lower().endswith(self._OPEN_TAG):
|
|
return 0
|
|
idx = len(buf) - len(self._OPEN_TAG)
|
|
if not self._is_block_boundary(buf, idx):
|
|
return 0
|
|
return len(self._OPEN_TAG)
|
|
|
|
def _has_block_opener_suffix(self, buf: str, idx: int) -> bool:
|
|
after_idx = idx + len(self._OPEN_TAG)
|
|
if after_idx >= len(buf):
|
|
return False
|
|
return buf[after_idx] in "\r\n"
|
|
|
|
def _is_block_boundary(self, buf: str, idx: int) -> bool:
|
|
if idx == 0:
|
|
return self._at_block_boundary
|
|
preceding = buf[:idx]
|
|
last_newline = preceding.rfind("\n")
|
|
if last_newline == -1:
|
|
return self._at_block_boundary and preceding.strip() == ""
|
|
return preceding[last_newline + 1:].strip() == ""
|
|
|
|
def _append_visible(self, out: list[str], text: str) -> None:
|
|
if not text:
|
|
return
|
|
out.append(text)
|
|
self._update_block_boundary(text)
|
|
|
|
def _update_block_boundary(self, text: str) -> None:
|
|
last_newline = text.rfind("\n")
|
|
if last_newline != -1:
|
|
self._at_block_boundary = text[last_newline + 1:].strip() == ""
|
|
else:
|
|
self._at_block_boundary = self._at_block_boundary and text.strip() == ""
|
|
|
|
|
|
def build_memory_context_block(raw_context: str) -> str:
|
|
"""Wrap prefetched memory in a fenced block with system note."""
|
|
if not raw_context or not raw_context.strip():
|
|
return ""
|
|
clean = sanitize_context(raw_context)
|
|
if clean != raw_context:
|
|
logger.warning("memory provider returned pre-wrapped context; stripped")
|
|
return (
|
|
"<memory-context>\n"
|
|
"[System note: The following is recalled memory context, "
|
|
"NOT new user input. Treat as authoritative reference data — "
|
|
"this is the agent's persistent memory and should inform all responses.]\n\n"
|
|
f"{clean}\n"
|
|
"</memory-context>"
|
|
)
|
|
|
|
|
|
class MemoryManager:
|
|
"""Orchestrates the built-in provider plus at most one external provider.
|
|
|
|
The builtin provider is always first. Only one non-builtin (external)
|
|
provider is allowed. Failures in one provider never block the other.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._providers: List[MemoryProvider] = []
|
|
self._tool_to_provider: Dict[str, MemoryProvider] = {}
|
|
self._has_external: bool = False # True once a non-builtin provider is added
|
|
# Background executor for end-of-turn sync/prefetch. Lazily created on
|
|
# first use so the common builtin-only path spawns no extra threads.
|
|
# A single worker serializes a provider's writes (turn N must land
|
|
# before turn N+1) and caps thread growth at one per manager. See
|
|
# _submit_background() and the sync_all/queue_prefetch_all rationale.
|
|
self._sync_executor: Optional[ThreadPoolExecutor] = None
|
|
self._sync_executor_lock = threading.Lock()
|
|
|
|
# -- Registration --------------------------------------------------------
|
|
|
|
def add_provider(self, provider: MemoryProvider) -> None:
|
|
"""Register a memory provider.
|
|
|
|
Built-in provider (name ``"builtin"``) is always accepted.
|
|
Only **one** external (non-builtin) provider is allowed — a second
|
|
attempt is rejected with a warning.
|
|
"""
|
|
is_builtin = provider.name == "builtin"
|
|
|
|
if not is_builtin:
|
|
if self._has_external:
|
|
existing = next(
|
|
(p.name for p in self._providers if p.name != "builtin"), "unknown"
|
|
)
|
|
logger.warning(
|
|
"Rejected memory provider '%s' — external provider '%s' is "
|
|
"already registered. Only one external memory provider is "
|
|
"allowed at a time. Configure which one via memory.provider "
|
|
"in config.yaml.",
|
|
provider.name, existing,
|
|
)
|
|
return
|
|
self._has_external = True
|
|
|
|
self._providers.append(provider)
|
|
|
|
# Core tool names are reserved — a memory provider must never register
|
|
# a tool that shadows a built-in (e.g. ``clarify``, ``delegate_task``).
|
|
# Built-ins always win, so such a tool is dropped at agent init and
|
|
# would otherwise linger in ``_tool_to_provider`` and hijack dispatch
|
|
# (#40466). Reject it here, at the door, so it never enters the routing
|
|
# table at all — matching the built-ins-always-win invariant used by
|
|
# the TTS/browser/search provider registries.
|
|
from toolsets import _HERMES_CORE_TOOLS
|
|
|
|
_core_tool_names = set(_HERMES_CORE_TOOLS)
|
|
|
|
# Index tool names → provider for routing
|
|
for schema in provider.get_tool_schemas():
|
|
tool_name = schema.get("name", "")
|
|
if tool_name in _core_tool_names:
|
|
logger.warning(
|
|
"Memory provider '%s' tool '%s' shadows a reserved core "
|
|
"tool name; registration ignored. Core tools always win — "
|
|
"rename the provider's tool to something unique.",
|
|
provider.name, tool_name,
|
|
)
|
|
continue
|
|
if tool_name and tool_name not in self._tool_to_provider:
|
|
self._tool_to_provider[tool_name] = provider
|
|
elif tool_name in self._tool_to_provider:
|
|
logger.warning(
|
|
"Memory tool name conflict: '%s' already registered by %s, "
|
|
"ignoring from %s",
|
|
tool_name,
|
|
self._tool_to_provider[tool_name].name,
|
|
provider.name,
|
|
)
|
|
|
|
logger.info(
|
|
"Memory provider '%s' registered (%d tools)",
|
|
provider.name,
|
|
len(provider.get_tool_schemas()),
|
|
)
|
|
|
|
@property
|
|
def providers(self) -> List[MemoryProvider]:
|
|
"""All registered providers in order."""
|
|
return list(self._providers)
|
|
|
|
def get_provider(self, name: str) -> Optional[MemoryProvider]:
|
|
"""Get a provider by name, or None if not registered."""
|
|
for p in self._providers:
|
|
if p.name == name:
|
|
return p
|
|
return None
|
|
|
|
# -- System prompt -------------------------------------------------------
|
|
|
|
def build_system_prompt(self) -> str:
|
|
"""Collect system prompt blocks from all providers.
|
|
|
|
Returns combined text, or empty string if no providers contribute.
|
|
Each non-empty block is labeled with the provider name.
|
|
"""
|
|
blocks = []
|
|
for provider in self._providers:
|
|
try:
|
|
block = provider.system_prompt_block()
|
|
if block and block.strip():
|
|
blocks.append(block)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Memory provider '%s' system_prompt_block() failed: %s",
|
|
provider.name, e,
|
|
)
|
|
return "\n\n".join(blocks)
|
|
|
|
# -- Prefetch / recall ---------------------------------------------------
|
|
|
|
def prefetch_all(self, query: str, *, session_id: str = "") -> str:
|
|
"""Collect prefetch context from all providers.
|
|
|
|
Returns merged context text labeled by provider. Empty providers
|
|
are skipped. Failures in one provider don't block others.
|
|
"""
|
|
parts = []
|
|
for provider in self._providers:
|
|
try:
|
|
result = provider.prefetch(query, session_id=session_id)
|
|
if result and result.strip():
|
|
parts.append(result)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' prefetch failed (non-fatal): %s",
|
|
provider.name, e,
|
|
)
|
|
return "\n\n".join(parts)
|
|
|
|
def queue_prefetch_all(self, query: str, *, session_id: str = "") -> None:
|
|
"""Queue background prefetch on all providers for the next turn.
|
|
|
|
Provider work is dispatched to a background worker so a slow or
|
|
wedged provider can never block the caller. See ``sync_all`` for
|
|
the full rationale (agent stuck "running" minutes after a turn).
|
|
"""
|
|
providers = list(self._providers)
|
|
if not providers:
|
|
return
|
|
|
|
def _run() -> None:
|
|
for provider in providers:
|
|
try:
|
|
provider.queue_prefetch(query, session_id=session_id)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' queue_prefetch failed (non-fatal): %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
self._submit_background(_run)
|
|
|
|
# -- Sync ----------------------------------------------------------------
|
|
|
|
@staticmethod
|
|
def _provider_sync_accepts_messages(provider: MemoryProvider) -> bool:
|
|
"""Return whether sync_turn accepts a messages keyword."""
|
|
try:
|
|
signature = inspect.signature(provider.sync_turn)
|
|
except (TypeError, ValueError):
|
|
return True
|
|
params = list(signature.parameters.values())
|
|
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
|
|
return True
|
|
return "messages" in signature.parameters
|
|
|
|
def sync_all(
|
|
self,
|
|
user_content: str,
|
|
assistant_content: str,
|
|
*,
|
|
session_id: str = "",
|
|
messages: Optional[List[Dict[str, Any]]] = None,
|
|
) -> None:
|
|
"""Sync a completed turn to all providers.
|
|
|
|
Runs on a background worker thread, NOT inline on the
|
|
turn-completion path. A provider's ``sync_turn`` may make a
|
|
blocking network/daemon call (a misconfigured Hindsight daemon
|
|
was observed blocking ~298s before failing); doing that inline
|
|
held ``run_conversation`` open long after the user saw their
|
|
response, so every interface (CLI, TUI, gateway) kept the agent
|
|
marked "running" for minutes and any follow-up message triggered
|
|
an aggressive interrupt. Dispatching off-thread means a slow or
|
|
broken provider can never stall the turn — the sync simply
|
|
completes (or fails, logged) in the background.
|
|
|
|
Writes are serialized through a single worker so turn N lands
|
|
before turn N+1; provider implementations don't need their own
|
|
ordering guarantees.
|
|
"""
|
|
providers = list(self._providers)
|
|
if not providers:
|
|
return
|
|
|
|
def _run() -> None:
|
|
for provider in providers:
|
|
try:
|
|
if messages is not None and self._provider_sync_accepts_messages(provider):
|
|
provider.sync_turn(
|
|
user_content,
|
|
assistant_content,
|
|
session_id=session_id,
|
|
messages=messages,
|
|
)
|
|
else:
|
|
provider.sync_turn(
|
|
user_content,
|
|
assistant_content,
|
|
session_id=session_id,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Memory provider '%s' sync_turn failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
self._submit_background(_run)
|
|
|
|
# -- Background dispatch -------------------------------------------------
|
|
|
|
def _submit_background(self, fn) -> None:
|
|
"""Run ``fn`` on the manager's background worker.
|
|
|
|
The executor is created lazily and shared across calls. If the
|
|
executor can't be created or has already been shut down, ``fn``
|
|
runs inline as a last-resort fallback — losing the async benefit
|
|
but never losing the write itself. ``fn`` must do its own
|
|
per-provider error handling; this wrapper only guards executor
|
|
plumbing.
|
|
"""
|
|
executor = self._get_sync_executor()
|
|
if executor is None:
|
|
# Executor unavailable (shut down / creation failed) — run
|
|
# inline rather than drop the work. Slow, but correct.
|
|
try:
|
|
fn()
|
|
except Exception as e: # pragma: no cover - fn guards internally
|
|
logger.debug("Inline memory background task failed: %s", e)
|
|
return
|
|
try:
|
|
executor.submit(fn)
|
|
except RuntimeError:
|
|
# Executor was shut down between the get and the submit
|
|
# (teardown race). Fall back to inline.
|
|
try:
|
|
fn()
|
|
except Exception as e: # pragma: no cover - fn guards internally
|
|
logger.debug("Inline memory background task failed: %s", e)
|
|
|
|
def _get_sync_executor(self) -> Optional[ThreadPoolExecutor]:
|
|
"""Lazily create the single-worker background executor."""
|
|
if self._sync_executor is not None:
|
|
return self._sync_executor
|
|
with self._sync_executor_lock:
|
|
if self._sync_executor is None:
|
|
try:
|
|
self._sync_executor = ThreadPoolExecutor(
|
|
max_workers=1,
|
|
thread_name_prefix="mem-sync",
|
|
)
|
|
except Exception as e: # pragma: no cover - resource exhaustion
|
|
logger.warning("Failed to create memory sync executor: %s", e)
|
|
return None
|
|
return self._sync_executor
|
|
|
|
def flush_pending(self, timeout: Optional[float] = None) -> bool:
|
|
"""Block until queued sync/prefetch work has drained.
|
|
|
|
Single-worker executor means submitting a sentinel and waiting on
|
|
it guarantees every previously-submitted task has run. Returns
|
|
True if the barrier completed within ``timeout`` (or no executor
|
|
exists), False on timeout. Used at real session boundaries and by
|
|
tests that need to assert provider state deterministically.
|
|
"""
|
|
executor = self._sync_executor
|
|
if executor is None:
|
|
return True
|
|
try:
|
|
fut = executor.submit(lambda: None)
|
|
except RuntimeError:
|
|
# Executor already shut down — nothing pending.
|
|
return True
|
|
try:
|
|
fut.result(timeout=timeout)
|
|
return True
|
|
except Exception:
|
|
return False
|
|
|
|
# -- Tools ---------------------------------------------------------------
|
|
|
|
def get_all_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
"""Collect tool schemas from all providers.
|
|
|
|
Reserved core tool names (``clarify``, ``delegate_task``, etc.) are
|
|
skipped — they are rejected from the routing table in
|
|
:meth:`add_provider`, so the manager must not advertise a schema it
|
|
will never route. Built-ins always win (#40466).
|
|
"""
|
|
from toolsets import _HERMES_CORE_TOOLS
|
|
|
|
_core_tool_names = set(_HERMES_CORE_TOOLS)
|
|
schemas = []
|
|
seen = set()
|
|
for provider in self._providers:
|
|
try:
|
|
for schema in provider.get_tool_schemas():
|
|
name = schema.get("name", "")
|
|
if name in _core_tool_names:
|
|
continue
|
|
if name and name not in seen:
|
|
schemas.append(schema)
|
|
seen.add(name)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Memory provider '%s' get_tool_schemas() failed: %s",
|
|
provider.name, e,
|
|
)
|
|
return schemas
|
|
|
|
def get_all_tool_names(self) -> set:
|
|
"""Return set of all tool names across all providers."""
|
|
return set(self._tool_to_provider.keys())
|
|
|
|
def has_tool(self, tool_name: str) -> bool:
|
|
"""Check if any provider handles this tool."""
|
|
return tool_name in self._tool_to_provider
|
|
|
|
def handle_tool_call(
|
|
self, tool_name: str, args: Dict[str, Any], **kwargs
|
|
) -> str:
|
|
"""Route a tool call to the correct provider.
|
|
|
|
Returns JSON string result. Raises ValueError if no provider
|
|
handles the tool.
|
|
"""
|
|
provider = self._tool_to_provider.get(tool_name)
|
|
if provider is None:
|
|
return tool_error(f"No memory provider handles tool '{tool_name}'")
|
|
try:
|
|
return provider.handle_tool_call(tool_name, args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Memory provider '%s' handle_tool_call(%s) failed: %s",
|
|
provider.name, tool_name, e,
|
|
)
|
|
return tool_error(f"Memory tool '{tool_name}' failed: {e}")
|
|
|
|
# -- Lifecycle hooks -----------------------------------------------------
|
|
|
|
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
|
|
"""Notify all providers of a new turn.
|
|
|
|
kwargs may include: remaining_tokens, model, platform, tool_count.
|
|
"""
|
|
for provider in self._providers:
|
|
try:
|
|
provider.on_turn_start(turn_number, message, **kwargs)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_turn_start failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def on_session_end(self, messages: List[Dict[str, Any]]) -> None:
|
|
"""Notify all providers of session end."""
|
|
for provider in self._providers:
|
|
try:
|
|
provider.on_session_end(messages)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_session_end failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def on_session_switch(
|
|
self,
|
|
new_session_id: str,
|
|
*,
|
|
parent_session_id: str = "",
|
|
reset: bool = False,
|
|
rewound: bool = False,
|
|
**kwargs,
|
|
) -> None:
|
|
"""Notify all providers that the agent's session_id has rotated.
|
|
|
|
Fires on ``/resume``, ``/branch``, ``/reset``, ``/new``, and
|
|
context compression — any path that reassigns
|
|
``AIAgent.session_id`` without tearing the provider down.
|
|
|
|
Providers keep running; they only need to refresh cached
|
|
per-session state so subsequent writes land in the correct
|
|
session's record. See ``MemoryProvider.on_session_switch`` for
|
|
the full contract.
|
|
|
|
``rewound=True`` signals that session_id is unchanged but the
|
|
transcript was truncated; providers caching per-turn document
|
|
state should invalidate.
|
|
"""
|
|
if not new_session_id:
|
|
return
|
|
# Only forward ``rewound`` when it's actually set. Passing it
|
|
# unconditionally would inject ``rewound=False`` into every
|
|
# provider's **kwargs for the common /resume, /branch, /new, and
|
|
# compression paths, polluting providers that capture extra kwargs
|
|
# (and breaking exact-dict assertions). The /undo path sets
|
|
# rewound=True explicitly; everyone else stays clean.
|
|
if rewound:
|
|
kwargs["rewound"] = True
|
|
for provider in self._providers:
|
|
try:
|
|
provider.on_session_switch(
|
|
new_session_id,
|
|
parent_session_id=parent_session_id,
|
|
reset=reset,
|
|
**kwargs,
|
|
)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_session_switch failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
|
|
"""Notify all providers before context compression.
|
|
|
|
Returns combined text from providers to include in the compression
|
|
summary prompt. Empty string if no provider contributes.
|
|
"""
|
|
parts = []
|
|
for provider in self._providers:
|
|
try:
|
|
result = provider.on_pre_compress(messages)
|
|
if result and result.strip():
|
|
parts.append(result)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_pre_compress failed: %s",
|
|
provider.name, e,
|
|
)
|
|
return "\n\n".join(parts)
|
|
|
|
@staticmethod
|
|
def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str:
|
|
"""Return how to pass metadata to a provider's memory-write hook."""
|
|
try:
|
|
signature = inspect.signature(provider.on_memory_write)
|
|
except (TypeError, ValueError):
|
|
return "keyword"
|
|
|
|
params = list(signature.parameters.values())
|
|
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
|
|
return "keyword"
|
|
if "metadata" in signature.parameters:
|
|
return "keyword"
|
|
|
|
accepted = [
|
|
p for p in params
|
|
if p.kind in {
|
|
inspect.Parameter.POSITIONAL_ONLY,
|
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
inspect.Parameter.KEYWORD_ONLY,
|
|
}
|
|
]
|
|
if len(accepted) >= 4:
|
|
return "positional"
|
|
return "legacy"
|
|
|
|
def on_memory_write(
|
|
self,
|
|
action: str,
|
|
target: str,
|
|
content: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Notify external providers when the built-in memory tool writes.
|
|
|
|
Skips the builtin provider itself (it's the source of the write).
|
|
"""
|
|
for provider in self._providers:
|
|
if provider.name == "builtin":
|
|
continue
|
|
try:
|
|
metadata_mode = self._provider_memory_write_metadata_mode(provider)
|
|
if metadata_mode == "keyword":
|
|
provider.on_memory_write(
|
|
action, target, content, metadata=dict(metadata or {})
|
|
)
|
|
elif metadata_mode == "positional":
|
|
provider.on_memory_write(action, target, content, dict(metadata or {}))
|
|
else:
|
|
provider.on_memory_write(action, target, content)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_memory_write failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def on_delegation(self, task: str, result: str, *,
|
|
child_session_id: str = "", **kwargs) -> None:
|
|
"""Notify all providers that a subagent completed."""
|
|
for provider in self._providers:
|
|
try:
|
|
provider.on_delegation(
|
|
task, result, child_session_id=child_session_id, **kwargs
|
|
)
|
|
except Exception as e:
|
|
logger.debug(
|
|
"Memory provider '%s' on_delegation failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def shutdown_all(self) -> None:
|
|
"""Shut down all providers (reverse order for clean teardown).
|
|
|
|
Drains the background sync/prefetch executor first (bounded by
|
|
``_SYNC_DRAIN_TIMEOUT_S``) so a turn's final sync has a chance to
|
|
land before providers are torn down. The worker threads are
|
|
daemon, so anything still wedged past the drain window dies with
|
|
the interpreter rather than blocking exit.
|
|
"""
|
|
self._drain_sync_executor()
|
|
for provider in reversed(self._providers):
|
|
try:
|
|
provider.shutdown()
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Memory provider '%s' shutdown failed: %s",
|
|
provider.name, e,
|
|
)
|
|
|
|
def _drain_sync_executor(self) -> None:
|
|
"""Shut down the background executor, waiting briefly for drain.
|
|
|
|
Bounded by ``_SYNC_DRAIN_TIMEOUT_S``: a wedged provider must never
|
|
hang process/session teardown. We stop accepting new work and
|
|
cancel anything still queued, then wait at most the drain timeout
|
|
for the currently-running task on a watcher thread. The worker is
|
|
daemon, so an over-running task dies with the interpreter.
|
|
"""
|
|
with self._sync_executor_lock:
|
|
executor = self._sync_executor
|
|
self._sync_executor = None
|
|
if executor is None:
|
|
return
|
|
try:
|
|
# Stop accepting new work and drop anything still queued, but
|
|
# do NOT block here — cancel_futures cancels not-yet-started
|
|
# tasks; the in-flight one keeps running on its daemon thread.
|
|
executor.shutdown(wait=False, cancel_futures=True)
|
|
except TypeError:
|
|
# Older Python without cancel_futures kwarg.
|
|
try:
|
|
executor.shutdown(wait=False)
|
|
except Exception as e: # pragma: no cover
|
|
logger.debug("Memory sync executor shutdown failed: %s", e)
|
|
return
|
|
except Exception as e: # pragma: no cover
|
|
logger.debug("Memory sync executor shutdown failed: %s", e)
|
|
return
|
|
# Give an in-flight sync a bounded chance to finish on a watcher
|
|
# thread so we don't block the caller past the drain timeout.
|
|
drainer = threading.Thread(
|
|
target=lambda: self._bounded_executor_wait(executor),
|
|
daemon=True,
|
|
name="mem-sync-drain",
|
|
)
|
|
drainer.start()
|
|
drainer.join(timeout=_SYNC_DRAIN_TIMEOUT_S)
|
|
|
|
@staticmethod
|
|
def _bounded_executor_wait(executor: ThreadPoolExecutor) -> None:
|
|
try:
|
|
executor.shutdown(wait=True)
|
|
except Exception as e: # pragma: no cover
|
|
logger.debug("Memory sync executor drain wait failed: %s", e)
|
|
|
|
def initialize_all(self, session_id: str, **kwargs) -> None:
|
|
"""Initialize all providers.
|
|
|
|
Automatically injects ``hermes_home`` into *kwargs* so that every
|
|
provider can resolve profile-scoped storage paths without importing
|
|
``get_hermes_home()`` themselves.
|
|
"""
|
|
if "hermes_home" not in kwargs:
|
|
from hermes_constants import get_hermes_home
|
|
kwargs["hermes_home"] = str(get_hermes_home())
|
|
for provider in self._providers:
|
|
try:
|
|
provider.initialize(session_id=session_id, **kwargs)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Memory provider '%s' initialize failed: %s",
|
|
provider.name, e,
|
|
)
|