fix(memory/mem0): recall on the current question + stronger search guidance (#55535)

This commit is contained in:
Kartik 2026-06-30 15:51:08 +05:30 committed by GitHub
parent b8ebe32866
commit c6eb7f9e72
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 180 additions and 33 deletions

View file

@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
# for _BREAKER_COOLDOWN_SECS to avoid hammering a down server.
_BREAKER_THRESHOLD = 5
_BREAKER_COOLDOWN_SECS = 120
_PREFETCH_WAIT_SECS = 1.5
_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError")
@ -109,8 +110,10 @@ def _load_config() -> dict:
LIST_SCHEMA = {
"name": "mem0_list",
"description": (
"List all stored memories about the user. "
"Use at conversation start for full overview."
"List ALL stored memories about the user, unranked and paginated. "
"Use for a full overview/audit at conversation start, or to browse "
"everything when you don't have a specific query. For answering a "
"specific question, prefer mem0_search."
),
"parameters": {
"type": "object",
@ -125,7 +128,13 @@ LIST_SCHEMA = {
SEARCH_SCHEMA = {
"name": "mem0_search",
"description": (
"Search memories by meaning. Returns relevant facts ranked by relevance."
"Search the user's memories by meaning; returns facts ranked by "
"relevance. Use this BEFORE answering any question that may depend on "
"what you know about the user (preferences, facts, history, people, "
"projects, past decisions). For multi-part or multi-hop questions, "
"call it MULTIPLE times — vary the wording and run follow-up searches "
"on what earlier results reveal; one search is rarely enough. Set "
"rerank=true for higher accuracy on important queries."
),
"parameters": {
"type": "object",
@ -141,8 +150,11 @@ SEARCH_SCHEMA = {
ADD_SCHEMA = {
"name": "mem0_add",
"description": (
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
"Use for explicit preferences, corrections, or decisions."
"Store a durable fact about the user, verbatim (no LLM extraction). "
"Call this the moment the user states a lasting preference, correction, "
"decision, or personal detail worth recalling on future turns — don't "
"wait to be asked to remember. Skip transient chit-chat and facts you've "
"already stored."
),
"parameters": {
"type": "object",
@ -155,7 +167,11 @@ ADD_SCHEMA = {
UPDATE_SCHEMA = {
"name": "mem0_update",
"description": "Update an existing memory's text by its ID.",
"description": (
"Replace the text of an existing memory by its ID (take the ID from a "
"mem0_search or mem0_list result). Use when a stored fact has changed "
"or was wrong — correct it in place instead of adding a duplicate."
),
"parameters": {
"type": "object",
"properties": {
@ -168,7 +184,11 @@ UPDATE_SCHEMA = {
DELETE_SCHEMA = {
"name": "mem0_delete",
"description": "Delete a memory by its ID.",
"description": (
"Delete a memory by its ID (take the ID from a mem0_search or mem0_list "
"result). Use when a stored fact is obsolete or the user asks you to "
"forget it; prefer mem0_update if the fact merely changed."
),
"parameters": {
"type": "object",
"properties": {
@ -197,15 +217,17 @@ class Mem0MemoryProvider(MemoryProvider):
self._user_id = _DEFAULT_USER_ID
self._agent_id = "hermes"
self._channel = "cli" # gateway channel name (cli/telegram/discord/...)
self._prefetch_result = ""
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
self._sync_thread = None
self._prefetch_thread = None
self._prefetch_query = ""
self._prefetch_result = ""
self._prefetch_done = False
# Circuit breaker state
self._consecutive_failures = 0
self._breaker_open_until = 0.0
self._breaker_lock = threading.Lock()
self._sync_lock = threading.Lock()
self._prefetch_lock = threading.Lock()
self._atexit_registered = False
@property
@ -361,44 +383,83 @@ class Mem0MemoryProvider(MemoryProvider):
return (
"# Mem0 Memory\n"
f"Active. Mode: {mode_label}. User: {self._user_id}.\n"
"Use mem0_search to find memories, mem0_add to store facts, "
"You have persistent memory of this user from past conversations. "
"ALWAYS call mem0_search before answering anything that could depend "
"on prior context (the user's preferences, facts, history, people, "
"projects, or earlier decisions) — do not rely on the chat window "
"alone, and do not assume you have no memory.\n"
"For multi-part or multi-hop questions, run SEVERAL searches with "
"different wording/angles and follow-up searches on what the first "
"results surface; one search is rarely enough. Keep searching until "
"you have every fact the question needs before you answer.\n"
"Tools: mem0_search to find memories, mem0_add to store facts, "
f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}"
)
def prefetch(self, query: str, *, session_id: str = "") -> str:
if self._prefetch_thread and self._prefetch_thread.is_alive():
self._prefetch_thread.join(timeout=3.0)
# If the thread still hasn't finished, leave the result for the next call.
if self._prefetch_thread and self._prefetch_thread.is_alive():
return ""
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
self._start_prefetch(message)
def _consume_prefetch_result(self, query: str) -> str | None:
with self._prefetch_lock:
if self._prefetch_query != query or not self._prefetch_done:
return None
result = self._prefetch_result
self._prefetch_result = ""
if not result:
return ""
return f"## Mem0 Memory\n{result}"
self._prefetch_done = False
return result
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
if self._backend is None or self._is_breaker_open():
def _start_prefetch(self, query: str) -> None:
if not query or self._backend is None or self._is_breaker_open():
return
backend = self._backend
with self._prefetch_lock:
if self._prefetch_query == query:
if self._prefetch_done:
return
if self._prefetch_thread and self._prefetch_thread.is_alive():
return
self._prefetch_query = query
self._prefetch_result = ""
self._prefetch_done = False
def _run():
backend = self._backend
if backend is None:
return
body = ""
try:
results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True)
if results:
lines = [r.get("memory", "") for r in results if r.get("memory")]
with self._prefetch_lock:
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
results = backend.search(
query, filters=self._read_filters(), top_k=10, rerank=True,
)
lines = [r.get("memory", "") for r in (results or []) if r.get("memory")]
if lines:
body = "## Mem0 Memory\n" + "\n".join(f"- {l}" for l in lines)
self._record_success()
except Exception as e:
self._record_failure()
logger.debug("Mem0 prefetch failed: %s", e)
with self._prefetch_lock:
if self._prefetch_query == query:
self._prefetch_result = body
self._prefetch_done = True
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
self._prefetch_thread.start()
t = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
with self._prefetch_lock:
self._prefetch_thread = t
t.start()
def prefetch(self, query: str, *, session_id: str = "") -> str:
"""Recall memories for the CURRENT question with a short hot-path wait."""
cached = self._consume_prefetch_result(query)
if cached is not None:
return cached
self._start_prefetch(query)
with self._prefetch_lock:
thread = self._prefetch_thread if self._prefetch_query == query else None
if thread:
thread.join(timeout=_PREFETCH_WAIT_SECS)
cached = self._consume_prefetch_result(query)
if cached is not None:
return cached
# Slow backend: skip injection; mem0_search tool remains the backstop.
return ""
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""

View file

@ -1,5 +1,5 @@
name: mem0
version: 1.1.0
version: 1.2.0
description: "Mem0 — server-side LLM fact extraction with semantic search, reranking, and automatic deduplication."
pip_dependencies:
- mem0ai>=2.0.7,<3

View file

@ -1,8 +1,10 @@
"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools."""
import json
import time
import pytest
import plugins.memory.mem0 as mem0_plugin
from plugins.memory.mem0 import Mem0MemoryProvider
@ -280,6 +282,90 @@ class TestMem0V3Internal:
assert "error" in result
class TestMem0Prefetch:
"""prefetch() must recall on the CURRENT question, synchronously.
The old implementation ignored its ``query`` and returned whatever a
background ``queue_prefetch`` had warmed from the PREVIOUS turn so the
first turn injected nothing and later turns injected stale, off-topic
memories. These lock the corrected behaviour.
"""
def _make_provider(self, backend):
provider = Mem0MemoryProvider()
provider.initialize("test-session")
provider._user_id = "u123"
provider._agent_id = "hermes"
provider._backend = backend
return provider
def test_prefetch_searches_current_query(self):
backend = FakeBackend(search_results=[{"id": "m1", "memory": "user prefers dark mode"}])
provider = self._make_provider(backend)
result = provider.prefetch("what theme do I like?")
kind, query, opts = backend.captured[0]
assert kind == "search"
assert query == "what theme do I like?"
assert opts["filters"] == {"user_id": "u123"}
assert opts["top_k"] == 10
assert opts["rerank"] is True
assert "## Mem0 Memory" in result
assert "user prefers dark mode" in result
def test_prefetch_returns_memories_on_first_call(self):
# No prior queue_prefetch / warm — the very first call must still recall.
backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
provider = self._make_provider(backend)
result = provider.prefetch("where do I live?")
assert "lives in Berlin" in result
def test_on_turn_start_queues_current_query(self):
backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
provider = self._make_provider(backend)
provider.on_turn_start(1, "where do I live?")
provider._prefetch_thread.join(timeout=1)
result = provider.prefetch("where do I live?")
assert "lives in Berlin" in result
assert len([c for c in backend.captured if c[0] == "search"]) == 1
def test_slow_prefetch_returns_quickly(self, monkeypatch):
class SlowBackend(FakeBackend):
def search(self, query, *, filters, top_k=10, rerank=True):
time.sleep(0.2)
return super().search(query, filters=filters, top_k=top_k, rerank=rerank)
monkeypatch.setattr(mem0_plugin, "_PREFETCH_WAIT_SECS", 0.01)
provider = self._make_provider(
SlowBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
)
started = time.monotonic()
assert provider.prefetch("where do I live?") == ""
assert time.monotonic() - started < 0.1
provider._prefetch_thread.join(timeout=1)
assert "lives in Berlin" in provider.prefetch("where do I live?")
def test_prefetch_empty_results_returns_empty(self):
backend = FakeBackend(search_results=[])
provider = self._make_provider(backend)
assert provider.prefetch("anything") == ""
def test_prefetch_skips_when_breaker_open(self):
backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}])
provider = self._make_provider(backend)
provider._consecutive_failures = 5
provider._breaker_open_until = float("inf")
assert provider.prefetch("q") == ""
assert backend.captured == []
def test_queue_prefetch_fires_no_search(self):
# prefetch is synchronous now, so the post-turn warm is redundant and
# must not fire a wasted backend search.
backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}])
provider = self._make_provider(backend)
provider.queue_prefetch("previous turn text")
assert backend.captured == []
class TestMem0V3Config:
def test_tool_schemas_five_tools(self):