From 67351625316bb091b16f879da7af7e40d99181f8 Mon Sep 17 00:00:00 2001 From: yoniebans Date: Mon, 29 Jun 2026 13:07:03 +0200 Subject: [PATCH] fix(gateway): offload the Telegram topic-recovery helper tree off the loop The topic-mode helpers (_telegram_topic_mode_enabled, _recover_telegram_topic_thread_id, _record/_sync_telegram_topic_binding, _is_telegram_topic_lane/_root_lobby, _normalize_source_for_session_key, _telegram_topic_new_header, _schedule_telegram_topic_title_rename, and the base.py _apply_topic_recovery hook) each run a synchronous SessionDB read or write. They reach the event loop through async handlers, so a contended state.db froze the loop the same way the handoff watcher did. These helpers already run off-loop in the run_sync thread-pool closure, so they are proven thread-safe there. Rather than colour them async, loop-side callers now invoke them via asyncio.to_thread(...); the executor callers are unchanged. Inside the helpers the SessionDB handle is unwrapped to the sync door (getattr(db, '_db', db)) since they always run on a worker thread, and AIAgent construction + query_session_listing are handed the sync SessionDB directly. base.py wraps its single _apply_topic_recovery call in to_thread. The guard is now alias-aware (catches db = getattr(self, '_session_db', None); db.method(...)) and enforces the offload contract: the offloaded sync helpers may never be called bare on the loop. Sibling test fixtures wrap their injected SessionDB in AsyncSessionDB to match how the gateway holds it. --- gateway/platforms/base.py | 3 +- gateway/run.py | 36 ++-- gateway/slash_commands.py | 24 +-- .../test_35809_auto_reset_clean_context.py | 26 ++- tests/gateway/test_async_session_db.py | 187 ++++++++++++++++-- tests/gateway/test_resume_command.py | 4 + tests/gateway/test_telegram_topic_mode.py | 11 +- tests/gateway/test_title_command.py | 8 +- 8 files changed, 249 insertions(+), 50 deletions(-) diff --git a/gateway/platforms/base.py b/gateway/platforms/base.py index 31914b5861e..de7ec492329 100644 --- a/gateway/platforms/base.py +++ b/gateway/platforms/base.py @@ -4335,7 +4335,8 @@ class BasePlatformAdapter(ABC): # Rewrite ``event.source.thread_id`` via the installed recovery hook # (Telegram DM topic mode) so the session key, guard checks, and # downstream delivery all agree on the same lane. - self._apply_topic_recovery(event) + # Offloaded: the sync hook must not block the loop. + await asyncio.to_thread(self._apply_topic_recovery, event) session_key = build_session_key( event.source, diff --git a/gateway/run.py b/gateway/run.py index 6a0a1ca0146..04b67d2bae2 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -3255,6 +3255,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_db = getattr(self, "_session_db", None) if session_db is None: return False + # Runs off-loop (always via asyncio.to_thread); use the sync handle. + session_db = getattr(session_db, "_db", session_db) try: raw = session_db.is_telegram_topic_mode_enabled( chat_id=str(source.chat_id), @@ -3352,6 +3354,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_db = getattr(self, "_session_db", None) if session_db is None or not source.chat_id or not source.thread_id: return + # Runs off-loop (always via asyncio.to_thread); use the sync handle. + session_db = getattr(session_db, "_db", session_db) session_db.bind_telegram_topic( chat_id=str(source.chat_id), thread_id=str(source.thread_id), @@ -3420,6 +3424,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_db = getattr(self, "_session_db", None) if session_db is None: return None + # Runs off-loop (always via asyncio.to_thread); use the sync handle. + session_db = getattr(session_db, "_db", session_db) try: bindings = session_db.list_telegram_topic_bindings_for_chat( chat_id=str(source.chat_id), @@ -8686,7 +8692,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew break if canonical == "new": - if self._is_telegram_topic_root_lobby(source): + if await asyncio.to_thread(self._is_telegram_topic_root_lobby, source): return self._telegram_topic_root_new_message() async def _do_reset(): return await self._handle_reset_command(event) @@ -9147,7 +9153,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # No bare text matching — "yes" in normal conversation must not trigger # execution of a dangerous command. - if self._is_telegram_topic_root_lobby(source): + if await asyncio.to_thread(self._is_telegram_topic_root_lobby, source): # Debounce the lobby reminder so a user who forgets about # topic mode and fires ten prompts doesn't get ten copies. if self._should_send_telegram_lobby_reminder(source): @@ -9628,7 +9634,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # Topic-mode DMs: rewrite a stale/foreign thread_id to the user's # last-active topic so a cross-topic Reply or stripped plain reply # doesn't fragment the conversation across sessions. - recovered = self._recover_telegram_topic_thread_id(source) + recovered = await asyncio.to_thread(self._recover_telegram_topic_thread_id, source) if recovered is not None: logger.info( "telegram topic recovery: chat=%s user=%s %r -> %s", @@ -9643,7 +9649,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_entry = self.session_store.get_or_create_session(source) session_key = session_entry.session_key self._cache_session_source(session_key, source) - if self._is_telegram_topic_lane(source): + if await asyncio.to_thread(self._is_telegram_topic_lane, source): try: binding = (await self._session_db.get_telegram_topic_binding( chat_id=str(source.chat_id), @@ -9691,12 +9697,13 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew bound_session_id and bound_session_id != str(binding.get("session_id") or "") ): - self._sync_telegram_topic_binding( + await asyncio.to_thread( + self._sync_telegram_topic_binding, source, session_entry, reason="compression-tip-walk", ) else: try: - self._record_telegram_topic_binding(source, session_entry) + await asyncio.to_thread(self._record_telegram_topic_binding, source, session_entry) except Exception: logger.debug("Failed to record Telegram topic binding", exc_info=True) # Capture and immediately consume was_auto_reset so it does not @@ -10067,7 +10074,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew skip_memory=True, enabled_toolsets=["memory"], session_id=session_entry.session_id, - session_db=self._session_db, + session_db=getattr(self._session_db, "_db", self._session_db), ) try: # The hygiene agent rotates the session @@ -10100,7 +10107,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if _hyg_rotated: session_entry.session_id = _hyg_new_sid self.session_store._save() - self._sync_telegram_topic_binding( + await asyncio.to_thread( + self._sync_telegram_topic_binding, source, session_entry, reason="hygiene-compression", ) @@ -10491,7 +10499,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if agent_result.get("session_id") and agent_result["session_id"] != session_entry.session_id: session_entry.session_id = agent_result["session_id"] self.session_store._save() - self._sync_telegram_topic_binding( + await asyncio.to_thread( + self._sync_telegram_topic_binding, source, session_entry, reason="agent-result-compression", ) @@ -10694,7 +10703,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # forever (#35809 — regression of the #9893/#10063 auto-reset). # No-op on non-topic lanes. session_entry = new_entry - self._sync_telegram_topic_binding( + await asyncio.to_thread( + self._sync_telegram_topic_binding, source, session_entry, reason="compression-exhausted-reset", ) response = (response or "") + ( @@ -12085,7 +12095,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew chat_name=source.chat_name, chat_type=source.chat_type, thread_id=source.thread_id, - session_db=self._session_db, + session_db=getattr(self._session_db, "_db", self._session_db), fallback_model=self._fallback_model, ) try: @@ -12304,7 +12314,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew title: str, ) -> None: """Best-effort rename of a Telegram DM topic when Hermes auto-titles a session.""" - if not self._is_telegram_topic_lane(source) or not source.chat_id or not source.thread_id: + if not await asyncio.to_thread(self._is_telegram_topic_lane, source) or not source.chat_id or not source.thread_id: return # Operator can fully disable per-topic auto-rename via @@ -16462,7 +16472,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew chat_type=source.chat_type, thread_id=source.thread_id, gateway_session_key=session_key, - session_db=self._session_db, + session_db=getattr(self._session_db, "_db", self._session_db), fallback_model=self._fallback_model, ) if _cache_lock and _cache is not None: diff --git a/gateway/slash_commands.py b/gateway/slash_commands.py index fbcfa048ef3..5b63e591a2c 100644 --- a/gateway/slash_commands.py +++ b/gateway/slash_commands.py @@ -228,11 +228,11 @@ class GatewaySlashCommandsMixin: session_info = "" if new_entry: - header = self._telegram_topic_new_header(source) or t("gateway.reset.header_default") + header = await asyncio.to_thread(self._telegram_topic_new_header, source) or t("gateway.reset.header_default") else: # No existing session, just create one new_entry = self.session_store.get_or_create_session(source, force_new=True) - header = self._telegram_topic_new_header(source) or t("gateway.reset.header_new") + header = await asyncio.to_thread(self._telegram_topic_new_header, source) or t("gateway.reset.header_new") # Set session title if provided with /new _title_arg = event.get_command_args().strip() @@ -262,9 +262,9 @@ class GatewaySlashCommandsMixin: # uses the freshly-created session. Without this, the binding # still points at the old session and the binding-lookup at the # top of _handle_message_with_agent would switch right back. - if self._is_telegram_topic_lane(source) and new_entry is not None: + if await asyncio.to_thread(self._is_telegram_topic_lane, source) and new_entry is not None: try: - self._record_telegram_topic_binding(source, new_entry) + await asyncio.to_thread(self._record_telegram_topic_binding, source, new_entry) except Exception: logger.debug("Failed to rebind Telegram topic after /new", exc_info=True) @@ -1175,7 +1175,7 @@ class GatewaySlashCommandsMixin: # (Telegram DM topic recovery) before deriving the override key, so # the override is stored under the key the next message turn reads # (#30479). - source = self._normalize_source_for_session_key(source) + source = await asyncio.to_thread(self._normalize_source_for_session_key, source) session_key = self._session_key_for_source(source) override = self._session_model_overrides.get(session_key, {}) if override: @@ -2331,7 +2331,7 @@ class GatewaySlashCommandsMixin: # Normalize the source (Telegram DM topic recovery) before deriving # the override key so storage matches the key the next message turn # reads — same fix as /model (#30479). - _reasoning_source = self._normalize_source_for_session_key(event.source) + _reasoning_source = await asyncio.to_thread(self._normalize_source_for_session_key, event.source) session_key = self._session_key_for_source(_reasoning_source) self._show_reasoning = self._load_show_reasoning() self._reasoning_config = self._resolve_session_reasoning_config( @@ -2825,7 +2825,7 @@ class GatewaySlashCommandsMixin: skip_memory=True, enabled_toolsets=["memory"], session_id=session_entry.session_id, - session_db=self._session_db, + session_db=getattr(self._session_db, "_db", self._session_db), ) try: tmp_agent._print_fn = lambda *a, **kw: None @@ -2870,7 +2870,8 @@ class GatewaySlashCommandsMixin: if rotated: session_entry.session_id = new_session_id self.session_store._save() - self._sync_telegram_topic_binding( + await asyncio.to_thread( + self._sync_telegram_topic_binding, source, session_entry, reason="compress-command", ) @@ -3090,7 +3091,7 @@ class GatewaySlashCommandsMixin: ) if callable(schedule_rename): try: - schedule_rename(source, session_id, sanitized) + await asyncio.to_thread(schedule_rename, source, session_id, sanitized) except Exception: logger.debug( "Failed to rename Telegram topic from /title", @@ -3300,8 +3301,9 @@ class GatewaySlashCommandsMixin: return await self._handle_resume_command(resume_event) current_entry = self.session_store.get_or_create_session(source) - rows = query_session_listing( - self._session_db, + rows = await asyncio.to_thread( + query_session_listing, + getattr(self._session_db, "_db", self._session_db), source=source.platform.value if source.platform else None, current_session_id=current_entry.session_id, include_all_sources=include_all, diff --git a/tests/gateway/test_35809_auto_reset_clean_context.py b/tests/gateway/test_35809_auto_reset_clean_context.py index 3ce021b5b71..bf753cc7528 100644 --- a/tests/gateway/test_35809_auto_reset_clean_context.py +++ b/tests/gateway/test_35809_auto_reset_clean_context.py @@ -102,13 +102,25 @@ class TestAutoResetBlockReSyncsBinding: """The block must re-sync the topic binding so the next inbound message cannot ``switch_session`` back onto the bloated compressed child.""" block = _find_compression_exhausted_reset_block() - sync_calls = [ - sub - for sub in ast.walk(block) - if isinstance(sub, ast.Call) - and isinstance(sub.func, ast.Attribute) - and sub.func.attr == "_sync_telegram_topic_binding" - ] + + def _references_helper(node): + # Direct call: self._sync_telegram_topic_binding(...) + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "_sync_telegram_topic_binding" + ): + return True + # Offloaded: await asyncio.to_thread(self._sync_telegram_topic_binding, ...) + # — the helper is passed as an argument, not the call's func. + if ( + isinstance(node, ast.Attribute) + and node.attr == "_sync_telegram_topic_binding" + ): + return True + return False + + sync_calls = [sub for sub in ast.walk(block) if _references_helper(sub)] assert sync_calls, ( "gateway/run.py auto-reset block does not call " "_sync_telegram_topic_binding after reset_session. Without it the " diff --git a/tests/gateway/test_async_session_db.py b/tests/gateway/test_async_session_db.py index fa459cb31d7..4a3674757d3 100644 --- a/tests/gateway/test_async_session_db.py +++ b/tests/gateway/test_async_session_db.py @@ -128,33 +128,97 @@ _GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py") # Three such sites today; a fourth must be justified and this count bumped. _ALLOWED_SYNC_DB_ESCAPES = 3 +# Sync helpers that touch SessionDB but are NEVER invoked bare on the loop: +# every loop-side call wraps them in ``asyncio.to_thread(...)`` and the only +# bare calls live in the run_sync thread-pool closure. Their DB calls therefore +# run off-loop. The guard exempts their bodies AND enforces the contract — see +# test_offloaded_helpers_never_called_bare_on_loop. Adding a helper here without +# wrapping its loop call sites makes that test fail. +_OFFLOADED_SYNC_HELPERS = frozenset({ + "_telegram_topic_mode_enabled", + "_is_telegram_topic_lane", + "_is_telegram_topic_root_lobby", + "_recover_telegram_topic_thread_id", + "_normalize_source_for_session_key", + "_record_telegram_topic_binding", + "_sync_telegram_topic_binding", + "_telegram_topic_new_header", + "_schedule_telegram_topic_title_rename", + "_apply_topic_recovery", +}) + def _repo_root() -> Path: return Path(__file__).resolve().parents[2] class _RawCallVisitor: - """Collect non-awaited self._session_db.<method>(...) calls in a module. + """Collect non-awaited SessionDB calls reachable on the gateway loop. - An ``await x.y()`` parses as Await(value=Call(...)); those Call nodes are - exempt — they're the migrated path. We flag only Calls that are NOT directly - awaited, and separately count the self._session_db._db.<x> sync escape. The - sanitize_title staticmethod is called on the class (SessionDB.sanitize_title), - so it never matches the self._session_db.<method> shape. + Catches both shapes: + * direct: self._session_db.<method>(...) + * aliased: db = getattr(self, "_session_db", None) / db = self._session_db + then db.<method>(...) + An ``await x.y()`` is Await(value=Call(...)); those Calls are exempt (the + migrated path). The self._session_db._db.<x> sync escape is counted + separately. SessionDB.sanitize_title is a staticmethod called on the class, + so it never matches either shape. + + Alias detection scans, per function scope, for locals bound to the gateway's + _session_db (incl. closures that bind it off a captured ``self``-like param), + then flags non-awaited calls on those names. The literal-grep blind spot that + let six loop-reachable calls hide behind ``getattr(self, "_session_db")`` is + exactly what this closes. """ def __init__(self, tree: ast.AST): - self.raw_calls = [] # (method, lineno) — non-awaited + self.raw_calls = [] # (method, lineno) — direct, non-awaited, on-loop + self.alias_calls = [] # (method, lineno) — via a _session_db-bound local, on-loop self.db_escapes = [] # self._session_db._db.<x> sites (lineno) + # BARE self.<helper>(...) call sites of offloaded helpers — i.e. the + # helper is actually *called*, not passed to asyncio.to_thread (which + # references it as an attribute, producing no Call node here). Each is + # (helper, lineno, enclosing_fn) for the contract test. + self.bare_helper_calls = [] awaited = {id(n.value) for n in ast.walk(tree) if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)} + alias_names = self._collect_alias_names(tree) + # Map each node to the name of the function whose body lexically encloses + # it, so DB calls inside an offloaded helper (which runs off-loop) are + # exempt while bare on-loop calls are not. + enclosing = self._enclosing_fn_map(tree) + ancestry = self._ancestor_fns(tree) # id(node) -> frozenset of enclosing fn names for node in ast.walk(tree): if not isinstance(node, ast.Call): continue func = node.func - if not (isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute)): + if not isinstance(func, ast.Attribute): + continue + encl_fn = enclosing.get(id(node)) + in_offloaded_helper = encl_fn in _OFFLOADED_SYNC_HELPERS + # Bare call of an offloaded helper (self._helper(...)). A to_thread + # offload passes the helper as an attribute arg, not a Call, so it + # never lands here — exactly the distinction the contract test needs. + if ( + isinstance(func.value, ast.Name) and func.value.id == "self" + and func.attr in _OFFLOADED_SYNC_HELPERS + ): + self.bare_helper_calls.append( + (func.attr, node.lineno, ancestry.get(id(node), frozenset())) + ) + # alias.<method>(...) -> aliased loop call (var bound to _session_db) + if ( + isinstance(func.value, ast.Name) + and func.value.id in alias_names + and func.attr not in ("_db",) + and id(node) not in awaited + and not in_offloaded_helper + ): + self.alias_calls.append((func.attr, node.lineno)) + continue + if not isinstance(func.value, ast.Attribute): continue inner = func.value # self._session_db._db.<method>(...) -> sync escape @@ -172,9 +236,76 @@ class _RawCallVisitor: and isinstance(inner.value, ast.Name) and inner.value.id == "self" and id(node) not in awaited + and not in_offloaded_helper ): self.raw_calls.append((func.attr, node.lineno)) + @staticmethod + def _enclosing_fn_map(tree: ast.AST) -> dict: + """Map id(node) -> name of the nearest lexically-enclosing function.""" + out = {} + + def walk(node, fn_name): + this_fn = fn_name + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + this_fn = node.name + for child in ast.iter_child_nodes(node): + out[id(child)] = this_fn + walk(child, this_fn) + + walk(tree, None) + return out + + @staticmethod + def _ancestor_fns(tree: ast.AST) -> dict: + """Map id(node) -> frozenset of ALL enclosing function names (any depth).""" + out = {} + + def walk(node, stack): + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + stack = stack + (node.name,) + for child in ast.iter_child_nodes(node): + out[id(child)] = frozenset(stack) + walk(child, stack) + + walk(tree, ()) + return out + + @staticmethod + def _is_session_db_source(value: ast.AST) -> bool: + """True if an assignment RHS resolves to <obj>._session_db. + + Matches both ``<obj>._session_db`` and ``getattr(<obj>, "_session_db", ...)`` + where <obj> is any Name (covers ``self`` and captured closure params like + ``_self``). Excludes the ``._db`` sync handle. + """ + if isinstance(value, ast.Attribute): + return value.attr == "_session_db" and isinstance(value.value, ast.Name) + if ( + isinstance(value, ast.Call) + and isinstance(value.func, ast.Name) + and value.func.id == "getattr" + and len(value.args) >= 2 + and isinstance(value.args[1], ast.Constant) + and value.args[1].value == "_session_db" + ): + return True + return False + + @classmethod + def _collect_alias_names(cls, tree: ast.AST) -> set: + names = set() + for node in ast.walk(tree): + if isinstance(node, ast.Assign) and cls._is_session_db_source(node.value): + for tgt in node.targets: + if isinstance(tgt, ast.Name): + names.add(tgt.id) + elif isinstance(node, ast.AnnAssign) and node.value is not None \ + and cls._is_session_db_source(node.value) \ + and isinstance(node.target, ast.Name): + names.add(node.target.id) + return names + def _scan(rel_path: str) -> _RawCallVisitor: source = (_repo_root() / rel_path).read_text(encoding="utf-8") @@ -182,19 +313,22 @@ def _scan(rel_path: str) -> _RawCallVisitor: def test_no_raw_session_db_calls_on_gateway_loop(): - """Fail if any raw self._session_db.<method>( appears in gateway files. + """Fail if any non-awaited SessionDB call appears in gateway files. - Every loop-reachable DB call must go through AsyncSessionDB (await). The - sanitize_title staticmethod is called on the class, not self, so it is not - matched here; the _db. construction escape is checked separately below. + Every loop-reachable DB call must go through AsyncSessionDB (await), whether + spelled directly (self._session_db.<method>(...)) or via a local alias + (db = getattr(self, "_session_db", None); db.<method>(...)). The + sanitize_title staticmethod is called on the class, not self/an alias, so it + is not matched; the _db. sync escape is checked separately below. """ violations = [] for rel in _GATEWAY_FILES: v = _scan(rel) violations.extend(f"{rel}:{ln} self._session_db.{m}(" for m, ln in v.raw_calls) + violations.extend(f"{rel}:{ln} <alias>.{m}( (binds _session_db)" for m, ln in v.alias_calls) assert not violations, ( - "Raw SessionDB calls on the gateway loop — route through AsyncSessionDB " - "(await self._session_db.<method>(...)):\n " + "\n ".join(violations) + "Non-awaited SessionDB calls on the gateway loop — route through " + "AsyncSessionDB (await ...):\n " + "\n ".join(violations) ) @@ -211,3 +345,28 @@ def test_sync_db_escape_confined_to_off_loop_sites(): f"self._session_db._db. sync escape used {total} times; " f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed." ) + + +def test_offloaded_helpers_never_called_bare_on_loop(): + """The offloaded sync helpers must never be called bare on the event loop. + + They touch SessionDB synchronously, so a bare ``self._helper(...)`` on the + loop would freeze it. The contract: loop-side callers wrap them in + ``await asyncio.to_thread(self._helper, ...)`` (which references the helper + as an attribute — no Call node — so it never appears here). A bare call is + only legitimate when it runs off-loop: inside the ``run_sync`` thread-pool + closure, or inside another offloaded helper (sync->sync, same thread). Any + other bare call means a helper whose body the guard exempts is being invoked + on the loop anyway — re-freezing the loop through the exemption. + """ + off_loop_ok = _OFFLOADED_SYNC_HELPERS | {"run_sync"} + violations = [] + for rel in _GATEWAY_FILES: + v = _scan(rel) + for helper, ln, ancestors in v.bare_helper_calls: + if not (ancestors & off_loop_ok): + violations.append(f"{rel}:{ln} bare self.{helper}( on the loop") + assert not violations, ( + "Offloaded sync helper called bare on the gateway loop — wrap in " + "await asyncio.to_thread(self.<helper>, ...):\n " + "\n ".join(violations) + ) diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index 02ac2a449c8..bd52768830e 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -39,6 +39,10 @@ def _make_runner(session_db=None, current_session_id="current_session_001", runner.adapters = {} runner.config = SimpleNamespace(platforms={}) runner._voice_mode = {} + # Gateway holds the async facade; the slash handlers await it. + if session_db is not None: + from hermes_state import AsyncSessionDB + session_db = AsyncSessionDB(session_db) runner._session_db = session_db runner._running_agents = {} runner._is_user_authorized = lambda _source: True diff --git a/tests/gateway/test_telegram_topic_mode.py b/tests/gateway/test_telegram_topic_mode.py index c887153508c..37a769bf678 100644 --- a/tests/gateway/test_telegram_topic_mode.py +++ b/tests/gateway/test_telegram_topic_mode.py @@ -123,6 +123,10 @@ def _make_runner(session_db=None): runner._busy_ack_ts = {} runner._session_model_overrides = {} runner._pending_model_notes = {} + # Gateway holds the async facade; the slash handlers await it. + if session_db is not None: + from hermes_state import AsyncSessionDB + session_db = AsyncSessionDB(session_db) runner._session_db = session_db runner._reasoning_config = None runner._provider_routing = {} @@ -1399,7 +1403,8 @@ def test_session_split_restores_source_thread_id_from_binding(tmp_path): ) runner = object.__new__(GatewayRunner) - runner._session_db = db + from hermes_state import AsyncSessionDB + runner._session_db = AsyncSessionDB(db) # Build a source that looks like it came from a synthetic/recovered event: # platform and chat_type match a Telegram DM, but thread_id is None. @@ -1416,7 +1421,9 @@ def test_session_split_restores_source_thread_id_from_binding(tmp_path): and runner._session_db is not None ): try: - _binding = runner._session_db.get_telegram_topic_binding_by_session( + # Mirror production: this block runs in the run_sync executor, so it + # uses the sync handle (self._session_db._db), not the async facade. + _binding = runner._session_db._db.get_telegram_topic_binding_by_session( session_id="sess-split-new", ) if _binding and _binding.get("thread_id"): diff --git a/tests/gateway/test_title_command.py b/tests/gateway/test_title_command.py index 168fc1e708c..580b4974bf0 100644 --- a/tests/gateway/test_title_command.py +++ b/tests/gateway/test_title_command.py @@ -32,6 +32,10 @@ def _make_runner(session_db=None): runner = object.__new__(GatewayRunner) runner.adapters = {} runner._voice_mode = {} + # Gateway holds the async facade; the slash handlers await it. + if session_db is not None: + from hermes_state import AsyncSessionDB + session_db = AsyncSessionDB(session_db) runner._session_db = session_db # Mock session_store that returns a session entry with a known session_id @@ -296,7 +300,7 @@ class TestResetCommandWithTitle: runner._running_agents = {} runner._pending_messages = {} runner._pending_approvals = {} - runner._session_db = MagicMock() + runner._session_db = AsyncMock() runner._agent_cache = {} runner._agent_cache_lock = None runner._is_user_authorized = lambda _source: True @@ -356,7 +360,7 @@ class TestResetCommandWithTitle: runner._running_agents = {} runner._pending_messages = {} runner._pending_approvals = {} - runner._session_db = MagicMock() + runner._session_db = AsyncMock() runner._session_db.set_session_title.side_effect = ValueError( "Title 'Dup' is already in use by session abc-123" )