From 0896facce8139636be6de462fd286cf84b0069cb Mon Sep 17 00:00:00 2001 From: yoniebans Date: Mon, 29 Jun 2026 12:04:44 +0200 Subject: [PATCH] fix(gateway): route SessionDB calls through AsyncSessionDB --- gateway/run.py | 66 ++++++++++--------- gateway/slash_commands.py | 55 ++++++++-------- tests/e2e/conftest.py | 11 +++- tests/gateway/conftest.py | 9 +++ tests/gateway/test_agent_cache.py | 40 ++++++----- tests/gateway/test_async_session_db.py | 56 +++++++++------- .../gateway/test_handoff_watcher_async_db.py | 33 ++++++---- .../test_matrix_project_context_isolation.py | 23 +++---- .../test_session_boundary_security_state.py | 15 +++-- tests/gateway/test_session_race_guard.py | 13 +++- tests/gateway/test_status_command.py | 19 +++--- tests/gateway/test_usage_command.py | 5 +- 12 files changed, 203 insertions(+), 142 deletions(-) diff --git a/gateway/run.py b/gateway/run.py index 1164717391b..fb0cf0ae4d0 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2777,8 +2777,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # Initialize session database for session_search tool support self._session_db = None try: - from hermes_state import SessionDB - self._session_db = SessionDB() + from hermes_state import AsyncSessionDB, SessionDB + self._session_db = AsyncSessionDB(SessionDB()) except Exception as e: # WARNING (not DEBUG) so the failure appears in errors.log — matches # cli.py's handling of the same init path. Users hitting NFS-mounted @@ -2799,7 +2799,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew from hermes_cli.config import load_config as _load_full_config _sess_cfg = (_load_full_config().get("sessions") or {}) if _sess_cfg.get("auto_prune", False): - self._session_db.maybe_auto_prune_and_vacuum( + # Construction-time, before the loop serves traffic; sync DB is fine. + self._session_db._db.maybe_auto_prune_and_vacuum( retention_days=int(_sess_cfg.get("retention_days", 90)), min_interval_hours=int(_sess_cfg.get("min_interval_hours", 24)), vacuum=bool(_sess_cfg.get("vacuum_after_prune", True)), @@ -6578,23 +6579,23 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if self._session_db is None: await asyncio.sleep(interval) continue - pending = await asyncio.to_thread(self._session_db.list_pending_handoffs) + pending = await self._session_db.list_pending_handoffs() for row in pending: session_id = row.get("id") if not session_id: continue - if not await asyncio.to_thread(self._session_db.claim_handoff, session_id): + if not await self._session_db.claim_handoff(session_id): # Another tick or another gateway already claimed it. continue try: await self._process_handoff(row) - await asyncio.to_thread(self._session_db.complete_handoff, session_id) + await self._session_db.complete_handoff(session_id) except Exception as exc: logger.warning( "Handoff for session %s failed: %s", session_id, exc, exc_info=True, ) - await asyncio.to_thread(self._session_db.fail_handoff, session_id, str(exc)) + await self._session_db.fail_handoff(session_id, str(exc)) except asyncio.CancelledError: raise except Exception as exc: @@ -7443,8 +7444,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # old gateway's connection holding the WAL lock until Python # actually exits — causing 'database is locked' errors when # the new gateway tries to open the same file. - for _db_holder in (self, getattr(self, "session_store", None)): - _db = getattr(_db_holder, "_db", None) if _db_holder else None + # ``self`` holds the DB at ``_session_db`` (an AsyncSessionDB facade); + # unwrap to the sync handle. ``session_store`` holds it at ``_db``. + _self_db = getattr(self, "_session_db", None) + _self_db = getattr(_self_db, "_db", _self_db) + for _db in (_self_db, getattr(getattr(self, "session_store", None), "_db", None)): if _db is None or not hasattr(_db, "close"): continue try: @@ -9641,10 +9645,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew self._cache_session_source(session_key, source) if self._is_telegram_topic_lane(source): try: - binding = self._session_db.get_telegram_topic_binding( + binding = (await self._session_db.get_telegram_topic_binding( chat_id=str(source.chat_id), thread_id=str(source.thread_id), - ) if self._session_db else None + )) if self._session_db else None except Exception: logger.debug("Failed to read Telegram topic binding", exc_info=True) binding = None @@ -9658,7 +9662,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # a compression parent, so this is cheap and safe. if bound_session_id and self._session_db is not None: try: - canonical_session_id = self._session_db.get_compression_tip( + canonical_session_id = await self._session_db.get_compression_tip( bound_session_id, ) except Exception: @@ -10450,7 +10454,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew # prompt caching. Refreshing here makes the guard fire only on a # DIFFERENT process's writes. Uses the (possibly compaction- # updated) live session_id. Fail-safe inside the helper. - self._refresh_agent_cache_message_count( + await self._refresh_agent_cache_message_count( session_key, session_entry.session_id ) @@ -12481,7 +12485,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew "5. /topic inside a topic restores an old session into it." ) - def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str: + async def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str: """Cleanly disable topic mode for a chat via /topic off.""" if not self._session_db: from hermes_state import format_session_db_unavailable @@ -12491,7 +12495,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew return "Could not determine chat ID." # No-op if never enabled. try: - currently_enabled = self._session_db.is_telegram_topic_mode_enabled( + currently_enabled = await self._session_db.is_telegram_topic_mode_enabled( chat_id=chat_id, user_id=str(source.user_id or ""), ) @@ -12500,7 +12504,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if not currently_enabled: return "Multi-session topic mode is not currently enabled for this chat." try: - self._session_db.disable_telegram_topic_mode(chat_id=chat_id) + await self._session_db.disable_telegram_topic_mode(chat_id=chat_id) except Exception as exc: logger.exception("Failed to disable Telegram topic mode") return f"Failed to disable topic mode: {exc}" @@ -12518,7 +12522,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew ) - def _telegram_topic_root_status_message(self, source: SessionSource) -> str: + async def _telegram_topic_root_status_message(self, source: SessionSource) -> str: lines = [ "Telegram multi-session topics are enabled.", "", @@ -12528,7 +12532,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew "", ] try: - sessions = self._session_db.list_unlinked_telegram_sessions_for_user( + sessions = await self._session_db.list_unlinked_telegram_sessions_for_user( chat_id=str(source.chat_id), user_id=str(source.user_id), limit=10, @@ -12567,11 +12571,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew async def _restore_telegram_topic_session(self, event: MessageEvent, raw_session_id: str) -> str: """Restore an existing Telegram-owned Hermes session into this topic.""" source = event.source - session_id = self._session_db.resolve_session_id(raw_session_id.strip()) + session_id = await self._session_db.resolve_session_id(raw_session_id.strip()) if not session_id: return f"Session not found: {raw_session_id.strip()}" - session = self._session_db.get_session(session_id) + session = await self._session_db.get_session(session_id) if not session: return f"Session not found: {raw_session_id.strip()}" if str(session.get("source") or "") != "telegram": @@ -12579,8 +12583,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if str(session.get("user_id") or "") != str(source.user_id): return "That session does not belong to this Telegram user." - linked = self._session_db.is_telegram_session_linked_to_topic(session_id=session_id) - current_binding = self._session_db.get_telegram_topic_binding( + linked = await self._session_db.is_telegram_session_linked_to_topic(session_id=session_id) + current_binding = await self._session_db.get_telegram_topic_binding( chat_id=str(source.chat_id), thread_id=str(source.thread_id), ) @@ -12590,7 +12594,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew session_key = self._session_key_for_source(source) try: - self._session_db.bind_telegram_topic( + await self._session_db.bind_telegram_topic( chat_id=str(source.chat_id), thread_id=str(source.thread_id), user_id=str(source.user_id), @@ -12603,10 +12607,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew return "That session is already linked to another Telegram topic." raise - title = self._session_db.get_session_title(session_id) or session_id + title = await self._session_db.get_session_title(session_id) or session_id last_assistant = None try: - for message in reversed(self._session_db.get_messages(session_id)): + for message in reversed(await self._session_db.get_messages(session_id)): if message.get("role") == "assistant" and message.get("content"): last_assistant = str(message.get("content")) break @@ -14631,7 +14635,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if release_running_state: self._release_running_agent_state(session_key) - def _refresh_agent_cache_message_count( + async def _refresh_agent_cache_message_count( self, session_key: str, session_id: Optional[str] ) -> None: """Re-baseline a cached agent's stored message_count after THIS turn. @@ -14663,7 +14667,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if not _cache_lock or _cache is None: return try: - _sess_row = self._session_db.get_session(session_id) + _sess_row = await self._session_db.get_session(session_id) _live = _sess_row.get("message_count", 0) if _sess_row else None except Exception: return @@ -16346,7 +16350,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew _current_msg_count = None if self._session_db is not None and session_id: try: - _sess_row = self._session_db.get_session(session_id) + # run_sync is off-loop (executor); sync DB is fine. + _sess_row = self._session_db._db.get_session(session_id) if _sess_row: _current_msg_count = _sess_row.get("message_count", 0) except Exception: @@ -17044,7 +17049,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew and self._session_db is not None ): try: - _binding = self._session_db.get_telegram_topic_binding_by_session( + # run_sync is off-loop (executor); sync DB is fine. + _binding = self._session_db._db.get_telegram_topic_binding_by_session( session_id=agent_session_id, ) if _binding and _binding.get("thread_id"): @@ -17169,7 +17175,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew title, ) maybe_auto_title( - self._session_db, + getattr(self._session_db, "_db", self._session_db), effective_session_id, message, final_response, diff --git a/gateway/slash_commands.py b/gateway/slash_commands.py index aa523158d17..c952013378f 100644 --- a/gateway/slash_commands.py +++ b/gateway/slash_commands.py @@ -246,7 +246,7 @@ class GatewaySlashCommandsMixin: _title_note = t("gateway.reset.title_rejected", error=str(e)) if sanitized: try: - self._session_db.set_session_title(new_entry.session_id, sanitized) + await self._session_db.set_session_title(new_entry.session_id, sanitized) header = t("gateway.reset.header_titled", title=sanitized) except ValueError as e: _title_note = t("gateway.reset.title_error_untitled", error=str(e)) @@ -498,11 +498,11 @@ class GatewaySlashCommandsMixin: db_total_tokens = 0 if self._session_db: try: - title = self._session_db.get_session_title(session_entry.session_id) + title = await self._session_db.get_session_title(session_entry.session_id) except Exception: title = None try: - row = self._session_db.get_session(session_entry.session_id) + row = await self._session_db.get_session(session_entry.session_id) if isinstance(row, dict): session_row = row db_total_tokens = ( @@ -2983,7 +2983,7 @@ class GatewaySlashCommandsMixin: # /topic off — clean disable path so users don't have to edit the DB. if args.lower() in {"off", "disable", "stop"}: - return self._disable_telegram_topic_mode_for_chat(source) + return await self._disable_telegram_topic_mode_for_chat(source) if args: if not source.thread_id: @@ -3004,7 +3004,7 @@ class GatewaySlashCommandsMixin: return t("gateway.topic.topics_user_disallowed") try: - self._session_db.enable_telegram_topic_mode( + await self._session_db.enable_telegram_topic_mode( chat_id=str(source.chat_id), user_id=str(source.user_id), has_topics_enabled=capabilities.get("has_topics_enabled"), @@ -3019,7 +3019,7 @@ class GatewaySlashCommandsMixin: if source.thread_id: try: - binding = self._session_db.get_telegram_topic_binding( + binding = await self._session_db.get_telegram_topic_binding( chat_id=str(source.chat_id), thread_id=str(source.thread_id), ) @@ -3030,7 +3030,7 @@ class GatewaySlashCommandsMixin: session_id = str(binding.get("session_id") or "") title = None try: - title = self._session_db.get_session_title(session_id) + title = await self._session_db.get_session_title(session_id) except Exception: title = None session_label = title or t("gateway.topic.untitled_session") @@ -3041,7 +3041,7 @@ class GatewaySlashCommandsMixin: ) return t("gateway.topic.thread_ready") - return self._telegram_topic_root_status_message(source) + return await self._telegram_topic_root_status_message(source) async def _handle_title_command(self, event: MessageEvent) -> str: """Handle /title command — set or show the current session's title.""" @@ -3055,11 +3055,11 @@ class GatewaySlashCommandsMixin: # Ensure session exists in SQLite DB (it may only exist in session_store # if this is the first command in a new session) - existing_title = self._session_db.get_session_title(session_id) + existing_title = await self._session_db.get_session_title(session_id) if existing_title is None: # Session doesn't exist in DB yet — create it try: - self._session_db.create_session( + await self._session_db.create_session( session_id=session_id, source=source.platform.value if source.platform else "unknown", user_id=source.user_id, @@ -3071,14 +3071,15 @@ class GatewaySlashCommandsMixin: if title_arg: # Sanitize the title before setting try: - sanitized = self._session_db.sanitize_title(title_arg) + from hermes_state import SessionDB + sanitized = SessionDB.sanitize_title(title_arg) except ValueError as e: return t("gateway.shared.warn_passthrough", error=e) if not sanitized: return t("gateway.title.empty_after_clean") # Set the title try: - if self._session_db.set_session_title(session_id, sanitized): + if await self._session_db.set_session_title(session_id, sanitized): # Propagate the user-chosen title to the visible Telegram # forum topic name too. Auto-generated titles already rename # the topic; without this, /title only updated the DB title @@ -3102,7 +3103,7 @@ class GatewaySlashCommandsMixin: return t("gateway.shared.warn_passthrough", error=e) else: # Show the current title and session ID - title = self._session_db.get_session_title(session_id) + title = await self._session_db.get_session_title(session_id) if title: return t("gateway.title.current_with_title", session_id=session_id, title=title) else: @@ -3135,15 +3136,15 @@ class GatewaySlashCommandsMixin: ): name = name[1:-1].strip() - def _list_titled_sessions() -> list[dict]: + async def _list_titled_sessions() -> list[dict]: user_source = source.platform.value if source.platform else None - sessions = self._session_db.list_sessions_rich(source=user_source, limit=10) + sessions = await self._session_db.list_sessions_rich(source=user_source, limit=10) return [s for s in sessions if s.get("title")][:10] if not name: # List recent titled sessions for this user/platform try: - titled = _list_titled_sessions() + titled = await _list_titled_sessions() if source.platform == Platform.MATRIX and not allow_all: scoped = [] for s in titled: @@ -3174,7 +3175,7 @@ class GatewaySlashCommandsMixin: # Resolve a numbered choice or a title to a session ID. if name.isdigit(): try: - titled = _list_titled_sessions() + titled = await _list_titled_sessions() if source.platform == Platform.MATRIX and not allow_all: scoped = [] for s in titled: @@ -3194,17 +3195,17 @@ class GatewaySlashCommandsMixin: else: # Try direct session ID lookup first (so `/resume ` # works in the gateway, not just `/resume `). - session = self._session_db.get_session(name) + session = await self._session_db.get_session(name) if session: target_id = session["id"] else: - target_id = self._session_db.resolve_session_by_title(name) + target_id = await self._session_db.resolve_session_by_title(name) if not target_id: return t("gateway.resume.not_found", name=name) # Compression creates child continuations that hold the live transcript. # Follow that chain so gateway /resume matches CLI behavior (#15000). try: - target_id = self._session_db.resolve_resume_session_id(target_id) + target_id = await self._session_db.resolve_resume_session_id(target_id) except Exception as e: logger.debug("Failed to resolve resume continuation for %s: %s", target_id, e) @@ -3255,7 +3256,7 @@ class GatewaySlashCommandsMixin: self._evict_cached_agent(session_key) # Get the title for confirmation - title = self._session_db.get_session_title(target_id) or name + title = await self._session_db.get_session_title(target_id) or name # Count messages for context history = self.session_store.load_transcript(target_id) @@ -3356,9 +3357,9 @@ class GatewaySlashCommandsMixin: if branch_name: branch_title = branch_name else: - current_title = self._session_db.get_session_title(current_entry.session_id) + current_title = await self._session_db.get_session_title(current_entry.session_id) base = current_title or "branch" - branch_title = self._session_db.get_next_title_in_lineage(base) + branch_title = await self._session_db.get_next_title_in_lineage(base) parent_session_id = current_entry.session_id @@ -3368,7 +3369,7 @@ class GatewaySlashCommandsMixin: # /sessions even after the parent is reopened and re-ended with a # different end_reason (e.g. tui_shutdown overwriting 'branched'). try: - self._session_db.create_session( + await self._session_db.create_session( session_id=new_session_id, source=source.platform.value if source.platform else "gateway", model=(self.config.get("model", {}) or {}).get("default") if isinstance(self.config, dict) else None, @@ -3382,7 +3383,7 @@ class GatewaySlashCommandsMixin: # Copy conversation history to the new session for msg in history: try: - self._session_db.append_message( + await self._session_db.append_message( session_id=new_session_id, role=msg.get("role", "user"), content=msg.get("content"), @@ -3401,7 +3402,7 @@ class GatewaySlashCommandsMixin: # Set title try: - self._session_db.set_session_title(new_session_id, branch_title) + await self._session_db.set_session_title(new_session_id, branch_title) except Exception: pass @@ -3484,7 +3485,7 @@ class GatewaySlashCommandsMixin: if not provider and getattr(self, "_session_db", None) is not None: try: _entry_for_billing = self.session_store.get_or_create_session(source) - persisted = self._session_db.get_session(_entry_for_billing.session_id) or {} + persisted = await self._session_db.get_session(_entry_for_billing.session_id) or {} except Exception: persisted = {} provider = provider or persisted.get("billing_provider") diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index dcbbb1a1cb8..193f7f125d8 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -264,11 +264,18 @@ def make_adapter(platform: Platform, runner=None): async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock: - """Send a message through the full e2e flow and return the send mock.""" + """Send a message through the full e2e flow and return the send mock. + + Polls for the send rather than waiting a fixed delay: handler DB work now + hops to worker threads (AsyncSessionDB), so completion latency varies. + """ event = make_event(platform, text, **event_kwargs) adapter.send.reset_mock() await adapter.handle_message(event) - await asyncio.sleep(0.3) + for _ in range(40): # up to ~2s; returns as soon as the send lands + if adapter.send.called: + break + await asyncio.sleep(0.05) return adapter.send diff --git a/tests/gateway/conftest.py b/tests/gateway/conftest.py index a16eb76a6fe..d1f47c700b2 100644 --- a/tests/gateway/conftest.py +++ b/tests/gateway/conftest.py @@ -39,6 +39,15 @@ from unittest.mock import MagicMock import pytest +def make_async_session_db(sync_mock=None): + """Wrap a sync mock SessionDB in AsyncSessionDB so gateway code that awaits + the facade works in tests. Returns (facade, sync_mock); configure return + values and assert calls on sync_mock.""" + from hermes_state import AsyncSessionDB + sync_mock = sync_mock if sync_mock is not None else MagicMock() + return AsyncSessionDB(sync_mock), sync_mock + + def _ensure_telegram_mock() -> None: """Install a comprehensive telegram mock in sys.modules. diff --git a/tests/gateway/test_agent_cache.py b/tests/gateway/test_agent_cache.py index 54b0fe08794..bba92d37aa0 100644 --- a/tests/gateway/test_agent_cache.py +++ b/tests/gateway/test_agent_cache.py @@ -12,6 +12,8 @@ Verifies that the agent cache correctly: import threading from unittest.mock import MagicMock, patch +import pytest + def _make_runner(): @@ -1565,8 +1567,11 @@ class TestAgentCacheMessageCountRebaseline: """ def _runner_with_db(self, db): + from hermes_state import AsyncSessionDB + runner = _make_runner() - runner._session_db = db + # The gateway holds the async facade; the production refresh awaits it. + runner._session_db = AsyncSessionDB(db) return runner @staticmethod @@ -1577,7 +1582,7 @@ class TestAgentCacheMessageCountRebaseline: the cached agent (or either side is None / it's a legacy 2-tuple). """ try: - row = runner._session_db.get_session(session_id) + row = runner._session_db._db.get_session(session_id) live = row.get("message_count", 0) if row else None except Exception: live = None @@ -1591,7 +1596,8 @@ class TestAgentCacheMessageCountRebaseline: ) return not invalidate - def test_same_process_turns_preserve_cached_agent(self, tmp_path): + @pytest.mark.asyncio + async def test_same_process_turns_preserve_cached_agent(self, tmp_path): """The regression guard: consecutive same-process turns must REUSE the cached agent (prompt cache preserved), not rebuild every turn. @@ -1619,7 +1625,7 @@ class TestAgentCacheMessageCountRebaseline: db.append_message("s1", role="user", content="u") db.append_message("s1", role="assistant", content="a") # Post-turn re-baseline (the fix). - runner._refresh_agent_cache_message_count("telegram:s1", "s1") + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") # Next turn's guard decision. if self._guard_would_reuse(runner, "telegram:s1", "s1"): reuses += 1 @@ -1630,7 +1636,8 @@ class TestAgentCacheMessageCountRebaseline: with runner._agent_cache_lock: assert runner._agent_cache["telegram:s1"][0] is agent - def test_cross_process_write_still_invalidates(self, tmp_path): + @pytest.mark.asyncio + async def test_cross_process_write_still_invalidates(self, tmp_path): """After the re-baseline, a DIFFERENT process appending to the same session must still flip the guard to rebuild (the #45966 fix holds). """ @@ -1650,7 +1657,7 @@ class TestAgentCacheMessageCountRebaseline: # Our own turn + re-baseline -> reuse next turn. db.append_message("s1", role="user", content="u") db.append_message("s1", role="assistant", content="a") - runner._refresh_agent_cache_message_count("telegram:s1", "s1") + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") assert self._guard_would_reuse(runner, "telegram:s1", "s1") is True # ANOTHER process (e.g. the desktop dashboard backend) appends a turn @@ -1660,10 +1667,11 @@ class TestAgentCacheMessageCountRebaseline: # Guard must now reject reuse so the agent rebuilds from fresh disk. assert self._guard_would_reuse(runner, "telegram:s1", "s1") is False - def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path): + @pytest.mark.asyncio + async def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path): """Re-baseline must never crash and must leave legacy 2-tuples and pending-sentinel entries untouched.""" - from hermes_state import SessionDB + from hermes_state import AsyncSessionDB, SessionDB from gateway.run import _AGENT_PENDING_SENTINEL db = SessionDB(db_path=tmp_path / "sessions.db") @@ -1673,24 +1681,24 @@ class TestAgentCacheMessageCountRebaseline: # No session_db -> no-op, no crash. runner._session_db = None - runner._refresh_agent_cache_message_count("telegram:s1", "s1") - runner._session_db = db + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") + runner._session_db = AsyncSessionDB(db) # Falsy session_id -> no-op. - runner._refresh_agent_cache_message_count("telegram:s1", "") - runner._refresh_agent_cache_message_count("telegram:s1", None) + await runner._refresh_agent_cache_message_count("telegram:s1", "") + await runner._refresh_agent_cache_message_count("telegram:s1", None) # Legacy 2-tuple is left untouched (it opts out of the guard). with runner._agent_cache_lock: runner._agent_cache["telegram:s1"] = (object(), "sig") - runner._refresh_agent_cache_message_count("telegram:s1", "s1") + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") with runner._agent_cache_lock: assert len(runner._agent_cache["telegram:s1"]) == 2 # Pending sentinel entry is left untouched. with runner._agent_cache_lock: runner._agent_cache["telegram:s1"] = (_AGENT_PENDING_SENTINEL, "sig", 0) - runner._refresh_agent_cache_message_count("telegram:s1", "s1") + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") with runner._agent_cache_lock: assert runner._agent_cache["telegram:s1"][0] is _AGENT_PENDING_SENTINEL assert runner._agent_cache["telegram:s1"][2] == 0 @@ -1700,10 +1708,10 @@ class TestAgentCacheMessageCountRebaseline: def get_session(self, _sid): raise RuntimeError("db locked") - runner._session_db = _BoomDB() # type: ignore[assignment] + runner._session_db = AsyncSessionDB(_BoomDB()) # type: ignore[assignment] with runner._agent_cache_lock: runner._agent_cache["telegram:s1"] = (object(), "sig", 5) - runner._refresh_agent_cache_message_count("telegram:s1", "s1") + await runner._refresh_agent_cache_message_count("telegram:s1", "s1") with runner._agent_cache_lock: assert runner._agent_cache["telegram:s1"][2] == 5 diff --git a/tests/gateway/test_async_session_db.py b/tests/gateway/test_async_session_db.py index 897ceb1cf00..fa459cb31d7 100644 --- a/tests/gateway/test_async_session_db.py +++ b/tests/gateway/test_async_session_db.py @@ -122,29 +122,40 @@ def test_non_callable_attribute_passes_through(): _GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py") # The only legitimate non-loop paths: # - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB. -# - self._session_db._db.<x>: the sync escape, allowed ONLY at construction. -_ALLOWED_SYNC_DB_ESCAPES = 1 # exactly the maybe_auto_prune call in __init__ +# - self._session_db._db.<x>: the sync escape, allowed ONLY where the call is +# provably off the event loop — construction (__init__, before the loop +# serves) and the run_sync closure (executed in a thread-pool executor). +# Three such sites today; a fourth must be justified and this count bumped. +_ALLOWED_SYNC_DB_ESCAPES = 3 def _repo_root() -> Path: return Path(__file__).resolve().parents[2] -class _RawCallVisitor(ast.NodeVisitor): - """Collect calls of the shape self._session_db.<method>(...). +class _RawCallVisitor: + """Collect non-awaited self._session_db.<method>(...) calls in a module. - Whether the call is awaited is irrelevant to the AST node; an Await wraps - the Call. We flag the raw shape and separately exempt the _db. escape and - the sanitize_title staticmethod (which is called on the class, not self). + An ``await x.y()`` parses as Await(value=Call(...)); those Call nodes are + exempt — they're the migrated path. We flag only Calls that are NOT directly + awaited, and separately count the self._session_db._db.<x> sync escape. The + sanitize_title staticmethod is called on the class (SessionDB.sanitize_title), + so it never matches the self._session_db.<method> shape. """ - def __init__(self): - self.raw_calls = [] # (method, lineno) + def __init__(self, tree: ast.AST): + self.raw_calls = [] # (method, lineno) — non-awaited self.db_escapes = [] # self._session_db._db.<x> sites (lineno) - def visit_Call(self, node: ast.Call): - func = node.func - if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute): + awaited = {id(n.value) for n in ast.walk(tree) + if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)} + + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not (isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute)): + continue inner = func.value # self._session_db._db.<method>(...) -> sync escape if ( @@ -155,21 +166,19 @@ class _RawCallVisitor(ast.NodeVisitor): and inner.value.value.id == "self" ): self.db_escapes.append(inner.lineno) - # self._session_db.<method>(...) -> raw loop call + # self._session_db.<method>(...) not wrapped in await -> raw loop call elif ( inner.attr == "_session_db" and isinstance(inner.value, ast.Name) and inner.value.id == "self" + and id(node) not in awaited ): self.raw_calls.append((func.attr, node.lineno)) - self.generic_visit(node) def _scan(rel_path: str) -> _RawCallVisitor: source = (_repo_root() / rel_path).read_text(encoding="utf-8") - visitor = _RawCallVisitor() - visitor.visit(ast.parse(source)) - return visitor + return _RawCallVisitor(ast.parse(source)) def test_no_raw_session_db_calls_on_gateway_loop(): @@ -189,15 +198,16 @@ def test_no_raw_session_db_calls_on_gateway_loop(): ) -def test_sync_db_escape_confined_to_construction(): - """The self._session_db._db. sync escape must stay confined to one site. +def test_sync_db_escape_confined_to_off_loop_sites(): + """The self._session_db._db. sync escape must stay confined to known sites. - It is legitimate only at construction (before the loop serves traffic). - More than one occurrence means a blocking call leaked back onto the loop - through the escape hatch. + It is legitimate only where the call is provably off the loop: construction + (before the loop serves) and the run_sync executor closure. More occurrences + than the reviewed count means a blocking call may have leaked back onto the + loop through the escape hatch. """ total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES) assert total <= _ALLOWED_SYNC_DB_ESCAPES, ( f"self._session_db._db. sync escape used {total} times; " - f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction only) is allowed." + f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed." ) diff --git a/tests/gateway/test_handoff_watcher_async_db.py b/tests/gateway/test_handoff_watcher_async_db.py index c10093d07a5..dc7382dcf49 100644 --- a/tests/gateway/test_handoff_watcher_async_db.py +++ b/tests/gateway/test_handoff_watcher_async_db.py @@ -5,13 +5,14 @@ The Discord gateway heartbeat was stalling because the handoff watcher SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s ('Shard ID None heartbeat blocked for more than N seconds'). -The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside -the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a -worker thread and never blocks the event loop / Discord heartbeat. +The fix routes every blocking ``SessionDB`` call in the watcher through the +``AsyncSessionDB`` facade, which offloads each call via ``asyncio.to_thread`` so +the SQLite I/O runs on a worker thread and never blocks the event loop / Discord +heartbeat. These tests assert that behaviour contract. They are mutation-survivable: -reverting any ``asyncio.to_thread(self._session_db.<call>)`` wrap back to a -direct synchronous call on the loop makes the relevant assertion fail. +reverting any ``await self._session_db.<call>(...)`` back to a direct synchronous +call on the loop makes the relevant assertion fail. """ import asyncio @@ -62,9 +63,15 @@ class _RecordingSessionDB: def _make_fake_runner(session_db, *, fail_process=False): - """Build a minimal object that exposes exactly what the loop body touches.""" + """Build a minimal object that exposes exactly what the loop body touches. + + The watcher now talks to the SessionDB through the AsyncSessionDB facade, + so wrap the recording stand-in the same way the gateway does. + """ + from hermes_state import AsyncSessionDB + fake = types.SimpleNamespace() - fake._session_db = session_db + fake._session_db = AsyncSessionDB(session_db) # _running yields True for the first loop check, then False so the loop # exits after a single tick. states = iter([True, False]) @@ -141,21 +148,23 @@ async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch): async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch): """Explicitly assert the offload goes through asyncio.to_thread. - Patches ``run.asyncio.to_thread`` and records which SessionDB callables - were handed to it. Mutation-survivable: dropping any wrap removes its - callable from the recorded set. + Patches the AsyncSessionDB facade's ``asyncio.to_thread`` (it lives in + hermes_state) and records which SessionDB callables were handed to it. + Mutation-survivable: dropping any await removes its callable from the set. """ + import hermes_state + db = _RecordingSessionDB(loop_thread_ident=-1) fake = _make_fake_runner(db, fail_process=False) wrapped = [] - real_to_thread = run.asyncio.to_thread + real_to_thread = hermes_state.asyncio.to_thread async def _spy_to_thread(func, *args, **kwargs): wrapped.append(getattr(func, "__name__", repr(func))) return await real_to_thread(func, *args, **kwargs) - monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread) + monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy_to_thread) await _run_one_tick(fake, monkeypatch) diff --git a/tests/gateway/test_matrix_project_context_isolation.py b/tests/gateway/test_matrix_project_context_isolation.py index 943a367d67b..00341a8036c 100644 --- a/tests/gateway/test_matrix_project_context_isolation.py +++ b/tests/gateway/test_matrix_project_context_isolation.py @@ -12,6 +12,7 @@ import pytest from gateway.config import GatewayConfig, Platform, PlatformConfig from gateway.platforms.base import MessageEvent +from hermes_state import AsyncSessionDB from gateway.session import ( SessionContext, SessionEntry, @@ -343,16 +344,16 @@ def _make_runner(current_source: SessionSource, entries: list[SessionEntry]): runner._clear_session_boundary_security_state = MagicMock() runner._evict_cached_agent = MagicMock() runner._queue_depth = MagicMock(return_value=0) - runner._session_db = MagicMock() - runner._session_db.list_sessions_rich.return_value = [ + runner._session_db = AsyncSessionDB(MagicMock()) + runner._session_db._db.list_sessions_rich.return_value = [ {"id": entry.session_id, "title": entry.display_name, "preview": ""} for entry in entries ] - runner._session_db.resolve_resume_session_id.side_effect = lambda sid: sid - runner._session_db.get_session_title.side_effect = lambda sid: { + runner._session_db._db.resolve_resume_session_id.side_effect = lambda sid: sid + runner._session_db._db.get_session_title.side_effect = lambda sid: { entry.session_id: entry.display_name for entry in entries }.get(sid) - runner._session_db.get_session.return_value = None + runner._session_db._db.get_session.return_value = None return runner @@ -388,7 +389,7 @@ async def test_matrix_resume_does_not_cross_rooms_by_default(): entry_a = _entry(source_a, "session-a", "Project A Plan") entry_b = _entry(source_b, "session-b", "Project B Plan") runner = _make_runner(source_b, [entry_a, entry_b]) - runner._session_db.resolve_session_by_title.return_value = "session-a" + runner._session_db._db.resolve_session_by_title.return_value = "session-a" result = await runner._handle_resume_command(_event("/resume Project A Plan", source_b)) @@ -406,7 +407,7 @@ async def test_matrix_resume_allows_same_room_session(): source_b, "session-b-current", "Current Project B" ) runner.session_store.switch_session.return_value = entry_b - runner._session_db.resolve_session_by_title.return_value = "session-b-old" + runner._session_db._db.resolve_session_by_title.return_value = "session-b-old" result = await runner._handle_resume_command(_event("/resume Project B Plan", source_b)) @@ -423,14 +424,14 @@ async def test_matrix_resume_quoted_title_same_room(): source_b, "session-b-current", "Current Project B" ) runner.session_store.switch_session.return_value = entry_b - runner._session_db.resolve_session_by_title.return_value = "session-b-old" + runner._session_db._db.resolve_session_by_title.return_value = "session-b-old" result = await runner._handle_resume_command( _event('/resume "Project B Plan"', source_b) ) assert "Resumed session" in result - runner._session_db.resolve_session_by_title.assert_called_once_with("Project B Plan") + runner._session_db._db.resolve_session_by_title.assert_called_once_with("Project B Plan") @pytest.mark.asyncio @@ -440,7 +441,7 @@ async def test_matrix_resume_quoted_title_cross_room_blocked(): entry_a = _entry(source_a, "session-a", "Project A Plan") entry_b = _entry(source_b, "session-b", "Project B Plan") runner = _make_runner(source_b, [entry_a, entry_b]) - runner._session_db.resolve_session_by_title.return_value = "session-a" + runner._session_db._db.resolve_session_by_title.return_value = "session-a" result = await runner._handle_resume_command( _event('/resume "Project A Plan"', source_b) @@ -471,7 +472,7 @@ async def test_matrix_resume_cross_room_requires_explicit_flag_and_warns(): entry_b = _entry(source_b, "session-b", "Project B Plan") runner = _make_runner(source_b, [entry_a, entry_b]) runner.session_store.switch_session.return_value = entry_a - runner._session_db.resolve_session_by_title.return_value = "session-a" + runner._session_db._db.resolve_session_by_title.return_value = "session-a" result = await runner._handle_resume_command( _event("/resume --cross-room Project A Plan", source_b) diff --git a/tests/gateway/test_session_boundary_security_state.py b/tests/gateway/test_session_boundary_security_state.py index 0899d177c4d..f3862aac6a7 100644 --- a/tests/gateway/test_session_boundary_security_state.py +++ b/tests/gateway/test_session_boundary_security_state.py @@ -1,3 +1,4 @@ +from hermes_state import AsyncSessionDB """Regression tests for approval-state cleanup on session boundaries.""" from datetime import datetime @@ -86,9 +87,9 @@ def _make_resume_runner(): runner.session_store.get_or_create_session.return_value = current_entry runner.session_store.switch_session.return_value = resumed_entry runner.session_store.load_transcript.return_value = [] - runner._session_db = MagicMock() - runner._session_db.resolve_session_by_title.return_value = "resumed-session" - runner._session_db.get_session_title.return_value = "Resumed Work" + runner._session_db = AsyncSessionDB(MagicMock()) + runner._session_db._db.resolve_session_by_title.return_value = "resumed-session" + runner._session_db._db.get_session_title.return_value = "Resumed Work" return runner, session_key @@ -116,9 +117,9 @@ def _make_branch_runner(): {"role": "assistant", "content": "world"}, ] runner.session_store.switch_session.return_value = branched_entry - runner._session_db = MagicMock() - runner._session_db.get_session_title.return_value = "Current Work" - runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2" + runner._session_db = AsyncSessionDB(MagicMock()) + runner._session_db._db.get_session_title.return_value = "Current Work" + runner._session_db._db.get_next_title_in_lineage.return_value = "Current Work #2" return runner, session_key @@ -208,7 +209,7 @@ async def test_branch_preserves_persisted_assistant_metadata(): result = await runner._handle_branch_command(_make_event("/branch")) assert "Branched to" in result - append_calls = runner._session_db.append_message.call_args_list + append_calls = runner._session_db._db.append_message.call_args_list assert len(append_calls) == 2 assistant_kwargs = append_calls[1].kwargs assert assistant_kwargs["role"] == "assistant" diff --git a/tests/gateway/test_session_race_guard.py b/tests/gateway/test_session_race_guard.py index 80ec02c22f0..9a9c0bf7d08 100644 --- a/tests/gateway/test_session_race_guard.py +++ b/tests/gateway/test_session_race_guard.py @@ -171,8 +171,12 @@ async def test_second_message_during_sentinel_queued_not_duplicate(): with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner): # Start first message (will block at barrier) task1 = asyncio.create_task(runner._handle_message(event1)) - # Yield so task1 enters slow_inner and sentinel is set - await asyncio.sleep(0) + # Yield until task1 has claimed the sentinel (it crosses a few awaits + # before the claim; don't assume a fixed number of scheduler slices). + for _ in range(50): + await asyncio.sleep(0) + if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL: + break # Verify sentinel is set assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL @@ -417,7 +421,10 @@ async def test_stop_during_sentinel_force_cleans_session(): with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner): task1 = asyncio.create_task(runner._handle_message(event1)) - await asyncio.sleep(0) + for _ in range(50): + await asyncio.sleep(0) + if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL: + break # Sentinel should be set assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL diff --git a/tests/gateway/test_status_command.py b/tests/gateway/test_status_command.py index 39ea4e3ff13..cadeb9ca706 100644 --- a/tests/gateway/test_status_command.py +++ b/tests/gateway/test_status_command.py @@ -1,3 +1,4 @@ +from hermes_state import AsyncSessionDB """Tests for gateway /status behavior and token persistence.""" from datetime import datetime @@ -53,11 +54,11 @@ def _make_runner(session_entry: SessionEntry, *, platform: Platform = Platform.T runner._session_run_generation = {} runner._pending_messages = {} runner._pending_approvals = {} - runner._session_db = MagicMock() - runner._session_db.get_session_title.return_value = None + runner._session_db = AsyncSessionDB(MagicMock()) + runner._session_db._db.get_session_title.return_value = None # Default: no DB row → /status reports 0 tokens. Tests that exercise # the populated path override this. - runner._session_db.get_session.return_value = None + runner._session_db._db.get_session.return_value = None runner._reasoning_config = None runner._provider_routing = {} runner._fallback_model = None @@ -86,7 +87,7 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc ) runner = _make_runner(session_entry) # Token total comes from the SQLite SessionDB, not SessionEntry. - runner._session_db.get_session.return_value = { + runner._session_db._db.get_session.return_value = { "input_tokens": 200, "output_tokens": 121, "cache_read_tokens": 0, @@ -118,7 +119,7 @@ async def test_status_command_includes_session_title_when_present(): total_tokens=321, ) runner = _make_runner(session_entry) - runner._session_db.get_session_title.return_value = "My titled session" + runner._session_db._db.get_session_title.return_value = "My titled session" result = await runner._handle_message(_make_event("/status")) @@ -141,7 +142,7 @@ async def test_status_command_reads_token_totals_from_session_db(): total_tokens=0, # SessionEntry never gets written to — always 0. ) runner = _make_runner(session_entry) - runner._session_db.get_session.return_value = { + runner._session_db._db.get_session.return_value = { "input_tokens": 1000, "output_tokens": 250, "cache_read_tokens": 500, @@ -169,7 +170,7 @@ async def test_status_command_tokens_zero_when_session_db_row_missing(): total_tokens=999, # This should be ignored. ) runner = _make_runner(session_entry) - runner._session_db.get_session.return_value = None + runner._session_db._db.get_session.return_value = None result = await runner._handle_message(_make_event("/status")) @@ -188,7 +189,7 @@ async def test_status_command_includes_live_agent_model_and_context(): total_tokens=0, ) runner = _make_runner(session_entry) - runner._session_db.get_session.return_value = { + runner._session_db._db.get_session.return_value = { "input_tokens": 1000, "output_tokens": 250, "cache_read_tokens": 0, @@ -228,7 +229,7 @@ async def test_status_command_includes_persisted_model_and_context_when_agent_no last_prompt_tokens=24_000, ) runner = _make_runner(session_entry) - runner._session_db.get_session.return_value = { + runner._session_db._db.get_session.return_value = { "input_tokens": 2000, "output_tokens": 500, "cache_read_tokens": 0, diff --git a/tests/gateway/test_usage_command.py b/tests/gateway/test_usage_command.py index d58c57613dd..40cbe3192ff 100644 --- a/tests/gateway/test_usage_command.py +++ b/tests/gateway/test_usage_command.py @@ -1,3 +1,4 @@ +from hermes_state import AsyncSessionDB """Tests for gateway /usage command — agent cache lookup and output fields.""" import threading @@ -197,8 +198,8 @@ class TestUsageAccountSection: @pytest.mark.asyncio async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch): runner = _make_runner(SK) - runner._session_db = MagicMock() - runner._session_db.get_session.return_value = { + runner._session_db = AsyncSessionDB(MagicMock()) + runner._session_db._db.get_session.return_value = { "billing_provider": "openai-codex", "billing_base_url": "https://chatgpt.com/backend-api/codex", }