diff --git a/plugins/memory/mem0/__init__.py b/plugins/memory/mem0/__init__.py index f11d1f0cd83..3647496375d 100644 --- a/plugins/memory/mem0/__init__.py +++ b/plugins/memory/mem0/__init__.py @@ -44,6 +44,7 @@ logger = logging.getLogger(__name__) # for _BREAKER_COOLDOWN_SECS to avoid hammering a down server. _BREAKER_THRESHOLD = 5 _BREAKER_COOLDOWN_SECS = 120 +_PREFETCH_WAIT_SECS = 1.5 _CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError") @@ -109,8 +110,10 @@ def _load_config() -> dict: LIST_SCHEMA = { "name": "mem0_list", "description": ( - "List all stored memories about the user. " - "Use at conversation start for full overview." + "List ALL stored memories about the user, unranked and paginated. " + "Use for a full overview/audit at conversation start, or to browse " + "everything when you don't have a specific query. For answering a " + "specific question, prefer mem0_search." ), "parameters": { "type": "object", @@ -125,7 +128,13 @@ LIST_SCHEMA = { SEARCH_SCHEMA = { "name": "mem0_search", "description": ( - "Search memories by meaning. Returns relevant facts ranked by relevance." + "Search the user's memories by meaning; returns facts ranked by " + "relevance. Use this BEFORE answering any question that may depend on " + "what you know about the user (preferences, facts, history, people, " + "projects, past decisions). For multi-part or multi-hop questions, " + "call it MULTIPLE times — vary the wording and run follow-up searches " + "on what earlier results reveal; one search is rarely enough. Set " + "rerank=true for higher accuracy on important queries." ), "parameters": { "type": "object", @@ -141,8 +150,11 @@ SEARCH_SCHEMA = { ADD_SCHEMA = { "name": "mem0_add", "description": ( - "Store a durable fact about the user. Stored verbatim (no LLM extraction). " - "Use for explicit preferences, corrections, or decisions." + "Store a durable fact about the user, verbatim (no LLM extraction). " + "Call this the moment the user states a lasting preference, correction, " + "decision, or personal detail worth recalling on future turns — don't " + "wait to be asked to remember. Skip transient chit-chat and facts you've " + "already stored." ), "parameters": { "type": "object", @@ -155,7 +167,11 @@ ADD_SCHEMA = { UPDATE_SCHEMA = { "name": "mem0_update", - "description": "Update an existing memory's text by its ID.", + "description": ( + "Replace the text of an existing memory by its ID (take the ID from a " + "mem0_search or mem0_list result). Use when a stored fact has changed " + "or was wrong — correct it in place instead of adding a duplicate." + ), "parameters": { "type": "object", "properties": { @@ -168,7 +184,11 @@ UPDATE_SCHEMA = { DELETE_SCHEMA = { "name": "mem0_delete", - "description": "Delete a memory by its ID.", + "description": ( + "Delete a memory by its ID (take the ID from a mem0_search or mem0_list " + "result). Use when a stored fact is obsolete or the user asks you to " + "forget it; prefer mem0_update if the fact merely changed." + ), "parameters": { "type": "object", "properties": { @@ -197,15 +217,17 @@ class Mem0MemoryProvider(MemoryProvider): self._user_id = _DEFAULT_USER_ID self._agent_id = "hermes" self._channel = "cli" # gateway channel name (cli/telegram/discord/...) - self._prefetch_result = "" - self._prefetch_lock = threading.Lock() - self._prefetch_thread = None self._sync_thread = None + self._prefetch_thread = None + self._prefetch_query = "" + self._prefetch_result = "" + self._prefetch_done = False # Circuit breaker state self._consecutive_failures = 0 self._breaker_open_until = 0.0 self._breaker_lock = threading.Lock() self._sync_lock = threading.Lock() + self._prefetch_lock = threading.Lock() self._atexit_registered = False @property @@ -361,44 +383,83 @@ class Mem0MemoryProvider(MemoryProvider): return ( "# Mem0 Memory\n" f"Active. Mode: {mode_label}. User: {self._user_id}.\n" - "Use mem0_search to find memories, mem0_add to store facts, " + "You have persistent memory of this user from past conversations. " + "ALWAYS call mem0_search before answering anything that could depend " + "on prior context (the user's preferences, facts, history, people, " + "projects, or earlier decisions) — do not rely on the chat window " + "alone, and do not assume you have no memory.\n" + "For multi-part or multi-hop questions, run SEVERAL searches with " + "different wording/angles and follow-up searches on what the first " + "results surface; one search is rarely enough. Keep searching until " + "you have every fact the question needs before you answer.\n" + "Tools: mem0_search to find memories, mem0_add to store facts, " f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}" ) - def prefetch(self, query: str, *, session_id: str = "") -> str: - if self._prefetch_thread and self._prefetch_thread.is_alive(): - self._prefetch_thread.join(timeout=3.0) - # If the thread still hasn't finished, leave the result for the next call. - if self._prefetch_thread and self._prefetch_thread.is_alive(): - return "" + def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None: + self._start_prefetch(message) + + def _consume_prefetch_result(self, query: str) -> str | None: with self._prefetch_lock: + if self._prefetch_query != query or not self._prefetch_done: + return None result = self._prefetch_result self._prefetch_result = "" - if not result: - return "" - return f"## Mem0 Memory\n{result}" + self._prefetch_done = False + return result - def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - if self._backend is None or self._is_breaker_open(): + def _start_prefetch(self, query: str) -> None: + if not query or self._backend is None or self._is_breaker_open(): return + backend = self._backend + with self._prefetch_lock: + if self._prefetch_query == query: + if self._prefetch_done: + return + if self._prefetch_thread and self._prefetch_thread.is_alive(): + return + self._prefetch_query = query + self._prefetch_result = "" + self._prefetch_done = False def _run(): - backend = self._backend - if backend is None: - return + body = "" try: - results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True) - if results: - lines = [r.get("memory", "") for r in results if r.get("memory")] - with self._prefetch_lock: - self._prefetch_result = "\n".join(f"- {l}" for l in lines) + results = backend.search( + query, filters=self._read_filters(), top_k=10, rerank=True, + ) + lines = [r.get("memory", "") for r in (results or []) if r.get("memory")] + if lines: + body = "## Mem0 Memory\n" + "\n".join(f"- {l}" for l in lines) self._record_success() except Exception as e: self._record_failure() logger.debug("Mem0 prefetch failed: %s", e) + with self._prefetch_lock: + if self._prefetch_query == query: + self._prefetch_result = body + self._prefetch_done = True - self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="mem0-prefetch") - self._prefetch_thread.start() + t = threading.Thread(target=_run, daemon=True, name="mem0-prefetch") + with self._prefetch_lock: + self._prefetch_thread = t + t.start() + + def prefetch(self, query: str, *, session_id: str = "") -> str: + """Recall memories for the CURRENT question with a short hot-path wait.""" + cached = self._consume_prefetch_result(query) + if cached is not None: + return cached + self._start_prefetch(query) + with self._prefetch_lock: + thread = self._prefetch_thread if self._prefetch_query == query else None + if thread: + thread.join(timeout=_PREFETCH_WAIT_SECS) + cached = self._consume_prefetch_result(query) + if cached is not None: + return cached + # Slow backend: skip injection; mem0_search tool remains the backstop. + return "" def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Send the turn to Mem0 for server-side fact extraction (non-blocking).""" diff --git a/plugins/memory/mem0/plugin.yaml b/plugins/memory/mem0/plugin.yaml index 1d9dec52306..ef06b0f37f7 100644 --- a/plugins/memory/mem0/plugin.yaml +++ b/plugins/memory/mem0/plugin.yaml @@ -1,5 +1,5 @@ name: mem0 -version: 1.1.0 +version: 1.2.0 description: "Mem0 — server-side LLM fact extraction with semantic search, reranking, and automatic deduplication." pip_dependencies: - mem0ai>=2.0.7,<3 diff --git a/tests/plugins/memory/test_mem0_v3.py b/tests/plugins/memory/test_mem0_v3.py index e83a4171a4a..0e381d6a894 100644 --- a/tests/plugins/memory/test_mem0_v3.py +++ b/tests/plugins/memory/test_mem0_v3.py @@ -1,8 +1,10 @@ """Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools.""" import json +import time import pytest +import plugins.memory.mem0 as mem0_plugin from plugins.memory.mem0 import Mem0MemoryProvider @@ -280,6 +282,90 @@ class TestMem0V3Internal: assert "error" in result +class TestMem0Prefetch: + """prefetch() must recall on the CURRENT question, synchronously. + + The old implementation ignored its ``query`` and returned whatever a + background ``queue_prefetch`` had warmed from the PREVIOUS turn — so the + first turn injected nothing and later turns injected stale, off-topic + memories. These lock the corrected behaviour. + """ + + def _make_provider(self, backend): + provider = Mem0MemoryProvider() + provider.initialize("test-session") + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._backend = backend + return provider + + def test_prefetch_searches_current_query(self): + backend = FakeBackend(search_results=[{"id": "m1", "memory": "user prefers dark mode"}]) + provider = self._make_provider(backend) + result = provider.prefetch("what theme do I like?") + kind, query, opts = backend.captured[0] + assert kind == "search" + assert query == "what theme do I like?" + assert opts["filters"] == {"user_id": "u123"} + assert opts["top_k"] == 10 + assert opts["rerank"] is True + assert "## Mem0 Memory" in result + assert "user prefers dark mode" in result + + def test_prefetch_returns_memories_on_first_call(self): + # No prior queue_prefetch / warm — the very first call must still recall. + backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}]) + provider = self._make_provider(backend) + result = provider.prefetch("where do I live?") + assert "lives in Berlin" in result + + def test_on_turn_start_queues_current_query(self): + backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}]) + provider = self._make_provider(backend) + provider.on_turn_start(1, "where do I live?") + provider._prefetch_thread.join(timeout=1) + result = provider.prefetch("where do I live?") + assert "lives in Berlin" in result + assert len([c for c in backend.captured if c[0] == "search"]) == 1 + + def test_slow_prefetch_returns_quickly(self, monkeypatch): + class SlowBackend(FakeBackend): + def search(self, query, *, filters, top_k=10, rerank=True): + time.sleep(0.2) + return super().search(query, filters=filters, top_k=top_k, rerank=rerank) + + monkeypatch.setattr(mem0_plugin, "_PREFETCH_WAIT_SECS", 0.01) + provider = self._make_provider( + SlowBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}]) + ) + started = time.monotonic() + assert provider.prefetch("where do I live?") == "" + assert time.monotonic() - started < 0.1 + provider._prefetch_thread.join(timeout=1) + assert "lives in Berlin" in provider.prefetch("where do I live?") + + def test_prefetch_empty_results_returns_empty(self): + backend = FakeBackend(search_results=[]) + provider = self._make_provider(backend) + assert provider.prefetch("anything") == "" + + def test_prefetch_skips_when_breaker_open(self): + backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}]) + provider = self._make_provider(backend) + provider._consecutive_failures = 5 + provider._breaker_open_until = float("inf") + assert provider.prefetch("q") == "" + assert backend.captured == [] + + def test_queue_prefetch_fires_no_search(self): + # prefetch is synchronous now, so the post-turn warm is redundant and + # must not fire a wasted backend search. + backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}]) + provider = self._make_provider(backend) + provider.queue_prefetch("previous turn text") + assert backend.captured == [] + + class TestMem0V3Config: def test_tool_schemas_five_tools(self):