mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-24 10:52:21 +00:00
* fix: update to version 3 endpoints and adding update and delete tool
* chore: removing the test md file
* fix: prevent circuit breaker on client errors in Mem0 provider
* chore: add telemetry for platform version
* feat: add OSS mode support to Mem0 memory provider
* chore: bump mem0ai dependency to >=2.0.1 in memory plugin
* refactor: enhance dependency checks and embedder config in mem0 backend
* refactor: adjust fact storage message for OSS mode
* refactor: expand user paths, add collection recreation on dimension change for Qdrant
* fix(mem0): make MEM0_USER_ID override gateway-native ids and tag writes with channel
When MEM0_USER_ID was configured (env or mem0.json), the gateway-native id
from kwargs (Telegram numeric id, Discord snowflake, ...) still won, so the
same human ended up under different user_ids per channel and memories never
merged across CLI / Telegram / Slack / Discord. Mirrors openclaw's cfg.userId
pattern: configured override wins, gateway-native id is the fallback.
The legacy "hermes-user" placeholder default written by the setup wizard is
treated as unset to avoid silently bucketing every gateway user together.
Also tag every write with metadata.channel (cli/telegram/discord/...) so the
dashboard can offer per-channel filtered views without coupling identity to
the channel; document the read/write filter asymmetry as intentional
(reads scope to user_id only for cross-agent recall).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* refactor: improve Mem0 memory provider backend, pagination, config, and error handling
* refactor: update mem0 telemetry code, docs, and bump version
* fix(mem0): make get_config_schema() return unified schema with mode-aware required flag
Schema always includes api_key field so picker shows "API key / local" for
both modes. In OSS mode api_key.required=False so status won't mislead.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* refactor: improve mem0 telemetry, add env var key and OSS mode detection
* chore: bump mem0ai lower bound to 2.0.4 (latest SDK release)
* refactor: set telemetry sample rate to 1.0 and update docs for opt‑out
* fix(mem0): resolve 15 correctness, thread-safety, and resource bugs
Thread safety:
- Protect circuit breaker counters with _breaker_lock (race between
prefetch/sync daemon threads and main thread)
- Wrap sync_turn thread creation in _sync_lock; skip if previous sync
is still alive after 5 s join to prevent duplicate memory ingestion
- Guard _schedule_flush timer creation under _queue_lock (TOCTOU race)
- Capture local `backend` reference in prefetch/sync closures so
shutdown() nulling self._backend cannot crash in-flight threads
Correctness:
- Fix bool("false")==True for rerank param; parse string values explicitly
- Guard page/top_k with max(1,...) and move int() inside try blocks
- Fix fact_count=0 always in OSS mode (Memory.add returns list, not dict)
- Fix prefetch() not clearing result when thread still alive after timeout
- Fix atexit.register accumulating on repeated initialize() calls
Backend / setup:
- Handle Qdrant named-vector collections in _recreate_collection_if_dims_changed
(vectors is a dict; .size access raised AttributeError, swallowed silently)
- Wrap QdrantClient and psycopg2 conn/cursor in try/finally to prevent leaks
- Resolve ollama_bin at top of _ensure_ollama; use it for ollama pull
- Fix embedder key lookup when LLM provider has no env_var (e.g. ollama)
Also: remove _telemetry_enabled cache (env var check is cheap), bump
required mem0ai to >=2.0.7, minor README wording fix.
* fix(mem0): fix brittle qdrant path test + add telemetry sample-rate docs
- Replace generator-throw lambda with a proper def in
test_qdrant_path_not_writable; use tmp_path instead of a hardcoded
/nonexistent path so the test is root-safe
- Add MEM0_TELEMETRY_SAMPLE_RATE to memory-providers.md (was only
in the plugin README, not the user-guide docs)
* revert: remove MEM0_TELEMETRY_SAMPLE_RATE from user-guide docs
* refactor: remove telemetry from mem0 plugin and update documentation
* fix(mem0): set stdin=DEVNULL on setup subprocess calls
The TUI stdin guard (scripts/check_subprocess_stdin.py) requires every
subprocess call in plugin code to set stdin= so it can't inherit the
gateway's JSON-RPC stdin fd. Muzzle the docker/ollama calls in the OSS
setup wizard with stdin=subprocess.DEVNULL (none need interactive input).
Also covers the docker-inspect call the linter's regex misses.
---------
Co-authored-by: chaithanyak42 <chaithanya.kumar42a@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
562 lines
23 KiB
Python
562 lines
23 KiB
Python
"""Mem0 memory plugin — MemoryProvider interface.
|
|
|
|
Server-side LLM fact extraction, semantic search, and automatic deduplication
|
|
via the Mem0 Platform API (cloud) or OSS (self-hosted) via Memory.
|
|
|
|
Original PR #2933 by kartik-mem0, adapted to MemoryProvider ABC.
|
|
|
|
Configuration
|
|
-------------
|
|
Secret (lives in $HERMES_HOME/.env or the environment):
|
|
MEM0_API_KEY — Mem0 Platform API key (required for platform mode)
|
|
|
|
Behavioral settings (live in $HERMES_HOME/mem0.json, set via `hermes memory
|
|
setup`):
|
|
mode — Backend mode: "platform" (default) or "oss"
|
|
user_id — Canonical user identifier. When set, it is applied
|
|
uniformly across every gateway (CLI, Telegram, Slack,
|
|
Discord, …) so the same human gets one merged memory
|
|
store. When unset, the gateway-native id (e.g. Telegram
|
|
numeric id, Discord snowflake) is used instead.
|
|
agent_id — Agent identifier (default: hermes)
|
|
|
|
The matching MEM0_MODE / MEM0_USER_ID / MEM0_AGENT_ID environment variables are
|
|
still read as a backward-compatible fallback, but mem0.json is the canonical
|
|
home for these non-secret settings.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import atexit
|
|
import json
|
|
import logging
|
|
import os
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
from agent.memory_provider import MemoryProvider
|
|
from tools.registry import tool_error
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Circuit breaker: after this many consecutive failures, pause API calls
|
|
# for _BREAKER_COOLDOWN_SECS to avoid hammering a down server.
|
|
_BREAKER_THRESHOLD = 5
|
|
_BREAKER_COOLDOWN_SECS = 120
|
|
|
|
_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError")
|
|
|
|
# Sentinel returned when neither MEM0_USER_ID nor a gateway-native id is
|
|
# available. Treated as "no operator-configured user_id" by initialize() so
|
|
# that legacy mem0.json files written by the setup wizard (which historically
|
|
# wrote this exact placeholder) still allow gateway-native ids to flow
|
|
# through instead of silently overriding them with the placeholder.
|
|
_DEFAULT_USER_ID = "hermes-user"
|
|
|
|
|
|
def _is_client_error(exc: Exception) -> bool:
|
|
"""True for user-caused errors (bad ID, not found) that should NOT trip circuit breaker."""
|
|
etype = type(exc).__name__
|
|
if etype in _CLIENT_ERROR_TYPES:
|
|
return True
|
|
err_str = str(exc).lower()
|
|
return "404" in err_str or "not found" in err_str or "valid uuid" in err_str
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _load_config() -> dict:
|
|
"""Load config from env vars, with $HERMES_HOME/mem0.json overrides.
|
|
|
|
Environment variables provide defaults; mem0.json (if present) overrides
|
|
individual keys. This avoids a silent failure when the JSON file exists
|
|
but is missing fields like ``api_key`` that the user set in ``.env``.
|
|
"""
|
|
from hermes_constants import get_hermes_home
|
|
|
|
config = {
|
|
"mode": os.environ.get("MEM0_MODE", "platform"),
|
|
"api_key": os.environ.get("MEM0_API_KEY", ""),
|
|
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
|
"oss": {},
|
|
}
|
|
# Only carry user_id when the operator explicitly configured one (env or
|
|
# mem0.json). An absent key tells initialize() to fall back to the
|
|
# gateway-native id from kwargs instead of overriding it with a placeholder.
|
|
env_user_id = os.environ.get("MEM0_USER_ID")
|
|
if env_user_id:
|
|
config["user_id"] = env_user_id
|
|
|
|
config_path = get_hermes_home() / "mem0.json"
|
|
if config_path.exists():
|
|
try:
|
|
file_cfg = json.loads(config_path.read_text(encoding="utf-8"))
|
|
config.update({k: v for k, v in file_cfg.items()
|
|
if v is not None and v != ""})
|
|
except Exception:
|
|
pass
|
|
|
|
return config
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool schemas
|
|
# ---------------------------------------------------------------------------
|
|
|
|
LIST_SCHEMA = {
|
|
"name": "mem0_list",
|
|
"description": (
|
|
"List all stored memories about the user. "
|
|
"Use at conversation start for full overview."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"page": {"type": "integer", "description": "Page number (default: 1)."},
|
|
"page_size": {"type": "integer", "description": "Results per page (default: 100, max: 200)."},
|
|
},
|
|
"required": [],
|
|
},
|
|
}
|
|
|
|
SEARCH_SCHEMA = {
|
|
"name": "mem0_search",
|
|
"description": (
|
|
"Search memories by meaning. Returns relevant facts ranked by relevance."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string", "description": "What to search for."},
|
|
"top_k": {"type": "integer", "description": "Max results (default: 10, max: 50)."},
|
|
"rerank": {"type": "boolean", "description": "Rerank results for relevance (default: true, platform mode only)."},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
}
|
|
|
|
ADD_SCHEMA = {
|
|
"name": "mem0_add",
|
|
"description": (
|
|
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
|
|
"Use for explicit preferences, corrections, or decisions."
|
|
),
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"content": {"type": "string", "description": "The fact to store."},
|
|
},
|
|
"required": ["content"],
|
|
},
|
|
}
|
|
|
|
UPDATE_SCHEMA = {
|
|
"name": "mem0_update",
|
|
"description": "Update an existing memory's text by its ID.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"memory_id": {"type": "string", "description": "Memory UUID to update."},
|
|
"text": {"type": "string", "description": "New text content."},
|
|
},
|
|
"required": ["memory_id", "text"],
|
|
},
|
|
}
|
|
|
|
DELETE_SCHEMA = {
|
|
"name": "mem0_delete",
|
|
"description": "Delete a memory by its ID.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"memory_id": {"type": "string", "description": "Memory UUID to delete."},
|
|
},
|
|
"required": ["memory_id"],
|
|
},
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MemoryProvider implementation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class Mem0MemoryProvider(MemoryProvider):
|
|
"""Mem0 memory with server-side extraction and semantic search.
|
|
|
|
Supports Platform API (cloud) and OSS (self-hosted) modes via MEM0_MODE.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._config = None
|
|
self._backend = None
|
|
self._mode = "platform"
|
|
self._api_key = ""
|
|
self._user_id = _DEFAULT_USER_ID
|
|
self._agent_id = "hermes"
|
|
self._channel = "cli" # gateway channel name (cli/telegram/discord/...)
|
|
self._prefetch_result = ""
|
|
self._prefetch_lock = threading.Lock()
|
|
self._prefetch_thread = None
|
|
self._sync_thread = None
|
|
# Circuit breaker state
|
|
self._consecutive_failures = 0
|
|
self._breaker_open_until = 0.0
|
|
self._breaker_lock = threading.Lock()
|
|
self._sync_lock = threading.Lock()
|
|
self._atexit_registered = False
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "mem0"
|
|
|
|
def is_available(self) -> bool:
|
|
cfg = _load_config()
|
|
mode = cfg.get("mode", "platform")
|
|
if mode == "oss":
|
|
return bool(cfg.get("oss", {}).get("vector_store"))
|
|
return bool(cfg.get("api_key"))
|
|
|
|
def save_config(self, values, hermes_home):
|
|
"""Write config to $HERMES_HOME/mem0.json."""
|
|
import json
|
|
from pathlib import Path
|
|
config_path = Path(hermes_home) / "mem0.json"
|
|
existing = {}
|
|
if config_path.exists():
|
|
try:
|
|
existing = json.loads(config_path.read_text())
|
|
except Exception:
|
|
pass
|
|
existing.update(values)
|
|
from utils import atomic_json_write
|
|
atomic_json_write(config_path, existing, mode=0o600)
|
|
|
|
def get_config_schema(self):
|
|
cfg = _load_config()
|
|
mode = cfg.get("mode", "platform")
|
|
api_key_required = mode != "oss"
|
|
return [
|
|
{"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": api_key_required, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
|
|
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
|
|
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
|
|
{"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]},
|
|
]
|
|
|
|
def post_setup(self, hermes_home: str, config: dict) -> None:
|
|
from ._setup import post_setup
|
|
post_setup(hermes_home, config)
|
|
|
|
def _create_backend(self):
|
|
try:
|
|
if self._mode == "oss":
|
|
from ._backend import OSSBackend
|
|
return OSSBackend(self._config.get("oss", {}))
|
|
from ._backend import PlatformBackend
|
|
return PlatformBackend(self._api_key)
|
|
except Exception as e:
|
|
logger.error("Mem0 backend failed to initialize (%s mode): %s", self._mode, e)
|
|
self._init_error = str(e)
|
|
return None
|
|
|
|
def _is_breaker_open(self) -> bool:
|
|
"""Return True if the circuit breaker is tripped (too many failures)."""
|
|
with self._breaker_lock:
|
|
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
|
return False
|
|
if time.monotonic() >= self._breaker_open_until:
|
|
self._consecutive_failures = 0
|
|
return False
|
|
return True
|
|
|
|
def _format_error(self, prefix: str, exc: Exception) -> str:
|
|
msg = f"{prefix}: {exc}"
|
|
if self._mode == "oss":
|
|
err_str = str(exc).lower()
|
|
if "connection" in err_str or "refused" in err_str or "timeout" in err_str:
|
|
vs = self._config.get("oss", {}).get("vector_store", {})
|
|
msg += f" (check that {vs.get('provider', 'vector store')} is running)"
|
|
return msg
|
|
|
|
def _record_success(self):
|
|
with self._breaker_lock:
|
|
self._consecutive_failures = 0
|
|
|
|
def _record_failure(self):
|
|
with self._breaker_lock:
|
|
self._consecutive_failures += 1
|
|
count = self._consecutive_failures
|
|
if count >= _BREAKER_THRESHOLD:
|
|
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
|
else:
|
|
count = 0
|
|
if count >= _BREAKER_THRESHOLD:
|
|
hint = ""
|
|
if self._mode == "oss":
|
|
vs = self._config.get("oss", {}).get("vector_store", {})
|
|
provider = vs.get("provider", "unknown")
|
|
hint = f" Check that your {provider} vector store is running and reachable."
|
|
logger.warning(
|
|
"Mem0 circuit breaker tripped after %d consecutive failures. "
|
|
"Pausing API calls for %ds.%s",
|
|
count, _BREAKER_COOLDOWN_SECS, hint,
|
|
)
|
|
|
|
def initialize(self, session_id: str, **kwargs) -> None:
|
|
self._config = _load_config()
|
|
self._mode = self._config.get("mode", "platform")
|
|
self._api_key = self._config.get("api_key", "")
|
|
# Resolution order for user_id:
|
|
# 1. Operator-configured MEM0_USER_ID (env or $HERMES_HOME/mem0.json) —
|
|
# the canonical principal, applied across every gateway so the same
|
|
# human gets one merged memory store.
|
|
# 2. Gateway-native id from kwargs (Telegram numeric id, Discord
|
|
# snowflake, etc.) — preserves per-platform isolation when no
|
|
# override is configured.
|
|
# 3. Hardcoded fallback _DEFAULT_USER_ID (CLI with no auth).
|
|
# The literal _DEFAULT_USER_ID string is treated as unset so users who
|
|
# ran the setup wizard with the suggested default still get gateway-
|
|
# native ids instead of being silently bucketed together.
|
|
configured = self._config.get("user_id")
|
|
if configured == _DEFAULT_USER_ID:
|
|
configured = None
|
|
self._user_id = configured or kwargs.get("user_id") or _DEFAULT_USER_ID
|
|
self._agent_id = self._config.get("agent_id", "hermes")
|
|
self._channel = kwargs.get("platform") or "cli"
|
|
self._backend = self._create_backend()
|
|
if self._backend and not self._atexit_registered:
|
|
atexit.register(self._shutdown_backend)
|
|
self._atexit_registered = True
|
|
|
|
def _read_filters(self) -> Dict[str, Any]:
|
|
# Scoped to user_id only — by design — so recall surfaces memories
|
|
# written from any gateway/agent under this principal. Writes attach
|
|
# agent_id (and metadata.channel) so per-agent / per-channel views are
|
|
# still possible at query time when needed; reads default to the wider
|
|
# cross-agent recall.
|
|
return {"user_id": self._user_id}
|
|
|
|
def _write_metadata(self) -> Dict[str, Any]:
|
|
# Tag every write with the gateway channel so the dashboard can offer
|
|
# per-channel filtered views without coupling identity to the channel.
|
|
return {"channel": self._channel} if self._channel else {}
|
|
|
|
def system_prompt_block(self) -> str:
|
|
mode_label = "platform (cloud API)" if self._mode == "platform" else "OSS (self-hosted)"
|
|
rerank_note = " Rerank is available on search." if self._mode == "platform" else ""
|
|
return (
|
|
"# Mem0 Memory\n"
|
|
f"Active. Mode: {mode_label}. User: {self._user_id}.\n"
|
|
"Use mem0_search to find memories, mem0_add to store facts, "
|
|
f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}"
|
|
)
|
|
|
|
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
|
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
|
self._prefetch_thread.join(timeout=3.0)
|
|
# If the thread still hasn't finished, leave the result for the next call.
|
|
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
|
return ""
|
|
with self._prefetch_lock:
|
|
result = self._prefetch_result
|
|
self._prefetch_result = ""
|
|
if not result:
|
|
return ""
|
|
return f"## Mem0 Memory\n{result}"
|
|
|
|
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
|
if self._backend is None or self._is_breaker_open():
|
|
return
|
|
|
|
def _run():
|
|
backend = self._backend
|
|
if backend is None:
|
|
return
|
|
try:
|
|
results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True)
|
|
if results:
|
|
lines = [r.get("memory", "") for r in results if r.get("memory")]
|
|
with self._prefetch_lock:
|
|
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
|
|
self._record_success()
|
|
except Exception as e:
|
|
self._record_failure()
|
|
logger.debug("Mem0 prefetch failed: %s", e)
|
|
|
|
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
|
|
self._prefetch_thread.start()
|
|
|
|
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
|
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
|
|
if self._backend is None or self._is_breaker_open():
|
|
return
|
|
|
|
def _sync():
|
|
backend = self._backend
|
|
if backend is None:
|
|
return
|
|
try:
|
|
messages = [
|
|
{"role": "user", "content": user_content},
|
|
{"role": "assistant", "content": assistant_content},
|
|
]
|
|
backend.add(
|
|
messages,
|
|
user_id=self._user_id,
|
|
agent_id=self._agent_id,
|
|
infer=True,
|
|
metadata=self._write_metadata(),
|
|
)
|
|
self._record_success()
|
|
except Exception as e:
|
|
self._record_failure()
|
|
logger.warning("Mem0 sync failed: %s", e)
|
|
|
|
with self._sync_lock:
|
|
if self._sync_thread and self._sync_thread.is_alive():
|
|
self._sync_thread.join(timeout=5.0)
|
|
# If still alive after timeout, skip to avoid duplicate ingestion.
|
|
if self._sync_thread and self._sync_thread.is_alive():
|
|
return
|
|
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync")
|
|
self._sync_thread.start()
|
|
|
|
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
|
return [LIST_SCHEMA, SEARCH_SCHEMA, ADD_SCHEMA, UPDATE_SCHEMA, DELETE_SCHEMA]
|
|
|
|
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
|
if self._backend is None:
|
|
err = getattr(self, "_init_error", "unknown error")
|
|
hint = ""
|
|
if self._mode == "oss":
|
|
vs = self._config.get("oss", {}).get("vector_store", {})
|
|
provider = vs.get("provider", "vector store")
|
|
hint = f" Check that {provider} is running and reachable."
|
|
return json.dumps({"error": f"Mem0 backend not initialized: {err}.{hint}"})
|
|
|
|
if self._is_breaker_open():
|
|
msg = "Mem0 temporarily unavailable (multiple consecutive failures). Will retry automatically."
|
|
if self._mode == "oss":
|
|
vs = self._config.get("oss", {}).get("vector_store", {})
|
|
msg += f" Check that your {vs.get('provider', 'vector store')} is running."
|
|
return json.dumps({"error": msg})
|
|
|
|
if tool_name == "mem0_list":
|
|
try:
|
|
page = max(1, int(args.get("page", 1)))
|
|
page_size = min(max(1, int(args.get("page_size", 100))), 200)
|
|
response = self._backend.get_all(
|
|
filters=self._read_filters(), page=page, page_size=page_size,
|
|
)
|
|
self._record_success()
|
|
results = response.get("results", [])
|
|
if not results:
|
|
return json.dumps({"result": "No memories stored yet."})
|
|
items = [{"id": m.get("id"), "memory": m.get("memory", "")}
|
|
for m in results]
|
|
return json.dumps({
|
|
"results": items,
|
|
"count": response.get("count", len(items)),
|
|
"page": page, "page_size": page_size,
|
|
})
|
|
except Exception as e:
|
|
if not _is_client_error(e):
|
|
self._record_failure()
|
|
return tool_error(self._format_error("Failed to list memories", e))
|
|
|
|
elif tool_name == "mem0_search":
|
|
query = args.get("query", "")
|
|
if not query:
|
|
return tool_error("Missing required parameter: query")
|
|
try:
|
|
top_k = max(1, min(int(args.get("top_k", 10)), 50))
|
|
rerank_raw = args.get("rerank", True)
|
|
if isinstance(rerank_raw, str):
|
|
rerank = rerank_raw.lower() not in ("false", "0", "no")
|
|
else:
|
|
rerank = bool(rerank_raw)
|
|
results = self._backend.search(query, filters=self._read_filters(), top_k=top_k, rerank=rerank)
|
|
self._record_success()
|
|
if not results:
|
|
return json.dumps({"result": "No relevant memories found."})
|
|
items = [{"id": r.get("id"), "memory": r.get("memory", ""),
|
|
"score": r.get("score", 0)} for r in results]
|
|
return json.dumps({"results": items, "count": len(items)})
|
|
except Exception as e:
|
|
if not _is_client_error(e):
|
|
self._record_failure()
|
|
return tool_error(self._format_error("Search failed", e))
|
|
|
|
elif tool_name == "mem0_add":
|
|
content = args.get("content", "")
|
|
if not content:
|
|
return tool_error("Missing required parameter: content")
|
|
try:
|
|
result = self._backend.add(
|
|
[{"role": "user", "content": content}],
|
|
user_id=self._user_id,
|
|
agent_id=self._agent_id,
|
|
infer=False,
|
|
metadata=self._write_metadata(),
|
|
)
|
|
self._record_success()
|
|
event_id = result.get("event_id") if isinstance(result, dict) else None
|
|
msg = "Fact stored." if self._mode == "oss" else "Fact queued for storage."
|
|
return json.dumps({"result": msg, "event_id": event_id})
|
|
except Exception as e:
|
|
self._record_failure()
|
|
return tool_error(self._format_error("Failed to store", e))
|
|
|
|
elif tool_name == "mem0_update":
|
|
memory_id = args.get("memory_id", "")
|
|
text = args.get("text", "")
|
|
if not memory_id:
|
|
return tool_error("Missing required parameter: memory_id")
|
|
if not text:
|
|
return tool_error("Missing required parameter: text")
|
|
try:
|
|
result = self._backend.update(memory_id, text)
|
|
self._record_success()
|
|
return json.dumps(result)
|
|
except Exception as e:
|
|
if _is_client_error(e):
|
|
return tool_error(f"Memory not found: {memory_id}")
|
|
self._record_failure()
|
|
return tool_error(self._format_error("Update failed", e))
|
|
|
|
elif tool_name == "mem0_delete":
|
|
memory_id = args.get("memory_id", "")
|
|
if not memory_id:
|
|
return tool_error("Missing required parameter: memory_id")
|
|
try:
|
|
result = self._backend.delete(memory_id)
|
|
self._record_success()
|
|
return json.dumps(result)
|
|
except Exception as e:
|
|
if _is_client_error(e):
|
|
return tool_error(f"Memory not found: {memory_id}")
|
|
self._record_failure()
|
|
return tool_error(self._format_error("Delete failed", e))
|
|
|
|
return tool_error(f"Unknown tool: {tool_name}")
|
|
|
|
def _shutdown_backend(self):
|
|
try:
|
|
if self._backend:
|
|
self._backend.close()
|
|
self._backend = None
|
|
except Exception:
|
|
pass
|
|
|
|
def shutdown(self) -> None:
|
|
for t in (self._prefetch_thread, self._sync_thread):
|
|
if t and t.is_alive():
|
|
t.join(timeout=5.0)
|
|
self._shutdown_backend()
|
|
|
|
|
|
def register(ctx) -> None:
|
|
"""Register Mem0 as a memory provider plugin."""
|
|
ctx.register_memory_provider(Mem0MemoryProvider())
|