From aa6f2775fac7c460a73669f35d4d478fed393004 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Mon, 8 Jun 2026 02:18:59 -0700 Subject: [PATCH] fix(memory): run end-of-turn sync off the turn thread (#41945) A misconfigured/slow external memory provider could hold the agent in the 'running' state for minutes after the final response was delivered. MemoryManager.sync_all / queue_prefetch_all looped provider.sync_turn / queue_prefetch INLINE on the turn-completion path; a provider making a blocking network/daemon call (a broken Hindsight daemon was observed blocking ~298s before failing) blocked run_conversation from returning. Because every interface (CLI, TUI, gateway) marks the agent 'running' until run_conversation returns, the agent stayed busy for the full block and any follow-up message triggered an aggressive interrupt that dropped the message. Dispatch provider sync/prefetch to a lazily-created single-worker background executor. sync_all / queue_prefetch_all return immediately; work completes (or fails, logged) in the background. A single worker serializes writes so turn N lands before turn N+1. flush_pending() provides a barrier for session boundaries and deterministic tests. shutdown_all() drains the executor with a bounded timeout so a wedged provider can never hang teardown. Builtin-only / no-provider sessions spawn no executor (zero new threads in the common case). --- agent/memory_manager.py | 232 +++++++++++++++++++--- tests/agent/test_memory_async_sync.py | 138 +++++++++++++ tests/agent/test_memory_provider.py | 7 +- tests/agent/test_memory_session_switch.py | 2 + 4 files changed, 348 insertions(+), 31 deletions(-) create mode 100644 tests/agent/test_memory_async_sync.py diff --git a/agent/memory_manager.py b/agent/memory_manager.py index f0a72d35954..3cb3a734a8f 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -28,6 +28,8 @@ from __future__ import annotations import logging import re import inspect +import threading +from concurrent.futures import ThreadPoolExecutor from typing import Any, Dict, List, Optional from agent.memory_provider import MemoryProvider @@ -35,6 +37,12 @@ from tools.registry import tool_error logger = logging.getLogger(__name__) +# How long shutdown_all() waits for in-flight background sync/prefetch work +# to drain before abandoning it. A wedged provider must never block process +# teardown indefinitely — the worker threads are daemon, so anything still +# running past this window dies with the interpreter. +_SYNC_DRAIN_TIMEOUT_S = 5.0 + # --------------------------------------------------------------------------- # Context fencing helpers @@ -252,6 +260,13 @@ class MemoryManager: self._providers: List[MemoryProvider] = [] self._tool_to_provider: Dict[str, MemoryProvider] = {} self._has_external: bool = False # True once a non-builtin provider is added + # Background executor for end-of-turn sync/prefetch. Lazily created on + # first use so the common builtin-only path spawns no extra threads. + # A single worker serializes a provider's writes (turn N must land + # before turn N+1) and caps thread growth at one per manager. See + # _submit_background() and the sync_all/queue_prefetch_all rationale. + self._sync_executor: Optional[ThreadPoolExecutor] = None + self._sync_executor_lock = threading.Lock() # -- Registration -------------------------------------------------------- @@ -375,15 +390,27 @@ class MemoryManager: return "\n\n".join(parts) def queue_prefetch_all(self, query: str, *, session_id: str = "") -> None: - """Queue background prefetch on all providers for the next turn.""" - for provider in self._providers: - try: - provider.queue_prefetch(query, session_id=session_id) - except Exception as e: - logger.debug( - "Memory provider '%s' queue_prefetch failed (non-fatal): %s", - provider.name, e, - ) + """Queue background prefetch on all providers for the next turn. + + Provider work is dispatched to a background worker so a slow or + wedged provider can never block the caller. See ``sync_all`` for + the full rationale (agent stuck "running" minutes after a turn). + """ + providers = list(self._providers) + if not providers: + return + + def _run() -> None: + for provider in providers: + try: + provider.queue_prefetch(query, session_id=session_id) + except Exception as e: + logger.debug( + "Memory provider '%s' queue_prefetch failed (non-fatal): %s", + provider.name, e, + ) + + self._submit_background(_run) # -- Sync ---------------------------------------------------------------- @@ -407,27 +434,120 @@ class MemoryManager: session_id: str = "", messages: Optional[List[Dict[str, Any]]] = None, ) -> None: - """Sync a completed turn to all providers.""" - for provider in self._providers: + """Sync a completed turn to all providers. + + Runs on a background worker thread, NOT inline on the + turn-completion path. A provider's ``sync_turn`` may make a + blocking network/daemon call (a misconfigured Hindsight daemon + was observed blocking ~298s before failing); doing that inline + held ``run_conversation`` open long after the user saw their + response, so every interface (CLI, TUI, gateway) kept the agent + marked "running" for minutes and any follow-up message triggered + an aggressive interrupt. Dispatching off-thread means a slow or + broken provider can never stall the turn — the sync simply + completes (or fails, logged) in the background. + + Writes are serialized through a single worker so turn N lands + before turn N+1; provider implementations don't need their own + ordering guarantees. + """ + providers = list(self._providers) + if not providers: + return + + def _run() -> None: + for provider in providers: + try: + if messages is not None and self._provider_sync_accepts_messages(provider): + provider.sync_turn( + user_content, + assistant_content, + session_id=session_id, + messages=messages, + ) + else: + provider.sync_turn( + user_content, + assistant_content, + session_id=session_id, + ) + except Exception as e: + logger.warning( + "Memory provider '%s' sync_turn failed: %s", + provider.name, e, + ) + + self._submit_background(_run) + + # -- Background dispatch ------------------------------------------------- + + def _submit_background(self, fn) -> None: + """Run ``fn`` on the manager's background worker. + + The executor is created lazily and shared across calls. If the + executor can't be created or has already been shut down, ``fn`` + runs inline as a last-resort fallback — losing the async benefit + but never losing the write itself. ``fn`` must do its own + per-provider error handling; this wrapper only guards executor + plumbing. + """ + executor = self._get_sync_executor() + if executor is None: + # Executor unavailable (shut down / creation failed) — run + # inline rather than drop the work. Slow, but correct. try: - if messages is not None and self._provider_sync_accepts_messages(provider): - provider.sync_turn( - user_content, - assistant_content, - session_id=session_id, - messages=messages, + fn() + except Exception as e: # pragma: no cover - fn guards internally + logger.debug("Inline memory background task failed: %s", e) + return + try: + executor.submit(fn) + except RuntimeError: + # Executor was shut down between the get and the submit + # (teardown race). Fall back to inline. + try: + fn() + except Exception as e: # pragma: no cover - fn guards internally + logger.debug("Inline memory background task failed: %s", e) + + def _get_sync_executor(self) -> Optional[ThreadPoolExecutor]: + """Lazily create the single-worker background executor.""" + if self._sync_executor is not None: + return self._sync_executor + with self._sync_executor_lock: + if self._sync_executor is None: + try: + self._sync_executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="mem-sync", ) - else: - provider.sync_turn( - user_content, - assistant_content, - session_id=session_id, - ) - except Exception as e: - logger.warning( - "Memory provider '%s' sync_turn failed: %s", - provider.name, e, - ) + except Exception as e: # pragma: no cover - resource exhaustion + logger.warning("Failed to create memory sync executor: %s", e) + return None + return self._sync_executor + + def flush_pending(self, timeout: Optional[float] = None) -> bool: + """Block until queued sync/prefetch work has drained. + + Single-worker executor means submitting a sentinel and waiting on + it guarantees every previously-submitted task has run. Returns + True if the barrier completed within ``timeout`` (or no executor + exists), False on timeout. Used at real session boundaries and by + tests that need to assert provider state deterministically. + """ + executor = self._sync_executor + if executor is None: + return True + try: + fut = executor.submit(lambda: None) + except RuntimeError: + # Executor already shut down — nothing pending. + return True + try: + fut.result(timeout=timeout) + return True + except Exception: + return False # -- Tools --------------------------------------------------------------- @@ -653,7 +773,15 @@ class MemoryManager: ) def shutdown_all(self) -> None: - """Shut down all providers (reverse order for clean teardown).""" + """Shut down all providers (reverse order for clean teardown). + + Drains the background sync/prefetch executor first (bounded by + ``_SYNC_DRAIN_TIMEOUT_S``) so a turn's final sync has a chance to + land before providers are torn down. The worker threads are + daemon, so anything still wedged past the drain window dies with + the interpreter rather than blocking exit. + """ + self._drain_sync_executor() for provider in reversed(self._providers): try: provider.shutdown() @@ -663,6 +791,52 @@ class MemoryManager: provider.name, e, ) + def _drain_sync_executor(self) -> None: + """Shut down the background executor, waiting briefly for drain. + + Bounded by ``_SYNC_DRAIN_TIMEOUT_S``: a wedged provider must never + hang process/session teardown. We stop accepting new work and + cancel anything still queued, then wait at most the drain timeout + for the currently-running task on a watcher thread. The worker is + daemon, so an over-running task dies with the interpreter. + """ + with self._sync_executor_lock: + executor = self._sync_executor + self._sync_executor = None + if executor is None: + return + try: + # Stop accepting new work and drop anything still queued, but + # do NOT block here — cancel_futures cancels not-yet-started + # tasks; the in-flight one keeps running on its daemon thread. + executor.shutdown(wait=False, cancel_futures=True) + except TypeError: + # Older Python without cancel_futures kwarg. + try: + executor.shutdown(wait=False) + except Exception as e: # pragma: no cover + logger.debug("Memory sync executor shutdown failed: %s", e) + return + except Exception as e: # pragma: no cover + logger.debug("Memory sync executor shutdown failed: %s", e) + return + # Give an in-flight sync a bounded chance to finish on a watcher + # thread so we don't block the caller past the drain timeout. + drainer = threading.Thread( + target=lambda: self._bounded_executor_wait(executor), + daemon=True, + name="mem-sync-drain", + ) + drainer.start() + drainer.join(timeout=_SYNC_DRAIN_TIMEOUT_S) + + @staticmethod + def _bounded_executor_wait(executor: ThreadPoolExecutor) -> None: + try: + executor.shutdown(wait=True) + except Exception as e: # pragma: no cover + logger.debug("Memory sync executor drain wait failed: %s", e) + def initialize_all(self, session_id: str, **kwargs) -> None: """Initialize all providers. diff --git a/tests/agent/test_memory_async_sync.py b/tests/agent/test_memory_async_sync.py new file mode 100644 index 00000000000..7ff293e43fc --- /dev/null +++ b/tests/agent/test_memory_async_sync.py @@ -0,0 +1,138 @@ +"""Regression guard: end-of-turn memory sync must not block the turn. + +Before this fix, ``MemoryManager.sync_all`` / ``queue_prefetch_all`` looped +``provider.sync_turn`` / ``provider.queue_prefetch`` INLINE on the +turn-completion path. A provider making a blocking network/daemon call (a +misconfigured Hindsight daemon was observed blocking ~298s before failing) +held ``run_conversation`` open long after the user saw their response, so +every interface (CLI, TUI, gateway) kept the agent marked "running" for +minutes and any follow-up message triggered an aggressive interrupt that +dropped the message. + +The fix dispatches provider work to a single-worker background executor. +``sync_all`` / ``queue_prefetch_all`` return immediately; the work completes +(or fails, logged) in the background. ``flush_pending`` provides a barrier +for session boundaries and deterministic tests. ``shutdown_all`` drains the +executor with a bounded timeout so a wedged provider can't hang teardown. +""" +import time + +import pytest + +from agent.memory_provider import MemoryProvider +from agent.memory_manager import MemoryManager + + +class _SlowProvider(MemoryProvider): + """Provider whose sync/prefetch block, simulating a slow backend.""" + + _name = "slow" + + def __init__(self, delay: float = 1.0): + self._delay = delay + self.sync_done = False + self.prefetch_done = False + + @property + def name(self) -> str: + return self._name + + def initialize(self, session_id: str = "", **kwargs) -> None: + pass + + def is_available(self) -> bool: + return True + + def system_prompt_block(self) -> str: + return "" + + def prefetch(self, query, *, session_id: str = "") -> str: + return "" + + def queue_prefetch(self, query, *, session_id: str = "") -> None: + time.sleep(self._delay) + self.prefetch_done = True + + def sync_turn(self, user_content, assistant_content, *, session_id: str = "", messages=None) -> None: + time.sleep(self._delay) + self.sync_done = True + + def get_tool_schemas(self): + return [] + + def handle_tool_call(self, tool_name, args, **kwargs) -> str: + return "" + + +def test_sync_all_does_not_block_on_slow_provider(): + """The crux of the fix: a slow provider must NOT stall the caller.""" + mgr = MemoryManager() + mgr.add_provider(_SlowProvider(delay=2.0)) + + t0 = time.time() + mgr.sync_all("hi", "hey", session_id="s1") + mgr.queue_prefetch_all("hi", session_id="s1") + elapsed = time.time() - t0 + + # Provider blocks 2s per call inline; off-thread dispatch returns ~instantly. + assert elapsed < 0.5, f"turn-completion path blocked {elapsed:.2f}s" + + +def test_background_work_still_completes(): + """Dispatching off-thread must not silently drop the write.""" + mgr = MemoryManager() + p = _SlowProvider(delay=0.1) + mgr.add_provider(p) + + mgr.sync_all("hi", "hey", session_id="s1") + mgr.queue_prefetch_all("hi", session_id="s1") + + assert mgr.flush_pending(timeout=10) is True + assert p.sync_done is True + assert p.prefetch_done is True + + +def test_flush_pending_no_executor_is_true(): + """flush_pending must be a no-op (return True) before any sync ran.""" + mgr = MemoryManager() + assert mgr.flush_pending(timeout=1) is True + + +def test_no_providers_does_not_create_executor(): + """Builtin-only / no-provider sessions must not spawn an executor.""" + mgr = MemoryManager() + mgr.sync_all("hi", "hey") + mgr.queue_prefetch_all("hi") + assert mgr._sync_executor is None + + +def test_shutdown_all_is_bounded_with_wedged_provider(): + """A provider that never returns must not hang teardown.""" + mgr = MemoryManager() + mgr.add_provider(_SlowProvider(delay=30.0)) + mgr.sync_all("hi", "hey") + + t0 = time.time() + mgr.shutdown_all() + elapsed = time.time() - t0 + + # Bounded by _SYNC_DRAIN_TIMEOUT_S (5s) plus a little slack. + assert elapsed < 8.0, f"shutdown blocked {elapsed:.1f}s on wedged provider" + + +def test_writes_are_serialized_in_order(): + """Single-worker executor must preserve turn ordering (N before N+1).""" + order = [] + + class _OrderProvider(_SlowProvider): + _name = "order" + + def sync_turn(self, user_content, assistant_content, *, session_id="", messages=None): + order.append(user_content) + + mgr = MemoryManager() + mgr.add_provider(_OrderProvider(delay=0.0)) + for i in range(5): + mgr.sync_all(f"turn-{i}", "resp", session_id="s1") + assert mgr.flush_pending(timeout=10) is True + assert order == [f"turn-{i}" for i in range(5)] diff --git a/tests/agent/test_memory_provider.py b/tests/agent/test_memory_provider.py index bb84c4253f4..e12122724ad 100644 --- a/tests/agent/test_memory_provider.py +++ b/tests/agent/test_memory_provider.py @@ -229,6 +229,7 @@ class TestMemoryManager: mgr.add_provider(p2) mgr.queue_prefetch_all("next turn") + mgr.flush_pending(timeout=5) assert p1.queued_prefetches == ["next turn"] assert p2.queued_prefetches == ["next turn"] @@ -240,6 +241,7 @@ class TestMemoryManager: mgr.add_provider(p2) mgr.sync_all("user msg", "assistant msg") + mgr.flush_pending(timeout=5) assert p1.synced_turns == [("user msg", "assistant msg")] assert p2.synced_turns == [("user msg", "assistant msg")] @@ -253,7 +255,7 @@ class TestMemoryManager: ] mgr.sync_all("user msg", "assistant msg", session_id="sess-1", messages=messages) - + mgr.flush_pending(timeout=5) assert p.synced_turns == [("user msg", "assistant msg", "sess-1", messages)] def test_sync_all_omits_messages_for_legacy_provider(self): @@ -262,7 +264,7 @@ class TestMemoryManager: mgr.add_provider(p) mgr.sync_all("user msg", "assistant msg", messages=[{"role": "tool"}]) - + mgr.flush_pending(timeout=5) assert p.synced_turns == [("user msg", "assistant msg")] def test_sync_failure_doesnt_block_others(self): @@ -275,6 +277,7 @@ class TestMemoryManager: mgr.add_provider(p2) mgr.sync_all("user", "assistant") + mgr.flush_pending(timeout=5) # p1 failed but p2 still synced assert p2.synced_turns == [("user", "assistant")] diff --git a/tests/agent/test_memory_session_switch.py b/tests/agent/test_memory_session_switch.py index a40654fa579..ca04aa8875e 100644 --- a/tests/agent/test_memory_session_switch.py +++ b/tests/agent/test_memory_session_switch.py @@ -179,6 +179,7 @@ def test_sync_all_propagates_session_id_to_providers(): p = _RecordingProvider() mm.add_provider(p) mm.sync_all("hello", "world", session_id="sess-42") + mm.flush_pending(timeout=5) assert p.sync_calls == [ {"user": "hello", "asst": "world", "session_id": "sess-42"} ] @@ -189,6 +190,7 @@ def test_queue_prefetch_all_propagates_session_id_to_providers(): p = _RecordingProvider() mm.add_provider(p) mm.queue_prefetch_all("next query", session_id="sess-42") + mm.flush_pending(timeout=5) assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}]