mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-01 12:02:05 +00:00
fix(memory/mem0): recall on the current question + stronger search guidance (#55535)
This commit is contained in:
parent
b8ebe32866
commit
c6eb7f9e72
3 changed files with 180 additions and 33 deletions
|
|
@ -44,6 +44,7 @@ logger = logging.getLogger(__name__)
|
|||
# for _BREAKER_COOLDOWN_SECS to avoid hammering a down server.
|
||||
_BREAKER_THRESHOLD = 5
|
||||
_BREAKER_COOLDOWN_SECS = 120
|
||||
_PREFETCH_WAIT_SECS = 1.5
|
||||
|
||||
_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError")
|
||||
|
||||
|
|
@ -109,8 +110,10 @@ def _load_config() -> dict:
|
|||
LIST_SCHEMA = {
|
||||
"name": "mem0_list",
|
||||
"description": (
|
||||
"List all stored memories about the user. "
|
||||
"Use at conversation start for full overview."
|
||||
"List ALL stored memories about the user, unranked and paginated. "
|
||||
"Use for a full overview/audit at conversation start, or to browse "
|
||||
"everything when you don't have a specific query. For answering a "
|
||||
"specific question, prefer mem0_search."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
|
@ -125,7 +128,13 @@ LIST_SCHEMA = {
|
|||
SEARCH_SCHEMA = {
|
||||
"name": "mem0_search",
|
||||
"description": (
|
||||
"Search memories by meaning. Returns relevant facts ranked by relevance."
|
||||
"Search the user's memories by meaning; returns facts ranked by "
|
||||
"relevance. Use this BEFORE answering any question that may depend on "
|
||||
"what you know about the user (preferences, facts, history, people, "
|
||||
"projects, past decisions). For multi-part or multi-hop questions, "
|
||||
"call it MULTIPLE times — vary the wording and run follow-up searches "
|
||||
"on what earlier results reveal; one search is rarely enough. Set "
|
||||
"rerank=true for higher accuracy on important queries."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
|
@ -141,8 +150,11 @@ SEARCH_SCHEMA = {
|
|||
ADD_SCHEMA = {
|
||||
"name": "mem0_add",
|
||||
"description": (
|
||||
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
|
||||
"Use for explicit preferences, corrections, or decisions."
|
||||
"Store a durable fact about the user, verbatim (no LLM extraction). "
|
||||
"Call this the moment the user states a lasting preference, correction, "
|
||||
"decision, or personal detail worth recalling on future turns — don't "
|
||||
"wait to be asked to remember. Skip transient chit-chat and facts you've "
|
||||
"already stored."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
|
@ -155,7 +167,11 @@ ADD_SCHEMA = {
|
|||
|
||||
UPDATE_SCHEMA = {
|
||||
"name": "mem0_update",
|
||||
"description": "Update an existing memory's text by its ID.",
|
||||
"description": (
|
||||
"Replace the text of an existing memory by its ID (take the ID from a "
|
||||
"mem0_search or mem0_list result). Use when a stored fact has changed "
|
||||
"or was wrong — correct it in place instead of adding a duplicate."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -168,7 +184,11 @@ UPDATE_SCHEMA = {
|
|||
|
||||
DELETE_SCHEMA = {
|
||||
"name": "mem0_delete",
|
||||
"description": "Delete a memory by its ID.",
|
||||
"description": (
|
||||
"Delete a memory by its ID (take the ID from a mem0_search or mem0_list "
|
||||
"result). Use when a stored fact is obsolete or the user asks you to "
|
||||
"forget it; prefer mem0_update if the fact merely changed."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
@ -197,15 +217,17 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
self._user_id = _DEFAULT_USER_ID
|
||||
self._agent_id = "hermes"
|
||||
self._channel = "cli" # gateway channel name (cli/telegram/discord/...)
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
self._sync_thread = None
|
||||
self._prefetch_thread = None
|
||||
self._prefetch_query = ""
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_done = False
|
||||
# Circuit breaker state
|
||||
self._consecutive_failures = 0
|
||||
self._breaker_open_until = 0.0
|
||||
self._breaker_lock = threading.Lock()
|
||||
self._sync_lock = threading.Lock()
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._atexit_registered = False
|
||||
|
||||
@property
|
||||
|
|
@ -361,44 +383,83 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
return (
|
||||
"# Mem0 Memory\n"
|
||||
f"Active. Mode: {mode_label}. User: {self._user_id}.\n"
|
||||
"Use mem0_search to find memories, mem0_add to store facts, "
|
||||
"You have persistent memory of this user from past conversations. "
|
||||
"ALWAYS call mem0_search before answering anything that could depend "
|
||||
"on prior context (the user's preferences, facts, history, people, "
|
||||
"projects, or earlier decisions) — do not rely on the chat window "
|
||||
"alone, and do not assume you have no memory.\n"
|
||||
"For multi-part or multi-hop questions, run SEVERAL searches with "
|
||||
"different wording/angles and follow-up searches on what the first "
|
||||
"results surface; one search is rarely enough. Keep searching until "
|
||||
"you have every fact the question needs before you answer.\n"
|
||||
"Tools: mem0_search to find memories, mem0_add to store facts, "
|
||||
f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}"
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
# If the thread still hasn't finished, leave the result for the next call.
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
return ""
|
||||
def on_turn_start(self, turn_number: int, message: str, **kwargs) -> None:
|
||||
self._start_prefetch(message)
|
||||
|
||||
def _consume_prefetch_result(self, query: str) -> str | None:
|
||||
with self._prefetch_lock:
|
||||
if self._prefetch_query != query or not self._prefetch_done:
|
||||
return None
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
if not result:
|
||||
return ""
|
||||
return f"## Mem0 Memory\n{result}"
|
||||
self._prefetch_done = False
|
||||
return result
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
if self._backend is None or self._is_breaker_open():
|
||||
def _start_prefetch(self, query: str) -> None:
|
||||
if not query or self._backend is None or self._is_breaker_open():
|
||||
return
|
||||
backend = self._backend
|
||||
with self._prefetch_lock:
|
||||
if self._prefetch_query == query:
|
||||
if self._prefetch_done:
|
||||
return
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
return
|
||||
self._prefetch_query = query
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_done = False
|
||||
|
||||
def _run():
|
||||
backend = self._backend
|
||||
if backend is None:
|
||||
return
|
||||
body = ""
|
||||
try:
|
||||
results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True)
|
||||
if results:
|
||||
lines = [r.get("memory", "") for r in results if r.get("memory")]
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_result = "\n".join(f"- {l}" for l in lines)
|
||||
results = backend.search(
|
||||
query, filters=self._read_filters(), top_k=10, rerank=True,
|
||||
)
|
||||
lines = [r.get("memory", "") for r in (results or []) if r.get("memory")]
|
||||
if lines:
|
||||
body = "## Mem0 Memory\n" + "\n".join(f"- {l}" for l in lines)
|
||||
self._record_success()
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
logger.debug("Mem0 prefetch failed: %s", e)
|
||||
with self._prefetch_lock:
|
||||
if self._prefetch_query == query:
|
||||
self._prefetch_result = body
|
||||
self._prefetch_done = True
|
||||
|
||||
self._prefetch_thread = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
|
||||
self._prefetch_thread.start()
|
||||
t = threading.Thread(target=_run, daemon=True, name="mem0-prefetch")
|
||||
with self._prefetch_lock:
|
||||
self._prefetch_thread = t
|
||||
t.start()
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
"""Recall memories for the CURRENT question with a short hot-path wait."""
|
||||
cached = self._consume_prefetch_result(query)
|
||||
if cached is not None:
|
||||
return cached
|
||||
self._start_prefetch(query)
|
||||
with self._prefetch_lock:
|
||||
thread = self._prefetch_thread if self._prefetch_query == query else None
|
||||
if thread:
|
||||
thread.join(timeout=_PREFETCH_WAIT_SECS)
|
||||
cached = self._consume_prefetch_result(query)
|
||||
if cached is not None:
|
||||
return cached
|
||||
# Slow backend: skip injection; mem0_search tool remains the backstop.
|
||||
return ""
|
||||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
name: mem0
|
||||
version: 1.1.0
|
||||
version: 1.2.0
|
||||
description: "Mem0 — server-side LLM fact extraction with semantic search, reranking, and automatic deduplication."
|
||||
pip_dependencies:
|
||||
- mem0ai>=2.0.7,<3
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools."""
|
||||
|
||||
import json
|
||||
import time
|
||||
import pytest
|
||||
|
||||
import plugins.memory.mem0 as mem0_plugin
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
|
||||
|
|
@ -280,6 +282,90 @@ class TestMem0V3Internal:
|
|||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0Prefetch:
|
||||
"""prefetch() must recall on the CURRENT question, synchronously.
|
||||
|
||||
The old implementation ignored its ``query`` and returned whatever a
|
||||
background ``queue_prefetch`` had warmed from the PREVIOUS turn — so the
|
||||
first turn injected nothing and later turns injected stale, off-topic
|
||||
memories. These lock the corrected behaviour.
|
||||
"""
|
||||
|
||||
def _make_provider(self, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_prefetch_searches_current_query(self):
|
||||
backend = FakeBackend(search_results=[{"id": "m1", "memory": "user prefers dark mode"}])
|
||||
provider = self._make_provider(backend)
|
||||
result = provider.prefetch("what theme do I like?")
|
||||
kind, query, opts = backend.captured[0]
|
||||
assert kind == "search"
|
||||
assert query == "what theme do I like?"
|
||||
assert opts["filters"] == {"user_id": "u123"}
|
||||
assert opts["top_k"] == 10
|
||||
assert opts["rerank"] is True
|
||||
assert "## Mem0 Memory" in result
|
||||
assert "user prefers dark mode" in result
|
||||
|
||||
def test_prefetch_returns_memories_on_first_call(self):
|
||||
# No prior queue_prefetch / warm — the very first call must still recall.
|
||||
backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
|
||||
provider = self._make_provider(backend)
|
||||
result = provider.prefetch("where do I live?")
|
||||
assert "lives in Berlin" in result
|
||||
|
||||
def test_on_turn_start_queues_current_query(self):
|
||||
backend = FakeBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
|
||||
provider = self._make_provider(backend)
|
||||
provider.on_turn_start(1, "where do I live?")
|
||||
provider._prefetch_thread.join(timeout=1)
|
||||
result = provider.prefetch("where do I live?")
|
||||
assert "lives in Berlin" in result
|
||||
assert len([c for c in backend.captured if c[0] == "search"]) == 1
|
||||
|
||||
def test_slow_prefetch_returns_quickly(self, monkeypatch):
|
||||
class SlowBackend(FakeBackend):
|
||||
def search(self, query, *, filters, top_k=10, rerank=True):
|
||||
time.sleep(0.2)
|
||||
return super().search(query, filters=filters, top_k=top_k, rerank=rerank)
|
||||
|
||||
monkeypatch.setattr(mem0_plugin, "_PREFETCH_WAIT_SECS", 0.01)
|
||||
provider = self._make_provider(
|
||||
SlowBackend(search_results=[{"id": "m1", "memory": "lives in Berlin"}])
|
||||
)
|
||||
started = time.monotonic()
|
||||
assert provider.prefetch("where do I live?") == ""
|
||||
assert time.monotonic() - started < 0.1
|
||||
provider._prefetch_thread.join(timeout=1)
|
||||
assert "lives in Berlin" in provider.prefetch("where do I live?")
|
||||
|
||||
def test_prefetch_empty_results_returns_empty(self):
|
||||
backend = FakeBackend(search_results=[])
|
||||
provider = self._make_provider(backend)
|
||||
assert provider.prefetch("anything") == ""
|
||||
|
||||
def test_prefetch_skips_when_breaker_open(self):
|
||||
backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}])
|
||||
provider = self._make_provider(backend)
|
||||
provider._consecutive_failures = 5
|
||||
provider._breaker_open_until = float("inf")
|
||||
assert provider.prefetch("q") == ""
|
||||
assert backend.captured == []
|
||||
|
||||
def test_queue_prefetch_fires_no_search(self):
|
||||
# prefetch is synchronous now, so the post-turn warm is redundant and
|
||||
# must not fire a wasted backend search.
|
||||
backend = FakeBackend(search_results=[{"id": "m1", "memory": "x"}])
|
||||
provider = self._make_provider(backend)
|
||||
provider.queue_prefetch("previous turn text")
|
||||
assert backend.captured == []
|
||||
|
||||
|
||||
class TestMem0V3Config:
|
||||
|
||||
def test_tool_schemas_five_tools(self):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue