mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-23 10:42:00 +00:00
fix(openviking): drain memory mirror workers on shutdown
This commit is contained in:
parent
70e7132e2f
commit
c7e0501e9b
2 changed files with 67 additions and 1 deletions
|
|
@ -1793,6 +1793,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
self._prefetch_thread: Optional[threading.Thread] = None
|
||||
self._runtime_start_lock = threading.Lock()
|
||||
self._runtime_start_thread: Optional[threading.Thread] = None
|
||||
self._memory_write_lock = threading.Lock()
|
||||
self._memory_write_threads: Set[threading.Thread] = set()
|
||||
# All prefetch threads ever spawned (daemon, short-lived). Tracked so
|
||||
# shutdown() can drain them and rapid re-queues don't orphan a still-
|
||||
# running thread by overwriting the single _prefetch_thread slot.
|
||||
|
|
@ -2901,9 +2903,20 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
})
|
||||
except Exception as e:
|
||||
logger.debug("OpenViking memory mirror failed: %s", e)
|
||||
finally:
|
||||
with self._memory_write_lock:
|
||||
self._memory_write_threads.discard(threading.current_thread())
|
||||
|
||||
t = threading.Thread(target=_write, daemon=True, name="openviking-memwrite")
|
||||
t.start()
|
||||
with self._memory_write_lock:
|
||||
if self._shutting_down:
|
||||
return
|
||||
self._memory_write_threads.add(t)
|
||||
try:
|
||||
t.start()
|
||||
except Exception as e:
|
||||
self._memory_write_threads.discard(t)
|
||||
logger.debug("OpenViking memory mirror worker failed to start: %s", e)
|
||||
|
||||
def get_tool_schemas(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
|
|
@ -2949,6 +2962,8 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
deferred_workers = list(self._deferred_commit_threads)
|
||||
with self._prefetch_lock:
|
||||
prefetch_workers = list(self._prefetch_threads)
|
||||
with self._memory_write_lock:
|
||||
memory_write_workers = list(self._memory_write_threads)
|
||||
for t in all_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
|
|
@ -2958,6 +2973,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
|
|||
for t in prefetch_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
for t in memory_write_workers:
|
||||
if t.is_alive():
|
||||
t.join(timeout=5.0)
|
||||
# Clear atexit reference so it doesn't double-commit.
|
||||
global _last_active_provider
|
||||
if _last_active_provider is self:
|
||||
|
|
|
|||
|
|
@ -2746,6 +2746,54 @@ def test_on_memory_write_uses_content_write_independent_of_session_rotation():
|
|||
)
|
||||
|
||||
|
||||
def test_shutdown_waits_for_memory_write_worker(monkeypatch):
|
||||
import threading
|
||||
|
||||
provider = OpenVikingMemoryProvider()
|
||||
provider._client = MagicMock()
|
||||
provider._endpoint = "http://test"
|
||||
provider._api_key = ""
|
||||
provider._account = "acct"
|
||||
provider._user = "usr"
|
||||
provider._agent = "hermes"
|
||||
|
||||
worker_started = threading.Event()
|
||||
release_worker = threading.Event()
|
||||
worker_finished = threading.Event()
|
||||
shutdown_returned = threading.Event()
|
||||
|
||||
class StubClient:
|
||||
def __init__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
def post(self, path, payload=None, **kwargs):
|
||||
assert path == "/api/v1/content/write"
|
||||
worker_started.set()
|
||||
release_worker.wait(timeout=2.0)
|
||||
worker_finished.set()
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(openviking_module, "_VikingClient", StubClient)
|
||||
|
||||
provider.on_memory_write("add", "user", "remember this")
|
||||
assert worker_started.wait(timeout=2.0), "worker never entered post()"
|
||||
|
||||
shutdown_thread = threading.Thread(
|
||||
target=lambda: (provider.shutdown(), shutdown_returned.set()),
|
||||
daemon=True,
|
||||
)
|
||||
shutdown_thread.start()
|
||||
|
||||
returned_before_worker_finished = shutdown_returned.wait(timeout=0.1)
|
||||
release_worker.set()
|
||||
assert shutdown_returned.wait(timeout=2.0), "shutdown did not return after worker finished"
|
||||
shutdown_thread.join(timeout=2.0)
|
||||
|
||||
assert not returned_before_worker_finished
|
||||
assert worker_finished.is_set()
|
||||
assert provider._memory_write_threads == set()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("action", "content"),
|
||||
[
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue