mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
feat(mem0): v3 API, OSS mode, update/delete tools, telemetry & review fixes (#15624)
* fix: update to version 3 endpoints and adding update and delete tool
* chore: removing the test md file
* fix: prevent circuit breaker on client errors in Mem0 provider
* chore: add telemetry for platform version
* feat: add OSS mode support to Mem0 memory provider
* chore: bump mem0ai dependency to >=2.0.1 in memory plugin
* refactor: enhance dependency checks and embedder config in mem0 backend
* refactor: adjust fact storage message for OSS mode
* refactor: expand user paths, add collection recreation on dimension change for Qdrant
* fix(mem0): make MEM0_USER_ID override gateway-native ids and tag writes with channel
When MEM0_USER_ID was configured (env or mem0.json), the gateway-native id
from kwargs (Telegram numeric id, Discord snowflake, ...) still won, so the
same human ended up under different user_ids per channel and memories never
merged across CLI / Telegram / Slack / Discord. Mirrors openclaw's cfg.userId
pattern: configured override wins, gateway-native id is the fallback.
The legacy "hermes-user" placeholder default written by the setup wizard is
treated as unset to avoid silently bucketing every gateway user together.
Also tag every write with metadata.channel (cli/telegram/discord/...) so the
dashboard can offer per-channel filtered views without coupling identity to
the channel; document the read/write filter asymmetry as intentional
(reads scope to user_id only for cross-agent recall).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
* refactor: improve Mem0 memory provider backend, pagination, config, and error handling
* refactor: update mem0 telemetry code, docs, and bump version
* fix(mem0): make get_config_schema() return unified schema with mode-aware required flag
Schema always includes api_key field so picker shows "API key / local" for
both modes. In OSS mode api_key.required=False so status won't mislead.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
* refactor: improve mem0 telemetry, add env var key and OSS mode detection
* chore: bump mem0ai lower bound to 2.0.4 (latest SDK release)
* refactor: set telemetry sample rate to 1.0 and update docs for opt‑out
* fix(mem0): resolve 15 correctness, thread-safety, and resource bugs
Thread safety:
- Protect circuit breaker counters with _breaker_lock (race between
prefetch/sync daemon threads and main thread)
- Wrap sync_turn thread creation in _sync_lock; skip if previous sync
is still alive after 5 s join to prevent duplicate memory ingestion
- Guard _schedule_flush timer creation under _queue_lock (TOCTOU race)
- Capture local `backend` reference in prefetch/sync closures so
shutdown() nulling self._backend cannot crash in-flight threads
Correctness:
- Fix bool("false")==True for rerank param; parse string values explicitly
- Guard page/top_k with max(1,...) and move int() inside try blocks
- Fix fact_count=0 always in OSS mode (Memory.add returns list, not dict)
- Fix prefetch() not clearing result when thread still alive after timeout
- Fix atexit.register accumulating on repeated initialize() calls
Backend / setup:
- Handle Qdrant named-vector collections in _recreate_collection_if_dims_changed
(vectors is a dict; .size access raised AttributeError, swallowed silently)
- Wrap QdrantClient and psycopg2 conn/cursor in try/finally to prevent leaks
- Resolve ollama_bin at top of _ensure_ollama; use it for ollama pull
- Fix embedder key lookup when LLM provider has no env_var (e.g. ollama)
Also: remove _telemetry_enabled cache (env var check is cheap), bump
required mem0ai to >=2.0.7, minor README wording fix.
* fix(mem0): fix brittle qdrant path test + add telemetry sample-rate docs
- Replace generator-throw lambda with a proper def in
test_qdrant_path_not_writable; use tmp_path instead of a hardcoded
/nonexistent path so the test is root-safe
- Add MEM0_TELEMETRY_SAMPLE_RATE to memory-providers.md (was only
in the plugin README, not the user-guide docs)
* revert: remove MEM0_TELEMETRY_SAMPLE_RATE from user-guide docs
* refactor: remove telemetry from mem0 plugin and update documentation
* fix(mem0): set stdin=DEVNULL on setup subprocess calls
The TUI stdin guard (scripts/check_subprocess_stdin.py) requires every
subprocess call in plugin code to set stdin= so it can't inherit the
gateway's JSON-RPC stdin fd. Muzzle the docker/ollama calls in the OSS
setup wizard with stdin=subprocess.DEVNULL (none need interactive input).
Also covers the docker-inspect call the linter's regex misses.
---------
Co-authored-by: chaithanyak42 <chaithanya.kumar42a@gmail.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a904ff1724
commit
2e779d11a0
13 changed files with 2688 additions and 421 deletions
|
|
@ -1,53 +1,152 @@
|
|||
# Mem0 Memory Provider
|
||||
|
||||
Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication.
|
||||
|
||||
Supports both [Mem0 Cloud](https://app.mem0.ai) and self-hosted instances.
|
||||
Server-side LLM fact extraction with semantic search and hybrid multi-signal retrieval via the Mem0 Platform v3 API.
|
||||
|
||||
## Requirements
|
||||
|
||||
- `pip install mem0ai`
|
||||
- Mem0 Cloud API key **or** a self-hosted Mem0 server
|
||||
- Mem0 API key from [app.mem0.ai](https://app.mem0.ai)
|
||||
|
||||
## Setup
|
||||
|
||||
### Cloud
|
||||
|
||||
```bash
|
||||
hermes memory setup # select "mem0"
|
||||
```
|
||||
|
||||
Or manually:
|
||||
|
||||
```bash
|
||||
hermes config set memory.provider mem0
|
||||
echo "MEM0_API_KEY=your-key" >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
### Self-Hosted
|
||||
|
||||
```bash
|
||||
hermes config set memory.provider mem0
|
||||
echo "MEM0_HOST=http://your-mem0-server:24220" >> ~/.hermes/.env
|
||||
echo "MEM0_API_KEY=your-api-key" >> ~/.hermes/.env # if auth is enabled
|
||||
```
|
||||
|
||||
## Config
|
||||
|
||||
Config file: `$HERMES_HOME/mem0.json`
|
||||
Behavioral settings live in `$HERMES_HOME/mem0.json` (set them via `hermes memory setup`). Only the secret `MEM0_API_KEY` belongs in `~/.hermes/.env`.
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `api_key` | — | API key (required for cloud; optional for self-hosted without auth) |
|
||||
| `host` | `https://api.mem0.ai` | Self-hosted Mem0 URL. When set, overrides the cloud endpoint. |
|
||||
| `user_id` | `hermes-user` | User identifier |
|
||||
| `mode` | `platform` | `platform` (Mem0 Cloud) or `oss` (self-hosted) |
|
||||
| `user_id` | `hermes-user` | User identifier on Mem0 |
|
||||
| `agent_id` | `hermes` | Agent identifier |
|
||||
| `rerank` | `true` | Enable reranking for recall |
|
||||
| `rerank` | `true` | Rerank search results for relevance (platform mode only) |
|
||||
|
||||
## OSS (Self-Hosted) Mode
|
||||
|
||||
Run Mem0 locally with your own LLM, embedder, and vector store.
|
||||
|
||||
### Interactive Setup
|
||||
|
||||
```bash
|
||||
hermes memory setup
|
||||
# Select "mem0" → "Open Source (self-hosted)"
|
||||
# Follow prompts for LLM, embedder, and vector store
|
||||
```
|
||||
|
||||
### Agent-Driven Setup (Flags)
|
||||
|
||||
```bash
|
||||
hermes memory setup mem0 --mode oss \
|
||||
--oss-llm openai --oss-llm-key sk-... \
|
||||
--oss-vector qdrant
|
||||
```
|
||||
|
||||
### Supported Providers
|
||||
|
||||
| Component | Providers |
|
||||
|-----------|-----------|
|
||||
| LLM | openai, ollama |
|
||||
| Embedder | openai, ollama |
|
||||
| Vector Store | qdrant (local/server), pgvector |
|
||||
|
||||
### Flags Reference
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--mode` | `platform` or `oss` |
|
||||
| `--oss-llm` | LLM provider (default: openai) |
|
||||
| `--oss-llm-key` | LLM API key |
|
||||
| `--oss-embedder` | Embedder provider (default: openai) |
|
||||
| `--oss-vector` | Vector store (default: qdrant) |
|
||||
| `--oss-vector-path` | Qdrant local path |
|
||||
| `--user-id` | User identifier |
|
||||
|
||||
## Switching Modes
|
||||
|
||||
### Platform to OSS
|
||||
|
||||
```bash
|
||||
hermes memory setup mem0 --mode oss --oss-llm-key sk-...
|
||||
```
|
||||
|
||||
Or edit `$HERMES_HOME/mem0.json` directly:
|
||||
```json
|
||||
{
|
||||
"mode": "oss",
|
||||
"oss": {
|
||||
"llm": {"provider": "openai", "config": {"model": "gpt-5-mini"}},
|
||||
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
|
||||
"vector_store": {"provider": "qdrant", "config": {"path": "~/.hermes/mem0_qdrant"}}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### OSS to Platform
|
||||
|
||||
```bash
|
||||
hermes memory setup mem0 --mode platform --api-key sk-...
|
||||
```
|
||||
|
||||
### Dry Run (preview without writing)
|
||||
|
||||
```bash
|
||||
hermes memory setup mem0 --mode oss --oss-llm-key sk-... --dry-run
|
||||
```
|
||||
|
||||
## Tools
|
||||
|
||||
| Tool | Description |
|
||||
|------|-------------|
|
||||
| `mem0_profile` | All stored memories about the user |
|
||||
| `mem0_search` | Semantic search with optional reranking |
|
||||
| `mem0_conclude` | Store a fact verbatim (no LLM extraction) |
|
||||
| `mem0_list` | List all stored memories (paginated) |
|
||||
| `mem0_search` | Semantic search by meaning |
|
||||
| `mem0_add` | Store a fact verbatim (no LLM extraction) |
|
||||
| `mem0_update` | Update a memory's text by ID |
|
||||
| `mem0_delete` | Delete a memory by ID |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Mem0 temporarily unavailable"
|
||||
|
||||
Circuit breaker tripped after 5 consecutive failures. Resets after 2 minutes.
|
||||
|
||||
- **Platform mode**: Check API key and internet connectivity.
|
||||
- **OSS mode**: Check that your vector store (qdrant/pgvector) is running.
|
||||
|
||||
### OSS: Qdrant connection refused
|
||||
|
||||
```bash
|
||||
# If using local Qdrant, check the storage path is writable:
|
||||
ls -la ~/.hermes/mem0_qdrant
|
||||
|
||||
# If using Qdrant server, check it's reachable:
|
||||
curl http://localhost:6333/healthz
|
||||
```
|
||||
|
||||
### OSS: PGVector connection refused
|
||||
|
||||
```bash
|
||||
# Verify PostgreSQL is running and accepting connections:
|
||||
pg_isready -h localhost -p 5432
|
||||
```
|
||||
|
||||
### OSS: Ollama not reachable
|
||||
|
||||
```bash
|
||||
# Check Ollama is running:
|
||||
curl http://localhost:11434/api/tags
|
||||
```
|
||||
|
||||
### Memories not appearing
|
||||
|
||||
- `mem0_add` stores verbatim (no extraction). Use `sync_turn` for LLM extraction.
|
||||
- Search uses semantic matching — try broader queries.
|
||||
- Check `user_id` matches between sessions (`$HERMES_HOME/mem0.json`).
|
||||
|
|
|
|||
|
|
@ -1,21 +1,33 @@
|
|||
"""Mem0 memory plugin — MemoryProvider interface.
|
||||
|
||||
Server-side LLM fact extraction, semantic search with reranking, and
|
||||
automatic deduplication via the Mem0 Platform API or self-hosted instance.
|
||||
Server-side LLM fact extraction, semantic search, and automatic deduplication
|
||||
via the Mem0 Platform API (cloud) or OSS (self-hosted) via Memory.
|
||||
|
||||
Original PR #2933 by kartik-mem0, adapted to MemoryProvider ABC.
|
||||
|
||||
Config via environment variables:
|
||||
MEM0_API_KEY — Mem0 API key (required for cloud, optional for self-hosted)
|
||||
MEM0_HOST — Self-hosted Mem0 URL (default: https://api.mem0.ai)
|
||||
MEM0_USER_ID — User identifier (default: hermes-user)
|
||||
MEM0_AGENT_ID — Agent identifier (default: hermes)
|
||||
Configuration
|
||||
-------------
|
||||
Secret (lives in $HERMES_HOME/.env or the environment):
|
||||
MEM0_API_KEY — Mem0 Platform API key (required for platform mode)
|
||||
|
||||
Or via $HERMES_HOME/mem0.json.
|
||||
Behavioral settings (live in $HERMES_HOME/mem0.json, set via `hermes memory
|
||||
setup`):
|
||||
mode — Backend mode: "platform" (default) or "oss"
|
||||
user_id — Canonical user identifier. When set, it is applied
|
||||
uniformly across every gateway (CLI, Telegram, Slack,
|
||||
Discord, …) so the same human gets one merged memory
|
||||
store. When unset, the gateway-native id (e.g. Telegram
|
||||
numeric id, Discord snowflake) is used instead.
|
||||
agent_id — Agent identifier (default: hermes)
|
||||
|
||||
The matching MEM0_MODE / MEM0_USER_ID / MEM0_AGENT_ID environment variables are
|
||||
still read as a backward-compatible fallback, but mem0.json is the canonical
|
||||
home for these non-secret settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
|
@ -33,12 +45,29 @@ logger = logging.getLogger(__name__)
|
|||
_BREAKER_THRESHOLD = 5
|
||||
_BREAKER_COOLDOWN_SECS = 120
|
||||
|
||||
_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError")
|
||||
|
||||
# Sentinel returned when neither MEM0_USER_ID nor a gateway-native id is
|
||||
# available. Treated as "no operator-configured user_id" by initialize() so
|
||||
# that legacy mem0.json files written by the setup wizard (which historically
|
||||
# wrote this exact placeholder) still allow gateway-native ids to flow
|
||||
# through instead of silently overriding them with the placeholder.
|
||||
_DEFAULT_USER_ID = "hermes-user"
|
||||
|
||||
|
||||
def _is_client_error(exc: Exception) -> bool:
|
||||
"""True for user-caused errors (bad ID, not found) that should NOT trip circuit breaker."""
|
||||
etype = type(exc).__name__
|
||||
if etype in _CLIENT_ERROR_TYPES:
|
||||
return True
|
||||
err_str = str(exc).lower()
|
||||
return "404" in err_str or "not found" in err_str or "valid uuid" in err_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _load_config() -> dict:
|
||||
"""Load config from env vars, with $HERMES_HOME/mem0.json overrides.
|
||||
|
||||
|
|
@ -49,13 +78,17 @@ def _load_config() -> dict:
|
|||
from hermes_constants import get_hermes_home
|
||||
|
||||
config = {
|
||||
"mode": os.environ.get("MEM0_MODE", "platform"),
|
||||
"api_key": os.environ.get("MEM0_API_KEY", ""),
|
||||
"host": os.environ.get("MEM0_HOST", ""),
|
||||
"user_id": os.environ.get("MEM0_USER_ID", "hermes-user"),
|
||||
"agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"),
|
||||
"rerank": True,
|
||||
"keyword_search": False,
|
||||
"oss": {},
|
||||
}
|
||||
# Only carry user_id when the operator explicitly configured one (env or
|
||||
# mem0.json). An absent key tells initialize() to fall back to the
|
||||
# gateway-native id from kwargs instead of overriding it with a placeholder.
|
||||
env_user_id = os.environ.get("MEM0_USER_ID")
|
||||
if env_user_id:
|
||||
config["user_id"] = env_user_id
|
||||
|
||||
config_path = get_hermes_home() / "mem0.json"
|
||||
if config_path.exists():
|
||||
|
|
@ -73,34 +106,40 @@ def _load_config() -> dict:
|
|||
# Tool schemas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROFILE_SCHEMA = {
|
||||
"name": "mem0_profile",
|
||||
LIST_SCHEMA = {
|
||||
"name": "mem0_list",
|
||||
"description": (
|
||||
"Retrieve all stored memories about the user — preferences, facts, "
|
||||
"project context. Fast, no reranking. Use at conversation start."
|
||||
"List all stored memories about the user. "
|
||||
"Use at conversation start for full overview."
|
||||
),
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"page": {"type": "integer", "description": "Page number (default: 1)."},
|
||||
"page_size": {"type": "integer", "description": "Results per page (default: 100, max: 200)."},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
SEARCH_SCHEMA = {
|
||||
"name": "mem0_search",
|
||||
"description": (
|
||||
"Search memories by meaning. Returns relevant facts ranked by similarity. "
|
||||
"Set rerank=true for higher accuracy on important queries."
|
||||
"Search memories by meaning. Returns relevant facts ranked by relevance."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "What to search for."},
|
||||
"rerank": {"type": "boolean", "description": "Enable reranking for precision (default: false)."},
|
||||
"top_k": {"type": "integer", "description": "Max results (default: 10, max: 50)."},
|
||||
"rerank": {"type": "boolean", "description": "Rerank results for relevance (default: true, platform mode only)."},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}
|
||||
|
||||
CONCLUDE_SCHEMA = {
|
||||
"name": "mem0_conclude",
|
||||
ADD_SCHEMA = {
|
||||
"name": "mem0_add",
|
||||
"description": (
|
||||
"Store a durable fact about the user. Stored verbatim (no LLM extraction). "
|
||||
"Use for explicit preferences, corrections, or decisions."
|
||||
|
|
@ -108,9 +147,34 @@ CONCLUDE_SCHEMA = {
|
|||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conclusion": {"type": "string", "description": "The fact to store."},
|
||||
"content": {"type": "string", "description": "The fact to store."},
|
||||
},
|
||||
"required": ["conclusion"],
|
||||
"required": ["content"],
|
||||
},
|
||||
}
|
||||
|
||||
UPDATE_SCHEMA = {
|
||||
"name": "mem0_update",
|
||||
"description": "Update an existing memory's text by its ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {"type": "string", "description": "Memory UUID to update."},
|
||||
"text": {"type": "string", "description": "New text content."},
|
||||
},
|
||||
"required": ["memory_id", "text"],
|
||||
},
|
||||
}
|
||||
|
||||
DELETE_SCHEMA = {
|
||||
"name": "mem0_delete",
|
||||
"description": "Delete a memory by its ID.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"memory_id": {"type": "string", "description": "Memory UUID to delete."},
|
||||
},
|
||||
"required": ["memory_id"],
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -122,19 +186,17 @@ CONCLUDE_SCHEMA = {
|
|||
class Mem0MemoryProvider(MemoryProvider):
|
||||
"""Mem0 memory with server-side extraction and semantic search.
|
||||
|
||||
Supports both Mem0 Cloud (api.mem0.ai) and self-hosted instances
|
||||
via the ``host`` config key or ``MEM0_HOST`` env var.
|
||||
Supports Platform API (cloud) and OSS (self-hosted) modes via MEM0_MODE.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._config = None
|
||||
self._client = None
|
||||
self._client_lock = threading.Lock()
|
||||
self._backend = None
|
||||
self._mode = "platform"
|
||||
self._api_key = ""
|
||||
self._host = ""
|
||||
self._user_id = "hermes-user"
|
||||
self._user_id = _DEFAULT_USER_ID
|
||||
self._agent_id = "hermes"
|
||||
self._rerank = True
|
||||
self._channel = "cli" # gateway channel name (cli/telegram/discord/...)
|
||||
self._prefetch_result = ""
|
||||
self._prefetch_lock = threading.Lock()
|
||||
self._prefetch_thread = None
|
||||
|
|
@ -142,6 +204,9 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
# Circuit breaker state
|
||||
self._consecutive_failures = 0
|
||||
self._breaker_open_until = 0.0
|
||||
self._breaker_lock = threading.Lock()
|
||||
self._sync_lock = threading.Lock()
|
||||
self._atexit_registered = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -149,9 +214,10 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
|
||||
def is_available(self) -> bool:
|
||||
cfg = _load_config()
|
||||
host = cfg.get("host", "")
|
||||
api_key = cfg.get("api_key", "")
|
||||
return bool(host) or bool(api_key)
|
||||
mode = cfg.get("mode", "platform")
|
||||
if mode == "oss":
|
||||
return bool(cfg.get("oss", {}).get("vector_store"))
|
||||
return bool(cfg.get("api_key"))
|
||||
|
||||
def save_config(self, values, hermes_home):
|
||||
"""Write config to $HERMES_HOME/mem0.json."""
|
||||
|
|
@ -169,95 +235,130 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
atomic_json_write(config_path, existing, mode=0o600)
|
||||
|
||||
def get_config_schema(self):
|
||||
cfg = _load_config()
|
||||
mode = cfg.get("mode", "platform")
|
||||
api_key_required = mode != "oss"
|
||||
return [
|
||||
{"key": "api_key", "description": "Mem0 API key (cloud or self-hosted)", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
|
||||
{"key": "host", "description": "Self-hosted Mem0 URL (e.g. http://localhost:24220)", "default": "", "env_var": "MEM0_HOST"},
|
||||
{"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": api_key_required, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
|
||||
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
|
||||
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
|
||||
{"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]},
|
||||
]
|
||||
|
||||
def _get_client(self):
|
||||
"""Thread-safe client accessor with lazy initialization."""
|
||||
with self._client_lock:
|
||||
if self._client is not None:
|
||||
return self._client
|
||||
try:
|
||||
from mem0 import MemoryClient
|
||||
kwargs = {}
|
||||
if self._host:
|
||||
kwargs["host"] = self._host
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
elif not self._host:
|
||||
raise ValueError("Mem0: either api_key or host is required")
|
||||
self._client = MemoryClient(**kwargs)
|
||||
return self._client
|
||||
except ImportError:
|
||||
raise RuntimeError("mem0 package not installed. Run: pip install mem0ai")
|
||||
def post_setup(self, hermes_home: str, config: dict) -> None:
|
||||
from ._setup import post_setup
|
||||
post_setup(hermes_home, config)
|
||||
|
||||
def _create_backend(self):
|
||||
try:
|
||||
if self._mode == "oss":
|
||||
from ._backend import OSSBackend
|
||||
return OSSBackend(self._config.get("oss", {}))
|
||||
from ._backend import PlatformBackend
|
||||
return PlatformBackend(self._api_key)
|
||||
except Exception as e:
|
||||
logger.error("Mem0 backend failed to initialize (%s mode): %s", self._mode, e)
|
||||
self._init_error = str(e)
|
||||
return None
|
||||
|
||||
def _is_breaker_open(self) -> bool:
|
||||
"""Return True if the circuit breaker is tripped (too many failures)."""
|
||||
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
||||
return False
|
||||
if time.monotonic() >= self._breaker_open_until:
|
||||
# Cooldown expired — reset and allow a retry
|
||||
self._consecutive_failures = 0
|
||||
return False
|
||||
return True
|
||||
with self._breaker_lock:
|
||||
if self._consecutive_failures < _BREAKER_THRESHOLD:
|
||||
return False
|
||||
if time.monotonic() >= self._breaker_open_until:
|
||||
self._consecutive_failures = 0
|
||||
return False
|
||||
return True
|
||||
|
||||
def _format_error(self, prefix: str, exc: Exception) -> str:
|
||||
msg = f"{prefix}: {exc}"
|
||||
if self._mode == "oss":
|
||||
err_str = str(exc).lower()
|
||||
if "connection" in err_str or "refused" in err_str or "timeout" in err_str:
|
||||
vs = self._config.get("oss", {}).get("vector_store", {})
|
||||
msg += f" (check that {vs.get('provider', 'vector store')} is running)"
|
||||
return msg
|
||||
|
||||
def _record_success(self):
|
||||
self._consecutive_failures = 0
|
||||
with self._breaker_lock:
|
||||
self._consecutive_failures = 0
|
||||
|
||||
def _record_failure(self):
|
||||
self._consecutive_failures += 1
|
||||
if self._consecutive_failures >= _BREAKER_THRESHOLD:
|
||||
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
||||
with self._breaker_lock:
|
||||
self._consecutive_failures += 1
|
||||
count = self._consecutive_failures
|
||||
if count >= _BREAKER_THRESHOLD:
|
||||
self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS
|
||||
else:
|
||||
count = 0
|
||||
if count >= _BREAKER_THRESHOLD:
|
||||
hint = ""
|
||||
if self._mode == "oss":
|
||||
vs = self._config.get("oss", {}).get("vector_store", {})
|
||||
provider = vs.get("provider", "unknown")
|
||||
hint = f" Check that your {provider} vector store is running and reachable."
|
||||
logger.warning(
|
||||
"Mem0 circuit breaker tripped after %d consecutive failures. "
|
||||
"Pausing API calls for %ds.",
|
||||
self._consecutive_failures, _BREAKER_COOLDOWN_SECS,
|
||||
"Pausing API calls for %ds.%s",
|
||||
count, _BREAKER_COOLDOWN_SECS, hint,
|
||||
)
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
self._config = _load_config()
|
||||
self._mode = self._config.get("mode", "platform")
|
||||
self._api_key = self._config.get("api_key", "")
|
||||
self._host = self._config.get("host", "")
|
||||
# Prefer gateway-provided user_id for per-user memory scoping;
|
||||
# fall back to config/env default for CLI (single-user) sessions.
|
||||
self._user_id = kwargs.get("user_id") or self._config.get("user_id", "hermes-user")
|
||||
# Resolution order for user_id:
|
||||
# 1. Operator-configured MEM0_USER_ID (env or $HERMES_HOME/mem0.json) —
|
||||
# the canonical principal, applied across every gateway so the same
|
||||
# human gets one merged memory store.
|
||||
# 2. Gateway-native id from kwargs (Telegram numeric id, Discord
|
||||
# snowflake, etc.) — preserves per-platform isolation when no
|
||||
# override is configured.
|
||||
# 3. Hardcoded fallback _DEFAULT_USER_ID (CLI with no auth).
|
||||
# The literal _DEFAULT_USER_ID string is treated as unset so users who
|
||||
# ran the setup wizard with the suggested default still get gateway-
|
||||
# native ids instead of being silently bucketed together.
|
||||
configured = self._config.get("user_id")
|
||||
if configured == _DEFAULT_USER_ID:
|
||||
configured = None
|
||||
self._user_id = configured or kwargs.get("user_id") or _DEFAULT_USER_ID
|
||||
self._agent_id = self._config.get("agent_id", "hermes")
|
||||
self._rerank = self._config.get("rerank", True)
|
||||
self._channel = kwargs.get("platform") or "cli"
|
||||
self._backend = self._create_backend()
|
||||
if self._backend and not self._atexit_registered:
|
||||
atexit.register(self._shutdown_backend)
|
||||
self._atexit_registered = True
|
||||
|
||||
def _read_filters(self) -> Dict[str, Any]:
|
||||
"""Filters for search/get_all — scoped to user only for cross-session recall."""
|
||||
# Scoped to user_id only — by design — so recall surfaces memories
|
||||
# written from any gateway/agent under this principal. Writes attach
|
||||
# agent_id (and metadata.channel) so per-agent / per-channel views are
|
||||
# still possible at query time when needed; reads default to the wider
|
||||
# cross-agent recall.
|
||||
return {"user_id": self._user_id}
|
||||
|
||||
def _write_filters(self) -> Dict[str, Any]:
|
||||
"""Filters for add — scoped to user + agent for attribution."""
|
||||
return {"user_id": self._user_id, "agent_id": self._agent_id}
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_results(response: Any) -> list:
|
||||
"""Normalize Mem0 API response — v2 wraps results in {"results": [...]}."""
|
||||
if isinstance(response, dict):
|
||||
return response.get("results", [])
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
return []
|
||||
def _write_metadata(self) -> Dict[str, Any]:
|
||||
# Tag every write with the gateway channel so the dashboard can offer
|
||||
# per-channel filtered views without coupling identity to the channel.
|
||||
return {"channel": self._channel} if self._channel else {}
|
||||
|
||||
def system_prompt_block(self) -> str:
|
||||
target = self._host or "cloud"
|
||||
mode_label = "platform (cloud API)" if self._mode == "platform" else "OSS (self-hosted)"
|
||||
rerank_note = " Rerank is available on search." if self._mode == "platform" else ""
|
||||
return (
|
||||
f"# Mem0 Memory ({target})\n"
|
||||
f"Active. User: {self._user_id}.\n"
|
||||
"Use mem0_search to find memories, mem0_conclude to store facts, "
|
||||
"mem0_profile for a full overview."
|
||||
"# Mem0 Memory\n"
|
||||
f"Active. Mode: {mode_label}. User: {self._user_id}.\n"
|
||||
"Use mem0_search to find memories, mem0_add to store facts, "
|
||||
f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}"
|
||||
)
|
||||
|
||||
def prefetch(self, query: str, *, session_id: str = "") -> str:
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
self._prefetch_thread.join(timeout=3.0)
|
||||
# If the thread still hasn't finished, leave the result for the next call.
|
||||
if self._prefetch_thread and self._prefetch_thread.is_alive():
|
||||
return ""
|
||||
with self._prefetch_lock:
|
||||
result = self._prefetch_result
|
||||
self._prefetch_result = ""
|
||||
|
|
@ -266,18 +367,15 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
return f"## Mem0 Memory\n{result}"
|
||||
|
||||
def queue_prefetch(self, query: str, *, session_id: str = "") -> None:
|
||||
if self._is_breaker_open():
|
||||
if self._backend is None or self._is_breaker_open():
|
||||
return
|
||||
|
||||
def _run():
|
||||
backend = self._backend
|
||||
if backend is None:
|
||||
return
|
||||
try:
|
||||
client = self._get_client()
|
||||
results = self._unwrap_results(client.search(
|
||||
query=query,
|
||||
filters=self._read_filters(),
|
||||
rerank=self._rerank,
|
||||
top_k=5,
|
||||
))
|
||||
results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True)
|
||||
if results:
|
||||
lines = [r.get("memory", "") for r in results if r.get("memory")]
|
||||
with self._prefetch_lock:
|
||||
|
|
@ -292,101 +390,171 @@ class Mem0MemoryProvider(MemoryProvider):
|
|||
|
||||
def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None:
|
||||
"""Send the turn to Mem0 for server-side fact extraction (non-blocking)."""
|
||||
if self._is_breaker_open():
|
||||
if self._backend is None or self._is_breaker_open():
|
||||
return
|
||||
|
||||
def _sync():
|
||||
backend = self._backend
|
||||
if backend is None:
|
||||
return
|
||||
try:
|
||||
client = self._get_client()
|
||||
messages = [
|
||||
{"role": "user", "content": user_content},
|
||||
{"role": "assistant", "content": assistant_content},
|
||||
]
|
||||
client.add(messages, **self._write_filters())
|
||||
backend.add(
|
||||
messages,
|
||||
user_id=self._user_id,
|
||||
agent_id=self._agent_id,
|
||||
infer=True,
|
||||
metadata=self._write_metadata(),
|
||||
)
|
||||
self._record_success()
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
logger.warning("Mem0 sync failed: %s", e)
|
||||
|
||||
# Wait for any previous sync before starting a new one
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
|
||||
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync")
|
||||
self._sync_thread.start()
|
||||
with self._sync_lock:
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
self._sync_thread.join(timeout=5.0)
|
||||
# If still alive after timeout, skip to avoid duplicate ingestion.
|
||||
if self._sync_thread and self._sync_thread.is_alive():
|
||||
return
|
||||
self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync")
|
||||
self._sync_thread.start()
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA]
|
||||
return [LIST_SCHEMA, SEARCH_SCHEMA, ADD_SCHEMA, UPDATE_SCHEMA, DELETE_SCHEMA]
|
||||
|
||||
def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str:
|
||||
if self._backend is None:
|
||||
err = getattr(self, "_init_error", "unknown error")
|
||||
hint = ""
|
||||
if self._mode == "oss":
|
||||
vs = self._config.get("oss", {}).get("vector_store", {})
|
||||
provider = vs.get("provider", "vector store")
|
||||
hint = f" Check that {provider} is running and reachable."
|
||||
return json.dumps({"error": f"Mem0 backend not initialized: {err}.{hint}"})
|
||||
|
||||
if self._is_breaker_open():
|
||||
return json.dumps({
|
||||
"error": "Mem0 API temporarily unavailable (multiple consecutive failures). Will retry automatically."
|
||||
})
|
||||
msg = "Mem0 temporarily unavailable (multiple consecutive failures). Will retry automatically."
|
||||
if self._mode == "oss":
|
||||
vs = self._config.get("oss", {}).get("vector_store", {})
|
||||
msg += f" Check that your {vs.get('provider', 'vector store')} is running."
|
||||
return json.dumps({"error": msg})
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
except Exception as e:
|
||||
return tool_error(str(e))
|
||||
|
||||
if tool_name == "mem0_profile":
|
||||
if tool_name == "mem0_list":
|
||||
try:
|
||||
memories = self._unwrap_results(client.get_all(filters=self._read_filters()))
|
||||
page = max(1, int(args.get("page", 1)))
|
||||
page_size = min(max(1, int(args.get("page_size", 100))), 200)
|
||||
response = self._backend.get_all(
|
||||
filters=self._read_filters(), page=page, page_size=page_size,
|
||||
)
|
||||
self._record_success()
|
||||
if not memories:
|
||||
results = response.get("results", [])
|
||||
if not results:
|
||||
return json.dumps({"result": "No memories stored yet."})
|
||||
lines = [m.get("memory", "") for m in memories if m.get("memory")]
|
||||
return json.dumps({"result": "\n".join(lines), "count": len(lines)})
|
||||
items = [{"id": m.get("id"), "memory": m.get("memory", "")}
|
||||
for m in results]
|
||||
return json.dumps({
|
||||
"results": items,
|
||||
"count": response.get("count", len(items)),
|
||||
"page": page, "page_size": page_size,
|
||||
})
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
return tool_error(f"Failed to fetch profile: {e}")
|
||||
if not _is_client_error(e):
|
||||
self._record_failure()
|
||||
return tool_error(self._format_error("Failed to list memories", e))
|
||||
|
||||
elif tool_name == "mem0_search":
|
||||
query = args.get("query", "")
|
||||
if not query:
|
||||
return tool_error("Missing required parameter: query")
|
||||
rerank = args.get("rerank", False)
|
||||
top_k = min(int(args.get("top_k", 10)), 50)
|
||||
try:
|
||||
results = self._unwrap_results(client.search(
|
||||
query=query,
|
||||
filters=self._read_filters(),
|
||||
rerank=rerank,
|
||||
top_k=top_k,
|
||||
))
|
||||
top_k = max(1, min(int(args.get("top_k", 10)), 50))
|
||||
rerank_raw = args.get("rerank", True)
|
||||
if isinstance(rerank_raw, str):
|
||||
rerank = rerank_raw.lower() not in ("false", "0", "no")
|
||||
else:
|
||||
rerank = bool(rerank_raw)
|
||||
results = self._backend.search(query, filters=self._read_filters(), top_k=top_k, rerank=rerank)
|
||||
self._record_success()
|
||||
if not results:
|
||||
return json.dumps({"result": "No relevant memories found."})
|
||||
items = [{"memory": r.get("memory", ""), "score": r.get("score", 0)} for r in results]
|
||||
items = [{"id": r.get("id"), "memory": r.get("memory", ""),
|
||||
"score": r.get("score", 0)} for r in results]
|
||||
return json.dumps({"results": items, "count": len(items)})
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
return tool_error(f"Search failed: {e}")
|
||||
if not _is_client_error(e):
|
||||
self._record_failure()
|
||||
return tool_error(self._format_error("Search failed", e))
|
||||
|
||||
elif tool_name == "mem0_conclude":
|
||||
conclusion = args.get("conclusion", "")
|
||||
if not conclusion:
|
||||
return tool_error("Missing required parameter: conclusion")
|
||||
elif tool_name == "mem0_add":
|
||||
content = args.get("content", "")
|
||||
if not content:
|
||||
return tool_error("Missing required parameter: content")
|
||||
try:
|
||||
client.add(
|
||||
[{"role": "user", "content": conclusion}],
|
||||
**self._write_filters(),
|
||||
result = self._backend.add(
|
||||
[{"role": "user", "content": content}],
|
||||
user_id=self._user_id,
|
||||
agent_id=self._agent_id,
|
||||
infer=False,
|
||||
metadata=self._write_metadata(),
|
||||
)
|
||||
self._record_success()
|
||||
return json.dumps({"result": "Fact stored."})
|
||||
event_id = result.get("event_id") if isinstance(result, dict) else None
|
||||
msg = "Fact stored." if self._mode == "oss" else "Fact queued for storage."
|
||||
return json.dumps({"result": msg, "event_id": event_id})
|
||||
except Exception as e:
|
||||
self._record_failure()
|
||||
return tool_error(f"Failed to store: {e}")
|
||||
return tool_error(self._format_error("Failed to store", e))
|
||||
|
||||
elif tool_name == "mem0_update":
|
||||
memory_id = args.get("memory_id", "")
|
||||
text = args.get("text", "")
|
||||
if not memory_id:
|
||||
return tool_error("Missing required parameter: memory_id")
|
||||
if not text:
|
||||
return tool_error("Missing required parameter: text")
|
||||
try:
|
||||
result = self._backend.update(memory_id, text)
|
||||
self._record_success()
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
if _is_client_error(e):
|
||||
return tool_error(f"Memory not found: {memory_id}")
|
||||
self._record_failure()
|
||||
return tool_error(self._format_error("Update failed", e))
|
||||
|
||||
elif tool_name == "mem0_delete":
|
||||
memory_id = args.get("memory_id", "")
|
||||
if not memory_id:
|
||||
return tool_error("Missing required parameter: memory_id")
|
||||
try:
|
||||
result = self._backend.delete(memory_id)
|
||||
self._record_success()
|
||||
return json.dumps(result)
|
||||
except Exception as e:
|
||||
if _is_client_error(e):
|
||||
return tool_error(f"Memory not found: {memory_id}")
|
||||
self._record_failure()
|
||||
return tool_error(self._format_error("Delete failed", e))
|
||||
|
||||
return tool_error(f"Unknown tool: {tool_name}")
|
||||
|
||||
def _shutdown_backend(self):
|
||||
try:
|
||||
if self._backend:
|
||||
self._backend.close()
|
||||
self._backend = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def shutdown(self) -> None:
|
||||
for t in (self._prefetch_thread, self._sync_thread):
|
||||
if t and t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
with self._client_lock:
|
||||
self._client = None
|
||||
self._shutdown_backend()
|
||||
|
||||
|
||||
def register(ctx) -> None:
|
||||
|
|
|
|||
243
plugins/memory/mem0/_backend.py
Normal file
243
plugins/memory/mem0/_backend.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
"""Backend abstraction for Mem0 Platform and OSS modes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Mem0Backend(ABC):
|
||||
"""Unified interface over Platform (MemoryClient) and OSS (Memory) backends."""
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add(
|
||||
self,
|
||||
messages: list,
|
||||
*,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
infer: bool = False,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def update(self, memory_id: str, text: str) -> dict:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, memory_id: str) -> dict:
|
||||
...
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _unwrap_results(response: Any) -> list:
|
||||
"""Normalize API response — extract results list from dict or pass through."""
|
||||
if isinstance(response, dict):
|
||||
return response.get("results", [])
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
return []
|
||||
|
||||
|
||||
class PlatformBackend(Mem0Backend):
|
||||
"""Wraps mem0.MemoryClient for Mem0 Platform (cloud API)."""
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
from mem0 import MemoryClient
|
||||
self._client = MemoryClient(api_key=api_key)
|
||||
|
||||
def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]:
|
||||
response = self._client.search(query, filters=filters, top_k=top_k, rerank=rerank)
|
||||
return _unwrap_results(response)
|
||||
|
||||
def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict:
|
||||
response = self._client.get_all(filters=filters, page=page, page_size=page_size)
|
||||
results = response.get("results", []) if isinstance(response, dict) else response
|
||||
count = response.get("count", len(results)) if isinstance(response, dict) else len(results)
|
||||
return {"results": results, "count": count}
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages: list,
|
||||
*,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
infer: bool = False,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
kwargs: dict[str, Any] = {"user_id": user_id, "agent_id": agent_id, "infer": infer}
|
||||
if metadata:
|
||||
kwargs["metadata"] = metadata
|
||||
return self._client.add(messages, **kwargs)
|
||||
|
||||
def update(self, memory_id: str, text: str) -> dict:
|
||||
self._client.update(memory_id=memory_id, text=text)
|
||||
return {"result": "Memory updated.", "memory_id": memory_id}
|
||||
|
||||
def delete(self, memory_id: str) -> dict:
|
||||
self._client.delete(memory_id=memory_id)
|
||||
return {"result": "Memory deleted.", "memory_id": memory_id}
|
||||
|
||||
|
||||
class OSSBackend(Mem0Backend):
|
||||
"""Wraps mem0.Memory for self-hosted (OSS) mode."""
|
||||
|
||||
def __init__(self, oss_config: dict):
|
||||
import os
|
||||
from mem0 import Memory
|
||||
|
||||
vector_store = dict(oss_config["vector_store"])
|
||||
vs_config = dict(vector_store.get("config", {}))
|
||||
|
||||
if "path" in vs_config:
|
||||
vs_config["path"] = os.path.expanduser(vs_config["path"])
|
||||
|
||||
embedder_config = oss_config.get("embedder", {}).get("config", {})
|
||||
dims = embedder_config.get("embedding_dims")
|
||||
if not dims:
|
||||
from ._oss_providers import KNOWN_DIMS
|
||||
model = embedder_config.get("model", "")
|
||||
dims = KNOWN_DIMS.get(model)
|
||||
if dims:
|
||||
vs_config["embedding_model_dims"] = dims
|
||||
self._recreate_collection_if_dims_changed(
|
||||
vector_store.get("provider", "qdrant"), vs_config, dims,
|
||||
)
|
||||
|
||||
vector_store["config"] = vs_config
|
||||
|
||||
config = {
|
||||
"vector_store": vector_store,
|
||||
"llm": oss_config["llm"],
|
||||
"embedder": oss_config["embedder"],
|
||||
"version": "v1.1",
|
||||
}
|
||||
self._memory = Memory.from_config(config)
|
||||
|
||||
@staticmethod
|
||||
def _recreate_collection_if_dims_changed(provider: str, vs_config: dict, expected_dims: int) -> None:
|
||||
"""Delete stale vector collection when embedding dimensions change."""
|
||||
collection_name = vs_config.get("collection_name", "mem0")
|
||||
if provider == "qdrant":
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
path = vs_config.get("path")
|
||||
url = vs_config.get("url")
|
||||
if path:
|
||||
client = QdrantClient(path=path)
|
||||
elif url:
|
||||
client = QdrantClient(url=url, api_key=vs_config.get("api_key"))
|
||||
else:
|
||||
return
|
||||
try:
|
||||
if not client.collection_exists(collection_name):
|
||||
return
|
||||
info = client.get_collection(collection_name)
|
||||
vectors = info.config.params.vectors
|
||||
# Named-vector collections expose a dict; unnamed expose an object with .size.
|
||||
if isinstance(vectors, dict):
|
||||
first = next(iter(vectors.values()), None)
|
||||
current_dims = first.size if first else None
|
||||
else:
|
||||
current_dims = getattr(vectors, "size", None)
|
||||
if current_dims is not None and current_dims != expected_dims:
|
||||
client.delete_collection(collection_name)
|
||||
finally:
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
elif provider == "pgvector":
|
||||
try:
|
||||
import psycopg2
|
||||
from psycopg2 import sql as pgsql
|
||||
conn_params = {}
|
||||
for k in ("host", "port", "user", "password", "dbname"):
|
||||
if vs_config.get(k):
|
||||
conn_params[k] = vs_config[k]
|
||||
if vs_config.get("sslmode"):
|
||||
conn_params["sslmode"] = vs_config["sslmode"]
|
||||
conn = psycopg2.connect(**conn_params)
|
||||
conn.autocommit = True
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
try:
|
||||
cur.execute(
|
||||
"SELECT atttypmod FROM pg_attribute "
|
||||
"WHERE attrelid = %s::regclass AND attname = 'vector'",
|
||||
(collection_name,),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row and row[0] > 0 and row[0] != expected_dims:
|
||||
cur.execute(pgsql.SQL("DROP TABLE IF EXISTS {}").format(
|
||||
pgsql.Identifier(collection_name)
|
||||
))
|
||||
finally:
|
||||
cur.close()
|
||||
finally:
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]:
|
||||
response = self._memory.search(query, filters=filters, top_k=top_k)
|
||||
return _unwrap_results(response)
|
||||
|
||||
def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict:
|
||||
response = self._memory.get_all(filters=filters)
|
||||
all_results = _unwrap_results(response)
|
||||
total = len(all_results)
|
||||
start = (page - 1) * page_size
|
||||
results = all_results[start : start + page_size]
|
||||
return {"results": results, "count": total}
|
||||
|
||||
def add(
|
||||
self,
|
||||
messages: list,
|
||||
*,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
infer: bool = False,
|
||||
metadata: dict | None = None,
|
||||
) -> dict:
|
||||
kwargs: dict[str, Any] = {"user_id": user_id, "agent_id": agent_id, "infer": infer}
|
||||
if metadata:
|
||||
kwargs["metadata"] = metadata
|
||||
return self._memory.add(messages, **kwargs)
|
||||
|
||||
def update(self, memory_id: str, text: str) -> dict:
|
||||
self._memory.update(memory_id, data=text)
|
||||
return {"result": "Memory updated.", "memory_id": memory_id}
|
||||
|
||||
def delete(self, memory_id: str) -> dict:
|
||||
self._memory.delete(memory_id)
|
||||
return {"result": "Memory deleted.", "memory_id": memory_id}
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
telemetry = getattr(self._memory, "telemetry", None)
|
||||
if telemetry and hasattr(telemetry, "posthog"):
|
||||
try:
|
||||
telemetry.posthog.shutdown()
|
||||
except Exception:
|
||||
pass
|
||||
if hasattr(self._memory, "close"):
|
||||
self._memory.close()
|
||||
vs = getattr(self._memory, "vector_store", None)
|
||||
if vs and hasattr(vs, "close"):
|
||||
vs.close()
|
||||
client = getattr(vs, "client", None)
|
||||
if client and hasattr(client, "close"):
|
||||
client.close()
|
||||
except Exception:
|
||||
pass
|
||||
84
plugins/memory/mem0/_oss_providers.py
Normal file
84
plugins/memory/mem0/_oss_providers.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""OSS provider definitions for LLM, embedder, and vector store."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
LLM_PROVIDERS: dict[str, dict[str, Any]] = {
|
||||
"openai": {
|
||||
"label": "OpenAI",
|
||||
"needs_key": True,
|
||||
"env_var": "OPENAI_API_KEY",
|
||||
"default_model": "gpt-5-mini",
|
||||
},
|
||||
"ollama": {
|
||||
"label": "Ollama (local)",
|
||||
"needs_key": False,
|
||||
"default_model": "llama3.1:8b",
|
||||
"default_url": "http://localhost:11434",
|
||||
"pip_dep": "ollama",
|
||||
},
|
||||
}
|
||||
|
||||
EMBEDDER_PROVIDERS: dict[str, dict[str, Any]] = {
|
||||
"openai": {
|
||||
"label": "OpenAI",
|
||||
"needs_key": True,
|
||||
"env_var": "OPENAI_API_KEY",
|
||||
"default_model": "text-embedding-3-small",
|
||||
"dims": 1536,
|
||||
},
|
||||
"ollama": {
|
||||
"label": "Ollama (local)",
|
||||
"needs_key": False,
|
||||
"default_model": "nomic-embed-text",
|
||||
"default_url": "http://localhost:11434",
|
||||
"dims": 768,
|
||||
"pip_dep": "ollama",
|
||||
},
|
||||
}
|
||||
|
||||
VECTOR_PROVIDERS: dict[str, dict[str, Any]] = {
|
||||
"qdrant": {
|
||||
"label": "Qdrant",
|
||||
"default_config": {"path": os.path.expanduser("~/.hermes/mem0_qdrant")},
|
||||
"pip_dep": "qdrant-client",
|
||||
},
|
||||
"pgvector": {
|
||||
"label": "PGVector",
|
||||
"default_config": {"host": "localhost", "port": 5432, "user": os.getenv("USER", "postgres"), "dbname": "postgres"},
|
||||
"pip_dep": "psycopg2-binary",
|
||||
},
|
||||
}
|
||||
|
||||
KNOWN_DIMS: dict[str, int] = {
|
||||
"text-embedding-3-small": 1536,
|
||||
"text-embedding-3-large": 3072,
|
||||
"text-embedding-ada-002": 1536,
|
||||
"nomic-embed-text": 768,
|
||||
}
|
||||
|
||||
|
||||
def validate_oss_config(oss_config: dict) -> list[str]:
|
||||
"""Validate an OSS config dict. Returns list of error strings (empty = valid)."""
|
||||
errors: list[str] = []
|
||||
|
||||
for section, registry in [("llm", LLM_PROVIDERS), ("embedder", EMBEDDER_PROVIDERS),
|
||||
("vector_store", VECTOR_PROVIDERS)]:
|
||||
block = oss_config.get(section)
|
||||
if not block or not isinstance(block, dict):
|
||||
errors.append(f"Missing required section: {section}")
|
||||
continue
|
||||
provider_id = block.get("provider", "")
|
||||
if provider_id not in registry:
|
||||
valid = ", ".join(registry.keys())
|
||||
errors.append(f"Unknown {section} provider '{provider_id}'. Valid: {valid}")
|
||||
|
||||
vs = oss_config.get("vector_store", {})
|
||||
if vs.get("provider") == "pgvector":
|
||||
cfg = vs.get("config", {})
|
||||
if not cfg.get("user"):
|
||||
errors.append("PGVector requires 'user' in vector_store.config")
|
||||
|
||||
return errors
|
||||
858
plugins/memory/mem0/_setup.py
Normal file
858
plugins/memory/mem0/_setup.py
Normal file
|
|
@ -0,0 +1,858 @@
|
|||
"""Setup wizard for Mem0 plugin — interactive and flag-based modes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import getpass
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from hermes_constants import get_hermes_home
|
||||
|
||||
from ._oss_providers import (
|
||||
LLM_PROVIDERS,
|
||||
EMBEDDER_PROVIDERS,
|
||||
VECTOR_PROVIDERS,
|
||||
KNOWN_DIMS,
|
||||
validate_oss_config,
|
||||
)
|
||||
|
||||
|
||||
def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -> int:
|
||||
"""Interactive single-select with arrow keys."""
|
||||
from hermes_cli.curses_ui import curses_radiolist
|
||||
display_items = [
|
||||
f"{label} {desc}" if desc else label
|
||||
for label, desc in items
|
||||
]
|
||||
return curses_radiolist(title, display_items, selected=default, cancel_returns=default)
|
||||
|
||||
|
||||
def _prompt(label: str, default: str | None = None, secret: bool = False) -> str:
|
||||
"""Prompt for a value with optional default and secret masking."""
|
||||
suffix = f" [{default}]" if default else ""
|
||||
if secret:
|
||||
sys.stdout.write(f" {label}{suffix}: ")
|
||||
sys.stdout.flush()
|
||||
if sys.stdin.isatty():
|
||||
val = getpass.getpass(prompt="")
|
||||
else:
|
||||
val = sys.stdin.readline().strip()
|
||||
else:
|
||||
sys.stdout.write(f" {label}{suffix}: ")
|
||||
sys.stdout.flush()
|
||||
val = sys.stdin.readline().strip()
|
||||
return val or (default or "")
|
||||
|
||||
|
||||
def has_oss_flags() -> bool:
|
||||
"""Check if OSS-related flags are present in sys.argv."""
|
||||
flags = parse_flags(sys.argv[1:])
|
||||
if flags["mode"] == "oss":
|
||||
return True
|
||||
if any(flags.get(k) for k in ("oss_llm_key", "oss_vector_path", "oss_vector_url")):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def parse_flags(argv: list[str] | None = None) -> dict[str, str]:
|
||||
"""Parse CLI flags from argv. Returns dict of flag values."""
|
||||
args = argv if argv is not None else sys.argv[1:]
|
||||
flags: dict[str, str] = {
|
||||
"mode": "",
|
||||
"api_key": "",
|
||||
"oss_llm": "openai",
|
||||
"oss_llm_key": "",
|
||||
"oss_llm_model": "",
|
||||
"oss_llm_url": "",
|
||||
"oss_embedder": "openai",
|
||||
"oss_embedder_key": "",
|
||||
"oss_embedder_model": "",
|
||||
"oss_embedder_url": "",
|
||||
"oss_vector": "qdrant",
|
||||
"oss_vector_path": "",
|
||||
"oss_vector_url": "",
|
||||
"oss_vector_host": "",
|
||||
"oss_vector_port": "",
|
||||
"oss_vector_user": "",
|
||||
"oss_vector_password": "",
|
||||
"oss_vector_dbname": "",
|
||||
"user_id": "",
|
||||
"dry_run": False,
|
||||
}
|
||||
|
||||
flag_map = {
|
||||
"--mode": "mode",
|
||||
"--api-key": "api_key",
|
||||
"--oss-llm": "oss_llm",
|
||||
"--oss-llm-key": "oss_llm_key",
|
||||
"--oss-llm-model": "oss_llm_model",
|
||||
"--oss-llm-url": "oss_llm_url",
|
||||
"--oss-embedder": "oss_embedder",
|
||||
"--oss-embedder-key": "oss_embedder_key",
|
||||
"--oss-embedder-model": "oss_embedder_model",
|
||||
"--oss-embedder-url": "oss_embedder_url",
|
||||
"--oss-vector": "oss_vector",
|
||||
"--oss-vector-path": "oss_vector_path",
|
||||
"--oss-vector-url": "oss_vector_url",
|
||||
"--oss-vector-host": "oss_vector_host",
|
||||
"--oss-vector-port": "oss_vector_port",
|
||||
"--oss-vector-user": "oss_vector_user",
|
||||
"--oss-vector-password": "oss_vector_password",
|
||||
"--oss-vector-dbname": "oss_vector_dbname",
|
||||
"--user-id": "user_id",
|
||||
}
|
||||
|
||||
i = 0
|
||||
while i < len(args):
|
||||
if args[i] == "--dry-run":
|
||||
flags["dry_run"] = True
|
||||
i += 1
|
||||
elif args[i] in flag_map and i + 1 < len(args):
|
||||
flags[flag_map[args[i]]] = args[i + 1]
|
||||
i += 2
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return flags
|
||||
|
||||
|
||||
def build_oss_config(flags: dict[str, str]) -> tuple[dict, dict[str, str]]:
|
||||
"""Build OSS config dict + env_writes from parsed flags.
|
||||
|
||||
Returns (oss_config, env_writes) where oss_config goes into mem0.json
|
||||
and env_writes maps env var names to secret values for .env.
|
||||
"""
|
||||
llm_id = flags.get("oss_llm", "openai")
|
||||
llm_def = LLM_PROVIDERS[llm_id]
|
||||
llm_model = flags.get("oss_llm_model") or llm_def["default_model"]
|
||||
llm_config: dict[str, Any] = {"model": llm_model}
|
||||
if "default_url" in llm_def:
|
||||
llm_config["ollama_base_url"] = flags.get("oss_llm_url") or llm_def["default_url"]
|
||||
|
||||
embedder_id = flags.get("oss_embedder", "openai")
|
||||
embedder_def = EMBEDDER_PROVIDERS[embedder_id]
|
||||
embedder_model = flags.get("oss_embedder_model") or embedder_def["default_model"]
|
||||
embedder_config: dict[str, Any] = {"model": embedder_model}
|
||||
if "default_url" in embedder_def:
|
||||
embedder_config["ollama_base_url"] = flags.get("oss_embedder_url") or embedder_def["default_url"]
|
||||
dims = KNOWN_DIMS.get(embedder_model)
|
||||
if dims:
|
||||
embedder_config["embedding_dims"] = dims
|
||||
|
||||
vector_id = flags.get("oss_vector", "qdrant")
|
||||
vector_def = VECTOR_PROVIDERS[vector_id]
|
||||
vector_config = dict(vector_def["default_config"])
|
||||
if vector_id == "qdrant":
|
||||
if flags.get("oss_vector_path"):
|
||||
vector_config["path"] = flags["oss_vector_path"]
|
||||
if flags.get("oss_vector_url"):
|
||||
vector_config.pop("path", None)
|
||||
vector_config["url"] = flags["oss_vector_url"]
|
||||
elif vector_id == "pgvector":
|
||||
if flags.get("oss_vector_host"):
|
||||
vector_config["host"] = flags["oss_vector_host"]
|
||||
if flags.get("oss_vector_port"):
|
||||
vector_config["port"] = int(flags["oss_vector_port"])
|
||||
if flags.get("oss_vector_user"):
|
||||
vector_config["user"] = flags["oss_vector_user"]
|
||||
if flags.get("oss_vector_password"):
|
||||
vector_config["password"] = flags["oss_vector_password"]
|
||||
if flags.get("oss_vector_dbname"):
|
||||
vector_config["dbname"] = flags["oss_vector_dbname"]
|
||||
|
||||
oss_config = {
|
||||
"llm": {"provider": llm_id, "config": llm_config},
|
||||
"embedder": {"provider": embedder_id, "config": embedder_config},
|
||||
"vector_store": {"provider": vector_id, "config": vector_config},
|
||||
}
|
||||
|
||||
env_writes: dict[str, str] = {}
|
||||
if llm_def.get("needs_key") and flags.get("oss_llm_key"):
|
||||
env_writes[llm_def["env_var"]] = flags["oss_llm_key"]
|
||||
if embedder_def.get("needs_key") and flags.get("oss_embedder_key"):
|
||||
env_writes[embedder_def["env_var"]] = flags["oss_embedder_key"]
|
||||
elif embedder_def.get("needs_key") and embedder_id == llm_id and flags.get("oss_llm_key"):
|
||||
env_writes[embedder_def["env_var"]] = flags["oss_llm_key"]
|
||||
|
||||
return oss_config, env_writes
|
||||
|
||||
|
||||
def _write_env(env_path: Path, env_writes: dict[str, str]) -> None:
|
||||
"""Append or update env vars in .env file."""
|
||||
env_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
existing_lines: list[str] = []
|
||||
if env_path.exists():
|
||||
existing_lines = env_path.read_text().splitlines()
|
||||
|
||||
updated_keys: set[str] = set()
|
||||
new_lines: list[str] = []
|
||||
for line in existing_lines:
|
||||
key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None
|
||||
if key_match and key_match in env_writes:
|
||||
new_lines.append(f"{key_match}={env_writes[key_match]}")
|
||||
updated_keys.add(key_match)
|
||||
else:
|
||||
new_lines.append(line)
|
||||
for k, v in env_writes.items():
|
||||
if k not in updated_keys:
|
||||
new_lines.append(f"{k}={v}")
|
||||
|
||||
env_path.write_text("\n".join(new_lines) + "\n")
|
||||
|
||||
|
||||
def _save_mem0_json(hermes_home: str, data: dict) -> None:
|
||||
"""Merge-write to mem0.json."""
|
||||
config_path = Path(hermes_home) / "mem0.json"
|
||||
existing = {}
|
||||
if config_path.exists():
|
||||
try:
|
||||
existing = json.loads(config_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
existing.update(data)
|
||||
config_path.write_text(json.dumps(existing, indent=2) + "\n")
|
||||
|
||||
|
||||
def _setup_platform(hermes_home: str, config: dict, flags: dict[str, str]) -> None:
|
||||
"""Platform mode setup — uses the framework's schema-based flow.
|
||||
|
||||
Delegates to the same code path the framework uses when post_setup
|
||||
doesn't exist, preserving the original platform onboarding experience.
|
||||
"""
|
||||
schema = [
|
||||
{"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"},
|
||||
{"key": "user_id", "description": "User identifier", "default": "hermes-user"},
|
||||
{"key": "agent_id", "description": "Agent identifier", "default": "hermes"},
|
||||
{"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]},
|
||||
]
|
||||
|
||||
existing_config = {}
|
||||
config_path = Path(hermes_home) / "mem0.json"
|
||||
if config_path.exists():
|
||||
try:
|
||||
existing_config = json.loads(config_path.read_text())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
provider_config = dict(existing_config)
|
||||
env_writes: dict[str, str] = {}
|
||||
|
||||
print("\n Configuring mem0:\n")
|
||||
|
||||
for field in schema:
|
||||
key = field["key"]
|
||||
desc = field.get("description", key)
|
||||
default = field.get("default")
|
||||
is_secret = field.get("secret", False)
|
||||
choices = field.get("choices")
|
||||
env_var = field.get("env_var")
|
||||
url = field.get("url")
|
||||
|
||||
if flags.get("api_key") and key == "api_key":
|
||||
env_writes["MEM0_API_KEY"] = flags["api_key"]
|
||||
continue
|
||||
|
||||
if choices and not is_secret:
|
||||
choice_items = [(c, "") for c in choices]
|
||||
current = provider_config.get(key, default)
|
||||
current_idx = 0
|
||||
if current and str(current).lower() in choices:
|
||||
current_idx = choices.index(str(current).lower())
|
||||
sel = _curses_select(f" {desc}", choice_items, default=current_idx)
|
||||
provider_config[key] = choices[sel]
|
||||
elif is_secret:
|
||||
existing = os.environ.get(env_var, "") if env_var else ""
|
||||
if existing:
|
||||
masked = f"...{existing[-4:]}" if len(existing) > 4 else "set"
|
||||
val = _prompt(f"{desc} (current: {masked}, blank to keep)", secret=True)
|
||||
else:
|
||||
if url:
|
||||
print(f" Get yours at {url}")
|
||||
val = _prompt(desc, secret=True)
|
||||
if val and env_var:
|
||||
env_writes[env_var] = val
|
||||
else:
|
||||
current = provider_config.get(key)
|
||||
effective_default = current or default
|
||||
val = _prompt(desc, default=str(effective_default) if effective_default else None)
|
||||
if val:
|
||||
provider_config[key] = val
|
||||
|
||||
if flags.get("dry_run"):
|
||||
print(f"\n [dry-run] Would save config: {provider_config}")
|
||||
if env_writes:
|
||||
print(" [dry-run] Would write API key to .env")
|
||||
print(" [dry-run] No files written.\n")
|
||||
return
|
||||
|
||||
provider_config["mode"] = "platform"
|
||||
|
||||
from hermes_cli.config import save_config
|
||||
config["memory"]["provider"] = "mem0"
|
||||
save_config(config)
|
||||
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.save_config(provider_config, hermes_home)
|
||||
|
||||
if env_writes:
|
||||
_write_env(Path(hermes_home) / ".env", env_writes)
|
||||
|
||||
print(f"\n Memory provider: mem0")
|
||||
print(f" Activation saved to config.yaml")
|
||||
print(f" Provider config saved")
|
||||
if env_writes:
|
||||
print(f" API keys saved to .env")
|
||||
print(f"\n Start a new session to activate.\n")
|
||||
|
||||
|
||||
def _setup_oss(hermes_home: str, config: dict, flags: dict[str, str]) -> None:
|
||||
"""OSS mode setup — build config from flags or interactive prompts.
|
||||
|
||||
Non-interactive when --mode was set explicitly via flags (post_setup already
|
||||
resolved mode). Interactive only when mode was chosen via curses picker.
|
||||
"""
|
||||
if not flags.get("_mode_from_flag"):
|
||||
_setup_oss_interactive(hermes_home, config)
|
||||
return
|
||||
|
||||
oss_config, env_writes = build_oss_config(flags)
|
||||
errors = validate_oss_config(oss_config)
|
||||
if errors:
|
||||
for e in errors:
|
||||
print(f" Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
user_id = flags.get("user_id") or os.getenv("USER", "hermes-user")
|
||||
|
||||
llm_id = oss_config["llm"]["provider"]
|
||||
embedder_id = oss_config["embedder"]["provider"]
|
||||
vector_id = oss_config["vector_store"]["provider"]
|
||||
|
||||
if flags.get("dry_run"):
|
||||
print("\n [dry-run] OSS config would be:")
|
||||
print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})")
|
||||
print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})")
|
||||
print(f" Vector: {vector_id}")
|
||||
if env_writes:
|
||||
print(f" Env vars: {', '.join(env_writes.keys())}")
|
||||
_run_connectivity_checks(oss_config)
|
||||
print(" [dry-run] No files written.\n")
|
||||
return
|
||||
|
||||
if env_writes:
|
||||
_write_env(Path(hermes_home) / ".env", env_writes)
|
||||
_save_mem0_json(hermes_home, {"mode": "oss", "user_id": user_id, "agent_id": "hermes", "oss": oss_config})
|
||||
|
||||
_install_provider_deps(llm_id, embedder_id, vector_id)
|
||||
|
||||
from hermes_cli.config import save_config
|
||||
config["memory"]["provider"] = "mem0"
|
||||
save_config(config)
|
||||
|
||||
_run_connectivity_checks(oss_config)
|
||||
print(f"\n ✓ Mem0 configured (OSS mode)")
|
||||
print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})")
|
||||
print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})")
|
||||
print(f" Vector: {vector_id}")
|
||||
if env_writes:
|
||||
print(f" API keys saved to .env")
|
||||
print(f" Config saved to mem0.json")
|
||||
print(f" Provider set in config.yaml")
|
||||
print("\n Start a new session to activate.\n")
|
||||
|
||||
|
||||
def _prompt_api_key(label: str, env_var: str, hermes_home: str) -> str:
|
||||
"""Prompt for API key, showing masked existing value if found."""
|
||||
existing = os.environ.get(env_var, "")
|
||||
if not existing:
|
||||
env_path = Path(hermes_home) / ".env"
|
||||
if env_path.exists():
|
||||
for line in env_path.read_text().splitlines():
|
||||
if line.startswith(f"{env_var}="):
|
||||
existing = line.split("=", 1)[1].strip()
|
||||
break
|
||||
if existing:
|
||||
masked = f"...{existing[-4:]}" if len(existing) > 4 else "set"
|
||||
return getpass.getpass(f" {label} API key (current: {masked}, blank to keep): ").strip()
|
||||
return getpass.getpass(f" {label} API key: ").strip()
|
||||
|
||||
|
||||
_PGVECTOR_CONTAINER = "hermes-pgvector"
|
||||
_PGVECTOR_IMAGE = "pgvector/pgvector:pg17"
|
||||
_PGVECTOR_PASSWORD = "hermes"
|
||||
|
||||
|
||||
def _ensure_pgvector(host: str = "localhost", port: int = 5432) -> dict | None:
|
||||
"""Ensure pgvector is reachable; offer Docker setup if not.
|
||||
|
||||
Returns updated vector_config dict if Docker was started, None otherwise.
|
||||
"""
|
||||
ok, _ = _check_pgvector(host, port)
|
||||
if ok:
|
||||
print(f" ✓ PostgreSQL reachable at {host}:{port}")
|
||||
return None
|
||||
|
||||
print(f" PostgreSQL not reachable at {host}:{port}")
|
||||
|
||||
# Check if our container already exists but is stopped
|
||||
if shutil.which("docker"):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["docker", "inspect", _PGVECTOR_CONTAINER, "--format", "{{.State.Status}}"],
|
||||
capture_output=True, text=True, timeout=10, stdin=subprocess.DEVNULL,
|
||||
)
|
||||
if result.returncode == 0 and "exited" in result.stdout:
|
||||
print(f" Found stopped container '{_PGVECTOR_CONTAINER}', restarting...")
|
||||
subprocess.run(["docker", "start", _PGVECTOR_CONTAINER],
|
||||
capture_output=True, timeout=15,
|
||||
stdin=subprocess.DEVNULL)
|
||||
_wait_for_port(host, port, timeout=15)
|
||||
ok, _ = _check_pgvector(host, port)
|
||||
if ok:
|
||||
print(f" ✓ PostgreSQL container restarted")
|
||||
return None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
answer = input(" Start pgvector via Docker? [Y/n]: ").strip().lower()
|
||||
if answer in ("", "y", "yes"):
|
||||
return _start_pgvector_docker(host, port)
|
||||
else:
|
||||
print(" Skipping Docker setup. Make sure PostgreSQL with pgvector is running.")
|
||||
return None
|
||||
else:
|
||||
print(" Docker not found. Install Docker to auto-start pgvector,")
|
||||
print(" or run PostgreSQL with pgvector manually.")
|
||||
return None
|
||||
|
||||
|
||||
def _start_pgvector_docker(host: str, port: int) -> dict | None:
|
||||
"""Pull and start pgvector Docker container."""
|
||||
try:
|
||||
print(f" Pulling {_PGVECTOR_IMAGE}...")
|
||||
subprocess.run(["docker", "pull", _PGVECTOR_IMAGE],
|
||||
capture_output=True, timeout=120,
|
||||
stdin=subprocess.DEVNULL)
|
||||
|
||||
# Remove existing container if present
|
||||
subprocess.run(["docker", "rm", "-f", _PGVECTOR_CONTAINER],
|
||||
capture_output=True, timeout=10,
|
||||
stdin=subprocess.DEVNULL)
|
||||
|
||||
print(f" Starting container '{_PGVECTOR_CONTAINER}' on port {port}...")
|
||||
subprocess.run([
|
||||
"docker", "run", "-d",
|
||||
"--name", _PGVECTOR_CONTAINER,
|
||||
"-e", f"POSTGRES_PASSWORD={_PGVECTOR_PASSWORD}",
|
||||
"-p", f"{port}:5432",
|
||||
_PGVECTOR_IMAGE,
|
||||
], capture_output=True, timeout=30, check=True, stdin=subprocess.DEVNULL)
|
||||
|
||||
_wait_for_port(host, port, timeout=20)
|
||||
ok, _ = _check_pgvector(host, port)
|
||||
if ok:
|
||||
print(f" ✓ pgvector running on {host}:{port}")
|
||||
return {
|
||||
"host": host, "port": port,
|
||||
"user": "postgres", "password": _PGVECTOR_PASSWORD,
|
||||
"dbname": "postgres",
|
||||
}
|
||||
else:
|
||||
print(" Warning: Container started but PostgreSQL not yet accepting connections.")
|
||||
print(" It may need a few more seconds. Config will be saved; retry later.")
|
||||
return {
|
||||
"host": host, "port": port,
|
||||
"user": "postgres", "password": _PGVECTOR_PASSWORD,
|
||||
"dbname": "postgres",
|
||||
}
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f" Failed to start Docker container: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" Docker error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_ollama(models: list[str]) -> bool:
|
||||
"""Ensure Ollama is running and required models are pulled.
|
||||
|
||||
Returns True if Ollama is ready, False if user needs to handle it manually.
|
||||
"""
|
||||
url = "http://localhost:11434"
|
||||
ollama_bin = shutil.which("ollama")
|
||||
ok, _ = _check_ollama(url)
|
||||
|
||||
if not ok:
|
||||
if ollama_bin:
|
||||
print(" Ollama installed but not running. Starting...")
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[ollama_bin, "serve"],
|
||||
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
|
||||
)
|
||||
_wait_for_port("localhost", 11434, timeout=10)
|
||||
ok, _ = _check_ollama(url)
|
||||
if ok:
|
||||
print(" ✓ Ollama started")
|
||||
except Exception as e:
|
||||
print(f" Could not start Ollama: {e}")
|
||||
else:
|
||||
print(" Ollama not found. Install it:")
|
||||
print(" curl -fsSL https://ollama.com/install.sh | sh")
|
||||
print(" Or on macOS: brew install ollama")
|
||||
return False
|
||||
|
||||
if not ok:
|
||||
print(" Warning: Ollama not reachable. Models cannot be pulled.")
|
||||
return False
|
||||
|
||||
# Pull required models
|
||||
for model in models:
|
||||
if _ollama_has_model(url, model):
|
||||
print(f" ✓ Model '{model}' available")
|
||||
else:
|
||||
print(f" Pulling '{model}'... (this may take a few minutes)")
|
||||
try:
|
||||
subprocess.run([ollama_bin or "ollama", "pull", model], timeout=600,
|
||||
stdin=subprocess.DEVNULL)
|
||||
print(f" ✓ Model '{model}' pulled")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not pull '{model}': {e}")
|
||||
print(f" Run manually: ollama pull {model}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _ollama_has_model(url: str, model: str) -> bool:
|
||||
"""Check if Ollama already has a model pulled."""
|
||||
try:
|
||||
req = urllib.request.Request(f"{url}/api/tags", method="GET")
|
||||
resp = urllib.request.urlopen(req, timeout=5)
|
||||
data = json.loads(resp.read())
|
||||
names = [m.get("name", "") for m in data.get("models", [])]
|
||||
base_model = model.split(":")[0]
|
||||
return any(model in n or base_model in n for n in names)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_pgvector_extension(pg_config: dict) -> None:
|
||||
"""Create the pgvector extension if it doesn't exist."""
|
||||
try:
|
||||
import psycopg2
|
||||
except ImportError:
|
||||
return
|
||||
conn_params = {
|
||||
"host": pg_config.get("host", "localhost"),
|
||||
"port": pg_config.get("port", 5432),
|
||||
"user": pg_config.get("user", "postgres"),
|
||||
"dbname": pg_config.get("dbname", "postgres"),
|
||||
}
|
||||
if pg_config.get("password"):
|
||||
conn_params["password"] = pg_config["password"]
|
||||
try:
|
||||
conn = psycopg2.connect(**conn_params)
|
||||
conn.autocommit = True
|
||||
cur = conn.cursor()
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS vector")
|
||||
cur.close()
|
||||
conn.close()
|
||||
print(" ✓ pgvector extension enabled")
|
||||
except Exception as e:
|
||||
print(f" Warning: Could not enable pgvector extension: {e}")
|
||||
|
||||
|
||||
def _wait_for_port(host: str, port: int, timeout: int = 15) -> None:
|
||||
"""Wait until a TCP port is accepting connections."""
|
||||
import time
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
sock = socket.create_connection((host, port), timeout=1)
|
||||
sock.close()
|
||||
return
|
||||
except OSError:
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _provider_description(v: dict) -> str:
|
||||
"""Description for LLM/embedder picker: model + URL if applicable."""
|
||||
model = v.get("default_model", "")
|
||||
url = v.get("default_url")
|
||||
if url:
|
||||
return f"{model} ({url})"
|
||||
return model
|
||||
|
||||
|
||||
def _vector_description(pid: str, v: dict) -> str:
|
||||
cfg = v.get("default_config", {})
|
||||
if pid == "qdrant":
|
||||
return cfg.get("path", "local storage")
|
||||
if pid == "pgvector":
|
||||
return f"{cfg.get('host', 'localhost')}:{cfg.get('port', 5432)}"
|
||||
return pid
|
||||
|
||||
|
||||
def _setup_oss_interactive(hermes_home: str, config: dict) -> None:
|
||||
"""Interactive OSS setup using curses pickers."""
|
||||
llm_items = [(v["label"], _provider_description(v)) for pid, v in LLM_PROVIDERS.items()]
|
||||
llm_idx = _curses_select("LLM Provider", llm_items, 0)
|
||||
llm_id = list(LLM_PROVIDERS.keys())[llm_idx]
|
||||
llm_def = LLM_PROVIDERS[llm_id]
|
||||
|
||||
env_writes: dict[str, str] = {}
|
||||
llm_model = llm_def["default_model"]
|
||||
llm_url = llm_def.get("default_url")
|
||||
if llm_def["needs_key"]:
|
||||
key = _prompt_api_key(llm_def["label"], llm_def["env_var"], hermes_home)
|
||||
if key:
|
||||
env_writes[llm_def["env_var"]] = key
|
||||
if llm_id == "ollama":
|
||||
llm_model = input(f" LLM model [{llm_def['default_model']}]: ").strip() or llm_def["default_model"]
|
||||
llm_url = input(f" Ollama URL [{llm_def['default_url']}]: ").strip() or llm_def["default_url"]
|
||||
|
||||
embedder_items = [(v["label"], _provider_description(v)) for pid, v in EMBEDDER_PROVIDERS.items()]
|
||||
embedder_idx = _curses_select("Embedder Provider", embedder_items, 0)
|
||||
embedder_id = list(EMBEDDER_PROVIDERS.keys())[embedder_idx]
|
||||
embedder_def = EMBEDDER_PROVIDERS[embedder_id]
|
||||
|
||||
embedder_model = embedder_def["default_model"]
|
||||
embedder_url = embedder_def.get("default_url")
|
||||
if embedder_def["needs_key"] and embedder_id != llm_id:
|
||||
key = _prompt_api_key(f"{embedder_def['label']} embedder", embedder_def["env_var"], hermes_home)
|
||||
if key:
|
||||
env_writes[embedder_def["env_var"]] = key
|
||||
elif embedder_def["needs_key"] and embedder_id == llm_id:
|
||||
if llm_def.get("env_var") in env_writes:
|
||||
env_writes[embedder_def["env_var"]] = env_writes[llm_def["env_var"]]
|
||||
if embedder_id == "ollama":
|
||||
embedder_model = input(f" Embedder model [{embedder_def['default_model']}]: ").strip() or embedder_def["default_model"]
|
||||
embedder_url = input(f" Ollama URL [{embedder_def['default_url']}]: ").strip() or embedder_def["default_url"]
|
||||
|
||||
vector_items = [(v["label"], _vector_description(pid, v)) for pid, v in VECTOR_PROVIDERS.items()]
|
||||
vector_idx = _curses_select("Vector Store", vector_items, 0)
|
||||
vector_id = list(VECTOR_PROVIDERS.keys())[vector_idx]
|
||||
|
||||
# Auto-setup: ensure Ollama is running and models are pulled
|
||||
ollama_models = []
|
||||
if llm_id == "ollama":
|
||||
ollama_models.append(llm_model)
|
||||
if embedder_id == "ollama":
|
||||
ollama_models.append(embedder_model)
|
||||
if ollama_models:
|
||||
_ensure_ollama(ollama_models)
|
||||
|
||||
# Auto-setup: ensure pgvector is reachable (offer Docker if not)
|
||||
pgvector_config = None
|
||||
if vector_id == "pgvector":
|
||||
pgvector_config = _ensure_pgvector()
|
||||
if not pgvector_config:
|
||||
# Native PostgreSQL — prompt for connection details
|
||||
default_user = os.getenv("USER", "postgres")
|
||||
pg_user = input(f" PostgreSQL user [{default_user}]: ").strip() or default_user
|
||||
pg_host = input(" PostgreSQL host [localhost]: ").strip() or "localhost"
|
||||
pg_port = input(" PostgreSQL port [5432]: ").strip() or "5432"
|
||||
pg_dbname = input(" PostgreSQL database [postgres]: ").strip() or "postgres"
|
||||
pg_password = getpass.getpass(" PostgreSQL password (blank if none): ").strip()
|
||||
pgvector_config = {
|
||||
"host": pg_host, "port": int(pg_port),
|
||||
"user": pg_user, "dbname": pg_dbname,
|
||||
}
|
||||
if pg_password:
|
||||
pgvector_config["password"] = pg_password
|
||||
|
||||
user_id = input(f" User ID [{os.getenv('USER', 'hermes-user')}]: ").strip()
|
||||
user_id = user_id or os.getenv("USER", "hermes-user")
|
||||
|
||||
agent_id = input(" Agent ID [hermes]: ").strip()
|
||||
agent_id = agent_id or "hermes"
|
||||
|
||||
flags = {
|
||||
"oss_llm": llm_id,
|
||||
"oss_llm_key": env_writes.get(llm_def["env_var"], "") if llm_def.get("env_var") else "",
|
||||
"oss_llm_model": llm_model,
|
||||
"oss_llm_url": llm_url or "",
|
||||
"oss_embedder": embedder_id,
|
||||
"oss_embedder_model": embedder_model,
|
||||
"oss_embedder_url": embedder_url or "",
|
||||
"oss_vector": vector_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
|
||||
if pgvector_config:
|
||||
flags["oss_vector_host"] = pgvector_config["host"]
|
||||
flags["oss_vector_port"] = str(pgvector_config["port"])
|
||||
flags["oss_vector_user"] = pgvector_config["user"]
|
||||
if pgvector_config.get("password"):
|
||||
flags["oss_vector_password"] = pgvector_config["password"]
|
||||
flags["oss_vector_dbname"] = pgvector_config["dbname"]
|
||||
|
||||
oss_config, _ = build_oss_config(flags)
|
||||
|
||||
if env_writes:
|
||||
_write_env(Path(hermes_home) / ".env", env_writes)
|
||||
_save_mem0_json(hermes_home, {"mode": "oss", "user_id": user_id, "agent_id": agent_id, "oss": oss_config})
|
||||
|
||||
_install_provider_deps(llm_id, embedder_id, vector_id)
|
||||
|
||||
if vector_id == "pgvector" and pgvector_config:
|
||||
_ensure_pgvector_extension(pgvector_config)
|
||||
|
||||
from hermes_cli.config import save_config
|
||||
config["memory"]["provider"] = "mem0"
|
||||
save_config(config)
|
||||
|
||||
_run_connectivity_checks(oss_config)
|
||||
print(f"\n ✓ Mem0 configured (OSS mode)")
|
||||
print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})")
|
||||
print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})")
|
||||
print(f" Vector: {vector_id}")
|
||||
if env_writes:
|
||||
print(f" API keys saved to .env")
|
||||
print(f" Config saved to mem0.json")
|
||||
print(f" Provider set in config.yaml")
|
||||
print("\n Start a new session to activate.\n")
|
||||
|
||||
|
||||
def _install_provider_deps(llm_id: str, embedder_id: str, vector_id: str) -> None:
|
||||
"""Install all optional pip deps for selected providers."""
|
||||
deps: set[str] = set()
|
||||
for registry, pid in [(LLM_PROVIDERS, llm_id), (EMBEDDER_PROVIDERS, embedder_id),
|
||||
(VECTOR_PROVIDERS, vector_id)]:
|
||||
dep = registry.get(pid, {}).get("pip_dep")
|
||||
if dep:
|
||||
deps.add(dep)
|
||||
for dep in sorted(deps):
|
||||
try:
|
||||
print(f" Installing {dep}...")
|
||||
subprocess.run(
|
||||
["uv", "pip", "install", "--python", sys.executable, dep],
|
||||
capture_output=True, timeout=60,
|
||||
)
|
||||
print(f" ✓ Installed {dep}")
|
||||
except Exception:
|
||||
print(f" Warning: Could not install {dep}. Install manually: uv pip install {dep}")
|
||||
if deps:
|
||||
import importlib
|
||||
importlib.invalidate_caches()
|
||||
|
||||
|
||||
def _check_qdrant_path(path: str) -> tuple[bool, str]:
|
||||
"""Check that qdrant local storage parent dir is writable."""
|
||||
p = Path(path).expanduser()
|
||||
parent = p.parent
|
||||
try:
|
||||
parent.mkdir(parents=True, exist_ok=True)
|
||||
return True, f"Directory writable: {parent}"
|
||||
except OSError as e:
|
||||
return False, f"Cannot write to {parent}: {e}"
|
||||
|
||||
|
||||
def _check_ollama(url: str) -> tuple[bool, str]:
|
||||
"""Check Ollama is reachable via /api/tags."""
|
||||
try:
|
||||
req = urllib.request.Request(f"{url.rstrip('/')}/api/tags", method="GET")
|
||||
urllib.request.urlopen(req, timeout=3)
|
||||
return True, "Ollama reachable"
|
||||
except Exception as e:
|
||||
return False, f"Ollama not reachable at {url}: {e}"
|
||||
|
||||
|
||||
def _check_pgvector(host: str, port: int) -> tuple[bool, str]:
|
||||
"""Check PGVector via TCP socket."""
|
||||
try:
|
||||
sock = socket.create_connection((host, port), timeout=3)
|
||||
sock.close()
|
||||
return True, f"PGVector reachable at {host}:{port}"
|
||||
except Exception as e:
|
||||
return False, f"PGVector not reachable at {host}:{port}: {e}"
|
||||
|
||||
|
||||
def _run_connectivity_checks(oss_config: dict) -> None:
|
||||
"""Run connectivity checks and print warnings."""
|
||||
vs = oss_config.get("vector_store", {})
|
||||
if vs.get("provider") == "qdrant":
|
||||
path = vs.get("config", {}).get("path")
|
||||
url = vs.get("config", {}).get("url")
|
||||
if path:
|
||||
ok, msg = _check_qdrant_path(path)
|
||||
if not ok:
|
||||
print(f" Warning: {msg}")
|
||||
elif url:
|
||||
try:
|
||||
req = urllib.request.Request(f"{url.rstrip('/')}/healthz", method="GET")
|
||||
urllib.request.urlopen(req, timeout=3)
|
||||
except Exception as e:
|
||||
print(f" Warning: Qdrant not reachable at {url}: {e}")
|
||||
elif vs.get("provider") == "pgvector":
|
||||
cfg = vs.get("config", {})
|
||||
ok, msg = _check_pgvector(cfg.get("host", "localhost"), cfg.get("port", 5432))
|
||||
if not ok:
|
||||
print(f" Warning: {msg}")
|
||||
|
||||
llm = oss_config.get("llm", {})
|
||||
if llm.get("provider") == "ollama":
|
||||
url = llm.get("config", {}).get("ollama_base_url", "http://localhost:11434")
|
||||
ok, msg = _check_ollama(url)
|
||||
if not ok:
|
||||
print(f" Warning: {msg}")
|
||||
|
||||
|
||||
def _check_min_dep_version() -> None:
|
||||
"""Ensure mem0ai meets the minimum version from plugin.yaml."""
|
||||
try:
|
||||
import mem0
|
||||
installed_ver = getattr(mem0, "__version__", None)
|
||||
if not installed_ver:
|
||||
return
|
||||
installed_parts = tuple(int(x) for x in installed_ver.split(".")[:3])
|
||||
required_parts = (2, 0, 7)
|
||||
if installed_parts < required_parts:
|
||||
req_str = ".".join(str(x) for x in required_parts)
|
||||
print(f"\n ⚠ mem0ai {installed_ver} installed but >={req_str} required.")
|
||||
print(f" Run: uv pip install --python {sys.executable} 'mem0ai>={req_str}'")
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def post_setup(hermes_home: str, config: dict) -> None:
|
||||
"""Entry point called by hermes memory setup framework.
|
||||
|
||||
Only intercepts when OSS mode is requested (via --mode oss flag or
|
||||
interactive picker). For platform mode, returns without action so the
|
||||
framework's schema-based flow handles it (preserving the original
|
||||
platform onboarding experience).
|
||||
"""
|
||||
_check_min_dep_version()
|
||||
flags = parse_flags(sys.argv[1:])
|
||||
|
||||
if flags["mode"] == "oss":
|
||||
flags["_mode_from_flag"] = True
|
||||
_setup_oss(hermes_home, config, flags)
|
||||
return
|
||||
|
||||
if flags["mode"] == "platform":
|
||||
_setup_platform(hermes_home, config, flags)
|
||||
return
|
||||
|
||||
# No --mode flag: show interactive picker
|
||||
mode_items = [
|
||||
("Platform", "Mem0 Cloud API (lightweight, just needs an API key)"),
|
||||
("Open Source", "Run Mem0 locally (self-hosted LLM + vector store)"),
|
||||
]
|
||||
mode_idx = _curses_select(" Select mode", mode_items, 0)
|
||||
if mode_idx == 1:
|
||||
flags["_mode_from_flag"] = False
|
||||
_setup_oss(hermes_home, config, flags)
|
||||
else:
|
||||
_setup_platform(hermes_home, config, flags)
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
name: mem0
|
||||
version: 1.0.0
|
||||
version: 1.1.0
|
||||
description: "Mem0 — server-side LLM fact extraction with semantic search, reranking, and automatic deduplication."
|
||||
pip_dependencies:
|
||||
- mem0ai
|
||||
- mem0ai>=2.0.7,<3
|
||||
|
|
|
|||
|
|
@ -1410,6 +1410,8 @@ AUTHOR_MAP = {
|
|||
"caojiguang@gmail.com": "caojiguang", # PR #35117 carries #31853 (weixin _api_post/_api_get wait_for)
|
||||
"gooku94123@gmail.com": "goku94123", # PR #46609 salvage (MiniMax reasoning extra_body)
|
||||
# pander: empty email, salvaged via PR #19665 from #16126 by @ms-alan
|
||||
"chaithanya.kumar42a@gmail.com": "chaithanyak42", # PR #15624
|
||||
"kartik.labhshetwar@mem0.ai": "kartik-mem0", # PR #15624
|
||||
"ayman.a.kamal@hotmail.com": "A-kamal", # PR #18678 (xAI image resolution fix)
|
||||
# Kanban bug-fix batch salvage (May 2026)
|
||||
"frowte3k@gmail.com": "Frowtek", # salvage of #23206 (gateway --board auto-subscribe)
|
||||
|
|
|
|||
209
tests/plugins/memory/test_mem0_backend.py
Normal file
209
tests/plugins/memory/test_mem0_backend.py
Normal file
|
|
@ -0,0 +1,209 @@
|
|||
"""Tests for Mem0Backend abstraction — PlatformBackend and OSSBackend."""
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0._backend import Mem0Backend, PlatformBackend, OSSBackend
|
||||
|
||||
|
||||
class FakePlatformClient:
|
||||
"""Fake MemoryClient for PlatformBackend tests."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def search(self, query, **kwargs):
|
||||
self.calls.append(("search", query, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.9}]}
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.calls.append(("get_all", kwargs))
|
||||
return {"count": 1, "next": None, "results": [{"id": "m1", "memory": "fact1"}]}
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.calls.append(("add", messages, kwargs))
|
||||
return {"status": "PENDING", "event_id": "evt-1"}
|
||||
|
||||
def update(self, **kwargs):
|
||||
self.calls.append(("update", kwargs))
|
||||
return {"id": kwargs["memory_id"], "text": kwargs["text"]}
|
||||
|
||||
def delete(self, **kwargs):
|
||||
self.calls.append(("delete", kwargs))
|
||||
|
||||
|
||||
class TestPlatformBackend:
|
||||
|
||||
def _make(self):
|
||||
client = FakePlatformClient()
|
||||
backend = PlatformBackend.__new__(PlatformBackend)
|
||||
backend._client = client
|
||||
return backend, client
|
||||
|
||||
def test_search_forwards_params(self):
|
||||
backend, client = self._make()
|
||||
result = backend.search("test query", filters={"user_id": "u1"}, top_k=5)
|
||||
assert client.calls[0][0] == "search"
|
||||
assert client.calls[0][1] == "test query"
|
||||
assert client.calls[0][2]["filters"] == {"user_id": "u1"}
|
||||
assert client.calls[0][2]["top_k"] == 5
|
||||
|
||||
def test_search_forwards_rerank(self):
|
||||
backend, client = self._make()
|
||||
backend.search("q", filters={}, rerank=False)
|
||||
assert client.calls[0][2]["rerank"] is False
|
||||
|
||||
def test_search_rerank_default_true(self):
|
||||
backend, client = self._make()
|
||||
backend.search("q", filters={})
|
||||
assert client.calls[0][2]["rerank"] is True
|
||||
|
||||
def test_search_returns_list(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.search("q", filters={})
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["id"] == "m1"
|
||||
|
||||
def test_get_all_forwards_pagination(self):
|
||||
backend, client = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
|
||||
assert client.calls[0][1]["page"] == 2
|
||||
assert client.calls[0][1]["page_size"] == 50
|
||||
assert "count" in result
|
||||
|
||||
def test_add_forwards_kwargs(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
result = backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
|
||||
call = client.calls[0]
|
||||
assert call[2]["user_id"] == "u1"
|
||||
assert call[2]["infer"] is False
|
||||
# metadata kwarg should be omitted entirely when not provided so we
|
||||
# don't surprise older mem0 client versions with an unknown kwarg.
|
||||
assert "metadata" not in call[2]
|
||||
|
||||
def test_add_forwards_metadata_when_present(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(
|
||||
msgs,
|
||||
user_id="u1",
|
||||
agent_id="hermes",
|
||||
infer=False,
|
||||
metadata={"channel": "telegram"},
|
||||
)
|
||||
assert client.calls[0][2]["metadata"] == {"channel": "telegram"}
|
||||
|
||||
def test_add_omits_empty_metadata(self):
|
||||
backend, client = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False, metadata={})
|
||||
assert "metadata" not in client.calls[0][2]
|
||||
|
||||
def test_update_forwards(self):
|
||||
backend, client = self._make()
|
||||
backend.update("m1", "new text")
|
||||
assert client.calls[0][1] == {"memory_id": "m1", "text": "new text"}
|
||||
|
||||
def test_delete_forwards(self):
|
||||
backend, client = self._make()
|
||||
backend.delete("m1")
|
||||
assert client.calls[0][1] == {"memory_id": "m1"}
|
||||
|
||||
|
||||
class FakeOSSMemory:
|
||||
"""Fake mem0.Memory for OSSBackend tests."""
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
def search(self, query, **kwargs):
|
||||
self.calls.append(("search", query, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "score": 0.8}]}
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.calls.append(("get_all", kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1"}]}
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.calls.append(("add", messages, kwargs))
|
||||
return {"results": [{"id": "m1", "memory": "fact1", "event": "ADD"}]}
|
||||
|
||||
def update(self, memory_id, **kwargs):
|
||||
self.calls.append(("update", memory_id, kwargs))
|
||||
return {"message": "Memory updated successfully!"}
|
||||
|
||||
def delete(self, memory_id):
|
||||
self.calls.append(("delete", memory_id))
|
||||
return {"message": "Memory deleted successfully!"}
|
||||
|
||||
|
||||
class TestOSSBackend:
|
||||
|
||||
def _make(self):
|
||||
memory = FakeOSSMemory()
|
||||
backend = OSSBackend.__new__(OSSBackend)
|
||||
backend._memory = memory
|
||||
return backend, memory
|
||||
|
||||
def test_search_returns_list(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.search("test", filters={"user_id": "u1"})
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["id"] == "m1"
|
||||
|
||||
def test_search_passes_filters(self):
|
||||
backend, memory = self._make()
|
||||
backend.search("q", filters={"user_id": "u1"}, top_k=3)
|
||||
assert memory.calls[0][2]["filters"] == {"user_id": "u1"}
|
||||
assert memory.calls[0][2]["top_k"] == 3
|
||||
|
||||
def test_search_ignores_rerank(self):
|
||||
"""OSS backend accepts rerank param but does not forward it to Memory."""
|
||||
backend, memory = self._make()
|
||||
backend.search("q", filters={}, rerank=True)
|
||||
assert "rerank" not in memory.calls[0][2]
|
||||
|
||||
def test_get_all_ignores_pagination(self):
|
||||
"""OSSBackend accepts page/page_size but does NOT forward to Memory.get_all()."""
|
||||
backend, memory = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50)
|
||||
call_kwargs = memory.calls[0][1]
|
||||
assert "page" not in call_kwargs
|
||||
assert "page_size" not in call_kwargs
|
||||
assert result["count"] == 1
|
||||
|
||||
def test_get_all_returns_envelope(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.get_all(filters={"user_id": "u1"})
|
||||
assert "results" in result
|
||||
assert "count" in result
|
||||
|
||||
def test_add_forwards_kwargs(self):
|
||||
backend, memory = self._make()
|
||||
msgs = [{"role": "user", "content": "hi"}]
|
||||
backend.add(msgs, user_id="u1", agent_id="hermes", infer=False)
|
||||
assert memory.calls[0][2]["user_id"] == "u1"
|
||||
assert memory.calls[0][2]["infer"] is False
|
||||
|
||||
def test_update_maps_text_to_data(self):
|
||||
"""OSS Memory.update uses `data=` param, not `text=`."""
|
||||
backend, memory = self._make()
|
||||
backend.update("m1", "new text")
|
||||
assert memory.calls[0][0] == "update"
|
||||
assert memory.calls[0][1] == "m1"
|
||||
assert memory.calls[0][2] == {"data": "new text"}
|
||||
|
||||
def test_delete_positional_arg(self):
|
||||
backend, memory = self._make()
|
||||
backend.delete("m1")
|
||||
assert memory.calls[0] == ("delete", "m1")
|
||||
|
||||
def test_update_normalizes_response(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.update("m1", "text")
|
||||
assert result == {"result": "Memory updated.", "memory_id": "m1"}
|
||||
|
||||
def test_delete_normalizes_response(self):
|
||||
backend, _ = self._make()
|
||||
result = backend.delete("m1")
|
||||
assert result == {"result": "Memory deleted.", "memory_id": "m1"}
|
||||
107
tests/plugins/memory/test_mem0_providers.py
Normal file
107
tests/plugins/memory/test_mem0_providers.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
"""Tests for OSS provider definitions and validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0._oss_providers import (
|
||||
LLM_PROVIDERS,
|
||||
EMBEDDER_PROVIDERS,
|
||||
VECTOR_PROVIDERS,
|
||||
KNOWN_DIMS,
|
||||
validate_oss_config,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderDefinitions:
|
||||
|
||||
def test_llm_providers_have_required_keys(self):
|
||||
for pid, p in LLM_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "needs_key" in p
|
||||
assert "default_model" in p
|
||||
|
||||
def test_embedder_providers_have_required_keys(self):
|
||||
for pid, p in EMBEDDER_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "needs_key" in p
|
||||
assert "default_model" in p
|
||||
assert "dims" in p
|
||||
|
||||
def test_embedder_provider_ids(self):
|
||||
assert set(EMBEDDER_PROVIDERS.keys()) == {"openai", "ollama"}
|
||||
|
||||
def test_vector_providers_have_required_keys(self):
|
||||
for pid, p in VECTOR_PROVIDERS.items():
|
||||
assert "label" in p
|
||||
assert "default_config" in p
|
||||
|
||||
def test_vector_provider_ids(self):
|
||||
assert set(VECTOR_PROVIDERS.keys()) == {"qdrant", "pgvector"}
|
||||
|
||||
def test_known_dims_covers_defaults(self):
|
||||
for pid, p in EMBEDDER_PROVIDERS.items():
|
||||
assert p["default_model"] in KNOWN_DIMS
|
||||
|
||||
|
||||
class TestValidation:
|
||||
|
||||
def test_valid_openai_config(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {"model": "gpt-4o-mini"}},
|
||||
"embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}},
|
||||
"vector_store": {"provider": "qdrant", "config": {"path": "/tmp/test"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert errors == []
|
||||
|
||||
def test_unknown_llm_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "gemini", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("llm" in e.lower() for e in errors)
|
||||
|
||||
def test_unknown_embedder_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "cohere", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("embedder" in e.lower() for e in errors)
|
||||
|
||||
def test_unknown_vector_provider(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "redis", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("vector" in e.lower() for e in errors)
|
||||
|
||||
def test_missing_llm_section(self):
|
||||
cfg = {
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "qdrant", "config": {}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("llm" in e.lower() for e in errors)
|
||||
|
||||
def test_pgvector_needs_user(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "pgvector", "config": {"host": "localhost"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert any("user" in e.lower() for e in errors)
|
||||
|
||||
def test_pgvector_with_user_valid(self):
|
||||
cfg = {
|
||||
"llm": {"provider": "openai", "config": {}},
|
||||
"embedder": {"provider": "openai", "config": {}},
|
||||
"vector_store": {"provider": "pgvector", "config": {"host": "localhost", "user": "pg"}},
|
||||
}
|
||||
errors = validate_oss_config(cfg)
|
||||
assert errors == []
|
||||
251
tests/plugins/memory/test_mem0_setup.py
Normal file
251
tests/plugins/memory/test_mem0_setup.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
"""Tests for Mem0 setup wizard — flag parsing, config building, validation."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from plugins.memory.mem0._setup import (
|
||||
parse_flags,
|
||||
build_oss_config,
|
||||
_write_env,
|
||||
post_setup,
|
||||
_check_qdrant_path,
|
||||
_check_ollama,
|
||||
_check_pgvector,
|
||||
)
|
||||
|
||||
|
||||
def _inject_fake_hermes_cli(monkeypatch):
|
||||
"""Inject fake hermes_cli modules so yaml/curses aren't required."""
|
||||
fake_config_mod = types.ModuleType("hermes_cli.config")
|
||||
fake_config_mod.save_config = lambda c: None
|
||||
|
||||
fake_setup_mod = types.ModuleType("hermes_cli.memory_setup")
|
||||
fake_setup_mod._curses_select = lambda *a, **kw: 0
|
||||
fake_setup_mod._prompt = lambda label, default=None, secret=False: default or ""
|
||||
|
||||
fake_hermes_cli = types.ModuleType("hermes_cli")
|
||||
fake_hermes_cli.config = fake_config_mod
|
||||
fake_hermes_cli.memory_setup = fake_setup_mod
|
||||
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli", fake_hermes_cli)
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.config", fake_config_mod)
|
||||
monkeypatch.setitem(sys.modules, "hermes_cli.memory_setup", fake_setup_mod)
|
||||
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._curses_select", lambda *a, **kw: 0)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._prompt", lambda label, default=None, secret=False: default or "")
|
||||
return fake_config_mod
|
||||
|
||||
|
||||
class TestParseFlags:
|
||||
|
||||
def test_mode_platform(self):
|
||||
flags = parse_flags(["--mode", "platform", "--api-key", "sk-test"])
|
||||
assert flags["mode"] == "platform"
|
||||
assert flags["api_key"] == "sk-test"
|
||||
|
||||
def test_mode_oss_defaults(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
assert flags["mode"] == "oss"
|
||||
assert flags["oss_llm"] == "openai"
|
||||
assert flags["oss_embedder"] == "openai"
|
||||
assert flags["oss_vector"] == "qdrant"
|
||||
|
||||
def test_mode_oss_all_flags(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss",
|
||||
"--oss-llm", "ollama",
|
||||
"--oss-llm-model", "llama3:latest",
|
||||
"--oss-embedder", "ollama",
|
||||
"--oss-embedder-model", "nomic-embed-text",
|
||||
"--oss-vector", "pgvector",
|
||||
"--oss-vector-host", "db.local",
|
||||
"--oss-vector-port", "5433",
|
||||
"--oss-vector-user", "pguser",
|
||||
"--oss-vector-password", "secret",
|
||||
"--oss-vector-dbname", "memdb",
|
||||
"--user-id", "my-user",
|
||||
])
|
||||
assert flags["oss_llm"] == "ollama"
|
||||
assert flags["oss_llm_model"] == "llama3:latest"
|
||||
assert flags["oss_vector"] == "pgvector"
|
||||
assert flags["oss_vector_user"] == "pguser"
|
||||
assert flags["user_id"] == "my-user"
|
||||
|
||||
def test_no_flags_returns_empty_mode(self):
|
||||
flags = parse_flags([])
|
||||
assert flags["mode"] == ""
|
||||
|
||||
def test_oss_vector_path_flag(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-vector-path", "/data/qdrant"])
|
||||
assert flags["oss_vector_path"] == "/data/qdrant"
|
||||
|
||||
|
||||
class TestBuildOSSConfig:
|
||||
|
||||
def test_openai_defaults(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
oss, env_writes = build_oss_config(flags)
|
||||
assert oss["llm"]["provider"] == "openai"
|
||||
assert oss["llm"]["config"]["model"] == "gpt-5-mini"
|
||||
assert oss["embedder"]["provider"] == "openai"
|
||||
assert oss["embedder"]["config"]["model"] == "text-embedding-3-small"
|
||||
assert oss["vector_store"]["provider"] == "qdrant"
|
||||
assert env_writes["OPENAI_API_KEY"] == "sk-oai"
|
||||
|
||||
def test_ollama_no_key_needed(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm", "ollama", "--oss-embedder", "ollama"])
|
||||
oss, env_writes = build_oss_config(flags)
|
||||
assert oss["llm"]["provider"] == "ollama"
|
||||
assert "model" in oss["llm"]["config"]
|
||||
assert env_writes == {}
|
||||
|
||||
def test_embedder_reuses_llm_key(self):
|
||||
"""When LLM and embedder share same provider, key written once."""
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
_, env_writes = build_oss_config(flags)
|
||||
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
|
||||
|
||||
def test_different_embedder_needs_separate_key(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss",
|
||||
"--oss-llm", "ollama",
|
||||
"--oss-embedder", "openai", "--oss-embedder-key", "sk-oai",
|
||||
])
|
||||
_, env_writes = build_oss_config(flags)
|
||||
assert env_writes == {"OPENAI_API_KEY": "sk-oai"}
|
||||
|
||||
def test_pgvector_config(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
"--oss-vector", "pgvector",
|
||||
"--oss-vector-host", "db.local", "--oss-vector-port", "5433",
|
||||
"--oss-vector-user", "pg", "--oss-vector-dbname", "memdb",
|
||||
])
|
||||
oss, _ = build_oss_config(flags)
|
||||
vs = oss["vector_store"]
|
||||
assert vs["provider"] == "pgvector"
|
||||
assert vs["config"]["host"] == "db.local"
|
||||
assert vs["config"]["port"] == 5433
|
||||
assert vs["config"]["user"] == "pg"
|
||||
|
||||
def test_known_dims_auto_set(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"])
|
||||
oss, _ = build_oss_config(flags)
|
||||
dims = oss["embedder"]["config"].get("embedding_dims")
|
||||
assert dims == 1536
|
||||
|
||||
def test_custom_qdrant_path(self):
|
||||
flags = parse_flags([
|
||||
"--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
"--oss-vector-path", "/data/qdrant",
|
||||
])
|
||||
oss, _ = build_oss_config(flags)
|
||||
assert oss["vector_store"]["config"]["path"] == "/data/qdrant"
|
||||
|
||||
|
||||
class TestWriteEnv:
|
||||
|
||||
def test_write_new_vars(self, tmp_path):
|
||||
env_path = tmp_path / ".env"
|
||||
_write_env(env_path, {"OPENAI_API_KEY": "sk-test"})
|
||||
content = env_path.read_text()
|
||||
assert "OPENAI_API_KEY=sk-test" in content
|
||||
|
||||
def test_update_existing_var(self, tmp_path):
|
||||
env_path = tmp_path / ".env"
|
||||
env_path.write_text("OPENAI_API_KEY=old\nOTHER=keep\n")
|
||||
_write_env(env_path, {"OPENAI_API_KEY": "new"})
|
||||
content = env_path.read_text()
|
||||
assert "OPENAI_API_KEY=new" in content
|
||||
assert "OTHER=keep" in content
|
||||
assert "old" not in content
|
||||
|
||||
|
||||
class TestPostSetup:
|
||||
|
||||
def test_platform_flag_mode(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test"])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert config["memory"]["provider"] == "mem0"
|
||||
env_content = (tmp_path / ".env").read_text()
|
||||
assert "MEM0_API_KEY=sk-test" in env_content
|
||||
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
|
||||
assert mem0_json["mode"] == "platform"
|
||||
|
||||
def test_oss_flag_mode(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", [
|
||||
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai",
|
||||
])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert config["memory"]["provider"] == "mem0"
|
||||
mem0_json = json.loads((tmp_path / "mem0.json").read_text())
|
||||
assert mem0_json["mode"] == "oss"
|
||||
assert mem0_json["oss"]["llm"]["provider"] == "openai"
|
||||
|
||||
|
||||
class TestDryRun:
|
||||
|
||||
def test_dry_run_flag_parsed(self):
|
||||
flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run"])
|
||||
assert flags["dry_run"] is True
|
||||
|
||||
def test_dry_run_not_set_by_default(self):
|
||||
flags = parse_flags(["--mode", "oss"])
|
||||
assert flags["dry_run"] is False
|
||||
|
||||
def test_dry_run_platform_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test", "--dry-run"])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert not (tmp_path / ".env").exists()
|
||||
assert not (tmp_path / "mem0.json").exists()
|
||||
assert "provider" not in config["memory"]
|
||||
|
||||
def test_dry_run_oss_no_files(self, tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("sys.argv", [
|
||||
"hermes", "--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run",
|
||||
])
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path)
|
||||
_inject_fake_hermes_cli(monkeypatch)
|
||||
monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None)
|
||||
config = {"memory": {}}
|
||||
post_setup(str(tmp_path), config)
|
||||
assert not (tmp_path / ".env").exists()
|
||||
assert not (tmp_path / "mem0.json").exists()
|
||||
assert "provider" not in config["memory"]
|
||||
|
||||
|
||||
class TestConnectivityChecks:
|
||||
|
||||
def test_qdrant_path_writable(self, tmp_path):
|
||||
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
|
||||
assert ok is True
|
||||
|
||||
def test_qdrant_path_not_writable(self, tmp_path, monkeypatch):
|
||||
def _raise_oserror(*a, **kw):
|
||||
raise OSError("Permission denied")
|
||||
monkeypatch.setattr(Path, "mkdir", _raise_oserror)
|
||||
ok, msg = _check_qdrant_path(str(tmp_path / "qdrant"))
|
||||
assert ok is False
|
||||
assert "Permission denied" in msg
|
||||
|
||||
def test_ollama_unreachable(self):
|
||||
ok, msg = _check_ollama("http://localhost:1")
|
||||
assert ok is False
|
||||
|
||||
def test_pgvector_unreachable(self):
|
||||
ok, msg = _check_pgvector("localhost", 1)
|
||||
assert ok is False
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
"""Tests for Mem0 API v2 compatibility — filters param and dict response unwrapping.
|
||||
|
||||
Salvaged from PRs #5301 (qaqcvc) and #5117 (vvvanguards).
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import stat
|
||||
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
|
||||
class FakeClientV2:
|
||||
"""Fake Mem0 client that returns v2-style dict responses and captures call kwargs."""
|
||||
|
||||
def __init__(self, search_results=None, all_results=None):
|
||||
self._search_results = search_results or {"results": []}
|
||||
self._all_results = all_results or {"results": []}
|
||||
self.captured_search = {}
|
||||
self.captured_get_all = {}
|
||||
self.captured_add = []
|
||||
|
||||
def search(self, **kwargs):
|
||||
self.captured_search = kwargs
|
||||
return self._search_results
|
||||
|
||||
def get_all(self, **kwargs):
|
||||
self.captured_get_all = kwargs
|
||||
return self._all_results
|
||||
|
||||
def add(self, messages, **kwargs):
|
||||
self.captured_add.append({"messages": messages, **kwargs})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter migration: bare user_id= -> filters={}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMem0FiltersV2:
|
||||
"""All API calls must use filters={} instead of bare user_id= kwargs."""
|
||||
|
||||
def _make_provider(self, monkeypatch, client):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
return provider
|
||||
|
||||
def test_search_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3, "rerank": False})
|
||||
|
||||
assert client.captured_search["query"] == "hello"
|
||||
assert client.captured_search["top_k"] == 3
|
||||
assert client.captured_search["rerank"] is False
|
||||
assert client.captured_search["filters"] == {"user_id": "u123"}
|
||||
# Must NOT have bare user_id kwarg
|
||||
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
|
||||
|
||||
def test_profile_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_profile", {})
|
||||
|
||||
assert client.captured_get_all["filters"] == {"user_id": "u123"}
|
||||
assert "user_id" not in {k for k in client.captured_get_all if k != "filters"}
|
||||
|
||||
def test_prefetch_uses_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.queue_prefetch("hello")
|
||||
provider._prefetch_thread.join(timeout=2)
|
||||
|
||||
assert client.captured_search["query"] == "hello"
|
||||
assert client.captured_search["filters"] == {"user_id": "u123"}
|
||||
assert "user_id" not in {k for k in client.captured_search if k != "filters"}
|
||||
|
||||
def test_sync_turn_uses_write_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.sync_turn("user said this", "assistant replied", session_id="s1")
|
||||
provider._sync_thread.join(timeout=2)
|
||||
|
||||
assert len(client.captured_add) == 1
|
||||
call = client.captured_add[0]
|
||||
assert call["user_id"] == "u123"
|
||||
assert call["agent_id"] == "hermes"
|
||||
|
||||
def test_conclude_uses_write_filters(self, monkeypatch):
|
||||
client = FakeClientV2()
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
provider.handle_tool_call("mem0_conclude", {"conclusion": "user likes dark mode"})
|
||||
|
||||
assert len(client.captured_add) == 1
|
||||
call = client.captured_add[0]
|
||||
assert call["user_id"] == "u123"
|
||||
assert call["agent_id"] == "hermes"
|
||||
assert call["infer"] is False
|
||||
|
||||
def test_read_filters_no_agent_id(self):
|
||||
"""Read filters should use user_id only — cross-session recall across agents."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
assert provider._read_filters() == {"user_id": "u123"}
|
||||
|
||||
def test_write_filters_include_agent_id(self):
|
||||
"""Write filters should include agent_id for attribution."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
assert provider._write_filters() == {"user_id": "u123", "agent_id": "hermes"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dict response unwrapping (API v2 wraps in {"results": [...]})
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMem0ResponseUnwrapping:
|
||||
"""API v2 returns {"results": [...]} dicts; we must extract the list."""
|
||||
|
||||
def _make_provider(self, monkeypatch, client):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
return provider
|
||||
|
||||
def test_profile_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(all_results={"results": [{"memory": "alpha"}, {"memory": "beta"}]})
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
|
||||
assert result["count"] == 2
|
||||
assert "alpha" in result["result"]
|
||||
assert "beta" in result["result"]
|
||||
|
||||
def test_profile_list_response_backward_compat(self, monkeypatch):
|
||||
"""Old API returned bare lists — still works."""
|
||||
client = FakeClientV2(all_results=[{"memory": "gamma"}])
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert result["count"] == 1
|
||||
assert "gamma" in result["result"]
|
||||
|
||||
def test_search_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(search_results={
|
||||
"results": [{"memory": "foo", "score": 0.9}, {"memory": "bar", "score": 0.7}]
|
||||
})
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_search", {"query": "test", "top_k": 5}
|
||||
))
|
||||
|
||||
assert result["count"] == 2
|
||||
assert result["results"][0]["memory"] == "foo"
|
||||
|
||||
def test_search_list_response_backward_compat(self, monkeypatch):
|
||||
"""Old API returned bare lists — still works."""
|
||||
client = FakeClientV2(search_results=[{"memory": "baz", "score": 0.8}])
|
||||
provider = self._make_provider(monkeypatch, client)
|
||||
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_search", {"query": "test"}
|
||||
))
|
||||
assert result["count"] == 1
|
||||
|
||||
def test_unwrap_results_edge_cases(self):
|
||||
"""_unwrap_results handles all shapes gracefully."""
|
||||
assert Mem0MemoryProvider._unwrap_results({"results": [1, 2]}) == [1, 2]
|
||||
assert Mem0MemoryProvider._unwrap_results([3, 4]) == [3, 4]
|
||||
assert Mem0MemoryProvider._unwrap_results({}) == []
|
||||
assert Mem0MemoryProvider._unwrap_results(None) == []
|
||||
assert Mem0MemoryProvider._unwrap_results("unexpected") == []
|
||||
|
||||
def test_prefetch_dict_response(self, monkeypatch):
|
||||
client = FakeClientV2(search_results={
|
||||
"results": [{"memory": "user prefers dark mode"}]
|
||||
})
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
monkeypatch.setattr(provider, "_get_client", lambda: client)
|
||||
|
||||
provider.queue_prefetch("preferences")
|
||||
provider._prefetch_thread.join(timeout=2)
|
||||
result = provider.prefetch("preferences")
|
||||
|
||||
assert "dark mode" in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default preservation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits not enforced on Windows")
|
||||
def test_save_config_sets_owner_only_permissions(tmp_path):
|
||||
"""mem0.json must be written with 0o600 so API key is not world-readable."""
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.save_config({"api_key": "m0-test-key"}, str(tmp_path))
|
||||
config_file = tmp_path / "mem0.json"
|
||||
assert config_file.exists()
|
||||
mode = stat.S_IMODE(config_file.stat().st_mode)
|
||||
assert mode == 0o600, f"Expected 0o600 (owner-only), got {oct(mode)}"
|
||||
|
||||
|
||||
class TestMem0Defaults:
|
||||
"""Ensure we don't break existing users' defaults."""
|
||||
|
||||
def test_default_user_id_hermes_user(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
|
||||
assert provider._user_id == "hermes-user"
|
||||
|
||||
def test_default_agent_id_hermes(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
monkeypatch.delenv("MEM0_AGENT_ID", raising=False)
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
|
||||
assert provider._agent_id == "hermes"
|
||||
463
tests/plugins/memory/test_mem0_v3.py
Normal file
463
tests/plugins/memory/test_mem0_v3.py
Normal file
|
|
@ -0,0 +1,463 @@
|
|||
"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools."""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from plugins.memory.mem0 import Mem0MemoryProvider
|
||||
|
||||
|
||||
class FakeBackend:
|
||||
"""Fake Mem0Backend for provider-level tests."""
|
||||
|
||||
def __init__(self, search_results=None, all_results=None):
|
||||
self._search_results = search_results or []
|
||||
self._all_results = all_results or {"results": [], "count": 0}
|
||||
self.captured = []
|
||||
|
||||
def search(self, query, *, filters, top_k=10, rerank=True):
|
||||
self.captured.append(("search", query, {"filters": filters, "top_k": top_k, "rerank": rerank}))
|
||||
return self._search_results
|
||||
|
||||
def get_all(self, *, filters, page=1, page_size=100):
|
||||
self.captured.append(("get_all", {"filters": filters, "page": page, "page_size": page_size}))
|
||||
return self._all_results
|
||||
|
||||
def add(self, messages, *, user_id, agent_id, infer=False, metadata=None):
|
||||
self.captured.append((
|
||||
"add",
|
||||
messages,
|
||||
{"user_id": user_id, "agent_id": agent_id, "infer": infer, "metadata": metadata},
|
||||
))
|
||||
return {"status": "PENDING", "event_id": "evt-test-123"}
|
||||
|
||||
def update(self, memory_id, text):
|
||||
self.captured.append(("update", memory_id, text))
|
||||
return {"result": "Memory updated.", "memory_id": memory_id}
|
||||
|
||||
def delete(self, memory_id):
|
||||
self.captured.append(("delete", memory_id))
|
||||
return {"result": "Memory deleted.", "memory_id": memory_id}
|
||||
|
||||
|
||||
class TestMem0V3Tools:
|
||||
"""Test v3 tool names and response handling."""
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_list_returns_paginated_with_ids(self, monkeypatch):
|
||||
backend = FakeBackend(all_results={
|
||||
"count": 2,
|
||||
"results": [
|
||||
{"id": "mem-1", "memory": "alpha"},
|
||||
{"id": "mem-2", "memory": "beta"},
|
||||
]
|
||||
})
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_list", {}))
|
||||
assert result["count"] == 2
|
||||
assert result["results"][0]["id"] == "mem-1"
|
||||
assert result["results"][0]["memory"] == "alpha"
|
||||
|
||||
def test_list_pagination_params(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_list", {"page": 2, "page_size": 50})
|
||||
assert backend.captured[0][1]["page"] == 2
|
||||
assert backend.captured[0][1]["page_size"] == 50
|
||||
|
||||
def test_list_empty(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_list", {}))
|
||||
assert result["result"] == "No memories stored yet."
|
||||
|
||||
def test_search_returns_ids(self, monkeypatch):
|
||||
backend = FakeBackend(search_results=[{"id": "mem-1", "memory": "foo", "score": 0.9}])
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_search", {"query": "test"}))
|
||||
assert result["results"][0]["id"] == "mem-1"
|
||||
|
||||
def test_search_uses_filters(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3})
|
||||
assert backend.captured[0][2]["filters"] == {"user_id": "u123"}
|
||||
assert backend.captured[0][2]["top_k"] == 3
|
||||
|
||||
def test_search_rerank_default_true(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "test"})
|
||||
assert backend.captured[0][2]["rerank"] is True
|
||||
|
||||
def test_search_rerank_override_false(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_search", {"query": "test", "rerank": False})
|
||||
assert backend.captured[0][2]["rerank"] is False
|
||||
|
||||
def test_add_uses_content_param(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"}))
|
||||
assert len(backend.captured) == 1
|
||||
call = backend.captured[0]
|
||||
assert call[2]["infer"] is False
|
||||
assert call[2]["user_id"] == "u123"
|
||||
assert call[2]["agent_id"] == "hermes"
|
||||
assert "event_id" in result
|
||||
|
||||
def test_add_returns_event_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {"content": "test"}))
|
||||
assert result["event_id"] == "evt-test-123"
|
||||
|
||||
def test_add_missing_content(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_add", {}))
|
||||
assert "error" in result
|
||||
|
||||
def test_old_tool_names_return_unknown(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert "error" in result
|
||||
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0UpdateDelete:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_update_calls_sdk(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "mem-1", "text": "updated fact"}
|
||||
))
|
||||
assert backend.captured[0][1] == "mem-1"
|
||||
assert backend.captured[0][2] == "updated fact"
|
||||
assert result["result"] == "Memory updated."
|
||||
assert result["memory_id"] == "mem-1"
|
||||
|
||||
def test_update_missing_memory_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_update", {"text": "no id"}))
|
||||
assert "error" in result
|
||||
|
||||
def test_update_missing_text(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_update", {"memory_id": "mem-1"}))
|
||||
assert "error" in result
|
||||
|
||||
def test_delete_calls_sdk(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "mem-1"}
|
||||
))
|
||||
assert backend.captured[0][1] == "mem-1"
|
||||
assert result["result"] == "Memory deleted."
|
||||
|
||||
def test_delete_missing_memory_id(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_delete", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0ErrorHandling:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_update_404_no_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("404 Not Found"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "bad-id", "text": "x"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_delete_404_no_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.delete = lambda mid: (_ for _ in ()).throw(Exception("404 not found"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "bad-id"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_update_validation_error_no_circuit_breaker(self, monkeypatch):
|
||||
"""ValidationError (bad UUID format) should not trip circuit breaker."""
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(
|
||||
ValidationError('{"error":"memory_id should be a valid UUID"}')
|
||||
)
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_update", {"memory_id": "not-a-uuid", "text": "x"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_delete_validation_error_no_circuit_breaker(self, monkeypatch):
|
||||
class ValidationError(Exception):
|
||||
pass
|
||||
backend = FakeBackend()
|
||||
backend.delete = lambda mid: (_ for _ in ()).throw(
|
||||
ValidationError('{"error":"memory_id should be a valid UUID"}')
|
||||
)
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call(
|
||||
"mem0_delete", {"memory_id": "not-a-uuid"}
|
||||
))
|
||||
assert "error" in result
|
||||
assert provider._consecutive_failures == 0
|
||||
|
||||
def test_update_5xx_trips_circuit_breaker(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("500 Internal Server Error"))
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.handle_tool_call("mem0_update", {"memory_id": "mem-1", "text": "x"})
|
||||
assert provider._consecutive_failures == 1
|
||||
|
||||
|
||||
class TestMem0V3Internal:
|
||||
|
||||
def _make_provider(self, monkeypatch, backend):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test-session")
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._backend = backend
|
||||
return provider
|
||||
|
||||
def test_sync_turn_explicit_kwargs(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
provider.sync_turn("user said", "assistant replied", session_id="s1")
|
||||
provider._sync_thread.join(timeout=2)
|
||||
assert len(backend.captured) == 1
|
||||
call = backend.captured[0]
|
||||
assert call[2]["user_id"] == "u123"
|
||||
assert call[2]["agent_id"] == "hermes"
|
||||
assert call[2]["infer"] is True
|
||||
|
||||
def test_old_tool_names_return_unknown(self, monkeypatch):
|
||||
backend = FakeBackend()
|
||||
provider = self._make_provider(monkeypatch, backend)
|
||||
result = json.loads(provider.handle_tool_call("mem0_profile", {}))
|
||||
assert "error" in result
|
||||
result = json.loads(provider.handle_tool_call("mem0_conclude", {}))
|
||||
assert "error" in result
|
||||
|
||||
|
||||
class TestMem0V3Config:
|
||||
|
||||
def test_tool_schemas_five_tools(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
names = [s["name"] for s in schemas]
|
||||
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
|
||||
|
||||
def test_system_prompt_new_tool_names(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
block = provider.system_prompt_block()
|
||||
assert "mem0_search" in block
|
||||
assert "mem0_add" in block
|
||||
assert "mem0_list" in block
|
||||
assert "mem0_update" in block
|
||||
assert "mem0_delete" in block
|
||||
assert "mem0_profile" not in block
|
||||
assert "mem0_conclude" not in block
|
||||
|
||||
def test_system_prompt_shows_platform_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "platform"
|
||||
block = provider.system_prompt_block()
|
||||
assert "platform" in block
|
||||
assert "Rerank" in block
|
||||
|
||||
def test_system_prompt_shows_oss_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "oss"
|
||||
block = provider.system_prompt_block()
|
||||
assert "OSS" in block
|
||||
assert "Rerank" not in block
|
||||
|
||||
def test_search_schema_has_rerank(self):
|
||||
"""rerank property available in SEARCH_SCHEMA for platform mode."""
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
search = next(s for s in schemas if s["name"] == "mem0_search")
|
||||
assert "rerank" in search["parameters"]["properties"]
|
||||
assert search["parameters"]["properties"]["rerank"]["type"] == "boolean"
|
||||
|
||||
|
||||
class TestMem0ModeSwitch:
|
||||
|
||||
def test_default_mode_is_platform(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
assert provider._mode == "platform"
|
||||
|
||||
def test_missing_mode_key_defaults_platform(self, monkeypatch, tmp_path):
|
||||
"""Backward compat: old mem0.json without mode key works."""
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"user_id": "old-user"}')
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
provider.initialize("test")
|
||||
assert provider._mode == "platform"
|
||||
assert provider._user_id == "old-user"
|
||||
|
||||
def test_is_available_platform_needs_key(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.delenv("MEM0_API_KEY", raising=False)
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is False
|
||||
|
||||
def test_is_available_oss_needs_vector(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"mode": "oss", "oss": {"vector_store": {"provider": "qdrant"}}}')
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is True
|
||||
|
||||
def test_is_available_oss_no_vector(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
config_path = tmp_path / "mem0.json"
|
||||
config_path.write_text('{"mode": "oss", "oss": {}}')
|
||||
provider = Mem0MemoryProvider()
|
||||
assert provider.is_available() is False
|
||||
|
||||
def test_tool_schemas_unchanged(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
schemas = provider.get_tool_schemas()
|
||||
names = [s["name"] for s in schemas]
|
||||
assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"]
|
||||
|
||||
def test_system_prompt_includes_mode(self):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "test"
|
||||
provider._mode = "oss"
|
||||
block = provider.system_prompt_block()
|
||||
assert "mem0_search" in block
|
||||
assert "mem0_list" in block
|
||||
assert "OSS" in block
|
||||
|
||||
|
||||
class TestMem0UserIdResolution:
|
||||
"""user_id resolution: configured override > gateway-native id > placeholder.
|
||||
|
||||
Same human across CLI / Telegram / Discord / Slack / etc. should map to
|
||||
the same memory store when MEM0_USER_ID is set, and only fall back to the
|
||||
gateway-native id when it isn't.
|
||||
"""
|
||||
|
||||
def _provider(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("HERMES_HOME", str(tmp_path))
|
||||
monkeypatch.setenv("MEM0_API_KEY", "test-key")
|
||||
provider = Mem0MemoryProvider()
|
||||
# Skip backend instantiation — we only care about identity resolution.
|
||||
provider._create_backend = lambda: None # type: ignore[method-assign]
|
||||
return provider
|
||||
|
||||
def test_env_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.setenv("MEM0_USER_ID", "ryan@example.com")
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "ryan@example.com"
|
||||
|
||||
def test_file_override_beats_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
(tmp_path / "mem0.json").write_text('{"user_id": "ryan@example.com"}')
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "ryan@example.com"
|
||||
|
||||
def test_unset_falls_back_to_gateway_native_id(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "123456789"
|
||||
|
||||
def test_unset_and_no_kwargs_falls_back_to_default(self, monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test")
|
||||
assert provider._user_id == "hermes-user"
|
||||
|
||||
def test_legacy_placeholder_in_config_does_not_override_kwargs(self, monkeypatch, tmp_path):
|
||||
# Setup wizard historically wrote {"user_id": "hermes-user"} as the
|
||||
# suggested default. Treat that placeholder as unset so users on
|
||||
# gateways still get gateway-native ids — not silent collisions.
|
||||
monkeypatch.delenv("MEM0_USER_ID", raising=False)
|
||||
(tmp_path / "mem0.json").write_text('{"user_id": "hermes-user"}')
|
||||
provider = self._provider(monkeypatch, tmp_path)
|
||||
provider.initialize("test", user_id="123456789", platform="telegram")
|
||||
assert provider._user_id == "123456789"
|
||||
|
||||
|
||||
class TestMem0WriteMetadata:
|
||||
"""Writes carry metadata.channel so per-channel filtered views are possible
|
||||
without coupling identity to the channel.
|
||||
"""
|
||||
|
||||
def _make_provider(self, channel: str = "cli"):
|
||||
provider = Mem0MemoryProvider()
|
||||
provider._user_id = "u123"
|
||||
provider._agent_id = "hermes"
|
||||
provider._channel = channel
|
||||
provider._backend = FakeBackend()
|
||||
return provider
|
||||
|
||||
def test_add_tool_passes_channel_metadata(self):
|
||||
provider = self._make_provider("telegram")
|
||||
provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"})
|
||||
call = provider._backend.captured[-1]
|
||||
assert call[2]["metadata"] == {"channel": "telegram"}
|
||||
|
||||
def test_sync_turn_passes_channel_metadata(self):
|
||||
provider = self._make_provider("discord")
|
||||
provider.sync_turn("hi", "hello", session_id="s")
|
||||
# sync_turn fires a daemon thread; wait for it.
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
adds = [c for c in provider._backend.captured if c[0] == "add"]
|
||||
assert adds, "expected an add call from sync_turn"
|
||||
assert adds[-1][2]["metadata"] == {"channel": "discord"}
|
||||
|
|
@ -315,31 +315,55 @@ echo "OPENVIKING_API_KEY=..." >> ~/.hermes/.env
|
|||
|
||||
### Mem0
|
||||
|
||||
Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication.
|
||||
Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication. Supports both Mem0 Platform (cloud) and OSS (self-hosted) modes.
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| **Best for** | Hands-off memory management — Mem0 handles extraction automatically |
|
||||
| **Requires** | `pip install mem0ai` + API key |
|
||||
| **Data storage** | Mem0 Cloud |
|
||||
| **Cost** | Mem0 pricing |
|
||||
| **Requires** | `pip install mem0ai` + API key (platform) or LLM/vector store (OSS) |
|
||||
| **Data storage** | Mem0 Cloud (platform) or self-hosted (OSS) |
|
||||
| **Cost** | Mem0 pricing (platform) / free (OSS) |
|
||||
|
||||
**Tools:** `mem0_profile` (all stored memories), `mem0_search` (semantic search + reranking), `mem0_conclude` (store verbatim facts)
|
||||
**Tools (5):** `mem0_list` (list all memories, paginated), `mem0_search` (semantic search with reranking in platform mode), `mem0_add` (store verbatim facts), `mem0_update` (update by ID), `mem0_delete` (delete by ID)
|
||||
|
||||
**Setup:**
|
||||
**Setup (Platform):**
|
||||
```bash
|
||||
hermes memory setup # select "mem0"
|
||||
hermes memory setup # select "mem0" → "Platform"
|
||||
# Or manually:
|
||||
hermes config set memory.provider mem0
|
||||
echo "MEM0_API_KEY=your-key" >> ~/.hermes/.env
|
||||
```
|
||||
|
||||
**Config:** `$HERMES_HOME/mem0.json`
|
||||
**Setup (OSS):**
|
||||
```bash
|
||||
hermes memory setup # select "mem0" → "Open Source (self-hosted)"
|
||||
# Or via flags:
|
||||
hermes memory setup mem0 --mode oss --oss-llm openai --oss-llm-key sk-... --oss-vector qdrant
|
||||
```
|
||||
|
||||
Preview without writing files:
|
||||
```bash
|
||||
hermes memory setup mem0 --mode oss --oss-llm-key sk-... --dry-run
|
||||
```
|
||||
|
||||
**Config:** `$HERMES_HOME/mem0.json` (behavioral settings). Only the secret `MEM0_API_KEY` belongs in `~/.hermes/.env`.
|
||||
|
||||
| Key | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `mode` | `platform` | `platform` (Mem0 Cloud) or `oss` (self-hosted) |
|
||||
| `user_id` | `hermes-user` | User identifier |
|
||||
| `agent_id` | `hermes` | Agent identifier |
|
||||
| `rerank` | `true` | Rerank search results for relevance (platform mode only) |
|
||||
|
||||
**OSS supported providers:**
|
||||
|
||||
| Component | Providers |
|
||||
|-----------|-----------|
|
||||
| LLM | openai, ollama |
|
||||
| Embedder | openai, ollama |
|
||||
| Vector Store | qdrant (local/server), pgvector |
|
||||
|
||||
**Switching modes:** Re-run `hermes memory setup mem0 --mode <platform|oss>` or edit `mem0.json` directly.
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -569,7 +593,7 @@ hermes memory setup
|
|||
|----------|---------|------|-------|-------------|----------------|
|
||||
| **Honcho** | Cloud | Paid | 5 | `honcho-ai` | Dialectic user modeling + session-scoped context |
|
||||
| **OpenViking** | Self-hosted | Free | 5 | `openviking` + server | Filesystem hierarchy + tiered loading |
|
||||
| **Mem0** | Cloud | Paid | 3 | `mem0ai` | Server-side LLM extraction |
|
||||
| **Mem0** | Cloud/Self-hosted | Free/Paid | 5 | `mem0ai` | Server-side LLM extraction + OSS mode |
|
||||
| **Hindsight** | Cloud/Local | Free/Paid | 3 | `hindsight-client` | Knowledge graph + reflect synthesis |
|
||||
| **Holographic** | Local | Free | 2 | None | HRR algebra + trust scoring |
|
||||
| **RetainDB** | Cloud | $20/mo | 5 | `requests` | Delta compression |
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue