diff --git a/plugins/memory/supermemory/__init__.py b/plugins/memory/supermemory/__init__.py index 05583fae3..f798b4a14 100644 --- a/plugins/memory/supermemory/__init__.py +++ b/plugins/memory/supermemory/__init__.py @@ -391,6 +391,7 @@ class SupermemoryMemoryProvider(MemoryProvider): self._prefetch_lock = threading.Lock() self._prefetch_thread: Optional[threading.Thread] = None self._sync_thread: Optional[threading.Thread] = None + self._write_thread: Optional[threading.Thread] = None self._auto_recall = True self._auto_capture = True self._max_recall_results = _DEFAULT_MAX_RECALL_RESULTS @@ -524,6 +525,7 @@ class SupermemoryMemoryProvider(MemoryProvider): if self._sync_thread and self._sync_thread.is_alive(): self._sync_thread.join(timeout=2.0) + self._sync_thread = None self._sync_thread = threading.Thread(target=_run, daemon=True, name="supermemory-sync") self._sync_thread.start() @@ -565,7 +567,18 @@ class SupermemoryMemoryProvider(MemoryProvider): except Exception: logger.debug("Supermemory on_memory_write failed", exc_info=True) - threading.Thread(target=_run, daemon=True, name="supermemory-memory-write").start() + if self._write_thread and self._write_thread.is_alive(): + self._write_thread.join(timeout=2.0) + self._write_thread = None + self._write_thread = threading.Thread(target=_run, daemon=False, name="supermemory-memory-write") + self._write_thread.start() + + def shutdown(self) -> None: + for attr_name in ("_prefetch_thread", "_sync_thread", "_write_thread"): + thread = getattr(self, attr_name, None) + if thread and thread.is_alive(): + thread.join(timeout=5.0) + setattr(self, attr_name, None) def get_tool_schemas(self) -> List[Dict[str, Any]]: return [STORE_SCHEMA, SEARCH_SCHEMA, FORGET_SCHEMA, PROFILE_SCHEMA] diff --git a/tests/plugins/memory/test_supermemory_provider.py b/tests/plugins/memory/test_supermemory_provider.py index 0bee1d215..689793f15 100644 --- a/tests/plugins/memory/test_supermemory_provider.py +++ b/tests/plugins/memory/test_supermemory_provider.py @@ -1,4 +1,5 @@ import json +import threading import pytest @@ -163,6 +164,52 @@ def test_on_session_end_ingests_clean_messages(provider): ] +def test_on_memory_write_tracks_thread(provider): + provider.on_memory_write("add", "memory", "Jordan likes concise docs") + assert provider._write_thread is not None + provider._write_thread.join(timeout=1) + assert len(provider._client.add_calls) == 1 + assert provider._client.add_calls[0]["metadata"]["type"] == "explicit_memory" + + +def test_shutdown_joins_and_clears_threads(provider, monkeypatch): + started = threading.Event() + release = threading.Event() + + def slow_add_memory(content, metadata=None, *, entity_context=""): + started.set() + release.wait(timeout=1) + provider._client.add_calls.append({ + "content": content, + "metadata": metadata, + "entity_context": entity_context, + }) + return {"id": "mem_slow"} + + monkeypatch.setattr(provider._client, "add_memory", slow_add_memory) + + provider.sync_turn( + "Please remember this request in long-term memory", + "Absolutely, I will keep that in long-term memory.", + session_id="session-1", + ) + assert started.wait(timeout=1) + assert provider._sync_thread is not None + + started.clear() + provider.on_memory_write("add", "memory", "Jordan likes concise docs") + assert started.wait(timeout=1) + assert provider._write_thread is not None + + release.set() + provider.shutdown() + + assert provider._sync_thread is None + assert provider._write_thread is None + assert provider._prefetch_thread is None + assert len(provider._client.add_calls) == 2 + + def test_store_tool_returns_saved_payload(provider): result = json.loads(provider.handle_tool_call("supermemory_store", {"content": "Jordan likes concise docs"})) assert result["saved"] is True