fix(openviking): drain memory mirror workers on shutdown

This commit is contained in:
Hao Zhe 2026-06-22 13:05:52 +08:00 committed by Teknium
parent 70e7132e2f
commit c7e0501e9b
2 changed files with 67 additions and 1 deletions

View file

@ -1793,6 +1793,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._prefetch_thread: Optional[threading.Thread] = None
self._runtime_start_lock = threading.Lock()
self._runtime_start_thread: Optional[threading.Thread] = None
self._memory_write_lock = threading.Lock()
self._memory_write_threads: Set[threading.Thread] = set()
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
# shutdown() can drain them and rapid re-queues don't orphan a still-
# running thread by overwriting the single _prefetch_thread slot.
@ -2901,9 +2903,20 @@ class OpenVikingMemoryProvider(MemoryProvider):
})
except Exception as e:
logger.debug("OpenViking memory mirror failed: %s", e)
finally:
with self._memory_write_lock:
self._memory_write_threads.discard(threading.current_thread())
t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite")
t.start()
with self._memory_write_lock:
if self._shutting_down:
return
self._memory_write_threads.add(t)
try:
t.start()
except Exception as e:
self._memory_write_threads.discard(t)
logger.debug("OpenViking memory mirror worker failed to start: %s", e)
def get_tool_schemas(self) -> List[Dict[str, Any]]:
return [
@ -2949,6 +2962,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
deferred_workers = list(self._deferred_commit_threads)
with self._prefetch_lock:
prefetch_workers = list(self._prefetch_threads)
with self._memory_write_lock:
memory_write_workers = list(self._memory_write_threads)
for t in all_workers:
if t.is_alive():
t.join(timeout=5.0)
@ -2958,6 +2973,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
for t in prefetch_workers:
if t.is_alive():
t.join(timeout=5.0)
for t in memory_write_workers:
if t.is_alive():
t.join(timeout=5.0)
# Clear atexit reference so it doesn't double-commit.
global _last_active_provider
if _last_active_provider is self:

View file

@ -2746,6 +2746,54 @@ def test_on_memory_write_uses_content_write_independent_of_session_rotation():
)
def test_shutdown_waits_for_memory_write_worker(monkeypatch):
import threading
provider = OpenVikingMemoryProvider()
provider._client = MagicMock()
provider._endpoint = "http://test"
provider._api_key = ""
provider._account = "acct"
provider._user = "usr"
provider._agent = "hermes"
worker_started = threading.Event()
release_worker = threading.Event()
worker_finished = threading.Event()
shutdown_returned = threading.Event()
class StubClient:
def __init__(self, *a, **kw):
pass
def post(self, path, payload=None, **kwargs):
assert path == "/api/v1/content/write"
worker_started.set()
release_worker.wait(timeout=2.0)
worker_finished.set()
return {}
monkeypatch.setattr(openviking_module, "_VikingClient", StubClient)
provider.on_memory_write("add", "user", "remember this")
assert worker_started.wait(timeout=2.0), "worker never entered post()"
shutdown_thread = threading.Thread(
target=lambda: (provider.shutdown(), shutdown_returned.set()),
daemon=True,
)
shutdown_thread.start()
returned_before_worker_finished = shutdown_returned.wait(timeout=0.1)
release_worker.set()
assert shutdown_returned.wait(timeout=2.0), "shutdown did not return after worker finished"
shutdown_thread.join(timeout=2.0)
assert not returned_before_worker_finished
assert worker_finished.is_set()
assert provider._memory_write_threads == set()
@pytest.mark.parametrize(
("action", "content"),
[