mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-01 01:51:44 +00:00
fix(hindsight): drain retain queue cleanly on shutdown
The plugin used to spawn one daemon thread per sync_turn() to do the
aretain_batch network write. On CLI exit, that pattern raced interpreter
shutdown — the last retain could reach aiohttp after asyncio's
"cannot schedule new futures" guard had fired, producing noisy logs and
silently losing the final unsaved turn:
WARNING ... Hindsight sync failed: cannot schedule new futures after
interpreter shutdown
ERROR asyncio: Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x...>
Switch to a single-writer model: each provider owns one long-lived
writer thread plus a queue. sync_turn() snapshots state and enqueues a
job; the writer drains sequentially. Once shutdown() is called:
- new sync_turn() / queue_prefetch() calls are dropped, not enqueued
- a sentinel wakes the writer so it finishes in-flight work
- shutdown joins the writer (10s) before nulling the client
Also register an idempotent atexit hook from the first sync_turn(), so
exit paths that don't go through MemoryManager.shutdown_all() (Ctrl-C,
abrupt exit) still get a chance to drain.
Tests: keep _sync_thread as a legacy alias to the writer, swap join()
calls to _retain_queue.join() (canonical wait-for-drain), add a new
TestShutdownRace suite covering single-writer reuse, post-shutdown drop,
queue draining, and shutdown idempotency.
This commit is contained in:
parent
5662ac2afc
commit
0565497dcc
2 changed files with 228 additions and 57 deletions
|
|
@ -669,7 +669,7 @@ class TestSyncTurn:
|
|||
p._client = _make_mock_client()
|
||||
|
||||
p.sync_turn("hello", "hi there")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
|
|
@ -710,8 +710,7 @@ class TestSyncTurn:
|
|||
def test_sync_turn_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(retain_tags=["conv", "session1"])
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "conv" in item["tags"]
|
||||
assert "session1" in item["tags"]
|
||||
|
|
@ -720,8 +719,7 @@ class TestSyncTurn:
|
|||
def test_sync_turn_uses_aretain_batch(self, provider):
|
||||
"""sync_turn should use aretain_batch with retain_async."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
|
|
@ -732,8 +730,7 @@ class TestSyncTurn:
|
|||
def test_sync_turn_custom_context(self, provider_with_config):
|
||||
p = provider_with_config(retain_context="my-agent")
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["context"] == "my-agent"
|
||||
|
||||
|
|
@ -744,7 +741,7 @@ class TestSyncTurn:
|
|||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
|
|
@ -765,15 +762,13 @@ class TestSyncTurn:
|
|||
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
p._client.aretain_batch.reset_mock()
|
||||
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p.sync_turn("turn4-user", "turn4-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain ALL turns from the session
|
||||
|
|
@ -785,8 +780,7 @@ class TestSyncTurn:
|
|||
def test_sync_turn_passes_document_id(self, provider):
|
||||
"""sync_turn should pass document_id (session_id + per-startup ts)."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
|
||||
assert call_kwargs["document_id"].startswith("test-session-")
|
||||
|
|
@ -819,8 +813,7 @@ class TestSyncTurn:
|
|||
def test_sync_turn_session_tag(self, provider):
|
||||
"""Each retain should be tagged with session:<id> for filtering."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "session:test-session" in item["tags"]
|
||||
|
||||
|
|
@ -841,8 +834,7 @@ class TestSyncTurn:
|
|||
)
|
||||
p._client = _make_mock_client()
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert "session:child-session" in item["tags"]
|
||||
|
|
@ -851,15 +843,14 @@ class TestSyncTurn:
|
|||
def test_sync_turn_error_does_not_raise(self, provider):
|
||||
provider._client.aretain_batch.side_effect = RuntimeError("network error")
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._retain_queue.join()
|
||||
|
||||
def test_sync_turn_preserves_unicode(self, provider_with_config):
|
||||
"""Non-ASCII text (CJK, ZWJ emoji) must survive JSON round-trip intact."""
|
||||
p = provider_with_config()
|
||||
p._client = _make_mock_client()
|
||||
p.sync_turn("안녕 こんにちは 你好", "👨👩👧👦 family")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
p._retain_queue.join()
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
# ensure_ascii=False means non-ASCII chars appear as-is in the raw JSON,
|
||||
|
|
@ -871,6 +862,71 @@ class TestSyncTurn:
|
|||
assert "👨👩👧👦" in raw_json
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shutdown / writer tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShutdownRace:
|
||||
def test_sync_turn_uses_single_writer_thread(self, provider):
|
||||
"""All retains run through one long-lived writer thread."""
|
||||
provider.sync_turn("a", "b")
|
||||
provider._retain_queue.join()
|
||||
first_writer = provider._writer_thread
|
||||
assert first_writer is not None
|
||||
assert first_writer.is_alive()
|
||||
|
||||
provider.sync_turn("c", "d")
|
||||
provider._retain_queue.join()
|
||||
# Same thread reused — no ad-hoc thread per call.
|
||||
assert provider._writer_thread is first_writer
|
||||
assert provider._client.aretain_batch.call_count == 2
|
||||
|
||||
def test_sync_turn_after_shutdown_is_dropped(self, provider):
|
||||
"""Once shutdown has fired, new sync_turn() calls are no-ops.
|
||||
|
||||
This is the core of the fix: the plugin must not enqueue a retain
|
||||
during interpreter teardown — that's what causes the
|
||||
'cannot schedule new futures' RuntimeError + unclosed aiohttp
|
||||
sessions on CLI exit.
|
||||
"""
|
||||
client = provider._client
|
||||
provider.shutdown()
|
||||
before_calls = client.aretain_batch.call_count
|
||||
provider.sync_turn("late", "turn")
|
||||
# No new enqueue — the retain queue stays empty.
|
||||
assert provider._retain_queue.empty()
|
||||
# And no new client call (would be impossible anyway since shutdown
|
||||
# nulled self._client; we assert via the captured handle).
|
||||
assert client.aretain_batch.call_count == before_calls
|
||||
|
||||
def test_queue_prefetch_after_shutdown_is_dropped(self, provider):
|
||||
provider.shutdown()
|
||||
provider.queue_prefetch("late query")
|
||||
assert provider._prefetch_thread is None
|
||||
|
||||
def test_shutdown_drains_pending_retains(self, provider):
|
||||
"""Shutdown must wait for queued retains to complete, not abandon them.
|
||||
|
||||
Otherwise the LAST in-flight turn — typically the most important —
|
||||
is silently lost.
|
||||
"""
|
||||
client = provider._client
|
||||
provider.sync_turn("a", "b")
|
||||
provider.sync_turn("c", "d")
|
||||
provider.shutdown()
|
||||
# Both retains drained before shutdown returned.
|
||||
assert client.aretain_batch.call_count == 2
|
||||
assert provider._retain_queue.empty()
|
||||
|
||||
def test_shutdown_is_idempotent(self, provider):
|
||||
provider.sync_turn("a", "b")
|
||||
provider.shutdown()
|
||||
# Second shutdown shouldn't blow up or re-close the client.
|
||||
provider.shutdown()
|
||||
assert provider._shutting_down.is_set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# System prompt tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue