"""AsyncSessionDB offload facade + gateway raw-call guard. The gateway runs one asyncio loop for every session; SessionDB is synchronous, so a raw call on the loop freezes every conversation until it returns. AsyncSessionDB offloads each call via asyncio.to_thread. These tests pin the facade's contract and lock the gateway boundary so a 39th raw call can't regress. """ import ast import asyncio import threading from pathlib import Path import pytest import hermes_state from hermes_state import AsyncSessionDB class _SpyDB: """SessionDB stand-in recording the thread each call ran on.""" def __init__(self): self.calls = [] self.attr = "plain-value" def _ran_on(self, name): self.calls.append((name, threading.get_ident())) def returns_none(self): self._ran_on("returns_none") return None def returns_bool(self): self._ran_on("returns_bool") return True def returns_str(self): self._ran_on("returns_str") return "title" def returns_dict(self): self._ran_on("returns_dict") return {"id": "s1"} def returns_list(self): self._ran_on("returns_list") return [{"id": "s1"}, {"id": "s2"}] def raises(self): self._ran_on("raises") raise ValueError("boom") # -------------------------------------------------------------------------- # Facade behaviour # -------------------------------------------------------------------------- @pytest.mark.asyncio async def test_offloads_off_calling_thread(): """A call must execute on a worker thread, not the caller's loop thread.""" db = _SpyDB() facade = AsyncSessionDB(db) caller_ident = threading.get_ident() await facade.returns_none() ran_idents = [ident for _name, ident in db.calls] assert ran_idents and all(i != caller_ident for i in ran_idents) @pytest.mark.asyncio async def test_offload_goes_through_to_thread(monkeypatch): """The offload must route through asyncio.to_thread (where the facade lives).""" db = _SpyDB() facade = AsyncSessionDB(db) seen = [] real = asyncio.to_thread async def _spy(func, *args, **kwargs): seen.append(getattr(func, "__name__", repr(func))) return await real(func, *args, **kwargs) monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy) await facade.returns_str() assert "returns_str" in seen @pytest.mark.asyncio @pytest.mark.parametrize( "method,expected", [ ("returns_none", None), ("returns_bool", True), ("returns_str", "title"), ("returns_dict", {"id": "s1"}), ("returns_list", [{"id": "s1"}, {"id": "s2"}]), ], ) async def test_returns_underlying_value_unchanged(method, expected): facade = AsyncSessionDB(_SpyDB()) assert await getattr(facade, method)() == expected @pytest.mark.asyncio async def test_propagates_exception(): facade = AsyncSessionDB(_SpyDB()) with pytest.raises(ValueError, match="boom"): await facade.raises() def test_non_callable_attribute_passes_through(): facade = AsyncSessionDB(_SpyDB()) assert facade.attr == "plain-value" # -------------------------------------------------------------------------- # Guard: no raw self._session_db.( on the gateway loop # -------------------------------------------------------------------------- _GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py") # The only legitimate non-loop paths: # - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB. # - self._session_db._db.: the sync escape, allowed ONLY where the call is # provably off the event loop — construction (__init__, before the loop # serves) and the run_sync closure (executed in a thread-pool executor). # Three such sites today; a fourth must be justified and this count bumped. _ALLOWED_SYNC_DB_ESCAPES = 3 # Sync helpers that touch SessionDB but are NEVER invoked bare on the loop: # every loop-side call wraps them in ``asyncio.to_thread(...)`` and the only # bare calls live in the run_sync thread-pool closure. Their DB calls therefore # run off-loop. The guard exempts their bodies AND enforces the contract — see # test_offloaded_helpers_never_called_bare_on_loop. Adding a helper here without # wrapping its loop call sites makes that test fail. _OFFLOADED_SYNC_HELPERS = frozenset({ "_telegram_topic_mode_enabled", "_is_telegram_topic_lane", "_is_telegram_topic_root_lobby", "_recover_telegram_topic_thread_id", "_normalize_source_for_session_key", "_record_telegram_topic_binding", "_sync_telegram_topic_binding", "_telegram_topic_new_header", "_schedule_telegram_topic_title_rename", "_apply_topic_recovery", }) def _repo_root() -> Path: return Path(__file__).resolve().parents[2] class _RawCallVisitor: """Collect non-awaited SessionDB calls reachable on the gateway loop. Catches both shapes: * direct: self._session_db.(...) * aliased: db = getattr(self, "_session_db", None) / db = self._session_db then db.(...) An ``await x.y()`` is Await(value=Call(...)); those Calls are exempt (the migrated path). The self._session_db._db. sync escape is counted separately. SessionDB.sanitize_title is a staticmethod called on the class, so it never matches either shape. Alias detection scans, per function scope, for locals bound to the gateway's _session_db (incl. closures that bind it off a captured ``self``-like param), then flags non-awaited calls on those names. The literal-grep blind spot that let six loop-reachable calls hide behind ``getattr(self, "_session_db")`` is exactly what this closes. """ def __init__(self, tree: ast.AST): self.raw_calls = [] # (method, lineno) — direct, non-awaited, on-loop self.alias_calls = [] # (method, lineno) — via a _session_db-bound local, on-loop self.db_escapes = [] # self._session_db._db. sites (lineno) # BARE self.(...) call sites of offloaded helpers — i.e. the # helper is actually *called*, not passed to asyncio.to_thread (which # references it as an attribute, producing no Call node here). Each is # (helper, lineno, enclosing_fn) for the contract test. self.bare_helper_calls = [] awaited = {id(n.value) for n in ast.walk(tree) if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)} alias_names = self._collect_alias_names(tree) # Map each node to the name of the function whose body lexically encloses # it, so DB calls inside an offloaded helper (which runs off-loop) are # exempt while bare on-loop calls are not. enclosing = self._enclosing_fn_map(tree) ancestry = self._ancestor_fns(tree) # id(node) -> frozenset of enclosing fn names for node in ast.walk(tree): if not isinstance(node, ast.Call): continue func = node.func if not isinstance(func, ast.Attribute): continue encl_fn = enclosing.get(id(node)) in_offloaded_helper = encl_fn in _OFFLOADED_SYNC_HELPERS # Bare call of an offloaded helper (self._helper(...)). A to_thread # offload passes the helper as an attribute arg, not a Call, so it # never lands here — exactly the distinction the contract test needs. if ( isinstance(func.value, ast.Name) and func.value.id == "self" and func.attr in _OFFLOADED_SYNC_HELPERS ): self.bare_helper_calls.append( (func.attr, node.lineno, ancestry.get(id(node), frozenset())) ) # alias.(...) -> aliased loop call (var bound to _session_db) if ( isinstance(func.value, ast.Name) and func.value.id in alias_names and func.attr not in ("_db",) and id(node) not in awaited and not in_offloaded_helper ): self.alias_calls.append((func.attr, node.lineno)) continue if not isinstance(func.value, ast.Attribute): continue inner = func.value # self._session_db._db.(...) -> sync escape if ( inner.attr == "_db" and isinstance(inner.value, ast.Attribute) and inner.value.attr == "_session_db" and isinstance(inner.value.value, ast.Name) and inner.value.value.id == "self" ): self.db_escapes.append(inner.lineno) # self._session_db.(...) not wrapped in await -> raw loop call elif ( inner.attr == "_session_db" and isinstance(inner.value, ast.Name) and inner.value.id == "self" and id(node) not in awaited and not in_offloaded_helper ): self.raw_calls.append((func.attr, node.lineno)) @staticmethod def _enclosing_fn_map(tree: ast.AST) -> dict: """Map id(node) -> name of the nearest lexically-enclosing function.""" out = {} def walk(node, fn_name): this_fn = fn_name if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): this_fn = node.name for child in ast.iter_child_nodes(node): out[id(child)] = this_fn walk(child, this_fn) walk(tree, None) return out @staticmethod def _ancestor_fns(tree: ast.AST) -> dict: """Map id(node) -> frozenset of ALL enclosing function names (any depth).""" out = {} def walk(node, stack): if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): stack = stack + (node.name,) for child in ast.iter_child_nodes(node): out[id(child)] = frozenset(stack) walk(child, stack) walk(tree, ()) return out @staticmethod def _is_session_db_source(value: ast.AST) -> bool: """True if an assignment RHS resolves to ._session_db. Matches both ``._session_db`` and ``getattr(, "_session_db", ...)`` where is any Name (covers ``self`` and captured closure params like ``_self``). Excludes the ``._db`` sync handle. """ if isinstance(value, ast.Attribute): return value.attr == "_session_db" and isinstance(value.value, ast.Name) if ( isinstance(value, ast.Call) and isinstance(value.func, ast.Name) and value.func.id == "getattr" and len(value.args) >= 2 and isinstance(value.args[1], ast.Constant) and value.args[1].value == "_session_db" ): return True return False @classmethod def _collect_alias_names(cls, tree: ast.AST) -> set: names = set() for node in ast.walk(tree): if isinstance(node, ast.Assign) and cls._is_session_db_source(node.value): for tgt in node.targets: if isinstance(tgt, ast.Name): names.add(tgt.id) elif isinstance(node, ast.AnnAssign) and node.value is not None \ and cls._is_session_db_source(node.value) \ and isinstance(node.target, ast.Name): names.add(node.target.id) return names def _scan(rel_path: str) -> _RawCallVisitor: source = (_repo_root() / rel_path).read_text(encoding="utf-8") return _RawCallVisitor(ast.parse(source)) def test_no_raw_session_db_calls_on_gateway_loop(): """Fail if any non-awaited SessionDB call appears in gateway files. Every loop-reachable DB call must go through AsyncSessionDB (await), whether spelled directly (self._session_db.(...)) or via a local alias (db = getattr(self, "_session_db", None); db.(...)). The sanitize_title staticmethod is called on the class, not self/an alias, so it is not matched; the _db. sync escape is checked separately below. """ violations = [] for rel in _GATEWAY_FILES: v = _scan(rel) violations.extend(f"{rel}:{ln} self._session_db.{m}(" for m, ln in v.raw_calls) violations.extend(f"{rel}:{ln} .{m}( (binds _session_db)" for m, ln in v.alias_calls) assert not violations, ( "Non-awaited SessionDB calls on the gateway loop — route through " "AsyncSessionDB (await ...):\n " + "\n ".join(violations) ) def test_sync_db_escape_confined_to_off_loop_sites(): """The self._session_db._db. sync escape must stay confined to known sites. It is legitimate only where the call is provably off the loop: construction (before the loop serves) and the run_sync executor closure. More occurrences than the reviewed count means a blocking call may have leaked back onto the loop through the escape hatch. """ total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES) assert total <= _ALLOWED_SYNC_DB_ESCAPES, ( f"self._session_db._db. sync escape used {total} times; " f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed." ) def test_offloaded_helpers_never_called_bare_on_loop(): """The offloaded sync helpers must never be called bare on the event loop. They touch SessionDB synchronously, so a bare ``self._helper(...)`` on the loop would freeze it. The contract: loop-side callers wrap them in ``await asyncio.to_thread(self._helper, ...)`` (which references the helper as an attribute — no Call node — so it never appears here). A bare call is only legitimate when it runs off-loop: inside the ``run_sync`` thread-pool closure, or inside another offloaded helper (sync->sync, same thread). Any other bare call means a helper whose body the guard exempts is being invoked on the loop anyway — re-freezing the loop through the exemption. """ off_loop_ok = _OFFLOADED_SYNC_HELPERS | {"run_sync"} violations = [] for rel in _GATEWAY_FILES: v = _scan(rel) for helper, ln, ancestors in v.bare_helper_calls: if not (ancestors & off_loop_ok): violations.append(f"{rel}:{ln} bare self.{helper}( on the loop") assert not violations, ( "Offloaded sync helper called bare on the gateway loop — wrap in " "await asyncio.to_thread(self., ...):\n " + "\n ".join(violations) ) # -------------------------------------------------------------------------- # Interleaving safety: offloading opens await points where coroutines can # interleave against the same session rows. The gateway relies on SessionDB's # atomic operations (compare-and-set, INSERT OR IGNORE) to stay single-winner. # These pin that the defenses hold when driven concurrently through the facade. # -------------------------------------------------------------------------- @pytest.mark.asyncio async def test_concurrent_claim_handoff_single_winner(tmp_path): db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db")) sid = "s-handoff" await db.create_session(sid, "test") await db.request_handoff(sid, "telegram") results = await asyncio.gather(*(db.claim_handoff(sid) for _ in range(20))) assert sum(results) == 1, f"exactly one claim must win, got {sum(results)}" @pytest.mark.asyncio async def test_concurrent_create_session_idempotent(tmp_path): db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db")) sid = "s-create" await asyncio.gather(*(db.create_session(sid, "test") for _ in range(20))) rows = await db.list_sessions_rich(limit=100) assert sum(1 for r in rows if r["id"] == sid) == 1