hermes-agent/plugins/memory/mem0/__init__.py

635 lines
26 KiB
Python

"""Mem0 memory plugin — MemoryProvider interface.
Server-side LLM fact extraction, semantic search, and automatic deduplication
via the Mem0 Platform API (cloud) or OSS (self-hosted) via Memory.
Original PR #2933 by kartik-mem0, adapted to MemoryProvider ABC.
Configuration
-------------
Secret (lives in $HERMES_HOME/.env or the environment):
MEM0_API_KEY — Mem0 Platform API key (required for platform mode)
Behavioral settings (live in $HERMES_HOME/mem0.json, set via `hermes memory
setup`):
mode — Backend mode: "platform" (default) or "oss"
user_id — Canonical user identifier. When set, it is applied
uniformly across every gateway (CLI, Telegram, Slack,
Discord, …) so the same human gets one merged memory
store. When unset, the gateway-native id (e.g. Telegram
numeric id, Discord snowflake) is used instead.
agent_id — Agent identifier (default: hermes)
The matching MEM0_MODE / MEM0_USER_ID / MEM0_AGENT_ID environment variables are
still read as a backward-compatible fallback, but mem0.json is the canonical
home for these non-secret settings.
"""
from __future__ import annotations
import atexit
import json
import logging
import os
import threading
import time
from typing import Any, Dict, List
from agent.memory_provider import MemoryProvider
from tools.registry import tool_error
logger = logging.getLogger(__name__)
# Circuit breaker: after this many consecutive failures, pause API calls
# for _BREAKER_COOLDOWN_SECS to avoid hammering a down server.
_BREAKER_THRESHOLD = 5
_BREAKER_COOLDOWN_SECS = 120
_PREFETCH_WAIT_SECS = 1.5
_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError")
# Sentinel returned when neither MEM0_USER_ID nor a gateway-native id is
# available. Treated as "no operator-configured user_id" by initialize() so
# that legacy mem0.json files written by the setup wizard (which historically
# wrote this exact placeholder) still allow gateway-native ids to flow
# through instead of silently overriding them with the placeholder.
_DEFAULT_USER_ID = "hermes-user"
def _is_client_error(exc: Exception) -> bool:
"""True for user-caused errors (bad ID, not found) that should NOT trip circuit breaker."""
etype = type(exc).__name__
if etype in _CLIENT_ERROR_TYPES:
return True
err_str = str(exc).lower()
return "404" in err_str or "not found" in err_str or "valid uuid" in err_str
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
def _load_config() -> dict:
"""Load config from env vars, with $HERMES_HOME/mem0.json overrides.
Environment variables provide defaults; mem0.json (if present) overrides
individual keys. This avoids a silent failure when the JSON file exists
but is missing fields like ``api_key`` that the user set in ``.env``.
"""
from hermes_constants import get_hermes_home
config = {
"mode": os.environ.get("MEM0_MODE", "platform"),
"api_key": os.environ.get("MEM0_API_KEY", ""),
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
"oss": {},
}
# Only carry user_id when the operator explicitly configured one (env or
# mem0.json). An absent key tells initialize() to fall back to the
# gateway-native id from kwargs instead of overriding it with a placeholder.
env_user_id = os.environ.get("MEM0_USER_ID")
if env_user_id:
config["user_id"] = env_user_id
config_path = get_hermes_home() / "mem0.json"
if config_path.exists():
try:
file_cfg = json.loads(config_path.read_text(encoding="utf-8"))
config.update({k: v for k, v in file_cfg.items()
if v is not None and v != ""})
except Exception:
pass
return config
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
LIST_SCHEMA = {
"name": "mem0_list",
"description": (
"List ALL stored memories about the user, unranked and paginated. "
"Use for a full overview/audit at conversation start, or to browse "
"everything when you don't have a specific query. For answering a "
"specific question, prefer mem0_search."
),
"parameters": {
"type": "object",
"properties": {
"page": {"type": "integer", "description": "Page number (default: 1)."},
"page_size": {"type": "integer", "description": "Results per page (default: 100, max: 200)."},
},
"required": [],
},
}
SEARCH_SCHEMA = {
"name": "mem0_search",
"description": (
"Search the user's memories by meaning; returns facts ranked by "
"relevance. Use this BEFORE answering any question that may depend on "
"what you know about the user (preferences, facts, history, people, "
"projects, past decisions). For multi-part or multi-hop questions, "
"call it MULTIPLE times — vary the wording and run follow-up searches "
"on what earlier results reveal; one search is rarely enough. Set "
"rerank=true for higher accuracy on important queries."
),
"parameters": {
"type": "object",
"properties": {
"query": {"type": "string", "description": "What to search for."},
"top_k": {"type": "integer", "description": "Max results (default: 10, max: 50)."},
"rerank": {"type": "boolean", "description": "Rerank results for relevance (default: true, platform mode only)."},
},
"required": ["query"],
},
}
ADD_SCHEMA = {
"name": "mem0_add",
"description": (
"Store a durable fact about the user, verbatim (no LLM extraction). "
"Call this the moment the user states a lasting preference, correction, "
"decision, or personal detail worth recalling on future turns — don't "
"wait to be asked to remember. Skip transient chit-chat and facts you've "
"already stored."
),
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "The fact to store."},
},
"required": ["content"],
},
}
UPDATE_SCHEMA = {
"name": "mem0_update",
"description": (
"Replace the text of an existing memory by its ID (take the ID from a "
"mem0_search or mem0_list result). Use when a stored fact has changed "
"or was wrong — correct it in place instead of adding a duplicate."
),
"parameters": {
"type": "object",
"properties": {
"memory_id": {"type": "string", "description": "Memory UUID to update."},
"text": {"type": "string", "description": "New text content."},
},
"required": ["memory_id", "text"],
},
}
DELETE_SCHEMA = {
"name": "mem0_delete",
"description": (
"Delete a memory by its ID (take the ID from a mem0_search or mem0_list "
"result). Use when a stored fact is obsolete or the user asks you to "
"forget it; prefer mem0_update if the fact merely changed."
),
"parameters": {
"type": "object",
"properties": {
"memory_id": {"type": "string", "description": "Memory UUID to delete."},
},
"required": ["memory_id"],
},
}
# ---------------------------------------------------------------------------
# MemoryProvider implementation
# ---------------------------------------------------------------------------
class Mem0MemoryProvider(MemoryProvider):
"""Mem0 memory with server-side extraction and semantic search.
Supports Platform API (cloud) and OSS (self-hosted) modes via MEM0_MODE.
"""
def __init__(self):
self._config = None
self._backend = None
self._mode = "platform"
self._api_key = ""
self._user_id = _DEFAULT_USER_ID
self._agent_id = "hermes"
self._channel = "cli" # gateway channel name (cli/telegram/discord/...)
self._sync_thread = None
self._prefetch_thread = None
self._prefetch_query = ""
self._prefetch_result = ""
self._prefetch_done = False
# Circuit breaker state
self._consecutive_failures = 0
self._breaker_open_until = 0.0
self._breaker_lock = threading.Lock()
self._sync_lock = threading.Lock()
self._prefetch_lock = threading.Lock()
self._atexit_registered = False
@property
def name(self) -> str:
return "mem0"
def is_available(self) -> bool:
cfg = _load_config()
mode = cfg.get("mode", "platform")
if mode == "oss":
return bool(cfg.get("oss", {}).get("vector_store"))
return bool(cfg.get("api_key"))
def save_config(self, values, hermes_home):
"""Write config to $HERMES_HOME/mem0.json."""
import json
from pathlib import Path
config_path = Path(hermes_home) / "mem0.json"
existing = {}
if config_path.exists():
try:
existing = json.loads(config_path.read_text())
except Exception:
pass
existing.update(values)
from utils import atomic_json_write
atomic_json_write(config_path, existing, mode=0o600)
def get_config_schema(self):
cfg = _load_config()
mode = cfg.get("mode", "platform")
api_key_required = mode != "oss"
return [
{"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": api_key_required, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
{"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]},
]
def post_setup(self, hermes_home: str, config: dict) -> None:
from ._setup import post_setup
post_setup(hermes_home, config)
def _create_backend(self):
# Lazy-install the mem0 SDK on demand before either backend imports
# it. ensure() honors security.allow_lazy_installs (default true) and,
# on a sealed Docker venv, redirects the install to the durable
# target. On failure we fall through so the import inside the backend
# produces the canonical error, captured below.
try:
from tools.lazy_deps import ensure as _lazy_ensure
_lazy_ensure("memory.mem0", prompt=False)
except ImportError:
pass
except Exception:
pass
try:
if self._mode == "oss":
from ._backend import OSSBackend
return OSSBackend(self._config.get("oss", {}))
from ._backend import PlatformBackend
return PlatformBackend(self._api_key)
except Exception as e:
logger.error("Mem0 backend failed to initialize (%s mode): %s", self._mode, e)
self._init_error = str(e)
return None
def _is_breaker_open(self) -> bool:
"""Return True if the circuit breaker is tripped (too many failures)."""
with self._breaker_lock:
if self._consecutive_failures < _BREAKER_THRESHOLD:
return False
if time.monotonic() >= self._breaker_open_until:
self._consecutive_failures = 0
return False
return True
def _format_error(self, prefix: str, exc: Exception) -> str:
msg = f"{prefix}: {exc}"
if self._mode == "oss":
err_str = str(exc).lower()
if "connection" in err_str or "refused" in err_str or "timeout" in err_str:
vs = self._config.get("oss", {}).get("vector_store", {})
msg += f" (check that {vs.get('provider', 'vector store')} is running)"
return msg
def _record_success(self):
with self._breaker_lock:
self._consecutive_failures = 0
def _record_failure(self):
with self._breaker_lock:
self._consecutive_failures += 1
count = self._consecutive_failures
if count >= _BREAKER_THRESHOLD:
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
else:
count = 0
if count >= _BREAKER_THRESHOLD:
hint = ""
if self._mode == "oss":
vs = self._config.get("oss", {}).get("vector_store", {})
provider = vs.get("provider", "unknown")
hint = f" Check that your {provider} vector store is running and reachable."
logger.warning(
"Mem0 circuit breaker tripped after %d consecutive failures. "
"Pausing API calls for %ds.%s",
count, _BREAKER_COOLDOWN_SECS, hint,
)
def initialize(self, session_id: str, **kwargs) -> None:
self._config = _load_config()
self._mode = self._config.get("mode", "platform")
self._api_key = self._config.get("api_key", "")
# Resolution order for user_id:
# 1. Operator-configured MEM0_USER_ID (env or $HERMES_HOME/mem0.json) —
# the canonical principal, applied across every gateway so the same
# human gets one merged memory store.
# 2. Gateway-native id from kwargs (Telegram numeric id, Discord
# snowflake, etc.) — preserves per-platform isolation when no
# override is configured.
# 3. Hardcoded fallback _DEFAULT_USER_ID (CLI with no auth).
# The literal _DEFAULT_USER_ID string is treated as unset so users who
# ran the setup wizard with the suggested default still get gateway-
# native ids instead of being silently bucketed together.
configured = self._config.get("user_id")
if configured == _DEFAULT_USER_ID:
configured = None
self._user_id = configured or kwargs.get("user_id") or _DEFAULT_USER_ID
self._agent_id = self._config.get("agent_id", "hermes")
self._channel = kwargs.get("platform") or "cli"
self._backend = self._create_backend()
if self._backend and not self._atexit_registered:
atexit.register(self._shutdown_backend)
self._atexit_registered = True
def _read_filters(self) -> Dict[str, Any]:
# Scoped to user_id only — by design — so recall surfaces memories
# written from any gateway/agent under this principal. Writes attach
# agent_id (and metadata.channel) so per-agent / per-channel views are
# still possible at query time when needed; reads default to the wider
# cross-agent recall.
return {"user_id": self._user_id}
def _write_metadata(self) -> Dict[str, Any]:
# Tag every write with the gateway channel so the dashboard can offer
# per-channel filtered views without coupling identity to the channel.
return {"channel": self._channel} if self._channel else {}
def system_prompt_block(self) -> str:
mode_label = "platform (cloud API)" if self._mode == "platform" else "OSS (self-hosted)"
rerank_note = " Rerank is available on search." if self._mode == "platform" else ""
return (
"# Mem0 Memory\n"
f"Active. Mode: {mode_label}. User: {self._user_id}.\n"
"You have persistent memory of this user from past conversations. "
"ALWAYS call mem0_search before answering anything that could depend "
"on prior context (the user's preferences, facts, history, people, "
"projects, or earlier decisions) — do not rely on the chat window "
"alone, and do not assume you have no memory.\n"
"For multi-part or multi-hop questions, run SEVERAL searches with "
"different wording/angles and follow-up searches on what the first "
"results surface; one search is rarely enough. Keep searching until "
"you have every fact the question needs before you answer.\n"
"Tools: mem0_search to find memories, mem0_add to store facts, "
f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}"
)
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
self._start_prefetch(message)
def _consume_prefetch_result(self, query: str) -> str | None:
with self._prefetch_lock:
if self._prefetch_query != query or not self._prefetch_done:
return None
result = self._prefetch_result
self._prefetch_result = ""
self._prefetch_done = False
return result
def _start_prefetch(self, query: str) -> None:
if not query or self._backend is None or self._is_breaker_open():
return
backend = self._backend
with self._prefetch_lock:
if self._prefetch_query == query:
if self._prefetch_done:
return
if self._prefetch_thread and self._prefetch_thread.is_alive():
return
self._prefetch_query = query
self._prefetch_result = ""
self._prefetch_done = False
def _run():
body = ""
try:
results = backend.search(
query, filters=self._read_filters(), top_k=10, rerank=True,
)
lines = [r.get("memory", "") for r in (results or []) if r.get("memory")]
if lines:
body = "## Mem0 Memory\n" + "\n".join(f"- {l}" for l in lines)
self._record_success()
except Exception as e:
self._record_failure()
logger.debug("Mem0 prefetch failed: %s", e)
with self._prefetch_lock:
if self._prefetch_query == query:
self._prefetch_result = body
self._prefetch_done = True
t = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
with self._prefetch_lock:
self._prefetch_thread = t
t.start()
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Recall memories for the CURRENT question with a short hot-path wait."""
cached = self._consume_prefetch_result(query)
if cached is not None:
return cached
self._start_prefetch(query)
with self._prefetch_lock:
thread = self._prefetch_thread if self._prefetch_query == query else None
if thread:
thread.join(timeout=_PREFETCH_WAIT_SECS)
cached = self._consume_prefetch_result(query)
if cached is not None:
return cached
# Slow backend: skip injection; mem0_search tool remains the backstop.
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
if self._backend is None or self._is_breaker_open():
return
def _sync():
backend = self._backend
if backend is None:
return
try:
messages = [
{"role": "user", "content": user_content},
{"role": "assistant", "content": assistant_content},
]
backend.add(
messages,
user_id=self._user_id,
agent_id=self._agent_id,
infer=True,
metadata=self._write_metadata(),
)
self._record_success()
except Exception as e:
self._record_failure()
logger.warning("Mem0 sync failed: %s", e)
with self._sync_lock:
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
# If still alive after timeout, skip to avoid duplicate ingestion.
if self._sync_thread and self._sync_thread.is_alive():
return
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync")
self._sync_thread.start()
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [LIST_SCHEMA, SEARCH_SCHEMA, ADD_SCHEMA, UPDATE_SCHEMA, DELETE_SCHEMA]
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
if self._backend is None:
err = getattr(self, "_init_error", "unknown error")
hint = ""
if self._mode == "oss":
vs = self._config.get("oss", {}).get("vector_store", {})
provider = vs.get("provider", "vector store")
hint = f" Check that {provider} is running and reachable."
return json.dumps({"error": f"Mem0 backend not initialized: {err}.{hint}"})
if self._is_breaker_open():
msg = "Mem0 temporarily unavailable (multiple consecutive failures). Will retry automatically."
if self._mode == "oss":
vs = self._config.get("oss", {}).get("vector_store", {})
msg += f" Check that your {vs.get('provider', 'vector store')} is running."
return json.dumps({"error": msg})
if tool_name == "mem0_list":
try:
page = max(1, int(args.get("page", 1)))
page_size = min(max(1, int(args.get("page_size", 100))), 200)
response = self._backend.get_all(
filters=self._read_filters(), page=page, page_size=page_size,
)
self._record_success()
results = response.get("results", [])
if not results:
return json.dumps({"result": "No memories stored yet."})
items = [{"id": m.get("id"), "memory": m.get("memory", "")}
for m in results]
return json.dumps({
"results": items,
"count": response.get("count", len(items)),
"page": page, "page_size": page_size,
})
except Exception as e:
if not _is_client_error(e):
self._record_failure()
return tool_error(self._format_error("Failed to list memories", e))
elif tool_name == "mem0_search":
query = args.get("query", "")
if not query:
return tool_error("Missing required parameter: query")
try:
top_k = max(1, min(int(args.get("top_k", 10)), 50))
rerank_raw = args.get("rerank", True)
if isinstance(rerank_raw, str):
rerank = rerank_raw.lower() not in ("false", "0", "no")
else:
rerank = bool(rerank_raw)
results = self._backend.search(query, filters=self._read_filters(), top_k=top_k, rerank=rerank)
self._record_success()
if not results:
return json.dumps({"result": "No relevant memories found."})
items = [{"id": r.get("id"), "memory": r.get("memory", ""),
"score": r.get("score", 0)} for r in results]
return json.dumps({"results": items, "count": len(items)})
except Exception as e:
if not _is_client_error(e):
self._record_failure()
return tool_error(self._format_error("Search failed", e))
elif tool_name == "mem0_add":
content = args.get("content", "")
if not content:
return tool_error("Missing required parameter: content")
try:
result = self._backend.add(
[{"role": "user", "content": content}],
user_id=self._user_id,
agent_id=self._agent_id,
infer=False,
metadata=self._write_metadata(),
)
self._record_success()
event_id = result.get("event_id") if isinstance(result, dict) else None
msg = "Fact stored." if self._mode == "oss" else "Fact queued for storage."
return json.dumps({"result": msg, "event_id": event_id})
except Exception as e:
self._record_failure()
return tool_error(self._format_error("Failed to store", e))
elif tool_name == "mem0_update":
memory_id = args.get("memory_id", "")
text = args.get("text", "")
if not memory_id:
return tool_error("Missing required parameter: memory_id")
if not text:
return tool_error("Missing required parameter: text")
try:
result = self._backend.update(memory_id, text)
self._record_success()
return json.dumps(result)
except Exception as e:
if _is_client_error(e):
return tool_error(f"Memory not found: {memory_id}")
self._record_failure()
return tool_error(self._format_error("Update failed", e))
elif tool_name == "mem0_delete":
memory_id = args.get("memory_id", "")
if not memory_id:
return tool_error("Missing required parameter: memory_id")
try:
result = self._backend.delete(memory_id)
self._record_success()
return json.dumps(result)
except Exception as e:
if _is_client_error(e):
return tool_error(f"Memory not found: {memory_id}")
self._record_failure()
return tool_error(self._format_error("Delete failed", e))
return tool_error(f"Unknown tool: {tool_name}")
def _shutdown_backend(self):
try:
if self._backend:
self._backend.close()
self._backend = None
except Exception:
pass
def shutdown(self) -> None:
for t in (self._prefetch_thread, self._sync_thread):
if t and t.is_alive():
t.join(timeout=5.0)
self._shutdown_backend()
def register(ctx) -> None:
"""Register Mem0 as a memory provider plugin."""
ctx.register_memory_provider(Mem0MemoryProvider())