mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-01 12:02:05 +00:00
402 lines
16 KiB
Python
402 lines
16 KiB
Python
"""AsyncSessionDB offload facade + gateway raw-call guard.
|
|
|
|
The gateway runs one asyncio loop for every session; SessionDB is synchronous,
|
|
so a raw call on the loop freezes every conversation until it returns.
|
|
AsyncSessionDB offloads each call via asyncio.to_thread. These tests pin the
|
|
facade's contract and lock the gateway boundary so a 39th raw call can't regress.
|
|
"""
|
|
|
|
import ast
|
|
import asyncio
|
|
import threading
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
import hermes_state
|
|
from hermes_state import AsyncSessionDB
|
|
|
|
|
|
class _SpyDB:
|
|
"""SessionDB stand-in recording the thread each call ran on."""
|
|
|
|
def __init__(self):
|
|
self.calls = []
|
|
self.attr = "plain-value"
|
|
|
|
def _ran_on(self, name):
|
|
self.calls.append((name, threading.get_ident()))
|
|
|
|
def returns_none(self):
|
|
self._ran_on("returns_none")
|
|
return None
|
|
|
|
def returns_bool(self):
|
|
self._ran_on("returns_bool")
|
|
return True
|
|
|
|
def returns_str(self):
|
|
self._ran_on("returns_str")
|
|
return "title"
|
|
|
|
def returns_dict(self):
|
|
self._ran_on("returns_dict")
|
|
return {"id": "s1"}
|
|
|
|
def returns_list(self):
|
|
self._ran_on("returns_list")
|
|
return [{"id": "s1"}, {"id": "s2"}]
|
|
|
|
def raises(self):
|
|
self._ran_on("raises")
|
|
raise ValueError("boom")
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Facade behaviour
|
|
# --------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_offloads_off_calling_thread():
|
|
"""A call must execute on a worker thread, not the caller's loop thread."""
|
|
db = _SpyDB()
|
|
facade = AsyncSessionDB(db)
|
|
caller_ident = threading.get_ident()
|
|
|
|
await facade.returns_none()
|
|
|
|
ran_idents = [ident for _name, ident in db.calls]
|
|
assert ran_idents and all(i != caller_ident for i in ran_idents)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_offload_goes_through_to_thread(monkeypatch):
|
|
"""The offload must route through asyncio.to_thread (where the facade lives)."""
|
|
db = _SpyDB()
|
|
facade = AsyncSessionDB(db)
|
|
|
|
seen = []
|
|
real = asyncio.to_thread
|
|
|
|
async def _spy(func, *args, **kwargs):
|
|
seen.append(getattr(func, "__name__", repr(func)))
|
|
return await real(func, *args, **kwargs)
|
|
|
|
monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy)
|
|
await facade.returns_str()
|
|
assert "returns_str" in seen
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize(
|
|
"method,expected",
|
|
[
|
|
("returns_none", None),
|
|
("returns_bool", True),
|
|
("returns_str", "title"),
|
|
("returns_dict", {"id": "s1"}),
|
|
("returns_list", [{"id": "s1"}, {"id": "s2"}]),
|
|
],
|
|
)
|
|
async def test_returns_underlying_value_unchanged(method, expected):
|
|
facade = AsyncSessionDB(_SpyDB())
|
|
assert await getattr(facade, method)() == expected
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_propagates_exception():
|
|
facade = AsyncSessionDB(_SpyDB())
|
|
with pytest.raises(ValueError, match="boom"):
|
|
await facade.raises()
|
|
|
|
|
|
def test_non_callable_attribute_passes_through():
|
|
facade = AsyncSessionDB(_SpyDB())
|
|
assert facade.attr == "plain-value"
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Guard: no raw self._session_db.<method>( on the gateway loop
|
|
# --------------------------------------------------------------------------
|
|
|
|
_GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py")
|
|
# The only legitimate non-loop paths:
|
|
# - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB.
|
|
# - self._session_db._db.<x>: the sync escape, allowed ONLY where the call is
|
|
# provably off the event loop — construction (__init__, before the loop
|
|
# serves) and the run_sync closure (executed in a thread-pool executor).
|
|
# Three such sites today; a fourth must be justified and this count bumped.
|
|
_ALLOWED_SYNC_DB_ESCAPES = 3
|
|
|
|
# Sync helpers that touch SessionDB but are NEVER invoked bare on the loop:
|
|
# every loop-side call wraps them in ``asyncio.to_thread(...)`` and the only
|
|
# bare calls live in the run_sync thread-pool closure. Their DB calls therefore
|
|
# run off-loop. The guard exempts their bodies AND enforces the contract — see
|
|
# test_offloaded_helpers_never_called_bare_on_loop. Adding a helper here without
|
|
# wrapping its loop call sites makes that test fail.
|
|
_OFFLOADED_SYNC_HELPERS = frozenset({
|
|
"_telegram_topic_mode_enabled",
|
|
"_is_telegram_topic_lane",
|
|
"_is_telegram_topic_root_lobby",
|
|
"_recover_telegram_topic_thread_id",
|
|
"_normalize_source_for_session_key",
|
|
"_record_telegram_topic_binding",
|
|
"_sync_telegram_topic_binding",
|
|
"_telegram_topic_new_header",
|
|
"_schedule_telegram_topic_title_rename",
|
|
"_apply_topic_recovery",
|
|
})
|
|
|
|
|
|
def _repo_root() -> Path:
|
|
return Path(__file__).resolve().parents[2]
|
|
|
|
|
|
class _RawCallVisitor:
|
|
"""Collect non-awaited SessionDB calls reachable on the gateway loop.
|
|
|
|
Catches both shapes:
|
|
* direct: self._session_db.<method>(...)
|
|
* aliased: db = getattr(self, "_session_db", None) / db = self._session_db
|
|
then db.<method>(...)
|
|
An ``await x.y()`` is Await(value=Call(...)); those Calls are exempt (the
|
|
migrated path). The self._session_db._db.<x> sync escape is counted
|
|
separately. SessionDB.sanitize_title is a staticmethod called on the class,
|
|
so it never matches either shape.
|
|
|
|
Alias detection scans, per function scope, for locals bound to the gateway's
|
|
_session_db (incl. closures that bind it off a captured ``self``-like param),
|
|
then flags non-awaited calls on those names. The literal-grep blind spot that
|
|
let six loop-reachable calls hide behind ``getattr(self, "_session_db")`` is
|
|
exactly what this closes.
|
|
"""
|
|
|
|
def __init__(self, tree: ast.AST):
|
|
self.raw_calls = [] # (method, lineno) — direct, non-awaited, on-loop
|
|
self.alias_calls = [] # (method, lineno) — via a _session_db-bound local, on-loop
|
|
self.db_escapes = [] # self._session_db._db.<x> sites (lineno)
|
|
# BARE self.<helper>(...) call sites of offloaded helpers — i.e. the
|
|
# helper is actually *called*, not passed to asyncio.to_thread (which
|
|
# references it as an attribute, producing no Call node here). Each is
|
|
# (helper, lineno, enclosing_fn) for the contract test.
|
|
self.bare_helper_calls = []
|
|
|
|
awaited = {id(n.value) for n in ast.walk(tree)
|
|
if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)}
|
|
alias_names = self._collect_alias_names(tree)
|
|
# Map each node to the name of the function whose body lexically encloses
|
|
# it, so DB calls inside an offloaded helper (which runs off-loop) are
|
|
# exempt while bare on-loop calls are not.
|
|
enclosing = self._enclosing_fn_map(tree)
|
|
ancestry = self._ancestor_fns(tree) # id(node) -> frozenset of enclosing fn names
|
|
|
|
for node in ast.walk(tree):
|
|
if not isinstance(node, ast.Call):
|
|
continue
|
|
func = node.func
|
|
if not isinstance(func, ast.Attribute):
|
|
continue
|
|
encl_fn = enclosing.get(id(node))
|
|
in_offloaded_helper = encl_fn in _OFFLOADED_SYNC_HELPERS
|
|
# Bare call of an offloaded helper (self._helper(...)). A to_thread
|
|
# offload passes the helper as an attribute arg, not a Call, so it
|
|
# never lands here — exactly the distinction the contract test needs.
|
|
if (
|
|
isinstance(func.value, ast.Name) and func.value.id == "self"
|
|
and func.attr in _OFFLOADED_SYNC_HELPERS
|
|
):
|
|
self.bare_helper_calls.append(
|
|
(func.attr, node.lineno, ancestry.get(id(node), frozenset()))
|
|
)
|
|
# alias.<method>(...) -> aliased loop call (var bound to _session_db)
|
|
if (
|
|
isinstance(func.value, ast.Name)
|
|
and func.value.id in alias_names
|
|
and func.attr not in ("_db",)
|
|
and id(node) not in awaited
|
|
and not in_offloaded_helper
|
|
):
|
|
self.alias_calls.append((func.attr, node.lineno))
|
|
continue
|
|
if not isinstance(func.value, ast.Attribute):
|
|
continue
|
|
inner = func.value
|
|
# self._session_db._db.<method>(...) -> sync escape
|
|
if (
|
|
inner.attr == "_db"
|
|
and isinstance(inner.value, ast.Attribute)
|
|
and inner.value.attr == "_session_db"
|
|
and isinstance(inner.value.value, ast.Name)
|
|
and inner.value.value.id == "self"
|
|
):
|
|
self.db_escapes.append(inner.lineno)
|
|
# self._session_db.<method>(...) not wrapped in await -> raw loop call
|
|
elif (
|
|
inner.attr == "_session_db"
|
|
and isinstance(inner.value, ast.Name)
|
|
and inner.value.id == "self"
|
|
and id(node) not in awaited
|
|
and not in_offloaded_helper
|
|
):
|
|
self.raw_calls.append((func.attr, node.lineno))
|
|
|
|
@staticmethod
|
|
def _enclosing_fn_map(tree: ast.AST) -> dict:
|
|
"""Map id(node) -> name of the nearest lexically-enclosing function."""
|
|
out = {}
|
|
|
|
def walk(node, fn_name):
|
|
this_fn = fn_name
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
this_fn = node.name
|
|
for child in ast.iter_child_nodes(node):
|
|
out[id(child)] = this_fn
|
|
walk(child, this_fn)
|
|
|
|
walk(tree, None)
|
|
return out
|
|
|
|
@staticmethod
|
|
def _ancestor_fns(tree: ast.AST) -> dict:
|
|
"""Map id(node) -> frozenset of ALL enclosing function names (any depth)."""
|
|
out = {}
|
|
|
|
def walk(node, stack):
|
|
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
stack = stack + (node.name,)
|
|
for child in ast.iter_child_nodes(node):
|
|
out[id(child)] = frozenset(stack)
|
|
walk(child, stack)
|
|
|
|
walk(tree, ())
|
|
return out
|
|
|
|
@staticmethod
|
|
def _is_session_db_source(value: ast.AST) -> bool:
|
|
"""True if an assignment RHS resolves to <obj>._session_db.
|
|
|
|
Matches both ``<obj>._session_db`` and ``getattr(<obj>, "_session_db", ...)``
|
|
where <obj> is any Name (covers ``self`` and captured closure params like
|
|
``_self``). Excludes the ``._db`` sync handle.
|
|
"""
|
|
if isinstance(value, ast.Attribute):
|
|
return value.attr == "_session_db" and isinstance(value.value, ast.Name)
|
|
if (
|
|
isinstance(value, ast.Call)
|
|
and isinstance(value.func, ast.Name)
|
|
and value.func.id == "getattr"
|
|
and len(value.args) >= 2
|
|
and isinstance(value.args[1], ast.Constant)
|
|
and value.args[1].value == "_session_db"
|
|
):
|
|
return True
|
|
return False
|
|
|
|
@classmethod
|
|
def _collect_alias_names(cls, tree: ast.AST) -> set:
|
|
names = set()
|
|
for node in ast.walk(tree):
|
|
if isinstance(node, ast.Assign) and cls._is_session_db_source(node.value):
|
|
for tgt in node.targets:
|
|
if isinstance(tgt, ast.Name):
|
|
names.add(tgt.id)
|
|
elif isinstance(node, ast.AnnAssign) and node.value is not None \
|
|
and cls._is_session_db_source(node.value) \
|
|
and isinstance(node.target, ast.Name):
|
|
names.add(node.target.id)
|
|
return names
|
|
|
|
|
|
def _scan(rel_path: str) -> _RawCallVisitor:
|
|
source = (_repo_root() / rel_path).read_text(encoding="utf-8")
|
|
return _RawCallVisitor(ast.parse(source))
|
|
|
|
|
|
def test_no_raw_session_db_calls_on_gateway_loop():
|
|
"""Fail if any non-awaited SessionDB call appears in gateway files.
|
|
|
|
Every loop-reachable DB call must go through AsyncSessionDB (await), whether
|
|
spelled directly (self._session_db.<method>(...)) or via a local alias
|
|
(db = getattr(self, "_session_db", None); db.<method>(...)). The
|
|
sanitize_title staticmethod is called on the class, not self/an alias, so it
|
|
is not matched; the _db. sync escape is checked separately below.
|
|
"""
|
|
violations = []
|
|
for rel in _GATEWAY_FILES:
|
|
v = _scan(rel)
|
|
violations.extend(f"{rel}:{ln} self._session_db.{m}(" for m, ln in v.raw_calls)
|
|
violations.extend(f"{rel}:{ln} <alias>.{m}( (binds _session_db)" for m, ln in v.alias_calls)
|
|
assert not violations, (
|
|
"Non-awaited SessionDB calls on the gateway loop — route through "
|
|
"AsyncSessionDB (await ...):\n " + "\n ".join(violations)
|
|
)
|
|
|
|
|
|
def test_sync_db_escape_confined_to_off_loop_sites():
|
|
"""The self._session_db._db. sync escape must stay confined to known sites.
|
|
|
|
It is legitimate only where the call is provably off the loop: construction
|
|
(before the loop serves) and the run_sync executor closure. More occurrences
|
|
than the reviewed count means a blocking call may have leaked back onto the
|
|
loop through the escape hatch.
|
|
"""
|
|
total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES)
|
|
assert total <= _ALLOWED_SYNC_DB_ESCAPES, (
|
|
f"self._session_db._db. sync escape used {total} times; "
|
|
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed."
|
|
)
|
|
|
|
|
|
def test_offloaded_helpers_never_called_bare_on_loop():
|
|
"""The offloaded sync helpers must never be called bare on the event loop.
|
|
|
|
They touch SessionDB synchronously, so a bare ``self._helper(...)`` on the
|
|
loop would freeze it. The contract: loop-side callers wrap them in
|
|
``await asyncio.to_thread(self._helper, ...)`` (which references the helper
|
|
as an attribute — no Call node — so it never appears here). A bare call is
|
|
only legitimate when it runs off-loop: inside the ``run_sync`` thread-pool
|
|
closure, or inside another offloaded helper (sync->sync, same thread). Any
|
|
other bare call means a helper whose body the guard exempts is being invoked
|
|
on the loop anyway — re-freezing the loop through the exemption.
|
|
"""
|
|
off_loop_ok = _OFFLOADED_SYNC_HELPERS | {"run_sync"}
|
|
violations = []
|
|
for rel in _GATEWAY_FILES:
|
|
v = _scan(rel)
|
|
for helper, ln, ancestors in v.bare_helper_calls:
|
|
if not (ancestors & off_loop_ok):
|
|
violations.append(f"{rel}:{ln} bare self.{helper}( on the loop")
|
|
assert not violations, (
|
|
"Offloaded sync helper called bare on the gateway loop — wrap in "
|
|
"await asyncio.to_thread(self.<helper>, ...):\n " + "\n ".join(violations)
|
|
)
|
|
|
|
|
|
# --------------------------------------------------------------------------
|
|
# Interleaving safety: offloading opens await points where coroutines can
|
|
# interleave against the same session rows. The gateway relies on SessionDB's
|
|
# atomic operations (compare-and-set, INSERT OR IGNORE) to stay single-winner.
|
|
# These pin that the defenses hold when driven concurrently through the facade.
|
|
# --------------------------------------------------------------------------
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_claim_handoff_single_winner(tmp_path):
|
|
db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db"))
|
|
sid = "s-handoff"
|
|
await db.create_session(sid, "test")
|
|
await db.request_handoff(sid, "telegram")
|
|
|
|
results = await asyncio.gather(*(db.claim_handoff(sid) for _ in range(20)))
|
|
|
|
assert sum(results) == 1, f"exactly one claim must win, got {sum(results)}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_create_session_idempotent(tmp_path):
|
|
db = AsyncSessionDB(hermes_state.SessionDB(db_path=tmp_path / "state.db"))
|
|
sid = "s-create"
|
|
|
|
await asyncio.gather(*(db.create_session(sid, "test") for _ in range(20)))
|
|
|
|
rows = await db.list_sessions_rich(limit=100)
|
|
assert sum(1 for r in rows if r["id"] == sid) == 1
|