fix(memory): run end-of-turn sync off the turn thread (#41945)

A misconfigured/slow external memory provider could hold the agent in
the 'running' state for minutes after the final response was delivered.
MemoryManager.sync_all / queue_prefetch_all looped provider.sync_turn /
queue_prefetch INLINE on the turn-completion path; a provider making a
blocking network/daemon call (a broken Hindsight daemon was observed
blocking ~298s before failing) blocked run_conversation from returning.
Because every interface (CLI, TUI, gateway) marks the agent 'running'
until run_conversation returns, the agent stayed busy for the full block
and any follow-up message triggered an aggressive interrupt that dropped
the message.

Dispatch provider sync/prefetch to a lazily-created single-worker
background executor. sync_all / queue_prefetch_all return immediately;
work completes (or fails, logged) in the background. A single worker
serializes writes so turn N lands before turn N+1. flush_pending()
provides a barrier for session boundaries and deterministic tests.
shutdown_all() drains the executor with a bounded timeout so a wedged
provider can never hang teardown.

Builtin-only / no-provider sessions spawn no executor (zero new threads
in the common case).
This commit is contained in:
Teknium 2026-06-08 02:18:59 -07:00 committed by GitHub
parent a5c12f5f59
commit aa6f2775fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 348 additions and 31 deletions

View file

@ -28,6 +28,8 @@ from __future__ import annotations
import logging
import re
import inspect
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional
from agent.memory_provider import MemoryProvider
@ -35,6 +37,12 @@ from tools.registry import tool_error
logger = logging.getLogger(__name__)
# How long shutdown_all() waits for in-flight background sync/prefetch work
# to drain before abandoning it. A wedged provider must never block process
# teardown indefinitely — the worker threads are daemon, so anything still
# running past this window dies with the interpreter.
_SYNC_DRAIN_TIMEOUT_S = 5.0
# ---------------------------------------------------------------------------
# Context fencing helpers
@ -252,6 +260,13 @@ class MemoryManager:
self._providers: List[MemoryProvider] = []
self._tool_to_provider: Dict[str, MemoryProvider] = {}
self._has_external: bool = False # True once a non-builtin provider is added
# Background executor for end-of-turn sync/prefetch. Lazily created on
# first use so the common builtin-only path spawns no extra threads.
# A single worker serializes a provider's writes (turn N must land
# before turn N+1) and caps thread growth at one per manager. See
# _submit_background() and the sync_all/queue_prefetch_all rationale.
self._sync_executor: Optional[ThreadPoolExecutor] = None
self._sync_executor_lock = threading.Lock()
# -- Registration --------------------------------------------------------
@ -375,15 +390,27 @@ class MemoryManager:
return "\n\n".join(parts)
def queue_prefetch_all(self, query: str, *, session_id: str = "") -> None:
"""Queue background prefetch on all providers for the next turn."""
for provider in self._providers:
try:
provider.queue_prefetch(query, session_id=session_id)
except Exception as e:
logger.debug(
"Memory provider '%s' queue_prefetch failed (non-fatal): %s",
provider.name, e,
)
"""Queue background prefetch on all providers for the next turn.
Provider work is dispatched to a background worker so a slow or
wedged provider can never block the caller. See ``sync_all`` for
the full rationale (agent stuck "running" minutes after a turn).
"""
providers = list(self._providers)
if not providers:
return
def _run() -> None:
for provider in providers:
try:
provider.queue_prefetch(query, session_id=session_id)
except Exception as e:
logger.debug(
"Memory provider '%s' queue_prefetch failed (non-fatal): %s",
provider.name, e,
)
self._submit_background(_run)
# -- Sync ----------------------------------------------------------------
@ -407,27 +434,120 @@ class MemoryManager:
session_id: str = "",
messages: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Sync a completed turn to all providers."""
for provider in self._providers:
"""Sync a completed turn to all providers.
Runs on a background worker thread, NOT inline on the
turn-completion path. A provider's ``sync_turn`` may make a
blocking network/daemon call (a misconfigured Hindsight daemon
was observed blocking ~298s before failing); doing that inline
held ``run_conversation`` open long after the user saw their
response, so every interface (CLI, TUI, gateway) kept the agent
marked "running" for minutes and any follow-up message triggered
an aggressive interrupt. Dispatching off-thread means a slow or
broken provider can never stall the turn the sync simply
completes (or fails, logged) in the background.
Writes are serialized through a single worker so turn N lands
before turn N+1; provider implementations don't need their own
ordering guarantees.
"""
providers = list(self._providers)
if not providers:
return
def _run() -> None:
for provider in providers:
try:
if messages is not None and self._provider_sync_accepts_messages(provider):
provider.sync_turn(
user_content,
assistant_content,
session_id=session_id,
messages=messages,
)
else:
provider.sync_turn(
user_content,
assistant_content,
session_id=session_id,
)
except Exception as e:
logger.warning(
"Memory provider '%s' sync_turn failed: %s",
provider.name, e,
)
self._submit_background(_run)
# -- Background dispatch -------------------------------------------------
def _submit_background(self, fn) -> None:
"""Run ``fn`` on the manager's background worker.
The executor is created lazily and shared across calls. If the
executor can't be created or has already been shut down, ``fn``
runs inline as a last-resort fallback losing the async benefit
but never losing the write itself. ``fn`` must do its own
per-provider error handling; this wrapper only guards executor
plumbing.
"""
executor = self._get_sync_executor()
if executor is None:
# Executor unavailable (shut down / creation failed) — run
# inline rather than drop the work. Slow, but correct.
try:
if messages is not None and self._provider_sync_accepts_messages(provider):
provider.sync_turn(
user_content,
assistant_content,
session_id=session_id,
messages=messages,
fn()
except Exception as e: # pragma: no cover - fn guards internally
logger.debug("Inline memory background task failed: %s", e)
return
try:
executor.submit(fn)
except RuntimeError:
# Executor was shut down between the get and the submit
# (teardown race). Fall back to inline.
try:
fn()
except Exception as e: # pragma: no cover - fn guards internally
logger.debug("Inline memory background task failed: %s", e)
def _get_sync_executor(self) -> Optional[ThreadPoolExecutor]:
"""Lazily create the single-worker background executor."""
if self._sync_executor is not None:
return self._sync_executor
with self._sync_executor_lock:
if self._sync_executor is None:
try:
self._sync_executor = ThreadPoolExecutor(
max_workers=1,
thread_name_prefix="mem-sync",
)
else:
provider.sync_turn(
user_content,
assistant_content,
session_id=session_id,
)
except Exception as e:
logger.warning(
"Memory provider '%s' sync_turn failed: %s",
provider.name, e,
)
except Exception as e: # pragma: no cover - resource exhaustion
logger.warning("Failed to create memory sync executor: %s", e)
return None
return self._sync_executor
def flush_pending(self, timeout: Optional[float] = None) -> bool:
"""Block until queued sync/prefetch work has drained.
Single-worker executor means submitting a sentinel and waiting on
it guarantees every previously-submitted task has run. Returns
True if the barrier completed within ``timeout`` (or no executor
exists), False on timeout. Used at real session boundaries and by
tests that need to assert provider state deterministically.
"""
executor = self._sync_executor
if executor is None:
return True
try:
fut = executor.submit(lambda: None)
except RuntimeError:
# Executor already shut down — nothing pending.
return True
try:
fut.result(timeout=timeout)
return True
except Exception:
return False
# -- Tools ---------------------------------------------------------------
@ -653,7 +773,15 @@ class MemoryManager:
)
def shutdown_all(self) -> None:
"""Shut down all providers (reverse order for clean teardown)."""
"""Shut down all providers (reverse order for clean teardown).
Drains the background sync/prefetch executor first (bounded by
``_SYNC_DRAIN_TIMEOUT_S``) so a turn's final sync has a chance to
land before providers are torn down. The worker threads are
daemon, so anything still wedged past the drain window dies with
the interpreter rather than blocking exit.
"""
self._drain_sync_executor()
for provider in reversed(self._providers):
try:
provider.shutdown()
@ -663,6 +791,52 @@ class MemoryManager:
provider.name, e,
)
def _drain_sync_executor(self) -> None:
"""Shut down the background executor, waiting briefly for drain.
Bounded by ``_SYNC_DRAIN_TIMEOUT_S``: a wedged provider must never
hang process/session teardown. We stop accepting new work and
cancel anything still queued, then wait at most the drain timeout
for the currently-running task on a watcher thread. The worker is
daemon, so an over-running task dies with the interpreter.
"""
with self._sync_executor_lock:
executor = self._sync_executor
self._sync_executor = None
if executor is None:
return
try:
# Stop accepting new work and drop anything still queued, but
# do NOT block here — cancel_futures cancels not-yet-started
# tasks; the in-flight one keeps running on its daemon thread.
executor.shutdown(wait=False, cancel_futures=True)
except TypeError:
# Older Python without cancel_futures kwarg.
try:
executor.shutdown(wait=False)
except Exception as e: # pragma: no cover
logger.debug("Memory sync executor shutdown failed: %s", e)
return
except Exception as e: # pragma: no cover
logger.debug("Memory sync executor shutdown failed: %s", e)
return
# Give an in-flight sync a bounded chance to finish on a watcher
# thread so we don't block the caller past the drain timeout.
drainer = threading.Thread(
target=lambda: self._bounded_executor_wait(executor),
daemon=True,
name="mem-sync-drain",
)
drainer.start()
drainer.join(timeout=_SYNC_DRAIN_TIMEOUT_S)
@staticmethod
def _bounded_executor_wait(executor: ThreadPoolExecutor) -> None:
try:
executor.shutdown(wait=True)
except Exception as e: # pragma: no cover
logger.debug("Memory sync executor drain wait failed: %s", e)
def initialize_all(self, session_id: str, **kwargs) -> None:
"""Initialize all providers.

View file

@ -0,0 +1,138 @@
"""Regression guard: end-of-turn memory sync must not block the turn.
Before this fix, ``MemoryManager.sync_all`` / ``queue_prefetch_all`` looped
``provider.sync_turn`` / ``provider.queue_prefetch`` INLINE on the
turn-completion path. A provider making a blocking network/daemon call (a
misconfigured Hindsight daemon was observed blocking ~298s before failing)
held ``run_conversation`` open long after the user saw their response, so
every interface (CLI, TUI, gateway) kept the agent marked "running" for
minutes and any follow-up message triggered an aggressive interrupt that
dropped the message.
The fix dispatches provider work to a single-worker background executor.
``sync_all`` / ``queue_prefetch_all`` return immediately; the work completes
(or fails, logged) in the background. ``flush_pending`` provides a barrier
for session boundaries and deterministic tests. ``shutdown_all`` drains the
executor with a bounded timeout so a wedged provider can't hang teardown.
"""
import time
import pytest
from agent.memory_provider import MemoryProvider
from agent.memory_manager import MemoryManager
class _SlowProvider(MemoryProvider):
"""Provider whose sync/prefetch block, simulating a slow backend."""
_name = "slow"
def __init__(self, delay: float = 1.0):
self._delay = delay
self.sync_done = False
self.prefetch_done = False
@property
def name(self) -> str:
return self._name
def initialize(self, session_id: str = "", **kwargs) -> None:
pass
def is_available(self) -> bool:
return True
def system_prompt_block(self) -> str:
return ""
def prefetch(self, query, *, session_id: str = "") -> str:
return ""
def queue_prefetch(self, query, *, session_id: str = "") -> None:
time.sleep(self._delay)
self.prefetch_done = True
def sync_turn(self, user_content, assistant_content, *, session_id: str = "", messages=None) -> None:
time.sleep(self._delay)
self.sync_done = True
def get_tool_schemas(self):
return []
def handle_tool_call(self, tool_name, args, **kwargs) -> str:
return ""
def test_sync_all_does_not_block_on_slow_provider():
"""The crux of the fix: a slow provider must NOT stall the caller."""
mgr = MemoryManager()
mgr.add_provider(_SlowProvider(delay=2.0))
t0 = time.time()
mgr.sync_all("hi", "hey", session_id="s1")
mgr.queue_prefetch_all("hi", session_id="s1")
elapsed = time.time() - t0
# Provider blocks 2s per call inline; off-thread dispatch returns ~instantly.
assert elapsed < 0.5, f"turn-completion path blocked {elapsed:.2f}s"
def test_background_work_still_completes():
"""Dispatching off-thread must not silently drop the write."""
mgr = MemoryManager()
p = _SlowProvider(delay=0.1)
mgr.add_provider(p)
mgr.sync_all("hi", "hey", session_id="s1")
mgr.queue_prefetch_all("hi", session_id="s1")
assert mgr.flush_pending(timeout=10) is True
assert p.sync_done is True
assert p.prefetch_done is True
def test_flush_pending_no_executor_is_true():
"""flush_pending must be a no-op (return True) before any sync ran."""
mgr = MemoryManager()
assert mgr.flush_pending(timeout=1) is True
def test_no_providers_does_not_create_executor():
"""Builtin-only / no-provider sessions must not spawn an executor."""
mgr = MemoryManager()
mgr.sync_all("hi", "hey")
mgr.queue_prefetch_all("hi")
assert mgr._sync_executor is None
def test_shutdown_all_is_bounded_with_wedged_provider():
"""A provider that never returns must not hang teardown."""
mgr = MemoryManager()
mgr.add_provider(_SlowProvider(delay=30.0))
mgr.sync_all("hi", "hey")
t0 = time.time()
mgr.shutdown_all()
elapsed = time.time() - t0
# Bounded by _SYNC_DRAIN_TIMEOUT_S (5s) plus a little slack.
assert elapsed < 8.0, f"shutdown blocked {elapsed:.1f}s on wedged provider"
def test_writes_are_serialized_in_order():
"""Single-worker executor must preserve turn ordering (N before N+1)."""
order = []
class _OrderProvider(_SlowProvider):
_name = "order"
def sync_turn(self, user_content, assistant_content, *, session_id="", messages=None):
order.append(user_content)
mgr = MemoryManager()
mgr.add_provider(_OrderProvider(delay=0.0))
for i in range(5):
mgr.sync_all(f"turn-{i}", "resp", session_id="s1")
assert mgr.flush_pending(timeout=10) is True
assert order == [f"turn-{i}" for i in range(5)]

View file

@ -229,6 +229,7 @@ class TestMemoryManager:
mgr.add_provider(p2)
mgr.queue_prefetch_all("next turn")
mgr.flush_pending(timeout=5)
assert p1.queued_prefetches == ["next turn"]
assert p2.queued_prefetches == ["next turn"]
@ -240,6 +241,7 @@ class TestMemoryManager:
mgr.add_provider(p2)
mgr.sync_all("user msg", "assistant msg")
mgr.flush_pending(timeout=5)
assert p1.synced_turns == [("user msg", "assistant msg")]
assert p2.synced_turns == [("user msg", "assistant msg")]
@ -253,7 +255,7 @@ class TestMemoryManager:
]
mgr.sync_all("user msg", "assistant msg", session_id="sess-1", messages=messages)
mgr.flush_pending(timeout=5)
assert p.synced_turns == [("user msg", "assistant msg", "sess-1", messages)]
def test_sync_all_omits_messages_for_legacy_provider(self):
@ -262,7 +264,7 @@ class TestMemoryManager:
mgr.add_provider(p)
mgr.sync_all("user msg", "assistant msg", messages=[{"role": "tool"}])
mgr.flush_pending(timeout=5)
assert p.synced_turns == [("user msg", "assistant msg")]
def test_sync_failure_doesnt_block_others(self):
@ -275,6 +277,7 @@ class TestMemoryManager:
mgr.add_provider(p2)
mgr.sync_all("user", "assistant")
mgr.flush_pending(timeout=5)
# p1 failed but p2 still synced
assert p2.synced_turns == [("user", "assistant")]

View file

@ -179,6 +179,7 @@ def test_sync_all_propagates_session_id_to_providers():
p = _RecordingProvider()
mm.add_provider(p)
mm.sync_all("hello", "world", session_id="sess-42")
mm.flush_pending(timeout=5)
assert p.sync_calls == [
{"user": "hello", "asst": "world", "session_id": "sess-42"}
]
@ -189,6 +190,7 @@ def test_queue_prefetch_all_propagates_session_id_to_providers():
p = _RecordingProvider()
mm.add_provider(p)
mm.queue_prefetch_all("next query", session_id="sess-42")
mm.flush_pending(timeout=5)
assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}]