From c7e0501e9b58dd1e52fa7944e2b55dc60582af7c Mon Sep 17 00:00:00 2001 From: Hao Zhe Date: Mon, 22 Jun 2026 13:05:52 +0800 Subject: [PATCH] fix(openviking): drain memory mirror workers on shutdown --- plugins/memory/openviking/__init__.py | 20 +++++++- .../memory/test_openviking_provider.py | 48 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index c3b652c3d22..030f6a59aa1 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -1793,6 +1793,8 @@ class OpenVikingMemoryProvider(MemoryProvider): self._prefetch_thread: Optional[threading.Thread] = None self._runtime_start_lock = threading.Lock() self._runtime_start_thread: Optional[threading.Thread] = None + self._memory_write_lock = threading.Lock() + self._memory_write_threads: Set[threading.Thread] = set() # All prefetch threads ever spawned (daemon, short-lived). Tracked so # shutdown() can drain them and rapid re-queues don't orphan a still- # running thread by overwriting the single _prefetch_thread slot. @@ -2901,9 +2903,20 @@ class OpenVikingMemoryProvider(MemoryProvider): }) except Exception as e: logger.debug("OpenViking memory mirror failed: %s", e) + finally: + with self._memory_write_lock: + self._memory_write_threads.discard(threading.current_thread()) t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite") - t.start() + with self._memory_write_lock: + if self._shutting_down: + return + self._memory_write_threads.add(t) + try: + t.start() + except Exception as e: + self._memory_write_threads.discard(t) + logger.debug("OpenViking memory mirror worker failed to start: %s", e) def get_tool_schemas(self) -> List[Dict[str, Any]]: return [ @@ -2949,6 +2962,8 @@ class OpenVikingMemoryProvider(MemoryProvider): deferred_workers = list(self._deferred_commit_threads) with self._prefetch_lock: prefetch_workers = list(self._prefetch_threads) + with self._memory_write_lock: + memory_write_workers = list(self._memory_write_threads) for t in all_workers: if t.is_alive(): t.join(timeout=5.0) @@ -2958,6 +2973,9 @@ class OpenVikingMemoryProvider(MemoryProvider): for t in prefetch_workers: if t.is_alive(): t.join(timeout=5.0) + for t in memory_write_workers: + if t.is_alive(): + t.join(timeout=5.0) # Clear atexit reference so it doesn't double-commit. global _last_active_provider if _last_active_provider is self: diff --git a/tests/plugins/memory/test_openviking_provider.py b/tests/plugins/memory/test_openviking_provider.py index d5b5f347994..f176492ca95 100644 --- a/tests/plugins/memory/test_openviking_provider.py +++ b/tests/plugins/memory/test_openviking_provider.py @@ -2746,6 +2746,54 @@ def test_on_memory_write_uses_content_write_independent_of_session_rotation(): ) +def test_shutdown_waits_for_memory_write_worker(monkeypatch): + import threading + + provider = OpenVikingMemoryProvider() + provider._client = MagicMock() + provider._endpoint = "http://test" + provider._api_key = "" + provider._account = "acct" + provider._user = "usr" + provider._agent = "hermes" + + worker_started = threading.Event() + release_worker = threading.Event() + worker_finished = threading.Event() + shutdown_returned = threading.Event() + + class StubClient: + def __init__(self, *a, **kw): + pass + + def post(self, path, payload=None, **kwargs): + assert path == "/api/v1/content/write" + worker_started.set() + release_worker.wait(timeout=2.0) + worker_finished.set() + return {} + + monkeypatch.setattr(openviking_module, "_VikingClient", StubClient) + + provider.on_memory_write("add", "user", "remember this") + assert worker_started.wait(timeout=2.0), "worker never entered post()" + + shutdown_thread = threading.Thread( + target=lambda: (provider.shutdown(), shutdown_returned.set()), + daemon=True, + ) + shutdown_thread.start() + + returned_before_worker_finished = shutdown_returned.wait(timeout=0.1) + release_worker.set() + assert shutdown_returned.wait(timeout=2.0), "shutdown did not return after worker finished" + shutdown_thread.join(timeout=2.0) + + assert not returned_before_worker_finished + assert worker_finished.is_set() + assert provider._memory_write_threads == set() + + @pytest.mark.parametrize( ("action", "content"), [