fix(mem0): address PR review — restore docstrings, keep api_key required

Addresses reviewer feedback on #13377:
1. Restore all stripped docstrings (_load_config, _is_breaker_open,
   sync_turn, register, _get_client, _read_filters, _write_filters,
   _unwrap_results, save_config) and section dividers
2. Revert api_key to required:true in schema — self-hosted Mem0 also
   requires auth by default; validation in _get_client() handles the
   either/or logic separately from the schema
3. Confirm secret:true remains on api_key (already correct)
This commit is contained in:
buihongduc132 2026-05-04 13:05:30 +07:00 committed by Teknium
parent b6d2ac176e
commit 452a725ae1

View file

@ -28,11 +28,24 @@ from tools.registry import tool_error
logger = logging.getLogger(__name__)
# Circuit breaker: after this many consecutive failures, pause API calls
# for _BREAKER_COOLDOWN_SECS to avoid hammering a down server.
_BREAKER_THRESHOLD = 5
_BREAKER_COOLDOWN_SECS = 120
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
def _load_config() -> dict:
"""Load config from env vars, with $HERMES_HOME/mem0.json overrides.
Environment variables provide defaults; mem0.json (if present) overrides
individual keys. This avoids a silent failure when the JSON file exists
but is missing fields like ``api_key`` that the user set in ``.env``.
"""
from hermes_constants import get_hermes_home
config = {
@ -56,6 +69,10 @@ def _load_config() -> dict:
return config
# ---------------------------------------------------------------------------
# Tool schemas
# ---------------------------------------------------------------------------
PROFILE_SCHEMA = {
"name": "mem0_profile",
"description": (
@ -98,6 +115,10 @@ CONCLUDE_SCHEMA = {
}
# ---------------------------------------------------------------------------
# MemoryProvider implementation
# ---------------------------------------------------------------------------
class Mem0MemoryProvider(MemoryProvider):
"""Mem0 memory with server-side extraction and semantic search.
@ -118,6 +139,7 @@ class Mem0MemoryProvider(MemoryProvider):
self._prefetch_lock = threading.Lock()
self._prefetch_thread = None
self._sync_thread = None
# Circuit breaker state
self._consecutive_failures = 0
self._breaker_open_until = 0.0
@ -132,6 +154,7 @@ class Mem0MemoryProvider(MemoryProvider):
return bool(host) or bool(api_key)
def save_config(self, values, hermes_home):
"""Write config to $HERMES_HOME/mem0.json."""
import json
from pathlib import Path
config_path = Path(hermes_home) / "mem0.json"
@ -147,7 +170,7 @@ class Mem0MemoryProvider(MemoryProvider):
def get_config_schema(self):
return [
{"key": "api_key", "description": "Mem0 API key (cloud or self-hosted)", "secret": True, "required": False, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
{"key": "api_key", "description": "Mem0 API key (cloud or self-hosted)", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
{"key": "host", "description": "Self-hosted Mem0 URL (e.g. http://localhost:24220)", "default": "", "env_var": "MEM0_HOST"},
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
@ -155,6 +178,7 @@ class Mem0MemoryProvider(MemoryProvider):
]
def _get_client(self):
"""Thread-safe client accessor with lazy initialization."""
with self._client_lock:
if self._client is not None:
return self._client
@ -173,9 +197,11 @@ class Mem0MemoryProvider(MemoryProvider):
raise RuntimeError("mem0 package not installed. Run: pip install mem0ai")
def _is_breaker_open(self) -> bool:
"""Return True if the circuit breaker is tripped (too many failures)."""
if self._consecutive_failures < _BREAKER_THRESHOLD:
return False
if time.monotonic() >= self._breaker_open_until:
# Cooldown expired — reset and allow a retry
self._consecutive_failures = 0
return False
return True
@ -197,18 +223,23 @@ class Mem0MemoryProvider(MemoryProvider):
self._config = _load_config()
self._api_key = self._config.get("api_key", "")
self._host = self._config.get("host", "")
# Prefer gateway-provided user_id for per-user memory scoping;
# fall back to config/env default for CLI (single-user) sessions.
self._user_id = kwargs.get("user_id") or self._config.get("user_id", "hermes-user")
self._agent_id = self._config.get("agent_id", "hermes")
self._rerank = self._config.get("rerank", True)
def _read_filters(self) -> Dict[str, Any]:
"""Filters for search/get_all — scoped to user only for cross-session recall."""
return {"user_id": self._user_id}
def _write_filters(self) -> Dict[str, Any]:
"""Filters for add — scoped to user + agent for attribution."""
return {"user_id": self._user_id, "agent_id": self._agent_id}
@staticmethod
def _unwrap_results(response: Any) -> list:
"""Normalize Mem0 API response — v2 wraps results in {"results": [...]}."""
if isinstance(response, dict):
return response.get("results", [])
if isinstance(response, list):
@ -260,6 +291,7 @@ class Mem0MemoryProvider(MemoryProvider):
self._prefetch_thread.start()
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
if self._is_breaker_open():
return
@ -276,6 +308,7 @@ class Mem0MemoryProvider(MemoryProvider):
self._record_failure()
logger.warning("Mem0 sync failed: %s", e)
# Wait for any previous sync before starting a new one
if self._sync_thread and self._sync_thread.is_alive():
self._sync_thread.join(timeout=5.0)
@ -357,4 +390,5 @@ class Mem0MemoryProvider(MemoryProvider):
def register(ctx) -> None:
"""Register Mem0 as a memory provider plugin."""
ctx.register_memory_provider(Mem0MemoryProvider())