From 2e779d11a03dbe37db8309a80750763b4b8d1b45 Mon Sep 17 00:00:00 2001 From: Kartik Date: Mon, 22 Jun 2026 18:00:47 +0530 Subject: [PATCH] feat(mem0): v3 API, OSS mode, update/delete tools, telemetry & review fixes (#15624) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: update to version 3 endpoints and adding update and delete tool * chore: removing the test md file * fix: prevent circuit breaker on client errors in Mem0 provider * chore: add telemetry for platform version * feat: add OSS mode support to Mem0 memory provider * chore: bump mem0ai dependency to >=2.0.1 in memory plugin * refactor: enhance dependency checks and embedder config in mem0 backend * refactor: adjust fact storage message for OSS mode * refactor: expand user paths, add collection recreation on dimension change for Qdrant * fix(mem0): make MEM0_USER_ID override gateway-native ids and tag writes with channel When MEM0_USER_ID was configured (env or mem0.json), the gateway-native id from kwargs (Telegram numeric id, Discord snowflake, ...) still won, so the same human ended up under different user_ids per channel and memories never merged across CLI / Telegram / Slack / Discord. Mirrors openclaw's cfg.userId pattern: configured override wins, gateway-native id is the fallback. The legacy "hermes-user" placeholder default written by the setup wizard is treated as unset to avoid silently bucketing every gateway user together. Also tag every write with metadata.channel (cli/telegram/discord/...) so the dashboard can offer per-channel filtered views without coupling identity to the channel; document the read/write filter asymmetry as intentional (reads scope to user_id only for cross-agent recall). Co-Authored-By: Claude Opus 4.7 (1M context) * refactor: improve Mem0 memory provider backend, pagination, config, and error handling * refactor: update mem0 telemetry code, docs, and bump version * fix(mem0): make get_config_schema() return unified schema with mode-aware required flag Schema always includes api_key field so picker shows "API key / local" for both modes. In OSS mode api_key.required=False so status won't mislead. Co-Authored-By: Claude Opus 4.6 * refactor: improve mem0 telemetry, add env var key and OSS mode detection * chore: bump mem0ai lower bound to 2.0.4 (latest SDK release) * refactor: set telemetry sample rate to 1.0 and update docs for opt‑out * fix(mem0): resolve 15 correctness, thread-safety, and resource bugs Thread safety: - Protect circuit breaker counters with _breaker_lock (race between prefetch/sync daemon threads and main thread) - Wrap sync_turn thread creation in _sync_lock; skip if previous sync is still alive after 5 s join to prevent duplicate memory ingestion - Guard _schedule_flush timer creation under _queue_lock (TOCTOU race) - Capture local `backend` reference in prefetch/sync closures so shutdown() nulling self._backend cannot crash in-flight threads Correctness: - Fix bool("false")==True for rerank param; parse string values explicitly - Guard page/top_k with max(1,...) and move int() inside try blocks - Fix fact_count=0 always in OSS mode (Memory.add returns list, not dict) - Fix prefetch() not clearing result when thread still alive after timeout - Fix atexit.register accumulating on repeated initialize() calls Backend / setup: - Handle Qdrant named-vector collections in _recreate_collection_if_dims_changed (vectors is a dict; .size access raised AttributeError, swallowed silently) - Wrap QdrantClient and psycopg2 conn/cursor in try/finally to prevent leaks - Resolve ollama_bin at top of _ensure_ollama; use it for ollama pull - Fix embedder key lookup when LLM provider has no env_var (e.g. ollama) Also: remove _telemetry_enabled cache (env var check is cheap), bump required mem0ai to >=2.0.7, minor README wording fix. * fix(mem0): fix brittle qdrant path test + add telemetry sample-rate docs - Replace generator-throw lambda with a proper def in test_qdrant_path_not_writable; use tmp_path instead of a hardcoded /nonexistent path so the test is root-safe - Add MEM0_TELEMETRY_SAMPLE_RATE to memory-providers.md (was only in the plugin README, not the user-guide docs) * revert: remove MEM0_TELEMETRY_SAMPLE_RATE from user-guide docs * refactor: remove telemetry from mem0 plugin and update documentation * fix(mem0): set stdin=DEVNULL on setup subprocess calls The TUI stdin guard (scripts/check_subprocess_stdin.py) requires every subprocess call in plugin code to set stdin= so it can't inherit the gateway's JSON-RPC stdin fd. Muzzle the docker/ollama calls in the OSS setup wizard with stdin=subprocess.DEVNULL (none need interactive input). Also covers the docker-inspect call the linter's regex misses. --------- Co-authored-by: chaithanyak42 Co-authored-by: Claude Opus 4.7 (1M context) --- plugins/memory/mem0/README.md | 145 ++- plugins/memory/mem0/__init__.py | 460 +++++++--- plugins/memory/mem0/_backend.py | 243 +++++ plugins/memory/mem0/_oss_providers.py | 84 ++ plugins/memory/mem0/_setup.py | 858 ++++++++++++++++++ plugins/memory/mem0/plugin.yaml | 4 +- scripts/release.py | 2 + tests/plugins/memory/test_mem0_backend.py | 209 +++++ tests/plugins/memory/test_mem0_providers.py | 107 +++ tests/plugins/memory/test_mem0_setup.py | 251 +++++ tests/plugins/memory/test_mem0_v2.py | 241 ----- tests/plugins/memory/test_mem0_v3.py | 463 ++++++++++ .../user-guide/features/memory-providers.md | 42 +- 13 files changed, 2688 insertions(+), 421 deletions(-) create mode 100644 plugins/memory/mem0/_backend.py create mode 100644 plugins/memory/mem0/_oss_providers.py create mode 100644 plugins/memory/mem0/_setup.py create mode 100644 tests/plugins/memory/test_mem0_backend.py create mode 100644 tests/plugins/memory/test_mem0_providers.py create mode 100644 tests/plugins/memory/test_mem0_setup.py delete mode 100644 tests/plugins/memory/test_mem0_v2.py create mode 100644 tests/plugins/memory/test_mem0_v3.py diff --git a/plugins/memory/mem0/README.md b/plugins/memory/mem0/README.md index 62c7494af77..53046b08e3a 100644 --- a/plugins/memory/mem0/README.md +++ b/plugins/memory/mem0/README.md @@ -1,53 +1,152 @@ # Mem0 Memory Provider -Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication. - -Supports both [Mem0 Cloud](https://app.mem0.ai) and self-hosted instances. +Server-side LLM fact extraction with semantic search and hybrid multi-signal retrieval via the Mem0 Platform v3 API. ## Requirements - `pip install mem0ai` -- Mem0 Cloud API key **or** a self-hosted Mem0 server +- Mem0 API key from [app.mem0.ai](https://app.mem0.ai) ## Setup -### Cloud - ```bash hermes memory setup # select "mem0" ``` Or manually: - ```bash hermes config set memory.provider mem0 echo "MEM0_API_KEY=your-key" >> ~/.hermes/.env ``` -### Self-Hosted - -```bash -hermes config set memory.provider mem0 -echo "MEM0_HOST=http://your-mem0-server:24220" >> ~/.hermes/.env -echo "MEM0_API_KEY=your-api-key" >> ~/.hermes/.env # if auth is enabled -``` - ## Config -Config file: `$HERMES_HOME/mem0.json` +Behavioral settings live in `$HERMES_HOME/mem0.json` (set them via `hermes memory setup`). Only the secret `MEM0_API_KEY` belongs in `~/.hermes/.env`. | Key | Default | Description | |-----|---------|-------------| -| `api_key` | — | API key (required for cloud; optional for self-hosted without auth) | -| `host` | `https://api.mem0.ai` | Self-hosted Mem0 URL. When set, overrides the cloud endpoint. | -| `user_id` | `hermes-user` | User identifier | +| `mode` | `platform` | `platform` (Mem0 Cloud) or `oss` (self-hosted) | +| `user_id` | `hermes-user` | User identifier on Mem0 | | `agent_id` | `hermes` | Agent identifier | -| `rerank` | `true` | Enable reranking for recall | +| `rerank` | `true` | Rerank search results for relevance (platform mode only) | + +## OSS (Self-Hosted) Mode + +Run Mem0 locally with your own LLM, embedder, and vector store. + +### Interactive Setup + +```bash +hermes memory setup +# Select "mem0" → "Open Source (self-hosted)" +# Follow prompts for LLM, embedder, and vector store +``` + +### Agent-Driven Setup (Flags) + +```bash +hermes memory setup mem0 --mode oss \ + --oss-llm openai --oss-llm-key sk-... \ + --oss-vector qdrant +``` + +### Supported Providers + +| Component | Providers | +|-----------|-----------| +| LLM | openai, ollama | +| Embedder | openai, ollama | +| Vector Store | qdrant (local/server), pgvector | + +### Flags Reference + +| Flag | Description | +|------|-------------| +| `--mode` | `platform` or `oss` | +| `--oss-llm` | LLM provider (default: openai) | +| `--oss-llm-key` | LLM API key | +| `--oss-embedder` | Embedder provider (default: openai) | +| `--oss-vector` | Vector store (default: qdrant) | +| `--oss-vector-path` | Qdrant local path | +| `--user-id` | User identifier | + +## Switching Modes + +### Platform to OSS + +```bash +hermes memory setup mem0 --mode oss --oss-llm-key sk-... +``` + +Or edit `$HERMES_HOME/mem0.json` directly: +```json +{ + "mode": "oss", + "oss": { + "llm": {"provider": "openai", "config": {"model": "gpt-5-mini"}}, + "embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}}, + "vector_store": {"provider": "qdrant", "config": {"path": "~/.hermes/mem0_qdrant"}} + } +} +``` + +### OSS to Platform + +```bash +hermes memory setup mem0 --mode platform --api-key sk-... +``` + +### Dry Run (preview without writing) + +```bash +hermes memory setup mem0 --mode oss --oss-llm-key sk-... --dry-run +``` ## Tools | Tool | Description | |------|-------------| -| `mem0_profile` | All stored memories about the user | -| `mem0_search` | Semantic search with optional reranking | -| `mem0_conclude` | Store a fact verbatim (no LLM extraction) | +| `mem0_list` | List all stored memories (paginated) | +| `mem0_search` | Semantic search by meaning | +| `mem0_add` | Store a fact verbatim (no LLM extraction) | +| `mem0_update` | Update a memory's text by ID | +| `mem0_delete` | Delete a memory by ID | + +## Troubleshooting + +### "Mem0 temporarily unavailable" + +Circuit breaker tripped after 5 consecutive failures. Resets after 2 minutes. + +- **Platform mode**: Check API key and internet connectivity. +- **OSS mode**: Check that your vector store (qdrant/pgvector) is running. + +### OSS: Qdrant connection refused + +```bash +# If using local Qdrant, check the storage path is writable: +ls -la ~/.hermes/mem0_qdrant + +# If using Qdrant server, check it's reachable: +curl http://localhost:6333/healthz +``` + +### OSS: PGVector connection refused + +```bash +# Verify PostgreSQL is running and accepting connections: +pg_isready -h localhost -p 5432 +``` + +### OSS: Ollama not reachable + +```bash +# Check Ollama is running: +curl http://localhost:11434/api/tags +``` + +### Memories not appearing + +- `mem0_add` stores verbatim (no extraction). Use `sync_turn` for LLM extraction. +- Search uses semantic matching — try broader queries. +- Check `user_id` matches between sessions (`$HERMES_HOME/mem0.json`). diff --git a/plugins/memory/mem0/__init__.py b/plugins/memory/mem0/__init__.py index 65cd2f355d1..eccf6ad53fe 100644 --- a/plugins/memory/mem0/__init__.py +++ b/plugins/memory/mem0/__init__.py @@ -1,21 +1,33 @@ """Mem0 memory plugin — MemoryProvider interface. -Server-side LLM fact extraction, semantic search with reranking, and -automatic deduplication via the Mem0 Platform API or self-hosted instance. +Server-side LLM fact extraction, semantic search, and automatic deduplication +via the Mem0 Platform API (cloud) or OSS (self-hosted) via Memory. Original PR #2933 by kartik-mem0, adapted to MemoryProvider ABC. -Config via environment variables: - MEM0_API_KEY — Mem0 API key (required for cloud, optional for self-hosted) - MEM0_HOST — Self-hosted Mem0 URL (default: https://api.mem0.ai) - MEM0_USER_ID — User identifier (default: hermes-user) - MEM0_AGENT_ID — Agent identifier (default: hermes) +Configuration +------------- +Secret (lives in $HERMES_HOME/.env or the environment): + MEM0_API_KEY — Mem0 Platform API key (required for platform mode) -Or via $HERMES_HOME/mem0.json. +Behavioral settings (live in $HERMES_HOME/mem0.json, set via `hermes memory +setup`): + mode — Backend mode: "platform" (default) or "oss" + user_id — Canonical user identifier. When set, it is applied + uniformly across every gateway (CLI, Telegram, Slack, + Discord, …) so the same human gets one merged memory + store. When unset, the gateway-native id (e.g. Telegram + numeric id, Discord snowflake) is used instead. + agent_id — Agent identifier (default: hermes) + +The matching MEM0_MODE / MEM0_USER_ID / MEM0_AGENT_ID environment variables are +still read as a backward-compatible fallback, but mem0.json is the canonical +home for these non-secret settings. """ from __future__ import annotations +import atexit import json import logging import os @@ -33,12 +45,29 @@ logger = logging.getLogger(__name__) _BREAKER_THRESHOLD = 5 _BREAKER_COOLDOWN_SECS = 120 +_CLIENT_ERROR_TYPES = ("MemoryNotFoundError", "ValidationError") + +# Sentinel returned when neither MEM0_USER_ID nor a gateway-native id is +# available. Treated as "no operator-configured user_id" by initialize() so +# that legacy mem0.json files written by the setup wizard (which historically +# wrote this exact placeholder) still allow gateway-native ids to flow +# through instead of silently overriding them with the placeholder. +_DEFAULT_USER_ID = "hermes-user" + + +def _is_client_error(exc: Exception) -> bool: + """True for user-caused errors (bad ID, not found) that should NOT trip circuit breaker.""" + etype = type(exc).__name__ + if etype in _CLIENT_ERROR_TYPES: + return True + err_str = str(exc).lower() + return "404" in err_str or "not found" in err_str or "valid uuid" in err_str + # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- - def _load_config() -> dict: """Load config from env vars, with $HERMES_HOME/mem0.json overrides. @@ -49,13 +78,17 @@ def _load_config() -> dict: from hermes_constants import get_hermes_home config = { + "mode": os.environ.get("MEM0_MODE", "platform"), "api_key": os.environ.get("MEM0_API_KEY", ""), - "host": os.environ.get("MEM0_HOST", ""), - "user_id": os.environ.get("MEM0_USER_ID", "hermes-user"), "agent_id": os.environ.get("MEM0_AGENT_ID", "hermes"), - "rerank": True, - "keyword_search": False, + "oss": {}, } + # Only carry user_id when the operator explicitly configured one (env or + # mem0.json). An absent key tells initialize() to fall back to the + # gateway-native id from kwargs instead of overriding it with a placeholder. + env_user_id = os.environ.get("MEM0_USER_ID") + if env_user_id: + config["user_id"] = env_user_id config_path = get_hermes_home() / "mem0.json" if config_path.exists(): @@ -73,34 +106,40 @@ def _load_config() -> dict: # Tool schemas # --------------------------------------------------------------------------- -PROFILE_SCHEMA = { - "name": "mem0_profile", +LIST_SCHEMA = { + "name": "mem0_list", "description": ( - "Retrieve all stored memories about the user — preferences, facts, " - "project context. Fast, no reranking. Use at conversation start." + "List all stored memories about the user. " + "Use at conversation start for full overview." ), - "parameters": {"type": "object", "properties": {}, "required": []}, + "parameters": { + "type": "object", + "properties": { + "page": {"type": "integer", "description": "Page number (default: 1)."}, + "page_size": {"type": "integer", "description": "Results per page (default: 100, max: 200)."}, + }, + "required": [], + }, } SEARCH_SCHEMA = { "name": "mem0_search", "description": ( - "Search memories by meaning. Returns relevant facts ranked by similarity. " - "Set rerank=true for higher accuracy on important queries." + "Search memories by meaning. Returns relevant facts ranked by relevance." ), "parameters": { "type": "object", "properties": { "query": {"type": "string", "description": "What to search for."}, - "rerank": {"type": "boolean", "description": "Enable reranking for precision (default: false)."}, "top_k": {"type": "integer", "description": "Max results (default: 10, max: 50)."}, + "rerank": {"type": "boolean", "description": "Rerank results for relevance (default: true, platform mode only)."}, }, "required": ["query"], }, } -CONCLUDE_SCHEMA = { - "name": "mem0_conclude", +ADD_SCHEMA = { + "name": "mem0_add", "description": ( "Store a durable fact about the user. Stored verbatim (no LLM extraction). " "Use for explicit preferences, corrections, or decisions." @@ -108,9 +147,34 @@ CONCLUDE_SCHEMA = { "parameters": { "type": "object", "properties": { - "conclusion": {"type": "string", "description": "The fact to store."}, + "content": {"type": "string", "description": "The fact to store."}, }, - "required": ["conclusion"], + "required": ["content"], + }, +} + +UPDATE_SCHEMA = { + "name": "mem0_update", + "description": "Update an existing memory's text by its ID.", + "parameters": { + "type": "object", + "properties": { + "memory_id": {"type": "string", "description": "Memory UUID to update."}, + "text": {"type": "string", "description": "New text content."}, + }, + "required": ["memory_id", "text"], + }, +} + +DELETE_SCHEMA = { + "name": "mem0_delete", + "description": "Delete a memory by its ID.", + "parameters": { + "type": "object", + "properties": { + "memory_id": {"type": "string", "description": "Memory UUID to delete."}, + }, + "required": ["memory_id"], }, } @@ -122,19 +186,17 @@ CONCLUDE_SCHEMA = { class Mem0MemoryProvider(MemoryProvider): """Mem0 memory with server-side extraction and semantic search. - Supports both Mem0 Cloud (api.mem0.ai) and self-hosted instances - via the ``host`` config key or ``MEM0_HOST`` env var. + Supports Platform API (cloud) and OSS (self-hosted) modes via MEM0_MODE. """ def __init__(self): self._config = None - self._client = None - self._client_lock = threading.Lock() + self._backend = None + self._mode = "platform" self._api_key = "" - self._host = "" - self._user_id = "hermes-user" + self._user_id = _DEFAULT_USER_ID self._agent_id = "hermes" - self._rerank = True + self._channel = "cli" # gateway channel name (cli/telegram/discord/...) self._prefetch_result = "" self._prefetch_lock = threading.Lock() self._prefetch_thread = None @@ -142,6 +204,9 @@ class Mem0MemoryProvider(MemoryProvider): # Circuit breaker state self._consecutive_failures = 0 self._breaker_open_until = 0.0 + self._breaker_lock = threading.Lock() + self._sync_lock = threading.Lock() + self._atexit_registered = False @property def name(self) -> str: @@ -149,9 +214,10 @@ class Mem0MemoryProvider(MemoryProvider): def is_available(self) -> bool: cfg = _load_config() - host = cfg.get("host", "") - api_key = cfg.get("api_key", "") - return bool(host) or bool(api_key) + mode = cfg.get("mode", "platform") + if mode == "oss": + return bool(cfg.get("oss", {}).get("vector_store")) + return bool(cfg.get("api_key")) def save_config(self, values, hermes_home): """Write config to $HERMES_HOME/mem0.json.""" @@ -169,95 +235,130 @@ class Mem0MemoryProvider(MemoryProvider): atomic_json_write(config_path, existing, mode=0o600) def get_config_schema(self): + cfg = _load_config() + mode = cfg.get("mode", "platform") + api_key_required = mode != "oss" return [ - {"key": "api_key", "description": "Mem0 API key (cloud or self-hosted)", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"}, - {"key": "host", "description": "Self-hosted Mem0 URL (e.g. http://localhost:24220)", "default": "", "env_var": "MEM0_HOST"}, + {"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": api_key_required, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"}, {"key": "user_id", "description": "User identifier", "default": "hermes-user"}, {"key": "agent_id", "description": "Agent identifier", "default": "hermes"}, {"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]}, ] - def _get_client(self): - """Thread-safe client accessor with lazy initialization.""" - with self._client_lock: - if self._client is not None: - return self._client - try: - from mem0 import MemoryClient - kwargs = {} - if self._host: - kwargs["host"] = self._host - if self._api_key: - kwargs["api_key"] = self._api_key - elif not self._host: - raise ValueError("Mem0: either api_key or host is required") - self._client = MemoryClient(**kwargs) - return self._client - except ImportError: - raise RuntimeError("mem0 package not installed. Run: pip install mem0ai") + def post_setup(self, hermes_home: str, config: dict) -> None: + from ._setup import post_setup + post_setup(hermes_home, config) + + def _create_backend(self): + try: + if self._mode == "oss": + from ._backend import OSSBackend + return OSSBackend(self._config.get("oss", {})) + from ._backend import PlatformBackend + return PlatformBackend(self._api_key) + except Exception as e: + logger.error("Mem0 backend failed to initialize (%s mode): %s", self._mode, e) + self._init_error = str(e) + return None def _is_breaker_open(self) -> bool: """Return True if the circuit breaker is tripped (too many failures).""" - if self._consecutive_failures < _BREAKER_THRESHOLD: - return False - if time.monotonic() >= self._breaker_open_until: - # Cooldown expired — reset and allow a retry - self._consecutive_failures = 0 - return False - return True + with self._breaker_lock: + if self._consecutive_failures < _BREAKER_THRESHOLD: + return False + if time.monotonic() >= self._breaker_open_until: + self._consecutive_failures = 0 + return False + return True + + def _format_error(self, prefix: str, exc: Exception) -> str: + msg = f"{prefix}: {exc}" + if self._mode == "oss": + err_str = str(exc).lower() + if "connection" in err_str or "refused" in err_str or "timeout" in err_str: + vs = self._config.get("oss", {}).get("vector_store", {}) + msg += f" (check that {vs.get('provider', 'vector store')} is running)" + return msg def _record_success(self): - self._consecutive_failures = 0 + with self._breaker_lock: + self._consecutive_failures = 0 def _record_failure(self): - self._consecutive_failures += 1 - if self._consecutive_failures >= _BREAKER_THRESHOLD: - self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS + with self._breaker_lock: + self._consecutive_failures += 1 + count = self._consecutive_failures + if count >= _BREAKER_THRESHOLD: + self._breaker_open_until = time.monotonic() + _BREAKER_COOLDOWN_SECS + else: + count = 0 + if count >= _BREAKER_THRESHOLD: + hint = "" + if self._mode == "oss": + vs = self._config.get("oss", {}).get("vector_store", {}) + provider = vs.get("provider", "unknown") + hint = f" Check that your {provider} vector store is running and reachable." logger.warning( "Mem0 circuit breaker tripped after %d consecutive failures. " - "Pausing API calls for %ds.", - self._consecutive_failures, _BREAKER_COOLDOWN_SECS, + "Pausing API calls for %ds.%s", + count, _BREAKER_COOLDOWN_SECS, hint, ) def initialize(self, session_id: str, **kwargs) -> None: self._config = _load_config() + self._mode = self._config.get("mode", "platform") self._api_key = self._config.get("api_key", "") - self._host = self._config.get("host", "") - # Prefer gateway-provided user_id for per-user memory scoping; - # fall back to config/env default for CLI (single-user) sessions. - self._user_id = kwargs.get("user_id") or self._config.get("user_id", "hermes-user") + # Resolution order for user_id: + # 1. Operator-configured MEM0_USER_ID (env or $HERMES_HOME/mem0.json) — + # the canonical principal, applied across every gateway so the same + # human gets one merged memory store. + # 2. Gateway-native id from kwargs (Telegram numeric id, Discord + # snowflake, etc.) — preserves per-platform isolation when no + # override is configured. + # 3. Hardcoded fallback _DEFAULT_USER_ID (CLI with no auth). + # The literal _DEFAULT_USER_ID string is treated as unset so users who + # ran the setup wizard with the suggested default still get gateway- + # native ids instead of being silently bucketed together. + configured = self._config.get("user_id") + if configured == _DEFAULT_USER_ID: + configured = None + self._user_id = configured or kwargs.get("user_id") or _DEFAULT_USER_ID self._agent_id = self._config.get("agent_id", "hermes") - self._rerank = self._config.get("rerank", True) + self._channel = kwargs.get("platform") or "cli" + self._backend = self._create_backend() + if self._backend and not self._atexit_registered: + atexit.register(self._shutdown_backend) + self._atexit_registered = True def _read_filters(self) -> Dict[str, Any]: - """Filters for search/get_all — scoped to user only for cross-session recall.""" + # Scoped to user_id only — by design — so recall surfaces memories + # written from any gateway/agent under this principal. Writes attach + # agent_id (and metadata.channel) so per-agent / per-channel views are + # still possible at query time when needed; reads default to the wider + # cross-agent recall. return {"user_id": self._user_id} - def _write_filters(self) -> Dict[str, Any]: - """Filters for add — scoped to user + agent for attribution.""" - return {"user_id": self._user_id, "agent_id": self._agent_id} - - @staticmethod - def _unwrap_results(response: Any) -> list: - """Normalize Mem0 API response — v2 wraps results in {"results": [...]}.""" - if isinstance(response, dict): - return response.get("results", []) - if isinstance(response, list): - return response - return [] + def _write_metadata(self) -> Dict[str, Any]: + # Tag every write with the gateway channel so the dashboard can offer + # per-channel filtered views without coupling identity to the channel. + return {"channel": self._channel} if self._channel else {} def system_prompt_block(self) -> str: - target = self._host or "cloud" + mode_label = "platform (cloud API)" if self._mode == "platform" else "OSS (self-hosted)" + rerank_note = " Rerank is available on search." if self._mode == "platform" else "" return ( - f"# Mem0 Memory ({target})\n" - f"Active. User: {self._user_id}.\n" - "Use mem0_search to find memories, mem0_conclude to store facts, " - "mem0_profile for a full overview." + "# Mem0 Memory\n" + f"Active. Mode: {mode_label}. User: {self._user_id}.\n" + "Use mem0_search to find memories, mem0_add to store facts, " + f"mem0_list for a full overview, mem0_update and mem0_delete to manage by ID.{rerank_note}" ) def prefetch(self, query: str, *, session_id: str = "") -> str: if self._prefetch_thread and self._prefetch_thread.is_alive(): self._prefetch_thread.join(timeout=3.0) + # If the thread still hasn't finished, leave the result for the next call. + if self._prefetch_thread and self._prefetch_thread.is_alive(): + return "" with self._prefetch_lock: result = self._prefetch_result self._prefetch_result = "" @@ -266,18 +367,15 @@ class Mem0MemoryProvider(MemoryProvider): return f"## Mem0 Memory\n{result}" def queue_prefetch(self, query: str, *, session_id: str = "") -> None: - if self._is_breaker_open(): + if self._backend is None or self._is_breaker_open(): return def _run(): + backend = self._backend + if backend is None: + return try: - client = self._get_client() - results = self._unwrap_results(client.search( - query=query, - filters=self._read_filters(), - rerank=self._rerank, - top_k=5, - )) + results = backend.search(query=query, filters=self._read_filters(), top_k=5, rerank=True) if results: lines = [r.get("memory", "") for r in results if r.get("memory")] with self._prefetch_lock: @@ -292,101 +390,171 @@ class Mem0MemoryProvider(MemoryProvider): def sync_turn(self, user_content: str, assistant_content: str, *, session_id: str = "") -> None: """Send the turn to Mem0 for server-side fact extraction (non-blocking).""" - if self._is_breaker_open(): + if self._backend is None or self._is_breaker_open(): return def _sync(): + backend = self._backend + if backend is None: + return try: - client = self._get_client() messages = [ {"role": "user", "content": user_content}, {"role": "assistant", "content": assistant_content}, ] - client.add(messages, **self._write_filters()) + backend.add( + messages, + user_id=self._user_id, + agent_id=self._agent_id, + infer=True, + metadata=self._write_metadata(), + ) self._record_success() except Exception as e: self._record_failure() logger.warning("Mem0 sync failed: %s", e) - # Wait for any previous sync before starting a new one - if self._sync_thread and self._sync_thread.is_alive(): - self._sync_thread.join(timeout=5.0) - - self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync") - self._sync_thread.start() + with self._sync_lock: + if self._sync_thread and self._sync_thread.is_alive(): + self._sync_thread.join(timeout=5.0) + # If still alive after timeout, skip to avoid duplicate ingestion. + if self._sync_thread and self._sync_thread.is_alive(): + return + self._sync_thread = threading.Thread(target=_sync, daemon=True, name="mem0-sync") + self._sync_thread.start() def get_tool_schemas(self) -> List[Dict[str, Any]]: - return [PROFILE_SCHEMA, SEARCH_SCHEMA, CONCLUDE_SCHEMA] + return [LIST_SCHEMA, SEARCH_SCHEMA, ADD_SCHEMA, UPDATE_SCHEMA, DELETE_SCHEMA] def handle_tool_call(self, tool_name: str, args: dict, **kwargs) -> str: + if self._backend is None: + err = getattr(self, "_init_error", "unknown error") + hint = "" + if self._mode == "oss": + vs = self._config.get("oss", {}).get("vector_store", {}) + provider = vs.get("provider", "vector store") + hint = f" Check that {provider} is running and reachable." + return json.dumps({"error": f"Mem0 backend not initialized: {err}.{hint}"}) + if self._is_breaker_open(): - return json.dumps({ - "error": "Mem0 API temporarily unavailable (multiple consecutive failures). Will retry automatically." - }) + msg = "Mem0 temporarily unavailable (multiple consecutive failures). Will retry automatically." + if self._mode == "oss": + vs = self._config.get("oss", {}).get("vector_store", {}) + msg += f" Check that your {vs.get('provider', 'vector store')} is running." + return json.dumps({"error": msg}) - try: - client = self._get_client() - except Exception as e: - return tool_error(str(e)) - - if tool_name == "mem0_profile": + if tool_name == "mem0_list": try: - memories = self._unwrap_results(client.get_all(filters=self._read_filters())) + page = max(1, int(args.get("page", 1))) + page_size = min(max(1, int(args.get("page_size", 100))), 200) + response = self._backend.get_all( + filters=self._read_filters(), page=page, page_size=page_size, + ) self._record_success() - if not memories: + results = response.get("results", []) + if not results: return json.dumps({"result": "No memories stored yet."}) - lines = [m.get("memory", "") for m in memories if m.get("memory")] - return json.dumps({"result": "\n".join(lines), "count": len(lines)}) + items = [{"id": m.get("id"), "memory": m.get("memory", "")} + for m in results] + return json.dumps({ + "results": items, + "count": response.get("count", len(items)), + "page": page, "page_size": page_size, + }) except Exception as e: - self._record_failure() - return tool_error(f"Failed to fetch profile: {e}") + if not _is_client_error(e): + self._record_failure() + return tool_error(self._format_error("Failed to list memories", e)) elif tool_name == "mem0_search": query = args.get("query", "") if not query: return tool_error("Missing required parameter: query") - rerank = args.get("rerank", False) - top_k = min(int(args.get("top_k", 10)), 50) try: - results = self._unwrap_results(client.search( - query=query, - filters=self._read_filters(), - rerank=rerank, - top_k=top_k, - )) + top_k = max(1, min(int(args.get("top_k", 10)), 50)) + rerank_raw = args.get("rerank", True) + if isinstance(rerank_raw, str): + rerank = rerank_raw.lower() not in ("false", "0", "no") + else: + rerank = bool(rerank_raw) + results = self._backend.search(query, filters=self._read_filters(), top_k=top_k, rerank=rerank) self._record_success() if not results: return json.dumps({"result": "No relevant memories found."}) - items = [{"memory": r.get("memory", ""), "score": r.get("score", 0)} for r in results] + items = [{"id": r.get("id"), "memory": r.get("memory", ""), + "score": r.get("score", 0)} for r in results] return json.dumps({"results": items, "count": len(items)}) except Exception as e: - self._record_failure() - return tool_error(f"Search failed: {e}") + if not _is_client_error(e): + self._record_failure() + return tool_error(self._format_error("Search failed", e)) - elif tool_name == "mem0_conclude": - conclusion = args.get("conclusion", "") - if not conclusion: - return tool_error("Missing required parameter: conclusion") + elif tool_name == "mem0_add": + content = args.get("content", "") + if not content: + return tool_error("Missing required parameter: content") try: - client.add( - [{"role": "user", "content": conclusion}], - **self._write_filters(), + result = self._backend.add( + [{"role": "user", "content": content}], + user_id=self._user_id, + agent_id=self._agent_id, infer=False, + metadata=self._write_metadata(), ) self._record_success() - return json.dumps({"result": "Fact stored."}) + event_id = result.get("event_id") if isinstance(result, dict) else None + msg = "Fact stored." if self._mode == "oss" else "Fact queued for storage." + return json.dumps({"result": msg, "event_id": event_id}) except Exception as e: self._record_failure() - return tool_error(f"Failed to store: {e}") + return tool_error(self._format_error("Failed to store", e)) + + elif tool_name == "mem0_update": + memory_id = args.get("memory_id", "") + text = args.get("text", "") + if not memory_id: + return tool_error("Missing required parameter: memory_id") + if not text: + return tool_error("Missing required parameter: text") + try: + result = self._backend.update(memory_id, text) + self._record_success() + return json.dumps(result) + except Exception as e: + if _is_client_error(e): + return tool_error(f"Memory not found: {memory_id}") + self._record_failure() + return tool_error(self._format_error("Update failed", e)) + + elif tool_name == "mem0_delete": + memory_id = args.get("memory_id", "") + if not memory_id: + return tool_error("Missing required parameter: memory_id") + try: + result = self._backend.delete(memory_id) + self._record_success() + return json.dumps(result) + except Exception as e: + if _is_client_error(e): + return tool_error(f"Memory not found: {memory_id}") + self._record_failure() + return tool_error(self._format_error("Delete failed", e)) return tool_error(f"Unknown tool: {tool_name}") + def _shutdown_backend(self): + try: + if self._backend: + self._backend.close() + self._backend = None + except Exception: + pass + def shutdown(self) -> None: for t in (self._prefetch_thread, self._sync_thread): if t and t.is_alive(): t.join(timeout=5.0) - with self._client_lock: - self._client = None + self._shutdown_backend() def register(ctx) -> None: diff --git a/plugins/memory/mem0/_backend.py b/plugins/memory/mem0/_backend.py new file mode 100644 index 00000000000..429a4f741be --- /dev/null +++ b/plugins/memory/mem0/_backend.py @@ -0,0 +1,243 @@ +"""Backend abstraction for Mem0 Platform and OSS modes.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + + +class Mem0Backend(ABC): + """Unified interface over Platform (MemoryClient) and OSS (Memory) backends.""" + + @abstractmethod + def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]: + ... + + @abstractmethod + def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict: + ... + + @abstractmethod + def add( + self, + messages: list, + *, + user_id: str, + agent_id: str, + infer: bool = False, + metadata: dict | None = None, + ) -> dict: + ... + + @abstractmethod + def update(self, memory_id: str, text: str) -> dict: + ... + + @abstractmethod + def delete(self, memory_id: str) -> dict: + ... + + def close(self) -> None: + pass + + +def _unwrap_results(response: Any) -> list: + """Normalize API response — extract results list from dict or pass through.""" + if isinstance(response, dict): + return response.get("results", []) + if isinstance(response, list): + return response + return [] + + +class PlatformBackend(Mem0Backend): + """Wraps mem0.MemoryClient for Mem0 Platform (cloud API).""" + + def __init__(self, api_key: str): + from mem0 import MemoryClient + self._client = MemoryClient(api_key=api_key) + + def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]: + response = self._client.search(query, filters=filters, top_k=top_k, rerank=rerank) + return _unwrap_results(response) + + def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict: + response = self._client.get_all(filters=filters, page=page, page_size=page_size) + results = response.get("results", []) if isinstance(response, dict) else response + count = response.get("count", len(results)) if isinstance(response, dict) else len(results) + return {"results": results, "count": count} + + def add( + self, + messages: list, + *, + user_id: str, + agent_id: str, + infer: bool = False, + metadata: dict | None = None, + ) -> dict: + kwargs: dict[str, Any] = {"user_id": user_id, "agent_id": agent_id, "infer": infer} + if metadata: + kwargs["metadata"] = metadata + return self._client.add(messages, **kwargs) + + def update(self, memory_id: str, text: str) -> dict: + self._client.update(memory_id=memory_id, text=text) + return {"result": "Memory updated.", "memory_id": memory_id} + + def delete(self, memory_id: str) -> dict: + self._client.delete(memory_id=memory_id) + return {"result": "Memory deleted.", "memory_id": memory_id} + + +class OSSBackend(Mem0Backend): + """Wraps mem0.Memory for self-hosted (OSS) mode.""" + + def __init__(self, oss_config: dict): + import os + from mem0 import Memory + + vector_store = dict(oss_config["vector_store"]) + vs_config = dict(vector_store.get("config", {})) + + if "path" in vs_config: + vs_config["path"] = os.path.expanduser(vs_config["path"]) + + embedder_config = oss_config.get("embedder", {}).get("config", {}) + dims = embedder_config.get("embedding_dims") + if not dims: + from ._oss_providers import KNOWN_DIMS + model = embedder_config.get("model", "") + dims = KNOWN_DIMS.get(model) + if dims: + vs_config["embedding_model_dims"] = dims + self._recreate_collection_if_dims_changed( + vector_store.get("provider", "qdrant"), vs_config, dims, + ) + + vector_store["config"] = vs_config + + config = { + "vector_store": vector_store, + "llm": oss_config["llm"], + "embedder": oss_config["embedder"], + "version": "v1.1", + } + self._memory = Memory.from_config(config) + + @staticmethod + def _recreate_collection_if_dims_changed(provider: str, vs_config: dict, expected_dims: int) -> None: + """Delete stale vector collection when embedding dimensions change.""" + collection_name = vs_config.get("collection_name", "mem0") + if provider == "qdrant": + try: + from qdrant_client import QdrantClient + path = vs_config.get("path") + url = vs_config.get("url") + if path: + client = QdrantClient(path=path) + elif url: + client = QdrantClient(url=url, api_key=vs_config.get("api_key")) + else: + return + try: + if not client.collection_exists(collection_name): + return + info = client.get_collection(collection_name) + vectors = info.config.params.vectors + # Named-vector collections expose a dict; unnamed expose an object with .size. + if isinstance(vectors, dict): + first = next(iter(vectors.values()), None) + current_dims = first.size if first else None + else: + current_dims = getattr(vectors, "size", None) + if current_dims is not None and current_dims != expected_dims: + client.delete_collection(collection_name) + finally: + client.close() + except Exception: + pass + elif provider == "pgvector": + try: + import psycopg2 + from psycopg2 import sql as pgsql + conn_params = {} + for k in ("host", "port", "user", "password", "dbname"): + if vs_config.get(k): + conn_params[k] = vs_config[k] + if vs_config.get("sslmode"): + conn_params["sslmode"] = vs_config["sslmode"] + conn = psycopg2.connect(**conn_params) + conn.autocommit = True + try: + cur = conn.cursor() + try: + cur.execute( + "SELECT atttypmod FROM pg_attribute " + "WHERE attrelid = %s::regclass AND attname = 'vector'", + (collection_name,), + ) + row = cur.fetchone() + if row and row[0] > 0 and row[0] != expected_dims: + cur.execute(pgsql.SQL("DROP TABLE IF EXISTS {}").format( + pgsql.Identifier(collection_name) + )) + finally: + cur.close() + finally: + conn.close() + except Exception: + pass + + def search(self, query: str, *, filters: dict, top_k: int = 10, rerank: bool = True) -> list[dict]: + response = self._memory.search(query, filters=filters, top_k=top_k) + return _unwrap_results(response) + + def get_all(self, *, filters: dict, page: int = 1, page_size: int = 100) -> dict: + response = self._memory.get_all(filters=filters) + all_results = _unwrap_results(response) + total = len(all_results) + start = (page - 1) * page_size + results = all_results[start : start + page_size] + return {"results": results, "count": total} + + def add( + self, + messages: list, + *, + user_id: str, + agent_id: str, + infer: bool = False, + metadata: dict | None = None, + ) -> dict: + kwargs: dict[str, Any] = {"user_id": user_id, "agent_id": agent_id, "infer": infer} + if metadata: + kwargs["metadata"] = metadata + return self._memory.add(messages, **kwargs) + + def update(self, memory_id: str, text: str) -> dict: + self._memory.update(memory_id, data=text) + return {"result": "Memory updated.", "memory_id": memory_id} + + def delete(self, memory_id: str) -> dict: + self._memory.delete(memory_id) + return {"result": "Memory deleted.", "memory_id": memory_id} + + def close(self): + try: + telemetry = getattr(self._memory, "telemetry", None) + if telemetry and hasattr(telemetry, "posthog"): + try: + telemetry.posthog.shutdown() + except Exception: + pass + if hasattr(self._memory, "close"): + self._memory.close() + vs = getattr(self._memory, "vector_store", None) + if vs and hasattr(vs, "close"): + vs.close() + client = getattr(vs, "client", None) + if client and hasattr(client, "close"): + client.close() + except Exception: + pass diff --git a/plugins/memory/mem0/_oss_providers.py b/plugins/memory/mem0/_oss_providers.py new file mode 100644 index 00000000000..fa36e73a91f --- /dev/null +++ b/plugins/memory/mem0/_oss_providers.py @@ -0,0 +1,84 @@ +"""OSS provider definitions for LLM, embedder, and vector store.""" + +from __future__ import annotations + +import os +from typing import Any + +LLM_PROVIDERS: dict[str, dict[str, Any]] = { + "openai": { + "label": "OpenAI", + "needs_key": True, + "env_var": "OPENAI_API_KEY", + "default_model": "gpt-5-mini", + }, + "ollama": { + "label": "Ollama (local)", + "needs_key": False, + "default_model": "llama3.1:8b", + "default_url": "http://localhost:11434", + "pip_dep": "ollama", + }, +} + +EMBEDDER_PROVIDERS: dict[str, dict[str, Any]] = { + "openai": { + "label": "OpenAI", + "needs_key": True, + "env_var": "OPENAI_API_KEY", + "default_model": "text-embedding-3-small", + "dims": 1536, + }, + "ollama": { + "label": "Ollama (local)", + "needs_key": False, + "default_model": "nomic-embed-text", + "default_url": "http://localhost:11434", + "dims": 768, + "pip_dep": "ollama", + }, +} + +VECTOR_PROVIDERS: dict[str, dict[str, Any]] = { + "qdrant": { + "label": "Qdrant", + "default_config": {"path": os.path.expanduser("~/.hermes/mem0_qdrant")}, + "pip_dep": "qdrant-client", + }, + "pgvector": { + "label": "PGVector", + "default_config": {"host": "localhost", "port": 5432, "user": os.getenv("USER", "postgres"), "dbname": "postgres"}, + "pip_dep": "psycopg2-binary", + }, +} + +KNOWN_DIMS: dict[str, int] = { + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "text-embedding-ada-002": 1536, + "nomic-embed-text": 768, +} + + +def validate_oss_config(oss_config: dict) -> list[str]: + """Validate an OSS config dict. Returns list of error strings (empty = valid).""" + errors: list[str] = [] + + for section, registry in [("llm", LLM_PROVIDERS), ("embedder", EMBEDDER_PROVIDERS), + ("vector_store", VECTOR_PROVIDERS)]: + block = oss_config.get(section) + if not block or not isinstance(block, dict): + errors.append(f"Missing required section: {section}") + continue + provider_id = block.get("provider", "") + if provider_id not in registry: + valid = ", ".join(registry.keys()) + errors.append(f"Unknown {section} provider '{provider_id}'. Valid: {valid}") + + vs = oss_config.get("vector_store", {}) + if vs.get("provider") == "pgvector": + cfg = vs.get("config", {}) + if not cfg.get("user"): + errors.append("PGVector requires 'user' in vector_store.config") + + return errors diff --git a/plugins/memory/mem0/_setup.py b/plugins/memory/mem0/_setup.py new file mode 100644 index 00000000000..4fd9795b32d --- /dev/null +++ b/plugins/memory/mem0/_setup.py @@ -0,0 +1,858 @@ +"""Setup wizard for Mem0 plugin — interactive and flag-based modes.""" + +from __future__ import annotations + +import getpass +import json +import os +import shutil +import socket +import subprocess +import sys +import urllib.request +from pathlib import Path +from typing import Any + +from hermes_constants import get_hermes_home + +from ._oss_providers import ( + LLM_PROVIDERS, + EMBEDDER_PROVIDERS, + VECTOR_PROVIDERS, + KNOWN_DIMS, + validate_oss_config, +) + + +def _curses_select(title: str, items: list[tuple[str, str]], default: int = 0) -> int: + """Interactive single-select with arrow keys.""" + from hermes_cli.curses_ui import curses_radiolist + display_items = [ + f"{label} {desc}" if desc else label + for label, desc in items + ] + return curses_radiolist(title, display_items, selected=default, cancel_returns=default) + + +def _prompt(label: str, default: str | None = None, secret: bool = False) -> str: + """Prompt for a value with optional default and secret masking.""" + suffix = f" [{default}]" if default else "" + if secret: + sys.stdout.write(f" {label}{suffix}: ") + sys.stdout.flush() + if sys.stdin.isatty(): + val = getpass.getpass(prompt="") + else: + val = sys.stdin.readline().strip() + else: + sys.stdout.write(f" {label}{suffix}: ") + sys.stdout.flush() + val = sys.stdin.readline().strip() + return val or (default or "") + + +def has_oss_flags() -> bool: + """Check if OSS-related flags are present in sys.argv.""" + flags = parse_flags(sys.argv[1:]) + if flags["mode"] == "oss": + return True + if any(flags.get(k) for k in ("oss_llm_key", "oss_vector_path", "oss_vector_url")): + return True + return False + + +def parse_flags(argv: list[str] | None = None) -> dict[str, str]: + """Parse CLI flags from argv. Returns dict of flag values.""" + args = argv if argv is not None else sys.argv[1:] + flags: dict[str, str] = { + "mode": "", + "api_key": "", + "oss_llm": "openai", + "oss_llm_key": "", + "oss_llm_model": "", + "oss_llm_url": "", + "oss_embedder": "openai", + "oss_embedder_key": "", + "oss_embedder_model": "", + "oss_embedder_url": "", + "oss_vector": "qdrant", + "oss_vector_path": "", + "oss_vector_url": "", + "oss_vector_host": "", + "oss_vector_port": "", + "oss_vector_user": "", + "oss_vector_password": "", + "oss_vector_dbname": "", + "user_id": "", + "dry_run": False, + } + + flag_map = { + "--mode": "mode", + "--api-key": "api_key", + "--oss-llm": "oss_llm", + "--oss-llm-key": "oss_llm_key", + "--oss-llm-model": "oss_llm_model", + "--oss-llm-url": "oss_llm_url", + "--oss-embedder": "oss_embedder", + "--oss-embedder-key": "oss_embedder_key", + "--oss-embedder-model": "oss_embedder_model", + "--oss-embedder-url": "oss_embedder_url", + "--oss-vector": "oss_vector", + "--oss-vector-path": "oss_vector_path", + "--oss-vector-url": "oss_vector_url", + "--oss-vector-host": "oss_vector_host", + "--oss-vector-port": "oss_vector_port", + "--oss-vector-user": "oss_vector_user", + "--oss-vector-password": "oss_vector_password", + "--oss-vector-dbname": "oss_vector_dbname", + "--user-id": "user_id", + } + + i = 0 + while i < len(args): + if args[i] == "--dry-run": + flags["dry_run"] = True + i += 1 + elif args[i] in flag_map and i + 1 < len(args): + flags[flag_map[args[i]]] = args[i + 1] + i += 2 + else: + i += 1 + + return flags + + +def build_oss_config(flags: dict[str, str]) -> tuple[dict, dict[str, str]]: + """Build OSS config dict + env_writes from parsed flags. + + Returns (oss_config, env_writes) where oss_config goes into mem0.json + and env_writes maps env var names to secret values for .env. + """ + llm_id = flags.get("oss_llm", "openai") + llm_def = LLM_PROVIDERS[llm_id] + llm_model = flags.get("oss_llm_model") or llm_def["default_model"] + llm_config: dict[str, Any] = {"model": llm_model} + if "default_url" in llm_def: + llm_config["ollama_base_url"] = flags.get("oss_llm_url") or llm_def["default_url"] + + embedder_id = flags.get("oss_embedder", "openai") + embedder_def = EMBEDDER_PROVIDERS[embedder_id] + embedder_model = flags.get("oss_embedder_model") or embedder_def["default_model"] + embedder_config: dict[str, Any] = {"model": embedder_model} + if "default_url" in embedder_def: + embedder_config["ollama_base_url"] = flags.get("oss_embedder_url") or embedder_def["default_url"] + dims = KNOWN_DIMS.get(embedder_model) + if dims: + embedder_config["embedding_dims"] = dims + + vector_id = flags.get("oss_vector", "qdrant") + vector_def = VECTOR_PROVIDERS[vector_id] + vector_config = dict(vector_def["default_config"]) + if vector_id == "qdrant": + if flags.get("oss_vector_path"): + vector_config["path"] = flags["oss_vector_path"] + if flags.get("oss_vector_url"): + vector_config.pop("path", None) + vector_config["url"] = flags["oss_vector_url"] + elif vector_id == "pgvector": + if flags.get("oss_vector_host"): + vector_config["host"] = flags["oss_vector_host"] + if flags.get("oss_vector_port"): + vector_config["port"] = int(flags["oss_vector_port"]) + if flags.get("oss_vector_user"): + vector_config["user"] = flags["oss_vector_user"] + if flags.get("oss_vector_password"): + vector_config["password"] = flags["oss_vector_password"] + if flags.get("oss_vector_dbname"): + vector_config["dbname"] = flags["oss_vector_dbname"] + + oss_config = { + "llm": {"provider": llm_id, "config": llm_config}, + "embedder": {"provider": embedder_id, "config": embedder_config}, + "vector_store": {"provider": vector_id, "config": vector_config}, + } + + env_writes: dict[str, str] = {} + if llm_def.get("needs_key") and flags.get("oss_llm_key"): + env_writes[llm_def["env_var"]] = flags["oss_llm_key"] + if embedder_def.get("needs_key") and flags.get("oss_embedder_key"): + env_writes[embedder_def["env_var"]] = flags["oss_embedder_key"] + elif embedder_def.get("needs_key") and embedder_id == llm_id and flags.get("oss_llm_key"): + env_writes[embedder_def["env_var"]] = flags["oss_llm_key"] + + return oss_config, env_writes + + +def _write_env(env_path: Path, env_writes: dict[str, str]) -> None: + """Append or update env vars in .env file.""" + env_path.parent.mkdir(parents=True, exist_ok=True) + existing_lines: list[str] = [] + if env_path.exists(): + existing_lines = env_path.read_text().splitlines() + + updated_keys: set[str] = set() + new_lines: list[str] = [] + for line in existing_lines: + key_match = line.split("=", 1)[0].strip() if "=" in line and not line.startswith("#") else None + if key_match and key_match in env_writes: + new_lines.append(f"{key_match}={env_writes[key_match]}") + updated_keys.add(key_match) + else: + new_lines.append(line) + for k, v in env_writes.items(): + if k not in updated_keys: + new_lines.append(f"{k}={v}") + + env_path.write_text("\n".join(new_lines) + "\n") + + +def _save_mem0_json(hermes_home: str, data: dict) -> None: + """Merge-write to mem0.json.""" + config_path = Path(hermes_home) / "mem0.json" + existing = {} + if config_path.exists(): + try: + existing = json.loads(config_path.read_text(encoding="utf-8")) + except Exception: + pass + existing.update(data) + config_path.write_text(json.dumps(existing, indent=2) + "\n") + + +def _setup_platform(hermes_home: str, config: dict, flags: dict[str, str]) -> None: + """Platform mode setup — uses the framework's schema-based flow. + + Delegates to the same code path the framework uses when post_setup + doesn't exist, preserving the original platform onboarding experience. + """ + schema = [ + {"key": "api_key", "description": "Mem0 Platform API key", "secret": True, "required": True, "env_var": "MEM0_API_KEY", "url": "https://app.mem0.ai"}, + {"key": "user_id", "description": "User identifier", "default": "hermes-user"}, + {"key": "agent_id", "description": "Agent identifier", "default": "hermes"}, + {"key": "rerank", "description": "Enable reranking for recall", "default": "true", "choices": ["true", "false"]}, + ] + + existing_config = {} + config_path = Path(hermes_home) / "mem0.json" + if config_path.exists(): + try: + existing_config = json.loads(config_path.read_text()) + except Exception: + pass + + provider_config = dict(existing_config) + env_writes: dict[str, str] = {} + + print("\n Configuring mem0:\n") + + for field in schema: + key = field["key"] + desc = field.get("description", key) + default = field.get("default") + is_secret = field.get("secret", False) + choices = field.get("choices") + env_var = field.get("env_var") + url = field.get("url") + + if flags.get("api_key") and key == "api_key": + env_writes["MEM0_API_KEY"] = flags["api_key"] + continue + + if choices and not is_secret: + choice_items = [(c, "") for c in choices] + current = provider_config.get(key, default) + current_idx = 0 + if current and str(current).lower() in choices: + current_idx = choices.index(str(current).lower()) + sel = _curses_select(f" {desc}", choice_items, default=current_idx) + provider_config[key] = choices[sel] + elif is_secret: + existing = os.environ.get(env_var, "") if env_var else "" + if existing: + masked = f"...{existing[-4:]}" if len(existing) > 4 else "set" + val = _prompt(f"{desc} (current: {masked}, blank to keep)", secret=True) + else: + if url: + print(f" Get yours at {url}") + val = _prompt(desc, secret=True) + if val and env_var: + env_writes[env_var] = val + else: + current = provider_config.get(key) + effective_default = current or default + val = _prompt(desc, default=str(effective_default) if effective_default else None) + if val: + provider_config[key] = val + + if flags.get("dry_run"): + print(f"\n [dry-run] Would save config: {provider_config}") + if env_writes: + print(" [dry-run] Would write API key to .env") + print(" [dry-run] No files written.\n") + return + + provider_config["mode"] = "platform" + + from hermes_cli.config import save_config + config["memory"]["provider"] = "mem0" + save_config(config) + + from plugins.memory.mem0 import Mem0MemoryProvider + provider = Mem0MemoryProvider() + provider.save_config(provider_config, hermes_home) + + if env_writes: + _write_env(Path(hermes_home) / ".env", env_writes) + + print(f"\n Memory provider: mem0") + print(f" Activation saved to config.yaml") + print(f" Provider config saved") + if env_writes: + print(f" API keys saved to .env") + print(f"\n Start a new session to activate.\n") + + +def _setup_oss(hermes_home: str, config: dict, flags: dict[str, str]) -> None: + """OSS mode setup — build config from flags or interactive prompts. + + Non-interactive when --mode was set explicitly via flags (post_setup already + resolved mode). Interactive only when mode was chosen via curses picker. + """ + if not flags.get("_mode_from_flag"): + _setup_oss_interactive(hermes_home, config) + return + + oss_config, env_writes = build_oss_config(flags) + errors = validate_oss_config(oss_config) + if errors: + for e in errors: + print(f" Error: {e}", file=sys.stderr) + sys.exit(1) + + user_id = flags.get("user_id") or os.getenv("USER", "hermes-user") + + llm_id = oss_config["llm"]["provider"] + embedder_id = oss_config["embedder"]["provider"] + vector_id = oss_config["vector_store"]["provider"] + + if flags.get("dry_run"): + print("\n [dry-run] OSS config would be:") + print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})") + print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})") + print(f" Vector: {vector_id}") + if env_writes: + print(f" Env vars: {', '.join(env_writes.keys())}") + _run_connectivity_checks(oss_config) + print(" [dry-run] No files written.\n") + return + + if env_writes: + _write_env(Path(hermes_home) / ".env", env_writes) + _save_mem0_json(hermes_home, {"mode": "oss", "user_id": user_id, "agent_id": "hermes", "oss": oss_config}) + + _install_provider_deps(llm_id, embedder_id, vector_id) + + from hermes_cli.config import save_config + config["memory"]["provider"] = "mem0" + save_config(config) + + _run_connectivity_checks(oss_config) + print(f"\n ✓ Mem0 configured (OSS mode)") + print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})") + print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})") + print(f" Vector: {vector_id}") + if env_writes: + print(f" API keys saved to .env") + print(f" Config saved to mem0.json") + print(f" Provider set in config.yaml") + print("\n Start a new session to activate.\n") + + +def _prompt_api_key(label: str, env_var: str, hermes_home: str) -> str: + """Prompt for API key, showing masked existing value if found.""" + existing = os.environ.get(env_var, "") + if not existing: + env_path = Path(hermes_home) / ".env" + if env_path.exists(): + for line in env_path.read_text().splitlines(): + if line.startswith(f"{env_var}="): + existing = line.split("=", 1)[1].strip() + break + if existing: + masked = f"...{existing[-4:]}" if len(existing) > 4 else "set" + return getpass.getpass(f" {label} API key (current: {masked}, blank to keep): ").strip() + return getpass.getpass(f" {label} API key: ").strip() + + +_PGVECTOR_CONTAINER = "hermes-pgvector" +_PGVECTOR_IMAGE = "pgvector/pgvector:pg17" +_PGVECTOR_PASSWORD = "hermes" + + +def _ensure_pgvector(host: str = "localhost", port: int = 5432) -> dict | None: + """Ensure pgvector is reachable; offer Docker setup if not. + + Returns updated vector_config dict if Docker was started, None otherwise. + """ + ok, _ = _check_pgvector(host, port) + if ok: + print(f" ✓ PostgreSQL reachable at {host}:{port}") + return None + + print(f" PostgreSQL not reachable at {host}:{port}") + + # Check if our container already exists but is stopped + if shutil.which("docker"): + try: + result = subprocess.run( + ["docker", "inspect", _PGVECTOR_CONTAINER, "--format", "{{.State.Status}}"], + capture_output=True, text=True, timeout=10, stdin=subprocess.DEVNULL, + ) + if result.returncode == 0 and "exited" in result.stdout: + print(f" Found stopped container '{_PGVECTOR_CONTAINER}', restarting...") + subprocess.run(["docker", "start", _PGVECTOR_CONTAINER], + capture_output=True, timeout=15, + stdin=subprocess.DEVNULL) + _wait_for_port(host, port, timeout=15) + ok, _ = _check_pgvector(host, port) + if ok: + print(f" ✓ PostgreSQL container restarted") + return None + except Exception: + pass + + answer = input(" Start pgvector via Docker? [Y/n]: ").strip().lower() + if answer in ("", "y", "yes"): + return _start_pgvector_docker(host, port) + else: + print(" Skipping Docker setup. Make sure PostgreSQL with pgvector is running.") + return None + else: + print(" Docker not found. Install Docker to auto-start pgvector,") + print(" or run PostgreSQL with pgvector manually.") + return None + + +def _start_pgvector_docker(host: str, port: int) -> dict | None: + """Pull and start pgvector Docker container.""" + try: + print(f" Pulling {_PGVECTOR_IMAGE}...") + subprocess.run(["docker", "pull", _PGVECTOR_IMAGE], + capture_output=True, timeout=120, + stdin=subprocess.DEVNULL) + + # Remove existing container if present + subprocess.run(["docker", "rm", "-f", _PGVECTOR_CONTAINER], + capture_output=True, timeout=10, + stdin=subprocess.DEVNULL) + + print(f" Starting container '{_PGVECTOR_CONTAINER}' on port {port}...") + subprocess.run([ + "docker", "run", "-d", + "--name", _PGVECTOR_CONTAINER, + "-e", f"POSTGRES_PASSWORD={_PGVECTOR_PASSWORD}", + "-p", f"{port}:5432", + _PGVECTOR_IMAGE, + ], capture_output=True, timeout=30, check=True, stdin=subprocess.DEVNULL) + + _wait_for_port(host, port, timeout=20) + ok, _ = _check_pgvector(host, port) + if ok: + print(f" ✓ pgvector running on {host}:{port}") + return { + "host": host, "port": port, + "user": "postgres", "password": _PGVECTOR_PASSWORD, + "dbname": "postgres", + } + else: + print(" Warning: Container started but PostgreSQL not yet accepting connections.") + print(" It may need a few more seconds. Config will be saved; retry later.") + return { + "host": host, "port": port, + "user": "postgres", "password": _PGVECTOR_PASSWORD, + "dbname": "postgres", + } + except subprocess.CalledProcessError as e: + print(f" Failed to start Docker container: {e}") + return None + except Exception as e: + print(f" Docker error: {e}") + return None + + +def _ensure_ollama(models: list[str]) -> bool: + """Ensure Ollama is running and required models are pulled. + + Returns True if Ollama is ready, False if user needs to handle it manually. + """ + url = "http://localhost:11434" + ollama_bin = shutil.which("ollama") + ok, _ = _check_ollama(url) + + if not ok: + if ollama_bin: + print(" Ollama installed but not running. Starting...") + try: + subprocess.Popen( + [ollama_bin, "serve"], + stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, + ) + _wait_for_port("localhost", 11434, timeout=10) + ok, _ = _check_ollama(url) + if ok: + print(" ✓ Ollama started") + except Exception as e: + print(f" Could not start Ollama: {e}") + else: + print(" Ollama not found. Install it:") + print(" curl -fsSL https://ollama.com/install.sh | sh") + print(" Or on macOS: brew install ollama") + return False + + if not ok: + print(" Warning: Ollama not reachable. Models cannot be pulled.") + return False + + # Pull required models + for model in models: + if _ollama_has_model(url, model): + print(f" ✓ Model '{model}' available") + else: + print(f" Pulling '{model}'... (this may take a few minutes)") + try: + subprocess.run([ollama_bin or "ollama", "pull", model], timeout=600, + stdin=subprocess.DEVNULL) + print(f" ✓ Model '{model}' pulled") + except Exception as e: + print(f" Warning: Could not pull '{model}': {e}") + print(f" Run manually: ollama pull {model}") + + return True + + +def _ollama_has_model(url: str, model: str) -> bool: + """Check if Ollama already has a model pulled.""" + try: + req = urllib.request.Request(f"{url}/api/tags", method="GET") + resp = urllib.request.urlopen(req, timeout=5) + data = json.loads(resp.read()) + names = [m.get("name", "") for m in data.get("models", [])] + base_model = model.split(":")[0] + return any(model in n or base_model in n for n in names) + except Exception: + return False + + +def _ensure_pgvector_extension(pg_config: dict) -> None: + """Create the pgvector extension if it doesn't exist.""" + try: + import psycopg2 + except ImportError: + return + conn_params = { + "host": pg_config.get("host", "localhost"), + "port": pg_config.get("port", 5432), + "user": pg_config.get("user", "postgres"), + "dbname": pg_config.get("dbname", "postgres"), + } + if pg_config.get("password"): + conn_params["password"] = pg_config["password"] + try: + conn = psycopg2.connect(**conn_params) + conn.autocommit = True + cur = conn.cursor() + cur.execute("CREATE EXTENSION IF NOT EXISTS vector") + cur.close() + conn.close() + print(" ✓ pgvector extension enabled") + except Exception as e: + print(f" Warning: Could not enable pgvector extension: {e}") + + +def _wait_for_port(host: str, port: int, timeout: int = 15) -> None: + """Wait until a TCP port is accepting connections.""" + import time + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + sock = socket.create_connection((host, port), timeout=1) + sock.close() + return + except OSError: + time.sleep(0.5) + + +def _provider_description(v: dict) -> str: + """Description for LLM/embedder picker: model + URL if applicable.""" + model = v.get("default_model", "") + url = v.get("default_url") + if url: + return f"{model} ({url})" + return model + + +def _vector_description(pid: str, v: dict) -> str: + cfg = v.get("default_config", {}) + if pid == "qdrant": + return cfg.get("path", "local storage") + if pid == "pgvector": + return f"{cfg.get('host', 'localhost')}:{cfg.get('port', 5432)}" + return pid + + +def _setup_oss_interactive(hermes_home: str, config: dict) -> None: + """Interactive OSS setup using curses pickers.""" + llm_items = [(v["label"], _provider_description(v)) for pid, v in LLM_PROVIDERS.items()] + llm_idx = _curses_select("LLM Provider", llm_items, 0) + llm_id = list(LLM_PROVIDERS.keys())[llm_idx] + llm_def = LLM_PROVIDERS[llm_id] + + env_writes: dict[str, str] = {} + llm_model = llm_def["default_model"] + llm_url = llm_def.get("default_url") + if llm_def["needs_key"]: + key = _prompt_api_key(llm_def["label"], llm_def["env_var"], hermes_home) + if key: + env_writes[llm_def["env_var"]] = key + if llm_id == "ollama": + llm_model = input(f" LLM model [{llm_def['default_model']}]: ").strip() or llm_def["default_model"] + llm_url = input(f" Ollama URL [{llm_def['default_url']}]: ").strip() or llm_def["default_url"] + + embedder_items = [(v["label"], _provider_description(v)) for pid, v in EMBEDDER_PROVIDERS.items()] + embedder_idx = _curses_select("Embedder Provider", embedder_items, 0) + embedder_id = list(EMBEDDER_PROVIDERS.keys())[embedder_idx] + embedder_def = EMBEDDER_PROVIDERS[embedder_id] + + embedder_model = embedder_def["default_model"] + embedder_url = embedder_def.get("default_url") + if embedder_def["needs_key"] and embedder_id != llm_id: + key = _prompt_api_key(f"{embedder_def['label']} embedder", embedder_def["env_var"], hermes_home) + if key: + env_writes[embedder_def["env_var"]] = key + elif embedder_def["needs_key"] and embedder_id == llm_id: + if llm_def.get("env_var") in env_writes: + env_writes[embedder_def["env_var"]] = env_writes[llm_def["env_var"]] + if embedder_id == "ollama": + embedder_model = input(f" Embedder model [{embedder_def['default_model']}]: ").strip() or embedder_def["default_model"] + embedder_url = input(f" Ollama URL [{embedder_def['default_url']}]: ").strip() or embedder_def["default_url"] + + vector_items = [(v["label"], _vector_description(pid, v)) for pid, v in VECTOR_PROVIDERS.items()] + vector_idx = _curses_select("Vector Store", vector_items, 0) + vector_id = list(VECTOR_PROVIDERS.keys())[vector_idx] + + # Auto-setup: ensure Ollama is running and models are pulled + ollama_models = [] + if llm_id == "ollama": + ollama_models.append(llm_model) + if embedder_id == "ollama": + ollama_models.append(embedder_model) + if ollama_models: + _ensure_ollama(ollama_models) + + # Auto-setup: ensure pgvector is reachable (offer Docker if not) + pgvector_config = None + if vector_id == "pgvector": + pgvector_config = _ensure_pgvector() + if not pgvector_config: + # Native PostgreSQL — prompt for connection details + default_user = os.getenv("USER", "postgres") + pg_user = input(f" PostgreSQL user [{default_user}]: ").strip() or default_user + pg_host = input(" PostgreSQL host [localhost]: ").strip() or "localhost" + pg_port = input(" PostgreSQL port [5432]: ").strip() or "5432" + pg_dbname = input(" PostgreSQL database [postgres]: ").strip() or "postgres" + pg_password = getpass.getpass(" PostgreSQL password (blank if none): ").strip() + pgvector_config = { + "host": pg_host, "port": int(pg_port), + "user": pg_user, "dbname": pg_dbname, + } + if pg_password: + pgvector_config["password"] = pg_password + + user_id = input(f" User ID [{os.getenv('USER', 'hermes-user')}]: ").strip() + user_id = user_id or os.getenv("USER", "hermes-user") + + agent_id = input(" Agent ID [hermes]: ").strip() + agent_id = agent_id or "hermes" + + flags = { + "oss_llm": llm_id, + "oss_llm_key": env_writes.get(llm_def["env_var"], "") if llm_def.get("env_var") else "", + "oss_llm_model": llm_model, + "oss_llm_url": llm_url or "", + "oss_embedder": embedder_id, + "oss_embedder_model": embedder_model, + "oss_embedder_url": embedder_url or "", + "oss_vector": vector_id, + "user_id": user_id, + } + + if pgvector_config: + flags["oss_vector_host"] = pgvector_config["host"] + flags["oss_vector_port"] = str(pgvector_config["port"]) + flags["oss_vector_user"] = pgvector_config["user"] + if pgvector_config.get("password"): + flags["oss_vector_password"] = pgvector_config["password"] + flags["oss_vector_dbname"] = pgvector_config["dbname"] + + oss_config, _ = build_oss_config(flags) + + if env_writes: + _write_env(Path(hermes_home) / ".env", env_writes) + _save_mem0_json(hermes_home, {"mode": "oss", "user_id": user_id, "agent_id": agent_id, "oss": oss_config}) + + _install_provider_deps(llm_id, embedder_id, vector_id) + + if vector_id == "pgvector" and pgvector_config: + _ensure_pgvector_extension(pgvector_config) + + from hermes_cli.config import save_config + config["memory"]["provider"] = "mem0" + save_config(config) + + _run_connectivity_checks(oss_config) + print(f"\n ✓ Mem0 configured (OSS mode)") + print(f" LLM: {oss_config['llm']['provider']} ({oss_config['llm']['config'].get('model', '')})") + print(f" Embedder: {oss_config['embedder']['provider']} ({oss_config['embedder']['config'].get('model', '')})") + print(f" Vector: {vector_id}") + if env_writes: + print(f" API keys saved to .env") + print(f" Config saved to mem0.json") + print(f" Provider set in config.yaml") + print("\n Start a new session to activate.\n") + + +def _install_provider_deps(llm_id: str, embedder_id: str, vector_id: str) -> None: + """Install all optional pip deps for selected providers.""" + deps: set[str] = set() + for registry, pid in [(LLM_PROVIDERS, llm_id), (EMBEDDER_PROVIDERS, embedder_id), + (VECTOR_PROVIDERS, vector_id)]: + dep = registry.get(pid, {}).get("pip_dep") + if dep: + deps.add(dep) + for dep in sorted(deps): + try: + print(f" Installing {dep}...") + subprocess.run( + ["uv", "pip", "install", "--python", sys.executable, dep], + capture_output=True, timeout=60, + ) + print(f" ✓ Installed {dep}") + except Exception: + print(f" Warning: Could not install {dep}. Install manually: uv pip install {dep}") + if deps: + import importlib + importlib.invalidate_caches() + + +def _check_qdrant_path(path: str) -> tuple[bool, str]: + """Check that qdrant local storage parent dir is writable.""" + p = Path(path).expanduser() + parent = p.parent + try: + parent.mkdir(parents=True, exist_ok=True) + return True, f"Directory writable: {parent}" + except OSError as e: + return False, f"Cannot write to {parent}: {e}" + + +def _check_ollama(url: str) -> tuple[bool, str]: + """Check Ollama is reachable via /api/tags.""" + try: + req = urllib.request.Request(f"{url.rstrip('/')}/api/tags", method="GET") + urllib.request.urlopen(req, timeout=3) + return True, "Ollama reachable" + except Exception as e: + return False, f"Ollama not reachable at {url}: {e}" + + +def _check_pgvector(host: str, port: int) -> tuple[bool, str]: + """Check PGVector via TCP socket.""" + try: + sock = socket.create_connection((host, port), timeout=3) + sock.close() + return True, f"PGVector reachable at {host}:{port}" + except Exception as e: + return False, f"PGVector not reachable at {host}:{port}: {e}" + + +def _run_connectivity_checks(oss_config: dict) -> None: + """Run connectivity checks and print warnings.""" + vs = oss_config.get("vector_store", {}) + if vs.get("provider") == "qdrant": + path = vs.get("config", {}).get("path") + url = vs.get("config", {}).get("url") + if path: + ok, msg = _check_qdrant_path(path) + if not ok: + print(f" Warning: {msg}") + elif url: + try: + req = urllib.request.Request(f"{url.rstrip('/')}/healthz", method="GET") + urllib.request.urlopen(req, timeout=3) + except Exception as e: + print(f" Warning: Qdrant not reachable at {url}: {e}") + elif vs.get("provider") == "pgvector": + cfg = vs.get("config", {}) + ok, msg = _check_pgvector(cfg.get("host", "localhost"), cfg.get("port", 5432)) + if not ok: + print(f" Warning: {msg}") + + llm = oss_config.get("llm", {}) + if llm.get("provider") == "ollama": + url = llm.get("config", {}).get("ollama_base_url", "http://localhost:11434") + ok, msg = _check_ollama(url) + if not ok: + print(f" Warning: {msg}") + + +def _check_min_dep_version() -> None: + """Ensure mem0ai meets the minimum version from plugin.yaml.""" + try: + import mem0 + installed_ver = getattr(mem0, "__version__", None) + if not installed_ver: + return + installed_parts = tuple(int(x) for x in installed_ver.split(".")[:3]) + required_parts = (2, 0, 7) + if installed_parts < required_parts: + req_str = ".".join(str(x) for x in required_parts) + print(f"\n ⚠ mem0ai {installed_ver} installed but >={req_str} required.") + print(f" Run: uv pip install --python {sys.executable} 'mem0ai>={req_str}'") + except ImportError: + pass + except Exception: + pass + + +def post_setup(hermes_home: str, config: dict) -> None: + """Entry point called by hermes memory setup framework. + + Only intercepts when OSS mode is requested (via --mode oss flag or + interactive picker). For platform mode, returns without action so the + framework's schema-based flow handles it (preserving the original + platform onboarding experience). + """ + _check_min_dep_version() + flags = parse_flags(sys.argv[1:]) + + if flags["mode"] == "oss": + flags["_mode_from_flag"] = True + _setup_oss(hermes_home, config, flags) + return + + if flags["mode"] == "platform": + _setup_platform(hermes_home, config, flags) + return + + # No --mode flag: show interactive picker + mode_items = [ + ("Platform", "Mem0 Cloud API (lightweight, just needs an API key)"), + ("Open Source", "Run Mem0 locally (self-hosted LLM + vector store)"), + ] + mode_idx = _curses_select(" Select mode", mode_items, 0) + if mode_idx == 1: + flags["_mode_from_flag"] = False + _setup_oss(hermes_home, config, flags) + else: + _setup_platform(hermes_home, config, flags) diff --git a/plugins/memory/mem0/plugin.yaml b/plugins/memory/mem0/plugin.yaml index 2e7104d75c4..1d9dec52306 100644 --- a/plugins/memory/mem0/plugin.yaml +++ b/plugins/memory/mem0/plugin.yaml @@ -1,5 +1,5 @@ name: mem0 -version: 1.0.0 +version: 1.1.0 description: "Mem0 — server-side LLM fact extraction with semantic search, reranking, and automatic deduplication." pip_dependencies: - - mem0ai + - mem0ai>=2.0.7,<3 diff --git a/scripts/release.py b/scripts/release.py index 9dae0c8bc29..74ce3def810 100755 --- a/scripts/release.py +++ b/scripts/release.py @@ -1410,6 +1410,8 @@ AUTHOR_MAP = { "caojiguang@gmail.com": "caojiguang", # PR #35117 carries #31853 (weixin _api_post/_api_get wait_for) "gooku94123@gmail.com": "goku94123", # PR #46609 salvage (MiniMax reasoning extra_body) # pander: empty email, salvaged via PR #19665 from #16126 by @ms-alan + "chaithanya.kumar42a@gmail.com": "chaithanyak42", # PR #15624 + "kartik.labhshetwar@mem0.ai": "kartik-mem0", # PR #15624 "ayman.a.kamal@hotmail.com": "A-kamal", # PR #18678 (xAI image resolution fix) # Kanban bug-fix batch salvage (May 2026) "frowte3k@gmail.com": "Frowtek", # salvage of #23206 (gateway --board auto-subscribe) diff --git a/tests/plugins/memory/test_mem0_backend.py b/tests/plugins/memory/test_mem0_backend.py new file mode 100644 index 00000000000..221da10823b --- /dev/null +++ b/tests/plugins/memory/test_mem0_backend.py @@ -0,0 +1,209 @@ +"""Tests for Mem0Backend abstraction — PlatformBackend and OSSBackend.""" + +import pytest + +from plugins.memory.mem0._backend import Mem0Backend, PlatformBackend, OSSBackend + + +class FakePlatformClient: + """Fake MemoryClient for PlatformBackend tests.""" + + def __init__(self): + self.calls = [] + + def search(self, query, **kwargs): + self.calls.append(("search", query, kwargs)) + return {"results": [{"id": "m1", "memory": "fact1", "score": 0.9}]} + + def get_all(self, **kwargs): + self.calls.append(("get_all", kwargs)) + return {"count": 1, "next": None, "results": [{"id": "m1", "memory": "fact1"}]} + + def add(self, messages, **kwargs): + self.calls.append(("add", messages, kwargs)) + return {"status": "PENDING", "event_id": "evt-1"} + + def update(self, **kwargs): + self.calls.append(("update", kwargs)) + return {"id": kwargs["memory_id"], "text": kwargs["text"]} + + def delete(self, **kwargs): + self.calls.append(("delete", kwargs)) + + +class TestPlatformBackend: + + def _make(self): + client = FakePlatformClient() + backend = PlatformBackend.__new__(PlatformBackend) + backend._client = client + return backend, client + + def test_search_forwards_params(self): + backend, client = self._make() + result = backend.search("test query", filters={"user_id": "u1"}, top_k=5) + assert client.calls[0][0] == "search" + assert client.calls[0][1] == "test query" + assert client.calls[0][2]["filters"] == {"user_id": "u1"} + assert client.calls[0][2]["top_k"] == 5 + + def test_search_forwards_rerank(self): + backend, client = self._make() + backend.search("q", filters={}, rerank=False) + assert client.calls[0][2]["rerank"] is False + + def test_search_rerank_default_true(self): + backend, client = self._make() + backend.search("q", filters={}) + assert client.calls[0][2]["rerank"] is True + + def test_search_returns_list(self): + backend, _ = self._make() + result = backend.search("q", filters={}) + assert isinstance(result, list) + assert result[0]["id"] == "m1" + + def test_get_all_forwards_pagination(self): + backend, client = self._make() + result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50) + assert client.calls[0][1]["page"] == 2 + assert client.calls[0][1]["page_size"] == 50 + assert "count" in result + + def test_add_forwards_kwargs(self): + backend, client = self._make() + msgs = [{"role": "user", "content": "hi"}] + result = backend.add(msgs, user_id="u1", agent_id="hermes", infer=False) + call = client.calls[0] + assert call[2]["user_id"] == "u1" + assert call[2]["infer"] is False + # metadata kwarg should be omitted entirely when not provided so we + # don't surprise older mem0 client versions with an unknown kwarg. + assert "metadata" not in call[2] + + def test_add_forwards_metadata_when_present(self): + backend, client = self._make() + msgs = [{"role": "user", "content": "hi"}] + backend.add( + msgs, + user_id="u1", + agent_id="hermes", + infer=False, + metadata={"channel": "telegram"}, + ) + assert client.calls[0][2]["metadata"] == {"channel": "telegram"} + + def test_add_omits_empty_metadata(self): + backend, client = self._make() + msgs = [{"role": "user", "content": "hi"}] + backend.add(msgs, user_id="u1", agent_id="hermes", infer=False, metadata={}) + assert "metadata" not in client.calls[0][2] + + def test_update_forwards(self): + backend, client = self._make() + backend.update("m1", "new text") + assert client.calls[0][1] == {"memory_id": "m1", "text": "new text"} + + def test_delete_forwards(self): + backend, client = self._make() + backend.delete("m1") + assert client.calls[0][1] == {"memory_id": "m1"} + + +class FakeOSSMemory: + """Fake mem0.Memory for OSSBackend tests.""" + + def __init__(self): + self.calls = [] + + def search(self, query, **kwargs): + self.calls.append(("search", query, kwargs)) + return {"results": [{"id": "m1", "memory": "fact1", "score": 0.8}]} + + def get_all(self, **kwargs): + self.calls.append(("get_all", kwargs)) + return {"results": [{"id": "m1", "memory": "fact1"}]} + + def add(self, messages, **kwargs): + self.calls.append(("add", messages, kwargs)) + return {"results": [{"id": "m1", "memory": "fact1", "event": "ADD"}]} + + def update(self, memory_id, **kwargs): + self.calls.append(("update", memory_id, kwargs)) + return {"message": "Memory updated successfully!"} + + def delete(self, memory_id): + self.calls.append(("delete", memory_id)) + return {"message": "Memory deleted successfully!"} + + +class TestOSSBackend: + + def _make(self): + memory = FakeOSSMemory() + backend = OSSBackend.__new__(OSSBackend) + backend._memory = memory + return backend, memory + + def test_search_returns_list(self): + backend, _ = self._make() + result = backend.search("test", filters={"user_id": "u1"}) + assert isinstance(result, list) + assert result[0]["id"] == "m1" + + def test_search_passes_filters(self): + backend, memory = self._make() + backend.search("q", filters={"user_id": "u1"}, top_k=3) + assert memory.calls[0][2]["filters"] == {"user_id": "u1"} + assert memory.calls[0][2]["top_k"] == 3 + + def test_search_ignores_rerank(self): + """OSS backend accepts rerank param but does not forward it to Memory.""" + backend, memory = self._make() + backend.search("q", filters={}, rerank=True) + assert "rerank" not in memory.calls[0][2] + + def test_get_all_ignores_pagination(self): + """OSSBackend accepts page/page_size but does NOT forward to Memory.get_all().""" + backend, memory = self._make() + result = backend.get_all(filters={"user_id": "u1"}, page=2, page_size=50) + call_kwargs = memory.calls[0][1] + assert "page" not in call_kwargs + assert "page_size" not in call_kwargs + assert result["count"] == 1 + + def test_get_all_returns_envelope(self): + backend, _ = self._make() + result = backend.get_all(filters={"user_id": "u1"}) + assert "results" in result + assert "count" in result + + def test_add_forwards_kwargs(self): + backend, memory = self._make() + msgs = [{"role": "user", "content": "hi"}] + backend.add(msgs, user_id="u1", agent_id="hermes", infer=False) + assert memory.calls[0][2]["user_id"] == "u1" + assert memory.calls[0][2]["infer"] is False + + def test_update_maps_text_to_data(self): + """OSS Memory.update uses `data=` param, not `text=`.""" + backend, memory = self._make() + backend.update("m1", "new text") + assert memory.calls[0][0] == "update" + assert memory.calls[0][1] == "m1" + assert memory.calls[0][2] == {"data": "new text"} + + def test_delete_positional_arg(self): + backend, memory = self._make() + backend.delete("m1") + assert memory.calls[0] == ("delete", "m1") + + def test_update_normalizes_response(self): + backend, _ = self._make() + result = backend.update("m1", "text") + assert result == {"result": "Memory updated.", "memory_id": "m1"} + + def test_delete_normalizes_response(self): + backend, _ = self._make() + result = backend.delete("m1") + assert result == {"result": "Memory deleted.", "memory_id": "m1"} diff --git a/tests/plugins/memory/test_mem0_providers.py b/tests/plugins/memory/test_mem0_providers.py new file mode 100644 index 00000000000..010e3263a5f --- /dev/null +++ b/tests/plugins/memory/test_mem0_providers.py @@ -0,0 +1,107 @@ +"""Tests for OSS provider definitions and validation.""" + +import pytest + +from plugins.memory.mem0._oss_providers import ( + LLM_PROVIDERS, + EMBEDDER_PROVIDERS, + VECTOR_PROVIDERS, + KNOWN_DIMS, + validate_oss_config, +) + + +class TestProviderDefinitions: + + def test_llm_providers_have_required_keys(self): + for pid, p in LLM_PROVIDERS.items(): + assert "label" in p + assert "needs_key" in p + assert "default_model" in p + + def test_embedder_providers_have_required_keys(self): + for pid, p in EMBEDDER_PROVIDERS.items(): + assert "label" in p + assert "needs_key" in p + assert "default_model" in p + assert "dims" in p + + def test_embedder_provider_ids(self): + assert set(EMBEDDER_PROVIDERS.keys()) == {"openai", "ollama"} + + def test_vector_providers_have_required_keys(self): + for pid, p in VECTOR_PROVIDERS.items(): + assert "label" in p + assert "default_config" in p + + def test_vector_provider_ids(self): + assert set(VECTOR_PROVIDERS.keys()) == {"qdrant", "pgvector"} + + def test_known_dims_covers_defaults(self): + for pid, p in EMBEDDER_PROVIDERS.items(): + assert p["default_model"] in KNOWN_DIMS + + +class TestValidation: + + def test_valid_openai_config(self): + cfg = { + "llm": {"provider": "openai", "config": {"model": "gpt-4o-mini"}}, + "embedder": {"provider": "openai", "config": {"model": "text-embedding-3-small"}}, + "vector_store": {"provider": "qdrant", "config": {"path": "/tmp/test"}}, + } + errors = validate_oss_config(cfg) + assert errors == [] + + def test_unknown_llm_provider(self): + cfg = { + "llm": {"provider": "gemini", "config": {}}, + "embedder": {"provider": "openai", "config": {}}, + "vector_store": {"provider": "qdrant", "config": {}}, + } + errors = validate_oss_config(cfg) + assert any("llm" in e.lower() for e in errors) + + def test_unknown_embedder_provider(self): + cfg = { + "llm": {"provider": "openai", "config": {}}, + "embedder": {"provider": "cohere", "config": {}}, + "vector_store": {"provider": "qdrant", "config": {}}, + } + errors = validate_oss_config(cfg) + assert any("embedder" in e.lower() for e in errors) + + def test_unknown_vector_provider(self): + cfg = { + "llm": {"provider": "openai", "config": {}}, + "embedder": {"provider": "openai", "config": {}}, + "vector_store": {"provider": "redis", "config": {}}, + } + errors = validate_oss_config(cfg) + assert any("vector" in e.lower() for e in errors) + + def test_missing_llm_section(self): + cfg = { + "embedder": {"provider": "openai", "config": {}}, + "vector_store": {"provider": "qdrant", "config": {}}, + } + errors = validate_oss_config(cfg) + assert any("llm" in e.lower() for e in errors) + + def test_pgvector_needs_user(self): + cfg = { + "llm": {"provider": "openai", "config": {}}, + "embedder": {"provider": "openai", "config": {}}, + "vector_store": {"provider": "pgvector", "config": {"host": "localhost"}}, + } + errors = validate_oss_config(cfg) + assert any("user" in e.lower() for e in errors) + + def test_pgvector_with_user_valid(self): + cfg = { + "llm": {"provider": "openai", "config": {}}, + "embedder": {"provider": "openai", "config": {}}, + "vector_store": {"provider": "pgvector", "config": {"host": "localhost", "user": "pg"}}, + } + errors = validate_oss_config(cfg) + assert errors == [] diff --git a/tests/plugins/memory/test_mem0_setup.py b/tests/plugins/memory/test_mem0_setup.py new file mode 100644 index 00000000000..e67293e8a23 --- /dev/null +++ b/tests/plugins/memory/test_mem0_setup.py @@ -0,0 +1,251 @@ +"""Tests for Mem0 setup wizard — flag parsing, config building, validation.""" + +import json +import sys +import types +import pytest +from pathlib import Path +from unittest.mock import patch, MagicMock + +from plugins.memory.mem0._setup import ( + parse_flags, + build_oss_config, + _write_env, + post_setup, + _check_qdrant_path, + _check_ollama, + _check_pgvector, +) + + +def _inject_fake_hermes_cli(monkeypatch): + """Inject fake hermes_cli modules so yaml/curses aren't required.""" + fake_config_mod = types.ModuleType("hermes_cli.config") + fake_config_mod.save_config = lambda c: None + + fake_setup_mod = types.ModuleType("hermes_cli.memory_setup") + fake_setup_mod._curses_select = lambda *a, **kw: 0 + fake_setup_mod._prompt = lambda label, default=None, secret=False: default or "" + + fake_hermes_cli = types.ModuleType("hermes_cli") + fake_hermes_cli.config = fake_config_mod + fake_hermes_cli.memory_setup = fake_setup_mod + + monkeypatch.setitem(sys.modules, "hermes_cli", fake_hermes_cli) + monkeypatch.setitem(sys.modules, "hermes_cli.config", fake_config_mod) + monkeypatch.setitem(sys.modules, "hermes_cli.memory_setup", fake_setup_mod) + + monkeypatch.setattr("plugins.memory.mem0._setup._curses_select", lambda *a, **kw: 0) + monkeypatch.setattr("plugins.memory.mem0._setup._prompt", lambda label, default=None, secret=False: default or "") + return fake_config_mod + + +class TestParseFlags: + + def test_mode_platform(self): + flags = parse_flags(["--mode", "platform", "--api-key", "sk-test"]) + assert flags["mode"] == "platform" + assert flags["api_key"] == "sk-test" + + def test_mode_oss_defaults(self): + flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"]) + assert flags["mode"] == "oss" + assert flags["oss_llm"] == "openai" + assert flags["oss_embedder"] == "openai" + assert flags["oss_vector"] == "qdrant" + + def test_mode_oss_all_flags(self): + flags = parse_flags([ + "--mode", "oss", + "--oss-llm", "ollama", + "--oss-llm-model", "llama3:latest", + "--oss-embedder", "ollama", + "--oss-embedder-model", "nomic-embed-text", + "--oss-vector", "pgvector", + "--oss-vector-host", "db.local", + "--oss-vector-port", "5433", + "--oss-vector-user", "pguser", + "--oss-vector-password", "secret", + "--oss-vector-dbname", "memdb", + "--user-id", "my-user", + ]) + assert flags["oss_llm"] == "ollama" + assert flags["oss_llm_model"] == "llama3:latest" + assert flags["oss_vector"] == "pgvector" + assert flags["oss_vector_user"] == "pguser" + assert flags["user_id"] == "my-user" + + def test_no_flags_returns_empty_mode(self): + flags = parse_flags([]) + assert flags["mode"] == "" + + def test_oss_vector_path_flag(self): + flags = parse_flags(["--mode", "oss", "--oss-vector-path", "/data/qdrant"]) + assert flags["oss_vector_path"] == "/data/qdrant" + + +class TestBuildOSSConfig: + + def test_openai_defaults(self): + flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"]) + oss, env_writes = build_oss_config(flags) + assert oss["llm"]["provider"] == "openai" + assert oss["llm"]["config"]["model"] == "gpt-5-mini" + assert oss["embedder"]["provider"] == "openai" + assert oss["embedder"]["config"]["model"] == "text-embedding-3-small" + assert oss["vector_store"]["provider"] == "qdrant" + assert env_writes["OPENAI_API_KEY"] == "sk-oai" + + def test_ollama_no_key_needed(self): + flags = parse_flags(["--mode", "oss", "--oss-llm", "ollama", "--oss-embedder", "ollama"]) + oss, env_writes = build_oss_config(flags) + assert oss["llm"]["provider"] == "ollama" + assert "model" in oss["llm"]["config"] + assert env_writes == {} + + def test_embedder_reuses_llm_key(self): + """When LLM and embedder share same provider, key written once.""" + flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"]) + _, env_writes = build_oss_config(flags) + assert env_writes == {"OPENAI_API_KEY": "sk-oai"} + + def test_different_embedder_needs_separate_key(self): + flags = parse_flags([ + "--mode", "oss", + "--oss-llm", "ollama", + "--oss-embedder", "openai", "--oss-embedder-key", "sk-oai", + ]) + _, env_writes = build_oss_config(flags) + assert env_writes == {"OPENAI_API_KEY": "sk-oai"} + + def test_pgvector_config(self): + flags = parse_flags([ + "--mode", "oss", "--oss-llm-key", "sk-oai", + "--oss-vector", "pgvector", + "--oss-vector-host", "db.local", "--oss-vector-port", "5433", + "--oss-vector-user", "pg", "--oss-vector-dbname", "memdb", + ]) + oss, _ = build_oss_config(flags) + vs = oss["vector_store"] + assert vs["provider"] == "pgvector" + assert vs["config"]["host"] == "db.local" + assert vs["config"]["port"] == 5433 + assert vs["config"]["user"] == "pg" + + def test_known_dims_auto_set(self): + flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai"]) + oss, _ = build_oss_config(flags) + dims = oss["embedder"]["config"].get("embedding_dims") + assert dims == 1536 + + def test_custom_qdrant_path(self): + flags = parse_flags([ + "--mode", "oss", "--oss-llm-key", "sk-oai", + "--oss-vector-path", "/data/qdrant", + ]) + oss, _ = build_oss_config(flags) + assert oss["vector_store"]["config"]["path"] == "/data/qdrant" + + +class TestWriteEnv: + + def test_write_new_vars(self, tmp_path): + env_path = tmp_path / ".env" + _write_env(env_path, {"OPENAI_API_KEY": "sk-test"}) + content = env_path.read_text() + assert "OPENAI_API_KEY=sk-test" in content + + def test_update_existing_var(self, tmp_path): + env_path = tmp_path / ".env" + env_path.write_text("OPENAI_API_KEY=old\nOTHER=keep\n") + _write_env(env_path, {"OPENAI_API_KEY": "new"}) + content = env_path.read_text() + assert "OPENAI_API_KEY=new" in content + assert "OTHER=keep" in content + assert "old" not in content + + +class TestPostSetup: + + def test_platform_flag_mode(self, tmp_path, monkeypatch): + monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test"]) + monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path) + _inject_fake_hermes_cli(monkeypatch) + config = {"memory": {}} + post_setup(str(tmp_path), config) + assert config["memory"]["provider"] == "mem0" + env_content = (tmp_path / ".env").read_text() + assert "MEM0_API_KEY=sk-test" in env_content + mem0_json = json.loads((tmp_path / "mem0.json").read_text()) + assert mem0_json["mode"] == "platform" + + def test_oss_flag_mode(self, tmp_path, monkeypatch): + monkeypatch.setattr("sys.argv", [ + "hermes", "--mode", "oss", "--oss-llm-key", "sk-oai", + ]) + monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path) + _inject_fake_hermes_cli(monkeypatch) + monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None) + config = {"memory": {}} + post_setup(str(tmp_path), config) + assert config["memory"]["provider"] == "mem0" + mem0_json = json.loads((tmp_path / "mem0.json").read_text()) + assert mem0_json["mode"] == "oss" + assert mem0_json["oss"]["llm"]["provider"] == "openai" + + +class TestDryRun: + + def test_dry_run_flag_parsed(self): + flags = parse_flags(["--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run"]) + assert flags["dry_run"] is True + + def test_dry_run_not_set_by_default(self): + flags = parse_flags(["--mode", "oss"]) + assert flags["dry_run"] is False + + def test_dry_run_platform_no_files(self, tmp_path, monkeypatch): + monkeypatch.setattr("sys.argv", ["hermes", "--mode", "platform", "--api-key", "sk-test", "--dry-run"]) + monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path) + _inject_fake_hermes_cli(monkeypatch) + config = {"memory": {}} + post_setup(str(tmp_path), config) + assert not (tmp_path / ".env").exists() + assert not (tmp_path / "mem0.json").exists() + assert "provider" not in config["memory"] + + def test_dry_run_oss_no_files(self, tmp_path, monkeypatch): + monkeypatch.setattr("sys.argv", [ + "hermes", "--mode", "oss", "--oss-llm-key", "sk-oai", "--dry-run", + ]) + monkeypatch.setattr("plugins.memory.mem0._setup.get_hermes_home", lambda: tmp_path) + _inject_fake_hermes_cli(monkeypatch) + monkeypatch.setattr("plugins.memory.mem0._setup._install_provider_deps", lambda l, e, v: None) + config = {"memory": {}} + post_setup(str(tmp_path), config) + assert not (tmp_path / ".env").exists() + assert not (tmp_path / "mem0.json").exists() + assert "provider" not in config["memory"] + + +class TestConnectivityChecks: + + def test_qdrant_path_writable(self, tmp_path): + ok, msg = _check_qdrant_path(str(tmp_path / "qdrant")) + assert ok is True + + def test_qdrant_path_not_writable(self, tmp_path, monkeypatch): + def _raise_oserror(*a, **kw): + raise OSError("Permission denied") + monkeypatch.setattr(Path, "mkdir", _raise_oserror) + ok, msg = _check_qdrant_path(str(tmp_path / "qdrant")) + assert ok is False + assert "Permission denied" in msg + + def test_ollama_unreachable(self): + ok, msg = _check_ollama("http://localhost:1") + assert ok is False + + def test_pgvector_unreachable(self): + ok, msg = _check_pgvector("localhost", 1) + assert ok is False diff --git a/tests/plugins/memory/test_mem0_v2.py b/tests/plugins/memory/test_mem0_v2.py deleted file mode 100644 index a9a86676452..00000000000 --- a/tests/plugins/memory/test_mem0_v2.py +++ /dev/null @@ -1,241 +0,0 @@ -"""Tests for Mem0 API v2 compatibility — filters param and dict response unwrapping. - -Salvaged from PRs #5301 (qaqcvc) and #5117 (vvvanguards). -""" - -import json -import os -import stat - -import pytest - -from plugins.memory.mem0 import Mem0MemoryProvider - - -class FakeClientV2: - """Fake Mem0 client that returns v2-style dict responses and captures call kwargs.""" - - def __init__(self, search_results=None, all_results=None): - self._search_results = search_results or {"results": []} - self._all_results = all_results or {"results": []} - self.captured_search = {} - self.captured_get_all = {} - self.captured_add = [] - - def search(self, **kwargs): - self.captured_search = kwargs - return self._search_results - - def get_all(self, **kwargs): - self.captured_get_all = kwargs - return self._all_results - - def add(self, messages, **kwargs): - self.captured_add.append({"messages": messages, **kwargs}) - - -# --------------------------------------------------------------------------- -# Filter migration: bare user_id= -> filters={} -# --------------------------------------------------------------------------- - - -class TestMem0FiltersV2: - """All API calls must use filters={} instead of bare user_id= kwargs.""" - - def _make_provider(self, monkeypatch, client): - provider = Mem0MemoryProvider() - provider.initialize("test-session") - provider._user_id = "u123" - provider._agent_id = "hermes" - monkeypatch.setattr(provider, "_get_client", lambda: client) - return provider - - def test_search_uses_filters(self, monkeypatch): - client = FakeClientV2() - provider = self._make_provider(monkeypatch, client) - - provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3, "rerank": False}) - - assert client.captured_search["query"] == "hello" - assert client.captured_search["top_k"] == 3 - assert client.captured_search["rerank"] is False - assert client.captured_search["filters"] == {"user_id": "u123"} - # Must NOT have bare user_id kwarg - assert "user_id" not in {k for k in client.captured_search if k != "filters"} - - def test_profile_uses_filters(self, monkeypatch): - client = FakeClientV2() - provider = self._make_provider(monkeypatch, client) - - provider.handle_tool_call("mem0_profile", {}) - - assert client.captured_get_all["filters"] == {"user_id": "u123"} - assert "user_id" not in {k for k in client.captured_get_all if k != "filters"} - - def test_prefetch_uses_filters(self, monkeypatch): - client = FakeClientV2() - provider = self._make_provider(monkeypatch, client) - - provider.queue_prefetch("hello") - provider._prefetch_thread.join(timeout=2) - - assert client.captured_search["query"] == "hello" - assert client.captured_search["filters"] == {"user_id": "u123"} - assert "user_id" not in {k for k in client.captured_search if k != "filters"} - - def test_sync_turn_uses_write_filters(self, monkeypatch): - client = FakeClientV2() - provider = self._make_provider(monkeypatch, client) - - provider.sync_turn("user said this", "assistant replied", session_id="s1") - provider._sync_thread.join(timeout=2) - - assert len(client.captured_add) == 1 - call = client.captured_add[0] - assert call["user_id"] == "u123" - assert call["agent_id"] == "hermes" - - def test_conclude_uses_write_filters(self, monkeypatch): - client = FakeClientV2() - provider = self._make_provider(monkeypatch, client) - - provider.handle_tool_call("mem0_conclude", {"conclusion": "user likes dark mode"}) - - assert len(client.captured_add) == 1 - call = client.captured_add[0] - assert call["user_id"] == "u123" - assert call["agent_id"] == "hermes" - assert call["infer"] is False - - def test_read_filters_no_agent_id(self): - """Read filters should use user_id only — cross-session recall across agents.""" - provider = Mem0MemoryProvider() - provider._user_id = "u123" - provider._agent_id = "hermes" - assert provider._read_filters() == {"user_id": "u123"} - - def test_write_filters_include_agent_id(self): - """Write filters should include agent_id for attribution.""" - provider = Mem0MemoryProvider() - provider._user_id = "u123" - provider._agent_id = "hermes" - assert provider._write_filters() == {"user_id": "u123", "agent_id": "hermes"} - - -# --------------------------------------------------------------------------- -# Dict response unwrapping (API v2 wraps in {"results": [...]}) -# --------------------------------------------------------------------------- - - -class TestMem0ResponseUnwrapping: - """API v2 returns {"results": [...]} dicts; we must extract the list.""" - - def _make_provider(self, monkeypatch, client): - provider = Mem0MemoryProvider() - provider.initialize("test-session") - monkeypatch.setattr(provider, "_get_client", lambda: client) - return provider - - def test_profile_dict_response(self, monkeypatch): - client = FakeClientV2(all_results={"results": [{"memory": "alpha"}, {"memory": "beta"}]}) - provider = self._make_provider(monkeypatch, client) - - result = json.loads(provider.handle_tool_call("mem0_profile", {})) - - assert result["count"] == 2 - assert "alpha" in result["result"] - assert "beta" in result["result"] - - def test_profile_list_response_backward_compat(self, monkeypatch): - """Old API returned bare lists — still works.""" - client = FakeClientV2(all_results=[{"memory": "gamma"}]) - provider = self._make_provider(monkeypatch, client) - - result = json.loads(provider.handle_tool_call("mem0_profile", {})) - assert result["count"] == 1 - assert "gamma" in result["result"] - - def test_search_dict_response(self, monkeypatch): - client = FakeClientV2(search_results={ - "results": [{"memory": "foo", "score": 0.9}, {"memory": "bar", "score": 0.7}] - }) - provider = self._make_provider(monkeypatch, client) - - result = json.loads(provider.handle_tool_call( - "mem0_search", {"query": "test", "top_k": 5} - )) - - assert result["count"] == 2 - assert result["results"][0]["memory"] == "foo" - - def test_search_list_response_backward_compat(self, monkeypatch): - """Old API returned bare lists — still works.""" - client = FakeClientV2(search_results=[{"memory": "baz", "score": 0.8}]) - provider = self._make_provider(monkeypatch, client) - - result = json.loads(provider.handle_tool_call( - "mem0_search", {"query": "test"} - )) - assert result["count"] == 1 - - def test_unwrap_results_edge_cases(self): - """_unwrap_results handles all shapes gracefully.""" - assert Mem0MemoryProvider._unwrap_results({"results": [1, 2]}) == [1, 2] - assert Mem0MemoryProvider._unwrap_results([3, 4]) == [3, 4] - assert Mem0MemoryProvider._unwrap_results({}) == [] - assert Mem0MemoryProvider._unwrap_results(None) == [] - assert Mem0MemoryProvider._unwrap_results("unexpected") == [] - - def test_prefetch_dict_response(self, monkeypatch): - client = FakeClientV2(search_results={ - "results": [{"memory": "user prefers dark mode"}] - }) - provider = Mem0MemoryProvider() - provider.initialize("test-session") - monkeypatch.setattr(provider, "_get_client", lambda: client) - - provider.queue_prefetch("preferences") - provider._prefetch_thread.join(timeout=2) - result = provider.prefetch("preferences") - - assert "dark mode" in result - - -# --------------------------------------------------------------------------- -# Default preservation -# --------------------------------------------------------------------------- - - -@pytest.mark.skipif(os.name == "nt", reason="POSIX mode bits not enforced on Windows") -def test_save_config_sets_owner_only_permissions(tmp_path): - """mem0.json must be written with 0o600 so API key is not world-readable.""" - provider = Mem0MemoryProvider() - provider.save_config({"api_key": "m0-test-key"}, str(tmp_path)) - config_file = tmp_path / "mem0.json" - assert config_file.exists() - mode = stat.S_IMODE(config_file.stat().st_mode) - assert mode == 0o600, f"Expected 0o600 (owner-only), got {oct(mode)}" - - -class TestMem0Defaults: - """Ensure we don't break existing users' defaults.""" - - def test_default_user_id_hermes_user(self, monkeypatch, tmp_path): - monkeypatch.setenv("MEM0_API_KEY", "test-key") - monkeypatch.delenv("MEM0_USER_ID", raising=False) - monkeypatch.setenv("HERMES_HOME", str(tmp_path)) - - provider = Mem0MemoryProvider() - provider.initialize("test") - - assert provider._user_id == "hermes-user" - - def test_default_agent_id_hermes(self, monkeypatch, tmp_path): - monkeypatch.setenv("MEM0_API_KEY", "test-key") - monkeypatch.delenv("MEM0_AGENT_ID", raising=False) - monkeypatch.setenv("HERMES_HOME", str(tmp_path)) - - provider = Mem0MemoryProvider() - provider.initialize("test") - - assert provider._agent_id == "hermes" diff --git a/tests/plugins/memory/test_mem0_v3.py b/tests/plugins/memory/test_mem0_v3.py new file mode 100644 index 00000000000..e83a4171a4a --- /dev/null +++ b/tests/plugins/memory/test_mem0_v3.py @@ -0,0 +1,463 @@ +"""Tests for Mem0 v3 API — new tool names, paginated responses, update/delete tools.""" + +import json +import pytest + +from plugins.memory.mem0 import Mem0MemoryProvider + + +class FakeBackend: + """Fake Mem0Backend for provider-level tests.""" + + def __init__(self, search_results=None, all_results=None): + self._search_results = search_results or [] + self._all_results = all_results or {"results": [], "count": 0} + self.captured = [] + + def search(self, query, *, filters, top_k=10, rerank=True): + self.captured.append(("search", query, {"filters": filters, "top_k": top_k, "rerank": rerank})) + return self._search_results + + def get_all(self, *, filters, page=1, page_size=100): + self.captured.append(("get_all", {"filters": filters, "page": page, "page_size": page_size})) + return self._all_results + + def add(self, messages, *, user_id, agent_id, infer=False, metadata=None): + self.captured.append(( + "add", + messages, + {"user_id": user_id, "agent_id": agent_id, "infer": infer, "metadata": metadata}, + )) + return {"status": "PENDING", "event_id": "evt-test-123"} + + def update(self, memory_id, text): + self.captured.append(("update", memory_id, text)) + return {"result": "Memory updated.", "memory_id": memory_id} + + def delete(self, memory_id): + self.captured.append(("delete", memory_id)) + return {"result": "Memory deleted.", "memory_id": memory_id} + + +class TestMem0V3Tools: + """Test v3 tool names and response handling.""" + + def _make_provider(self, monkeypatch, backend): + provider = Mem0MemoryProvider() + provider.initialize("test-session") + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._backend = backend + return provider + + def test_list_returns_paginated_with_ids(self, monkeypatch): + backend = FakeBackend(all_results={ + "count": 2, + "results": [ + {"id": "mem-1", "memory": "alpha"}, + {"id": "mem-2", "memory": "beta"}, + ] + }) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_list", {})) + assert result["count"] == 2 + assert result["results"][0]["id"] == "mem-1" + assert result["results"][0]["memory"] == "alpha" + + def test_list_pagination_params(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + provider.handle_tool_call("mem0_list", {"page": 2, "page_size": 50}) + assert backend.captured[0][1]["page"] == 2 + assert backend.captured[0][1]["page_size"] == 50 + + def test_list_empty(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_list", {})) + assert result["result"] == "No memories stored yet." + + def test_search_returns_ids(self, monkeypatch): + backend = FakeBackend(search_results=[{"id": "mem-1", "memory": "foo", "score": 0.9}]) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_search", {"query": "test"})) + assert result["results"][0]["id"] == "mem-1" + + def test_search_uses_filters(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + provider.handle_tool_call("mem0_search", {"query": "hello", "top_k": 3}) + assert backend.captured[0][2]["filters"] == {"user_id": "u123"} + assert backend.captured[0][2]["top_k"] == 3 + + def test_search_rerank_default_true(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + provider.handle_tool_call("mem0_search", {"query": "test"}) + assert backend.captured[0][2]["rerank"] is True + + def test_search_rerank_override_false(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + provider.handle_tool_call("mem0_search", {"query": "test", "rerank": False}) + assert backend.captured[0][2]["rerank"] is False + + def test_add_uses_content_param(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"})) + assert len(backend.captured) == 1 + call = backend.captured[0] + assert call[2]["infer"] is False + assert call[2]["user_id"] == "u123" + assert call[2]["agent_id"] == "hermes" + assert "event_id" in result + + def test_add_returns_event_id(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_add", {"content": "test"})) + assert result["event_id"] == "evt-test-123" + + def test_add_missing_content(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_add", {})) + assert "error" in result + + def test_old_tool_names_return_unknown(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_profile", {})) + assert "error" in result + result = json.loads(provider.handle_tool_call("mem0_conclude", {})) + assert "error" in result + + +class TestMem0UpdateDelete: + + def _make_provider(self, monkeypatch, backend): + provider = Mem0MemoryProvider() + provider.initialize("test-session") + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._backend = backend + return provider + + def test_update_calls_sdk(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_update", {"memory_id": "mem-1", "text": "updated fact"} + )) + assert backend.captured[0][1] == "mem-1" + assert backend.captured[0][2] == "updated fact" + assert result["result"] == "Memory updated." + assert result["memory_id"] == "mem-1" + + def test_update_missing_memory_id(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_update", {"text": "no id"})) + assert "error" in result + + def test_update_missing_text(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_update", {"memory_id": "mem-1"})) + assert "error" in result + + def test_delete_calls_sdk(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_delete", {"memory_id": "mem-1"} + )) + assert backend.captured[0][1] == "mem-1" + assert result["result"] == "Memory deleted." + + def test_delete_missing_memory_id(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_delete", {})) + assert "error" in result + + +class TestMem0ErrorHandling: + + def _make_provider(self, monkeypatch, backend): + provider = Mem0MemoryProvider() + provider.initialize("test-session") + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._backend = backend + return provider + + def test_update_404_no_circuit_breaker(self, monkeypatch): + backend = FakeBackend() + backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("404 Not Found")) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_update", {"memory_id": "bad-id", "text": "x"} + )) + assert "error" in result + assert provider._consecutive_failures == 0 + + def test_delete_404_no_circuit_breaker(self, monkeypatch): + backend = FakeBackend() + backend.delete = lambda mid: (_ for _ in ()).throw(Exception("404 not found")) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_delete", {"memory_id": "bad-id"} + )) + assert "error" in result + assert provider._consecutive_failures == 0 + + def test_update_validation_error_no_circuit_breaker(self, monkeypatch): + """ValidationError (bad UUID format) should not trip circuit breaker.""" + class ValidationError(Exception): + pass + backend = FakeBackend() + backend.update = lambda mid, text: (_ for _ in ()).throw( + ValidationError('{"error":"memory_id should be a valid UUID"}') + ) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_update", {"memory_id": "not-a-uuid", "text": "x"} + )) + assert "error" in result + assert provider._consecutive_failures == 0 + + def test_delete_validation_error_no_circuit_breaker(self, monkeypatch): + class ValidationError(Exception): + pass + backend = FakeBackend() + backend.delete = lambda mid: (_ for _ in ()).throw( + ValidationError('{"error":"memory_id should be a valid UUID"}') + ) + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call( + "mem0_delete", {"memory_id": "not-a-uuid"} + )) + assert "error" in result + assert provider._consecutive_failures == 0 + + def test_update_5xx_trips_circuit_breaker(self, monkeypatch): + backend = FakeBackend() + backend.update = lambda mid, text: (_ for _ in ()).throw(Exception("500 Internal Server Error")) + provider = self._make_provider(monkeypatch, backend) + provider.handle_tool_call("mem0_update", {"memory_id": "mem-1", "text": "x"}) + assert provider._consecutive_failures == 1 + + +class TestMem0V3Internal: + + def _make_provider(self, monkeypatch, backend): + provider = Mem0MemoryProvider() + provider.initialize("test-session") + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._backend = backend + return provider + + def test_sync_turn_explicit_kwargs(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + provider.sync_turn("user said", "assistant replied", session_id="s1") + provider._sync_thread.join(timeout=2) + assert len(backend.captured) == 1 + call = backend.captured[0] + assert call[2]["user_id"] == "u123" + assert call[2]["agent_id"] == "hermes" + assert call[2]["infer"] is True + + def test_old_tool_names_return_unknown(self, monkeypatch): + backend = FakeBackend() + provider = self._make_provider(monkeypatch, backend) + result = json.loads(provider.handle_tool_call("mem0_profile", {})) + assert "error" in result + result = json.loads(provider.handle_tool_call("mem0_conclude", {})) + assert "error" in result + + +class TestMem0V3Config: + + def test_tool_schemas_five_tools(self): + provider = Mem0MemoryProvider() + schemas = provider.get_tool_schemas() + names = [s["name"] for s in schemas] + assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"] + + def test_system_prompt_new_tool_names(self): + provider = Mem0MemoryProvider() + provider._user_id = "test" + block = provider.system_prompt_block() + assert "mem0_search" in block + assert "mem0_add" in block + assert "mem0_list" in block + assert "mem0_update" in block + assert "mem0_delete" in block + assert "mem0_profile" not in block + assert "mem0_conclude" not in block + + def test_system_prompt_shows_platform_mode(self): + provider = Mem0MemoryProvider() + provider._user_id = "test" + provider._mode = "platform" + block = provider.system_prompt_block() + assert "platform" in block + assert "Rerank" in block + + def test_system_prompt_shows_oss_mode(self): + provider = Mem0MemoryProvider() + provider._user_id = "test" + provider._mode = "oss" + block = provider.system_prompt_block() + assert "OSS" in block + assert "Rerank" not in block + + def test_search_schema_has_rerank(self): + """rerank property available in SEARCH_SCHEMA for platform mode.""" + provider = Mem0MemoryProvider() + schemas = provider.get_tool_schemas() + search = next(s for s in schemas if s["name"] == "mem0_search") + assert "rerank" in search["parameters"]["properties"] + assert search["parameters"]["properties"]["rerank"]["type"] == "boolean" + + +class TestMem0ModeSwitch: + + def test_default_mode_is_platform(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.setenv("MEM0_API_KEY", "test-key") + provider = Mem0MemoryProvider() + provider.initialize("test") + assert provider._mode == "platform" + + def test_missing_mode_key_defaults_platform(self, monkeypatch, tmp_path): + """Backward compat: old mem0.json without mode key works.""" + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config_path = tmp_path / "mem0.json" + config_path.write_text('{"user_id": "old-user"}') + monkeypatch.setenv("MEM0_API_KEY", "test-key") + provider = Mem0MemoryProvider() + provider.initialize("test") + assert provider._mode == "platform" + assert provider._user_id == "old-user" + + def test_is_available_platform_needs_key(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.delenv("MEM0_API_KEY", raising=False) + provider = Mem0MemoryProvider() + assert provider.is_available() is False + + def test_is_available_oss_needs_vector(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config_path = tmp_path / "mem0.json" + config_path.write_text('{"mode": "oss", "oss": {"vector_store": {"provider": "qdrant"}}}') + provider = Mem0MemoryProvider() + assert provider.is_available() is True + + def test_is_available_oss_no_vector(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + config_path = tmp_path / "mem0.json" + config_path.write_text('{"mode": "oss", "oss": {}}') + provider = Mem0MemoryProvider() + assert provider.is_available() is False + + def test_tool_schemas_unchanged(self): + provider = Mem0MemoryProvider() + schemas = provider.get_tool_schemas() + names = [s["name"] for s in schemas] + assert names == ["mem0_list", "mem0_search", "mem0_add", "mem0_update", "mem0_delete"] + + def test_system_prompt_includes_mode(self): + provider = Mem0MemoryProvider() + provider._user_id = "test" + provider._mode = "oss" + block = provider.system_prompt_block() + assert "mem0_search" in block + assert "mem0_list" in block + assert "OSS" in block + + +class TestMem0UserIdResolution: + """user_id resolution: configured override > gateway-native id > placeholder. + + Same human across CLI / Telegram / Discord / Slack / etc. should map to + the same memory store when MEM0_USER_ID is set, and only fall back to the + gateway-native id when it isn't. + """ + + def _provider(self, monkeypatch, tmp_path): + monkeypatch.setenv("HERMES_HOME", str(tmp_path)) + monkeypatch.setenv("MEM0_API_KEY", "test-key") + provider = Mem0MemoryProvider() + # Skip backend instantiation — we only care about identity resolution. + provider._create_backend = lambda: None # type: ignore[method-assign] + return provider + + def test_env_override_beats_gateway_native_id(self, monkeypatch, tmp_path): + monkeypatch.setenv("MEM0_USER_ID", "ryan@example.com") + provider = self._provider(monkeypatch, tmp_path) + provider.initialize("test", user_id="123456789", platform="telegram") + assert provider._user_id == "ryan@example.com" + + def test_file_override_beats_gateway_native_id(self, monkeypatch, tmp_path): + monkeypatch.delenv("MEM0_USER_ID", raising=False) + (tmp_path / "mem0.json").write_text('{"user_id": "ryan@example.com"}') + provider = self._provider(monkeypatch, tmp_path) + provider.initialize("test", user_id="123456789", platform="telegram") + assert provider._user_id == "ryan@example.com" + + def test_unset_falls_back_to_gateway_native_id(self, monkeypatch, tmp_path): + monkeypatch.delenv("MEM0_USER_ID", raising=False) + provider = self._provider(monkeypatch, tmp_path) + provider.initialize("test", user_id="123456789", platform="telegram") + assert provider._user_id == "123456789" + + def test_unset_and_no_kwargs_falls_back_to_default(self, monkeypatch, tmp_path): + monkeypatch.delenv("MEM0_USER_ID", raising=False) + provider = self._provider(monkeypatch, tmp_path) + provider.initialize("test") + assert provider._user_id == "hermes-user" + + def test_legacy_placeholder_in_config_does_not_override_kwargs(self, monkeypatch, tmp_path): + # Setup wizard historically wrote {"user_id": "hermes-user"} as the + # suggested default. Treat that placeholder as unset so users on + # gateways still get gateway-native ids — not silent collisions. + monkeypatch.delenv("MEM0_USER_ID", raising=False) + (tmp_path / "mem0.json").write_text('{"user_id": "hermes-user"}') + provider = self._provider(monkeypatch, tmp_path) + provider.initialize("test", user_id="123456789", platform="telegram") + assert provider._user_id == "123456789" + + +class TestMem0WriteMetadata: + """Writes carry metadata.channel so per-channel filtered views are possible + without coupling identity to the channel. + """ + + def _make_provider(self, channel: str = "cli"): + provider = Mem0MemoryProvider() + provider._user_id = "u123" + provider._agent_id = "hermes" + provider._channel = channel + provider._backend = FakeBackend() + return provider + + def test_add_tool_passes_channel_metadata(self): + provider = self._make_provider("telegram") + provider.handle_tool_call("mem0_add", {"content": "user likes dark mode"}) + call = provider._backend.captured[-1] + assert call[2]["metadata"] == {"channel": "telegram"} + + def test_sync_turn_passes_channel_metadata(self): + provider = self._make_provider("discord") + provider.sync_turn("hi", "hello", session_id="s") + # sync_turn fires a daemon thread; wait for it. + if provider._sync_thread: + provider._sync_thread.join(timeout=5.0) + adds = [c for c in provider._backend.captured if c[0] == "add"] + assert adds, "expected an add call from sync_turn" + assert adds[-1][2]["metadata"] == {"channel": "discord"} diff --git a/website/docs/user-guide/features/memory-providers.md b/website/docs/user-guide/features/memory-providers.md index e3054cf236a..6ba95342b49 100644 --- a/website/docs/user-guide/features/memory-providers.md +++ b/website/docs/user-guide/features/memory-providers.md @@ -315,31 +315,55 @@ echo "OPENVIKING_API_KEY=..." >> ~/.hermes/.env ### Mem0 -Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication. +Server-side LLM fact extraction with semantic search, reranking, and automatic deduplication. Supports both Mem0 Platform (cloud) and OSS (self-hosted) modes. | | | |---|---| | **Best for** | Hands-off memory management — Mem0 handles extraction automatically | -| **Requires** | `pip install mem0ai` + API key | -| **Data storage** | Mem0 Cloud | -| **Cost** | Mem0 pricing | +| **Requires** | `pip install mem0ai` + API key (platform) or LLM/vector store (OSS) | +| **Data storage** | Mem0 Cloud (platform) or self-hosted (OSS) | +| **Cost** | Mem0 pricing (platform) / free (OSS) | -**Tools:** `mem0_profile` (all stored memories), `mem0_search` (semantic search + reranking), `mem0_conclude` (store verbatim facts) +**Tools (5):** `mem0_list` (list all memories, paginated), `mem0_search` (semantic search with reranking in platform mode), `mem0_add` (store verbatim facts), `mem0_update` (update by ID), `mem0_delete` (delete by ID) -**Setup:** +**Setup (Platform):** ```bash -hermes memory setup # select "mem0" +hermes memory setup # select "mem0" → "Platform" # Or manually: hermes config set memory.provider mem0 echo "MEM0_API_KEY=your-key" >> ~/.hermes/.env ``` -**Config:** `$HERMES_HOME/mem0.json` +**Setup (OSS):** +```bash +hermes memory setup # select "mem0" → "Open Source (self-hosted)" +# Or via flags: +hermes memory setup mem0 --mode oss --oss-llm openai --oss-llm-key sk-... --oss-vector qdrant +``` + +Preview without writing files: +```bash +hermes memory setup mem0 --mode oss --oss-llm-key sk-... --dry-run +``` + +**Config:** `$HERMES_HOME/mem0.json` (behavioral settings). Only the secret `MEM0_API_KEY` belongs in `~/.hermes/.env`. | Key | Default | Description | |-----|---------|-------------| +| `mode` | `platform` | `platform` (Mem0 Cloud) or `oss` (self-hosted) | | `user_id` | `hermes-user` | User identifier | | `agent_id` | `hermes` | Agent identifier | +| `rerank` | `true` | Rerank search results for relevance (platform mode only) | + +**OSS supported providers:** + +| Component | Providers | +|-----------|-----------| +| LLM | openai, ollama | +| Embedder | openai, ollama | +| Vector Store | qdrant (local/server), pgvector | + +**Switching modes:** Re-run `hermes memory setup mem0 --mode ` or edit `mem0.json` directly. --- @@ -569,7 +593,7 @@ hermes memory setup |----------|---------|------|-------|-------------|----------------| | **Honcho** | Cloud | Paid | 5 | `honcho-ai` | Dialectic user modeling + session-scoped context | | **OpenViking** | Self-hosted | Free | 5 | `openviking` + server | Filesystem hierarchy + tiered loading | -| **Mem0** | Cloud | Paid | 3 | `mem0ai` | Server-side LLM extraction | +| **Mem0** | Cloud/Self-hosted | Free/Paid | 5 | `mem0ai` | Server-side LLM extraction + OSS mode | | **Hindsight** | Cloud/Local | Free/Paid | 3 | `hindsight-client` | Knowledge graph + reflect synthesis | | **Holographic** | Local | Free | 2 | None | HRR algebra + trust scoring | | **RetainDB** | Cloud | $20/mo | 5 | `requests` | Delta compression |