fix(gateway): route SessionDB calls through AsyncSessionDB

This commit is contained in:
yoniebans 2026-06-29 12:04:44 +02:00 committed by Teknium
parent ea26f22710
commit 0896facce8
12 changed files with 203 additions and 142 deletions

View file

@ -2777,8 +2777,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
# Initialize session database for session_search tool support
self._session_db = None
try:
from hermes_state import SessionDB
self._session_db = SessionDB()
from hermes_state import AsyncSessionDB, SessionDB
self._session_db = AsyncSessionDB(SessionDB())
except Exception as e:
# WARNING (not DEBUG) so the failure appears in errors.log — matches
# cli.py's handling of the same init path. Users hitting NFS-mounted
@ -2799,7 +2799,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
from hermes_cli.config import load_config as _load_full_config
_sess_cfg = (_load_full_config().get("sessions") or {})
if _sess_cfg.get("auto_prune", False):
self._session_db.maybe_auto_prune_and_vacuum(
# Construction-time, before the loop serves traffic; sync DB is fine.
self._session_db._db.maybe_auto_prune_and_vacuum(
retention_days=int(_sess_cfg.get("retention_days", 90)),
min_interval_hours=int(_sess_cfg.get("min_interval_hours", 24)),
vacuum=bool(_sess_cfg.get("vacuum_after_prune", True)),
@ -6578,23 +6579,23 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if self._session_db is None:
await asyncio.sleep(interval)
continue
pending = await asyncio.to_thread(self._session_db.list_pending_handoffs)
pending = await self._session_db.list_pending_handoffs()
for row in pending:
session_id = row.get("id")
if not session_id:
continue
if not await asyncio.to_thread(self._session_db.claim_handoff, session_id):
if not await self._session_db.claim_handoff(session_id):
# Another tick or another gateway already claimed it.
continue
try:
await self._process_handoff(row)
await asyncio.to_thread(self._session_db.complete_handoff, session_id)
await self._session_db.complete_handoff(session_id)
except Exception as exc:
logger.warning(
"Handoff for session %s failed: %s",
session_id, exc, exc_info=True,
)
await asyncio.to_thread(self._session_db.fail_handoff, session_id, str(exc))
await self._session_db.fail_handoff(session_id, str(exc))
except asyncio.CancelledError:
raise
except Exception as exc:
@ -7443,8 +7444,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
# old gateway's connection holding the WAL lock until Python
# actually exits — causing 'database is locked' errors when
# the new gateway tries to open the same file.
for _db_holder in (self, getattr(self, "session_store", None)):
_db = getattr(_db_holder, "_db", None) if _db_holder else None
# ``self`` holds the DB at ``_session_db`` (an AsyncSessionDB facade);
# unwrap to the sync handle. ``session_store`` holds it at ``_db``.
_self_db = getattr(self, "_session_db", None)
_self_db = getattr(_self_db, "_db", _self_db)
for _db in (_self_db, getattr(getattr(self, "session_store", None), "_db", None)):
if _db is None or not hasattr(_db, "close"):
continue
try:
@ -9641,10 +9645,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
self._cache_session_source(session_key, source)
if self._is_telegram_topic_lane(source):
try:
binding = self._session_db.get_telegram_topic_binding(
binding = (await self._session_db.get_telegram_topic_binding(
chat_id=str(source.chat_id),
thread_id=str(source.thread_id),
) if self._session_db else None
)) if self._session_db else None
except Exception:
logger.debug("Failed to read Telegram topic binding", exc_info=True)
binding = None
@ -9658,7 +9662,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
# a compression parent, so this is cheap and safe.
if bound_session_id and self._session_db is not None:
try:
canonical_session_id = self._session_db.get_compression_tip(
canonical_session_id = await self._session_db.get_compression_tip(
bound_session_id,
)
except Exception:
@ -10450,7 +10454,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
# prompt caching. Refreshing here makes the guard fire only on a
# DIFFERENT process's writes. Uses the (possibly compaction-
# updated) live session_id. Fail-safe inside the helper.
self._refresh_agent_cache_message_count(
await self._refresh_agent_cache_message_count(
session_key, session_entry.session_id
)
@ -12481,7 +12485,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
"5. /topic <id> inside a topic restores an old session into it."
)
def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str:
async def _disable_telegram_topic_mode_for_chat(self, source: SessionSource) -> str:
"""Cleanly disable topic mode for a chat via /topic off."""
if not self._session_db:
from hermes_state import format_session_db_unavailable
@ -12491,7 +12495,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
return "Could not determine chat ID."
# No-op if never enabled.
try:
currently_enabled = self._session_db.is_telegram_topic_mode_enabled(
currently_enabled = await self._session_db.is_telegram_topic_mode_enabled(
chat_id=chat_id,
user_id=str(source.user_id or ""),
)
@ -12500,7 +12504,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if not currently_enabled:
return "Multi-session topic mode is not currently enabled for this chat."
try:
self._session_db.disable_telegram_topic_mode(chat_id=chat_id)
await self._session_db.disable_telegram_topic_mode(chat_id=chat_id)
except Exception as exc:
logger.exception("Failed to disable Telegram topic mode")
return f"Failed to disable topic mode: {exc}"
@ -12518,7 +12522,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
)
def _telegram_topic_root_status_message(self, source: SessionSource) -> str:
async def _telegram_topic_root_status_message(self, source: SessionSource) -> str:
lines = [
"Telegram multi-session topics are enabled.",
"",
@ -12528,7 +12532,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
"",
]
try:
sessions = self._session_db.list_unlinked_telegram_sessions_for_user(
sessions = await self._session_db.list_unlinked_telegram_sessions_for_user(
chat_id=str(source.chat_id),
user_id=str(source.user_id),
limit=10,
@ -12567,11 +12571,11 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
async def _restore_telegram_topic_session(self, event: MessageEvent, raw_session_id: str) -> str:
"""Restore an existing Telegram-owned Hermes session into this topic."""
source = event.source
session_id = self._session_db.resolve_session_id(raw_session_id.strip())
session_id = await self._session_db.resolve_session_id(raw_session_id.strip())
if not session_id:
return f"Session not found: {raw_session_id.strip()}"
session = self._session_db.get_session(session_id)
session = await self._session_db.get_session(session_id)
if not session:
return f"Session not found: {raw_session_id.strip()}"
if str(session.get("source") or "") != "telegram":
@ -12579,8 +12583,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if str(session.get("user_id") or "") != str(source.user_id):
return "That session does not belong to this Telegram user."
linked = self._session_db.is_telegram_session_linked_to_topic(session_id=session_id)
current_binding = self._session_db.get_telegram_topic_binding(
linked = await self._session_db.is_telegram_session_linked_to_topic(session_id=session_id)
current_binding = await self._session_db.get_telegram_topic_binding(
chat_id=str(source.chat_id),
thread_id=str(source.thread_id),
)
@ -12590,7 +12594,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
session_key = self._session_key_for_source(source)
try:
self._session_db.bind_telegram_topic(
await self._session_db.bind_telegram_topic(
chat_id=str(source.chat_id),
thread_id=str(source.thread_id),
user_id=str(source.user_id),
@ -12603,10 +12607,10 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
return "That session is already linked to another Telegram topic."
raise
title = self._session_db.get_session_title(session_id) or session_id
title = await self._session_db.get_session_title(session_id) or session_id
last_assistant = None
try:
for message in reversed(self._session_db.get_messages(session_id)):
for message in reversed(await self._session_db.get_messages(session_id)):
if message.get("role") == "assistant" and message.get("content"):
last_assistant = str(message.get("content"))
break
@ -14631,7 +14635,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if release_running_state:
self._release_running_agent_state(session_key)
def _refresh_agent_cache_message_count(
async def _refresh_agent_cache_message_count(
self, session_key: str, session_id: Optional[str]
) -> None:
"""Re-baseline a cached agent's stored message_count after THIS turn.
@ -14663,7 +14667,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
if not _cache_lock or _cache is None:
return
try:
_sess_row = self._session_db.get_session(session_id)
_sess_row = await self._session_db.get_session(session_id)
_live = _sess_row.get("message_count", 0) if _sess_row else None
except Exception:
return
@ -16346,7 +16350,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
_current_msg_count = None
if self._session_db is not None and session_id:
try:
_sess_row = self._session_db.get_session(session_id)
# run_sync is off-loop (executor); sync DB is fine.
_sess_row = self._session_db._db.get_session(session_id)
if _sess_row:
_current_msg_count = _sess_row.get("message_count", 0)
except Exception:
@ -17044,7 +17049,8 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
and self._session_db is not None
):
try:
_binding = self._session_db.get_telegram_topic_binding_by_session(
# run_sync is off-loop (executor); sync DB is fine.
_binding = self._session_db._db.get_telegram_topic_binding_by_session(
session_id=agent_session_id,
)
if _binding and _binding.get("thread_id"):
@ -17169,7 +17175,7 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew
title,
)
maybe_auto_title(
self._session_db,
getattr(self._session_db, "_db", self._session_db),
effective_session_id,
message,
final_response,

View file

@ -246,7 +246,7 @@ class GatewaySlashCommandsMixin:
_title_note = t("gateway.reset.title_rejected", error=str(e))
if sanitized:
try:
self._session_db.set_session_title(new_entry.session_id, sanitized)
await self._session_db.set_session_title(new_entry.session_id, sanitized)
header = t("gateway.reset.header_titled", title=sanitized)
except ValueError as e:
_title_note = t("gateway.reset.title_error_untitled", error=str(e))
@ -498,11 +498,11 @@ class GatewaySlashCommandsMixin:
db_total_tokens = 0
if self._session_db:
try:
title = self._session_db.get_session_title(session_entry.session_id)
title = await self._session_db.get_session_title(session_entry.session_id)
except Exception:
title = None
try:
row = self._session_db.get_session(session_entry.session_id)
row = await self._session_db.get_session(session_entry.session_id)
if isinstance(row, dict):
session_row = row
db_total_tokens = (
@ -2983,7 +2983,7 @@ class GatewaySlashCommandsMixin:
# /topic off — clean disable path so users don't have to edit the DB.
if args.lower() in {"off", "disable", "stop"}:
return self._disable_telegram_topic_mode_for_chat(source)
return await self._disable_telegram_topic_mode_for_chat(source)
if args:
if not source.thread_id:
@ -3004,7 +3004,7 @@ class GatewaySlashCommandsMixin:
return t("gateway.topic.topics_user_disallowed")
try:
self._session_db.enable_telegram_topic_mode(
await self._session_db.enable_telegram_topic_mode(
chat_id=str(source.chat_id),
user_id=str(source.user_id),
has_topics_enabled=capabilities.get("has_topics_enabled"),
@ -3019,7 +3019,7 @@ class GatewaySlashCommandsMixin:
if source.thread_id:
try:
binding = self._session_db.get_telegram_topic_binding(
binding = await self._session_db.get_telegram_topic_binding(
chat_id=str(source.chat_id),
thread_id=str(source.thread_id),
)
@ -3030,7 +3030,7 @@ class GatewaySlashCommandsMixin:
session_id = str(binding.get("session_id") or "")
title = None
try:
title = self._session_db.get_session_title(session_id)
title = await self._session_db.get_session_title(session_id)
except Exception:
title = None
session_label = title or t("gateway.topic.untitled_session")
@ -3041,7 +3041,7 @@ class GatewaySlashCommandsMixin:
)
return t("gateway.topic.thread_ready")
return self._telegram_topic_root_status_message(source)
return await self._telegram_topic_root_status_message(source)
async def _handle_title_command(self, event: MessageEvent) -> str:
"""Handle /title command — set or show the current session's title."""
@ -3055,11 +3055,11 @@ class GatewaySlashCommandsMixin:
# Ensure session exists in SQLite DB (it may only exist in session_store
# if this is the first command in a new session)
existing_title = self._session_db.get_session_title(session_id)
existing_title = await self._session_db.get_session_title(session_id)
if existing_title is None:
# Session doesn't exist in DB yet — create it
try:
self._session_db.create_session(
await self._session_db.create_session(
session_id=session_id,
source=source.platform.value if source.platform else "unknown",
user_id=source.user_id,
@ -3071,14 +3071,15 @@ class GatewaySlashCommandsMixin:
if title_arg:
# Sanitize the title before setting
try:
sanitized = self._session_db.sanitize_title(title_arg)
from hermes_state import SessionDB
sanitized = SessionDB.sanitize_title(title_arg)
except ValueError as e:
return t("gateway.shared.warn_passthrough", error=e)
if not sanitized:
return t("gateway.title.empty_after_clean")
# Set the title
try:
if self._session_db.set_session_title(session_id, sanitized):
if await self._session_db.set_session_title(session_id, sanitized):
# Propagate the user-chosen title to the visible Telegram
# forum topic name too. Auto-generated titles already rename
# the topic; without this, /title only updated the DB title
@ -3102,7 +3103,7 @@ class GatewaySlashCommandsMixin:
return t("gateway.shared.warn_passthrough", error=e)
else:
# Show the current title and session ID
title = self._session_db.get_session_title(session_id)
title = await self._session_db.get_session_title(session_id)
if title:
return t("gateway.title.current_with_title", session_id=session_id, title=title)
else:
@ -3135,15 +3136,15 @@ class GatewaySlashCommandsMixin:
):
name = name[1:-1].strip()
def _list_titled_sessions() -> list[dict]:
async def _list_titled_sessions() -> list[dict]:
user_source = source.platform.value if source.platform else None
sessions = self._session_db.list_sessions_rich(source=user_source, limit=10)
sessions = await self._session_db.list_sessions_rich(source=user_source, limit=10)
return [s for s in sessions if s.get("title")][:10]
if not name:
# List recent titled sessions for this user/platform
try:
titled = _list_titled_sessions()
titled = await _list_titled_sessions()
if source.platform == Platform.MATRIX and not allow_all:
scoped = []
for s in titled:
@ -3174,7 +3175,7 @@ class GatewaySlashCommandsMixin:
# Resolve a numbered choice or a title to a session ID.
if name.isdigit():
try:
titled = _list_titled_sessions()
titled = await _list_titled_sessions()
if source.platform == Platform.MATRIX and not allow_all:
scoped = []
for s in titled:
@ -3194,17 +3195,17 @@ class GatewaySlashCommandsMixin:
else:
# Try direct session ID lookup first (so `/resume <session_id>`
# works in the gateway, not just `/resume <title>`).
session = self._session_db.get_session(name)
session = await self._session_db.get_session(name)
if session:
target_id = session["id"]
else:
target_id = self._session_db.resolve_session_by_title(name)
target_id = await self._session_db.resolve_session_by_title(name)
if not target_id:
return t("gateway.resume.not_found", name=name)
# Compression creates child continuations that hold the live transcript.
# Follow that chain so gateway /resume matches CLI behavior (#15000).
try:
target_id = self._session_db.resolve_resume_session_id(target_id)
target_id = await self._session_db.resolve_resume_session_id(target_id)
except Exception as e:
logger.debug("Failed to resolve resume continuation for %s: %s", target_id, e)
@ -3255,7 +3256,7 @@ class GatewaySlashCommandsMixin:
self._evict_cached_agent(session_key)
# Get the title for confirmation
title = self._session_db.get_session_title(target_id) or name
title = await self._session_db.get_session_title(target_id) or name
# Count messages for context
history = self.session_store.load_transcript(target_id)
@ -3356,9 +3357,9 @@ class GatewaySlashCommandsMixin:
if branch_name:
branch_title = branch_name
else:
current_title = self._session_db.get_session_title(current_entry.session_id)
current_title = await self._session_db.get_session_title(current_entry.session_id)
base = current_title or "branch"
branch_title = self._session_db.get_next_title_in_lineage(base)
branch_title = await self._session_db.get_next_title_in_lineage(base)
parent_session_id = current_entry.session_id
@ -3368,7 +3369,7 @@ class GatewaySlashCommandsMixin:
# /sessions even after the parent is reopened and re-ended with a
# different end_reason (e.g. tui_shutdown overwriting 'branched').
try:
self._session_db.create_session(
await self._session_db.create_session(
session_id=new_session_id,
source=source.platform.value if source.platform else "gateway",
model=(self.config.get("model", {}) or {}).get("default") if isinstance(self.config, dict) else None,
@ -3382,7 +3383,7 @@ class GatewaySlashCommandsMixin:
# Copy conversation history to the new session
for msg in history:
try:
self._session_db.append_message(
await self._session_db.append_message(
session_id=new_session_id,
role=msg.get("role", "user"),
content=msg.get("content"),
@ -3401,7 +3402,7 @@ class GatewaySlashCommandsMixin:
# Set title
try:
self._session_db.set_session_title(new_session_id, branch_title)
await self._session_db.set_session_title(new_session_id, branch_title)
except Exception:
pass
@ -3484,7 +3485,7 @@ class GatewaySlashCommandsMixin:
if not provider and getattr(self, "_session_db", None) is not None:
try:
_entry_for_billing = self.session_store.get_or_create_session(source)
persisted = self._session_db.get_session(_entry_for_billing.session_id) or {}
persisted = await self._session_db.get_session(_entry_for_billing.session_id) or {}
except Exception:
persisted = {}
provider = provider or persisted.get("billing_provider")

View file

@ -264,11 +264,18 @@ def make_adapter(platform: Platform, runner=None):
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
"""Send a message through the full e2e flow and return the send mock."""
"""Send a message through the full e2e flow and return the send mock.
Polls for the send rather than waiting a fixed delay: handler DB work now
hops to worker threads (AsyncSessionDB), so completion latency varies.
"""
event = make_event(platform, text, **event_kwargs)
adapter.send.reset_mock()
await adapter.handle_message(event)
await asyncio.sleep(0.3)
for _ in range(40): # up to ~2s; returns as soon as the send lands
if adapter.send.called:
break
await asyncio.sleep(0.05)
return adapter.send

View file

@ -39,6 +39,15 @@ from unittest.mock import MagicMock
import pytest
def make_async_session_db(sync_mock=None):
"""Wrap a sync mock SessionDB in AsyncSessionDB so gateway code that awaits
the facade works in tests. Returns (facade, sync_mock); configure return
values and assert calls on sync_mock."""
from hermes_state import AsyncSessionDB
sync_mock = sync_mock if sync_mock is not None else MagicMock()
return AsyncSessionDB(sync_mock), sync_mock
def _ensure_telegram_mock() -> None:
"""Install a comprehensive telegram mock in sys.modules.

View file

@ -12,6 +12,8 @@ Verifies that the agent cache correctly:
import threading
from unittest.mock import MagicMock, patch
import pytest
def _make_runner():
@ -1565,8 +1567,11 @@ class TestAgentCacheMessageCountRebaseline:
"""
def _runner_with_db(self, db):
from hermes_state import AsyncSessionDB
runner = _make_runner()
runner._session_db = db
# The gateway holds the async facade; the production refresh awaits it.
runner._session_db = AsyncSessionDB(db)
return runner
@staticmethod
@ -1577,7 +1582,7 @@ class TestAgentCacheMessageCountRebaseline:
the cached agent (or either side is None / it's a legacy 2-tuple).
"""
try:
row = runner._session_db.get_session(session_id)
row = runner._session_db._db.get_session(session_id)
live = row.get("message_count", 0) if row else None
except Exception:
live = None
@ -1591,7 +1596,8 @@ class TestAgentCacheMessageCountRebaseline:
)
return not invalidate
def test_same_process_turns_preserve_cached_agent(self, tmp_path):
@pytest.mark.asyncio
async def test_same_process_turns_preserve_cached_agent(self, tmp_path):
"""The regression guard: consecutive same-process turns must REUSE
the cached agent (prompt cache preserved), not rebuild every turn.
@ -1619,7 +1625,7 @@ class TestAgentCacheMessageCountRebaseline:
db.append_message("s1", role="user", content="u")
db.append_message("s1", role="assistant", content="a")
# Post-turn re-baseline (the fix).
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
# Next turn's guard decision.
if self._guard_would_reuse(runner, "telegram:s1", "s1"):
reuses += 1
@ -1630,7 +1636,8 @@ class TestAgentCacheMessageCountRebaseline:
with runner._agent_cache_lock:
assert runner._agent_cache["telegram:s1"][0] is agent
def test_cross_process_write_still_invalidates(self, tmp_path):
@pytest.mark.asyncio
async def test_cross_process_write_still_invalidates(self, tmp_path):
"""After the re-baseline, a DIFFERENT process appending to the same
session must still flip the guard to rebuild (the #45966 fix holds).
"""
@ -1650,7 +1657,7 @@ class TestAgentCacheMessageCountRebaseline:
# Our own turn + re-baseline -> reuse next turn.
db.append_message("s1", role="user", content="u")
db.append_message("s1", role="assistant", content="a")
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is True
# ANOTHER process (e.g. the desktop dashboard backend) appends a turn
@ -1660,10 +1667,11 @@ class TestAgentCacheMessageCountRebaseline:
# Guard must now reject reuse so the agent rebuilds from fresh disk.
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is False
def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
@pytest.mark.asyncio
async def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
"""Re-baseline must never crash and must leave legacy 2-tuples and
pending-sentinel entries untouched."""
from hermes_state import SessionDB
from hermes_state import AsyncSessionDB, SessionDB
from gateway.run import _AGENT_PENDING_SENTINEL
db = SessionDB(db_path=tmp_path / "sessions.db")
@ -1673,24 +1681,24 @@ class TestAgentCacheMessageCountRebaseline:
# No session_db -> no-op, no crash.
runner._session_db = None
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
runner._session_db = db
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
runner._session_db = AsyncSessionDB(db)
# Falsy session_id -> no-op.
runner._refresh_agent_cache_message_count("telegram:s1", "")
runner._refresh_agent_cache_message_count("telegram:s1", None)
await runner._refresh_agent_cache_message_count("telegram:s1", "")
await runner._refresh_agent_cache_message_count("telegram:s1", None)
# Legacy 2-tuple is left untouched (it opts out of the guard).
with runner._agent_cache_lock:
runner._agent_cache["telegram:s1"] = (object(), "sig")
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
with runner._agent_cache_lock:
assert len(runner._agent_cache["telegram:s1"]) == 2
# Pending sentinel entry is left untouched.
with runner._agent_cache_lock:
runner._agent_cache["telegram:s1"] = (_AGENT_PENDING_SENTINEL, "sig", 0)
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
with runner._agent_cache_lock:
assert runner._agent_cache["telegram:s1"][0] is _AGENT_PENDING_SENTINEL
assert runner._agent_cache["telegram:s1"][2] == 0
@ -1700,10 +1708,10 @@ class TestAgentCacheMessageCountRebaseline:
def get_session(self, _sid):
raise RuntimeError("db locked")
runner._session_db = _BoomDB() # type: ignore[assignment]
runner._session_db = AsyncSessionDB(_BoomDB()) # type: ignore[assignment]
with runner._agent_cache_lock:
runner._agent_cache["telegram:s1"] = (object(), "sig", 5)
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
with runner._agent_cache_lock:
assert runner._agent_cache["telegram:s1"][2] == 5

View file

@ -122,29 +122,40 @@ def test_non_callable_attribute_passes_through():
_GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py")
# The only legitimate non-loop paths:
# - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB.
# - self._session_db._db.<x>: the sync escape, allowed ONLY at construction.
_ALLOWED_SYNC_DB_ESCAPES = 1 # exactly the maybe_auto_prune call in __init__
# - self._session_db._db.<x>: the sync escape, allowed ONLY where the call is
# provably off the event loop — construction (__init__, before the loop
# serves) and the run_sync closure (executed in a thread-pool executor).
# Three such sites today; a fourth must be justified and this count bumped.
_ALLOWED_SYNC_DB_ESCAPES = 3
def _repo_root() -> Path:
return Path(__file__).resolve().parents[2]
class _RawCallVisitor(ast.NodeVisitor):
"""Collect calls of the shape self._session_db.<method>(...).
class _RawCallVisitor:
"""Collect non-awaited self._session_db.<method>(...) calls in a module.
Whether the call is awaited is irrelevant to the AST node; an Await wraps
the Call. We flag the raw shape and separately exempt the _db. escape and
the sanitize_title staticmethod (which is called on the class, not self).
An ``await x.y()`` parses as Await(value=Call(...)); those Call nodes are
exempt they're the migrated path. We flag only Calls that are NOT directly
awaited, and separately count the self._session_db._db.<x> sync escape. The
sanitize_title staticmethod is called on the class (SessionDB.sanitize_title),
so it never matches the self._session_db.<method> shape.
"""
def __init__(self):
self.raw_calls = [] # (method, lineno)
def __init__(self, tree: ast.AST):
self.raw_calls = [] # (method, lineno) — non-awaited
self.db_escapes = [] # self._session_db._db.<x> sites (lineno)
def visit_Call(self, node: ast.Call):
func = node.func
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute):
awaited = {id(n.value) for n in ast.walk(tree)
if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)}
for node in ast.walk(tree):
if not isinstance(node, ast.Call):
continue
func = node.func
if not (isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute)):
continue
inner = func.value
# self._session_db._db.<method>(...) -> sync escape
if (
@ -155,21 +166,19 @@ class _RawCallVisitor(ast.NodeVisitor):
and inner.value.value.id == "self"
):
self.db_escapes.append(inner.lineno)
# self._session_db.<method>(...) -> raw loop call
# self._session_db.<method>(...) not wrapped in await -> raw loop call
elif (
inner.attr == "_session_db"
and isinstance(inner.value, ast.Name)
and inner.value.id == "self"
and id(node) not in awaited
):
self.raw_calls.append((func.attr, node.lineno))
self.generic_visit(node)
def _scan(rel_path: str) -> _RawCallVisitor:
source = (_repo_root() / rel_path).read_text(encoding="utf-8")
visitor = _RawCallVisitor()
visitor.visit(ast.parse(source))
return visitor
return _RawCallVisitor(ast.parse(source))
def test_no_raw_session_db_calls_on_gateway_loop():
@ -189,15 +198,16 @@ def test_no_raw_session_db_calls_on_gateway_loop():
)
def test_sync_db_escape_confined_to_construction():
"""The self._session_db._db. sync escape must stay confined to one site.
def test_sync_db_escape_confined_to_off_loop_sites():
"""The self._session_db._db. sync escape must stay confined to known sites.
It is legitimate only at construction (before the loop serves traffic).
More than one occurrence means a blocking call leaked back onto the loop
through the escape hatch.
It is legitimate only where the call is provably off the loop: construction
(before the loop serves) and the run_sync executor closure. More occurrences
than the reviewed count means a blocking call may have leaked back onto the
loop through the escape hatch.
"""
total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES)
assert total <= _ALLOWED_SYNC_DB_ESCAPES, (
f"self._session_db._db. sync escape used {total} times; "
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction only) is allowed."
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed."
)

View file

@ -5,13 +5,14 @@ The Discord gateway heartbeat was stalling because the handoff watcher
SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s
('Shard ID None heartbeat blocked for more than N seconds').
The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside
the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a
worker thread and never blocks the event loop / Discord heartbeat.
The fix routes every blocking ``SessionDB`` call in the watcher through the
``AsyncSessionDB`` facade, which offloads each call via ``asyncio.to_thread`` so
the SQLite I/O runs on a worker thread and never blocks the event loop / Discord
heartbeat.
These tests assert that behaviour contract. They are mutation-survivable:
reverting any ``asyncio.to_thread(self._session_db.<call>)`` wrap back to a
direct synchronous call on the loop makes the relevant assertion fail.
reverting any ``await self._session_db.<call>(...)`` back to a direct synchronous
call on the loop makes the relevant assertion fail.
"""
import asyncio
@ -62,9 +63,15 @@ class _RecordingSessionDB:
def _make_fake_runner(session_db, *, fail_process=False):
"""Build a minimal object that exposes exactly what the loop body touches."""
"""Build a minimal object that exposes exactly what the loop body touches.
The watcher now talks to the SessionDB through the AsyncSessionDB facade,
so wrap the recording stand-in the same way the gateway does.
"""
from hermes_state import AsyncSessionDB
fake = types.SimpleNamespace()
fake._session_db = session_db
fake._session_db = AsyncSessionDB(session_db)
# _running yields True for the first loop check, then False so the loop
# exits after a single tick.
states = iter([True, False])
@ -141,21 +148,23 @@ async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch):
async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch):
"""Explicitly assert the offload goes through asyncio.to_thread.
Patches ``run.asyncio.to_thread`` and records which SessionDB callables
were handed to it. Mutation-survivable: dropping any wrap removes its
callable from the recorded set.
Patches the AsyncSessionDB facade's ``asyncio.to_thread`` (it lives in
hermes_state) and records which SessionDB callables were handed to it.
Mutation-survivable: dropping any await removes its callable from the set.
"""
import hermes_state
db = _RecordingSessionDB(loop_thread_ident=-1)
fake = _make_fake_runner(db, fail_process=False)
wrapped = []
real_to_thread = run.asyncio.to_thread
real_to_thread = hermes_state.asyncio.to_thread
async def _spy_to_thread(func, *args, **kwargs):
wrapped.append(getattr(func, "__name__", repr(func)))
return await real_to_thread(func, *args, **kwargs)
monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread)
monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy_to_thread)
await _run_one_tick(fake, monkeypatch)

View file

@ -12,6 +12,7 @@ import pytest
from gateway.config import GatewayConfig, Platform, PlatformConfig
from gateway.platforms.base import MessageEvent
from hermes_state import AsyncSessionDB
from gateway.session import (
SessionContext,
SessionEntry,
@ -343,16 +344,16 @@ def _make_runner(current_source: SessionSource, entries: list[SessionEntry]):
runner._clear_session_boundary_security_state = MagicMock()
runner._evict_cached_agent = MagicMock()
runner._queue_depth = MagicMock(return_value=0)
runner._session_db = MagicMock()
runner._session_db.list_sessions_rich.return_value = [
runner._session_db = AsyncSessionDB(MagicMock())
runner._session_db._db.list_sessions_rich.return_value = [
{"id": entry.session_id, "title": entry.display_name, "preview": ""}
for entry in entries
]
runner._session_db.resolve_resume_session_id.side_effect = lambda sid: sid
runner._session_db.get_session_title.side_effect = lambda sid: {
runner._session_db._db.resolve_resume_session_id.side_effect = lambda sid: sid
runner._session_db._db.get_session_title.side_effect = lambda sid: {
entry.session_id: entry.display_name for entry in entries
}.get(sid)
runner._session_db.get_session.return_value = None
runner._session_db._db.get_session.return_value = None
return runner
@ -388,7 +389,7 @@ async def test_matrix_resume_does_not_cross_rooms_by_default():
entry_a = _entry(source_a, "session-a", "Project A Plan")
entry_b = _entry(source_b, "session-b", "Project B Plan")
runner = _make_runner(source_b, [entry_a, entry_b])
runner._session_db.resolve_session_by_title.return_value = "session-a"
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
result = await runner._handle_resume_command(_event("/resume Project A Plan", source_b))
@ -406,7 +407,7 @@ async def test_matrix_resume_allows_same_room_session():
source_b, "session-b-current", "Current Project B"
)
runner.session_store.switch_session.return_value = entry_b
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
result = await runner._handle_resume_command(_event("/resume Project B Plan", source_b))
@ -423,14 +424,14 @@ async def test_matrix_resume_quoted_title_same_room():
source_b, "session-b-current", "Current Project B"
)
runner.session_store.switch_session.return_value = entry_b
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
result = await runner._handle_resume_command(
_event('/resume "Project B Plan"', source_b)
)
assert "Resumed session" in result
runner._session_db.resolve_session_by_title.assert_called_once_with("Project B Plan")
runner._session_db._db.resolve_session_by_title.assert_called_once_with("Project B Plan")
@pytest.mark.asyncio
@ -440,7 +441,7 @@ async def test_matrix_resume_quoted_title_cross_room_blocked():
entry_a = _entry(source_a, "session-a", "Project A Plan")
entry_b = _entry(source_b, "session-b", "Project B Plan")
runner = _make_runner(source_b, [entry_a, entry_b])
runner._session_db.resolve_session_by_title.return_value = "session-a"
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
result = await runner._handle_resume_command(
_event('/resume "Project A Plan"', source_b)
@ -471,7 +472,7 @@ async def test_matrix_resume_cross_room_requires_explicit_flag_and_warns():
entry_b = _entry(source_b, "session-b", "Project B Plan")
runner = _make_runner(source_b, [entry_a, entry_b])
runner.session_store.switch_session.return_value = entry_a
runner._session_db.resolve_session_by_title.return_value = "session-a"
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
result = await runner._handle_resume_command(
_event("/resume --cross-room Project A Plan", source_b)

View file

@ -1,3 +1,4 @@
from hermes_state import AsyncSessionDB
"""Regression tests for approval-state cleanup on session boundaries."""
from datetime import datetime
@ -86,9 +87,9 @@ def _make_resume_runner():
runner.session_store.get_or_create_session.return_value = current_entry
runner.session_store.switch_session.return_value = resumed_entry
runner.session_store.load_transcript.return_value = []
runner._session_db = MagicMock()
runner._session_db.resolve_session_by_title.return_value = "resumed-session"
runner._session_db.get_session_title.return_value = "Resumed Work"
runner._session_db = AsyncSessionDB(MagicMock())
runner._session_db._db.resolve_session_by_title.return_value = "resumed-session"
runner._session_db._db.get_session_title.return_value = "Resumed Work"
return runner, session_key
@ -116,9 +117,9 @@ def _make_branch_runner():
{"role": "assistant", "content": "world"},
]
runner.session_store.switch_session.return_value = branched_entry
runner._session_db = MagicMock()
runner._session_db.get_session_title.return_value = "Current Work"
runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2"
runner._session_db = AsyncSessionDB(MagicMock())
runner._session_db._db.get_session_title.return_value = "Current Work"
runner._session_db._db.get_next_title_in_lineage.return_value = "Current Work #2"
return runner, session_key
@ -208,7 +209,7 @@ async def test_branch_preserves_persisted_assistant_metadata():
result = await runner._handle_branch_command(_make_event("/branch"))
assert "Branched to" in result
append_calls = runner._session_db.append_message.call_args_list
append_calls = runner._session_db._db.append_message.call_args_list
assert len(append_calls) == 2
assistant_kwargs = append_calls[1].kwargs
assert assistant_kwargs["role"] == "assistant"

View file

@ -171,8 +171,12 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
# Start first message (will block at barrier)
task1 = asyncio.create_task(runner._handle_message(event1))
# Yield so task1 enters slow_inner and sentinel is set
await asyncio.sleep(0)
# Yield until task1 has claimed the sentinel (it crosses a few awaits
# before the claim; don't assume a fixed number of scheduler slices).
for _ in range(50):
await asyncio.sleep(0)
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
break
# Verify sentinel is set
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
@ -417,7 +421,10 @@ async def test_stop_during_sentinel_force_cleans_session():
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
task1 = asyncio.create_task(runner._handle_message(event1))
await asyncio.sleep(0)
for _ in range(50):
await asyncio.sleep(0)
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
break
# Sentinel should be set
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL

View file

@ -1,3 +1,4 @@
from hermes_state import AsyncSessionDB
"""Tests for gateway /status behavior and token persistence."""
from datetime import datetime
@ -53,11 +54,11 @@ def _make_runner(session_entry: SessionEntry, *, platform: Platform = Platform.T
runner._session_run_generation = {}
runner._pending_messages = {}
runner._pending_approvals = {}
runner._session_db = MagicMock()
runner._session_db.get_session_title.return_value = None
runner._session_db = AsyncSessionDB(MagicMock())
runner._session_db._db.get_session_title.return_value = None
# Default: no DB row → /status reports 0 tokens. Tests that exercise
# the populated path override this.
runner._session_db.get_session.return_value = None
runner._session_db._db.get_session.return_value = None
runner._reasoning_config = None
runner._provider_routing = {}
runner._fallback_model = None
@ -86,7 +87,7 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
)
runner = _make_runner(session_entry)
# Token total comes from the SQLite SessionDB, not SessionEntry.
runner._session_db.get_session.return_value = {
runner._session_db._db.get_session.return_value = {
"input_tokens": 200,
"output_tokens": 121,
"cache_read_tokens": 0,
@ -118,7 +119,7 @@ async def test_status_command_includes_session_title_when_present():
total_tokens=321,
)
runner = _make_runner(session_entry)
runner._session_db.get_session_title.return_value = "My titled session"
runner._session_db._db.get_session_title.return_value = "My titled session"
result = await runner._handle_message(_make_event("/status"))
@ -141,7 +142,7 @@ async def test_status_command_reads_token_totals_from_session_db():
total_tokens=0, # SessionEntry never gets written to — always 0.
)
runner = _make_runner(session_entry)
runner._session_db.get_session.return_value = {
runner._session_db._db.get_session.return_value = {
"input_tokens": 1000,
"output_tokens": 250,
"cache_read_tokens": 500,
@ -169,7 +170,7 @@ async def test_status_command_tokens_zero_when_session_db_row_missing():
total_tokens=999, # This should be ignored.
)
runner = _make_runner(session_entry)
runner._session_db.get_session.return_value = None
runner._session_db._db.get_session.return_value = None
result = await runner._handle_message(_make_event("/status"))
@ -188,7 +189,7 @@ async def test_status_command_includes_live_agent_model_and_context():
total_tokens=0,
)
runner = _make_runner(session_entry)
runner._session_db.get_session.return_value = {
runner._session_db._db.get_session.return_value = {
"input_tokens": 1000,
"output_tokens": 250,
"cache_read_tokens": 0,
@ -228,7 +229,7 @@ async def test_status_command_includes_persisted_model_and_context_when_agent_no
last_prompt_tokens=24_000,
)
runner = _make_runner(session_entry)
runner._session_db.get_session.return_value = {
runner._session_db._db.get_session.return_value = {
"input_tokens": 2000,
"output_tokens": 500,
"cache_read_tokens": 0,

View file

@ -1,3 +1,4 @@
from hermes_state import AsyncSessionDB
"""Tests for gateway /usage command — agent cache lookup and output fields."""
import threading
@ -197,8 +198,8 @@ class TestUsageAccountSection:
@pytest.mark.asyncio
async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
runner = _make_runner(SK)
runner._session_db = MagicMock()
runner._session_db.get_session.return_value = {
runner._session_db = AsyncSessionDB(MagicMock())
runner._session_db._db.get_session.return_value = {
"billing_provider": "openai-codex",
"billing_base_url": "https://chatgpt.com/backend-api/codex",
}