mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-01 12:02:05 +00:00
fix(gateway): route SessionDB calls through AsyncSessionDB
This commit is contained in:
parent
ea26f22710
commit
0896facce8
12 changed files with 203 additions and 142 deletions
|
|
@ -264,11 +264,18 @@ def make_adapter(platform: Platform, runner=None):
|
|||
|
||||
|
||||
async def send_and_capture(adapter, text: str, platform: Platform, **event_kwargs) -> AsyncMock:
|
||||
"""Send a message through the full e2e flow and return the send mock."""
|
||||
"""Send a message through the full e2e flow and return the send mock.
|
||||
|
||||
Polls for the send rather than waiting a fixed delay: handler DB work now
|
||||
hops to worker threads (AsyncSessionDB), so completion latency varies.
|
||||
"""
|
||||
event = make_event(platform, text, **event_kwargs)
|
||||
adapter.send.reset_mock()
|
||||
await adapter.handle_message(event)
|
||||
await asyncio.sleep(0.3)
|
||||
for _ in range(40): # up to ~2s; returns as soon as the send lands
|
||||
if adapter.send.called:
|
||||
break
|
||||
await asyncio.sleep(0.05)
|
||||
return adapter.send
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,15 @@ from unittest.mock import MagicMock
|
|||
import pytest
|
||||
|
||||
|
||||
def make_async_session_db(sync_mock=None):
|
||||
"""Wrap a sync mock SessionDB in AsyncSessionDB so gateway code that awaits
|
||||
the facade works in tests. Returns (facade, sync_mock); configure return
|
||||
values and assert calls on sync_mock."""
|
||||
from hermes_state import AsyncSessionDB
|
||||
sync_mock = sync_mock if sync_mock is not None else MagicMock()
|
||||
return AsyncSessionDB(sync_mock), sync_mock
|
||||
|
||||
|
||||
def _ensure_telegram_mock() -> None:
|
||||
"""Install a comprehensive telegram mock in sys.modules.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ Verifies that the agent cache correctly:
|
|||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
|
||||
def _make_runner():
|
||||
|
|
@ -1565,8 +1567,11 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
"""
|
||||
|
||||
def _runner_with_db(self, db):
|
||||
from hermes_state import AsyncSessionDB
|
||||
|
||||
runner = _make_runner()
|
||||
runner._session_db = db
|
||||
# The gateway holds the async facade; the production refresh awaits it.
|
||||
runner._session_db = AsyncSessionDB(db)
|
||||
return runner
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1577,7 +1582,7 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
the cached agent (or either side is None / it's a legacy 2-tuple).
|
||||
"""
|
||||
try:
|
||||
row = runner._session_db.get_session(session_id)
|
||||
row = runner._session_db._db.get_session(session_id)
|
||||
live = row.get("message_count", 0) if row else None
|
||||
except Exception:
|
||||
live = None
|
||||
|
|
@ -1591,7 +1596,8 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
)
|
||||
return not invalidate
|
||||
|
||||
def test_same_process_turns_preserve_cached_agent(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_same_process_turns_preserve_cached_agent(self, tmp_path):
|
||||
"""The regression guard: consecutive same-process turns must REUSE
|
||||
the cached agent (prompt cache preserved), not rebuild every turn.
|
||||
|
||||
|
|
@ -1619,7 +1625,7 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
db.append_message("s1", role="user", content="u")
|
||||
db.append_message("s1", role="assistant", content="a")
|
||||
# Post-turn re-baseline (the fix).
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
# Next turn's guard decision.
|
||||
if self._guard_would_reuse(runner, "telegram:s1", "s1"):
|
||||
reuses += 1
|
||||
|
|
@ -1630,7 +1636,8 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][0] is agent
|
||||
|
||||
def test_cross_process_write_still_invalidates(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_process_write_still_invalidates(self, tmp_path):
|
||||
"""After the re-baseline, a DIFFERENT process appending to the same
|
||||
session must still flip the guard to rebuild (the #45966 fix holds).
|
||||
"""
|
||||
|
|
@ -1650,7 +1657,7 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
# Our own turn + re-baseline -> reuse next turn.
|
||||
db.append_message("s1", role="user", content="u")
|
||||
db.append_message("s1", role="assistant", content="a")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is True
|
||||
|
||||
# ANOTHER process (e.g. the desktop dashboard backend) appends a turn
|
||||
|
|
@ -1660,10 +1667,11 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
# Guard must now reject reuse so the agent rebuilds from fresh disk.
|
||||
assert self._guard_would_reuse(runner, "telegram:s1", "s1") is False
|
||||
|
||||
def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path):
|
||||
"""Re-baseline must never crash and must leave legacy 2-tuples and
|
||||
pending-sentinel entries untouched."""
|
||||
from hermes_state import SessionDB
|
||||
from hermes_state import AsyncSessionDB, SessionDB
|
||||
from gateway.run import _AGENT_PENDING_SENTINEL
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "sessions.db")
|
||||
|
|
@ -1673,24 +1681,24 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
|
||||
# No session_db -> no-op, no crash.
|
||||
runner._session_db = None
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
runner._session_db = db
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
runner._session_db = AsyncSessionDB(db)
|
||||
|
||||
# Falsy session_id -> no-op.
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", None)
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", None)
|
||||
|
||||
# Legacy 2-tuple is left untouched (it opts out of the guard).
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (object(), "sig")
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert len(runner._agent_cache["telegram:s1"]) == 2
|
||||
|
||||
# Pending sentinel entry is left untouched.
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (_AGENT_PENDING_SENTINEL, "sig", 0)
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][0] is _AGENT_PENDING_SENTINEL
|
||||
assert runner._agent_cache["telegram:s1"][2] == 0
|
||||
|
|
@ -1700,10 +1708,10 @@ class TestAgentCacheMessageCountRebaseline:
|
|||
def get_session(self, _sid):
|
||||
raise RuntimeError("db locked")
|
||||
|
||||
runner._session_db = _BoomDB() # type: ignore[assignment]
|
||||
runner._session_db = AsyncSessionDB(_BoomDB()) # type: ignore[assignment]
|
||||
with runner._agent_cache_lock:
|
||||
runner._agent_cache["telegram:s1"] = (object(), "sig", 5)
|
||||
runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
await runner._refresh_agent_cache_message_count("telegram:s1", "s1")
|
||||
with runner._agent_cache_lock:
|
||||
assert runner._agent_cache["telegram:s1"][2] == 5
|
||||
|
||||
|
|
|
|||
|
|
@ -122,29 +122,40 @@ def test_non_callable_attribute_passes_through():
|
|||
_GATEWAY_FILES = ("gateway/run.py", "gateway/slash_commands.py")
|
||||
# The only legitimate non-loop paths:
|
||||
# - SessionDB.sanitize_title: pure @staticmethod string cleaning, no DB.
|
||||
# - self._session_db._db.<x>: the sync escape, allowed ONLY at construction.
|
||||
_ALLOWED_SYNC_DB_ESCAPES = 1 # exactly the maybe_auto_prune call in __init__
|
||||
# - self._session_db._db.<x>: the sync escape, allowed ONLY where the call is
|
||||
# provably off the event loop — construction (__init__, before the loop
|
||||
# serves) and the run_sync closure (executed in a thread-pool executor).
|
||||
# Three such sites today; a fourth must be justified and this count bumped.
|
||||
_ALLOWED_SYNC_DB_ESCAPES = 3
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
return Path(__file__).resolve().parents[2]
|
||||
|
||||
|
||||
class _RawCallVisitor(ast.NodeVisitor):
|
||||
"""Collect calls of the shape self._session_db.<method>(...).
|
||||
class _RawCallVisitor:
|
||||
"""Collect non-awaited self._session_db.<method>(...) calls in a module.
|
||||
|
||||
Whether the call is awaited is irrelevant to the AST node; an Await wraps
|
||||
the Call. We flag the raw shape and separately exempt the _db. escape and
|
||||
the sanitize_title staticmethod (which is called on the class, not self).
|
||||
An ``await x.y()`` parses as Await(value=Call(...)); those Call nodes are
|
||||
exempt — they're the migrated path. We flag only Calls that are NOT directly
|
||||
awaited, and separately count the self._session_db._db.<x> sync escape. The
|
||||
sanitize_title staticmethod is called on the class (SessionDB.sanitize_title),
|
||||
so it never matches the self._session_db.<method> shape.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.raw_calls = [] # (method, lineno)
|
||||
def __init__(self, tree: ast.AST):
|
||||
self.raw_calls = [] # (method, lineno) — non-awaited
|
||||
self.db_escapes = [] # self._session_db._db.<x> sites (lineno)
|
||||
|
||||
def visit_Call(self, node: ast.Call):
|
||||
func = node.func
|
||||
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute):
|
||||
awaited = {id(n.value) for n in ast.walk(tree)
|
||||
if isinstance(n, ast.Await) and isinstance(n.value, ast.Call)}
|
||||
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, ast.Call):
|
||||
continue
|
||||
func = node.func
|
||||
if not (isinstance(func, ast.Attribute) and isinstance(func.value, ast.Attribute)):
|
||||
continue
|
||||
inner = func.value
|
||||
# self._session_db._db.<method>(...) -> sync escape
|
||||
if (
|
||||
|
|
@ -155,21 +166,19 @@ class _RawCallVisitor(ast.NodeVisitor):
|
|||
and inner.value.value.id == "self"
|
||||
):
|
||||
self.db_escapes.append(inner.lineno)
|
||||
# self._session_db.<method>(...) -> raw loop call
|
||||
# self._session_db.<method>(...) not wrapped in await -> raw loop call
|
||||
elif (
|
||||
inner.attr == "_session_db"
|
||||
and isinstance(inner.value, ast.Name)
|
||||
and inner.value.id == "self"
|
||||
and id(node) not in awaited
|
||||
):
|
||||
self.raw_calls.append((func.attr, node.lineno))
|
||||
self.generic_visit(node)
|
||||
|
||||
|
||||
def _scan(rel_path: str) -> _RawCallVisitor:
|
||||
source = (_repo_root() / rel_path).read_text(encoding="utf-8")
|
||||
visitor = _RawCallVisitor()
|
||||
visitor.visit(ast.parse(source))
|
||||
return visitor
|
||||
return _RawCallVisitor(ast.parse(source))
|
||||
|
||||
|
||||
def test_no_raw_session_db_calls_on_gateway_loop():
|
||||
|
|
@ -189,15 +198,16 @@ def test_no_raw_session_db_calls_on_gateway_loop():
|
|||
)
|
||||
|
||||
|
||||
def test_sync_db_escape_confined_to_construction():
|
||||
"""The self._session_db._db. sync escape must stay confined to one site.
|
||||
def test_sync_db_escape_confined_to_off_loop_sites():
|
||||
"""The self._session_db._db. sync escape must stay confined to known sites.
|
||||
|
||||
It is legitimate only at construction (before the loop serves traffic).
|
||||
More than one occurrence means a blocking call leaked back onto the loop
|
||||
through the escape hatch.
|
||||
It is legitimate only where the call is provably off the loop: construction
|
||||
(before the loop serves) and the run_sync executor closure. More occurrences
|
||||
than the reviewed count means a blocking call may have leaked back onto the
|
||||
loop through the escape hatch.
|
||||
"""
|
||||
total = sum(len(_scan(rel).db_escapes) for rel in _GATEWAY_FILES)
|
||||
assert total <= _ALLOWED_SYNC_DB_ESCAPES, (
|
||||
f"self._session_db._db. sync escape used {total} times; "
|
||||
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction only) is allowed."
|
||||
f"at most {_ALLOWED_SYNC_DB_ESCAPES} (construction + run_sync) is allowed."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,13 +5,14 @@ The Discord gateway heartbeat was stalling because the handoff watcher
|
|||
SQLite-backed ``SessionDB`` directly on the asyncio event loop every 2s
|
||||
('Shard ID None heartbeat blocked for more than N seconds').
|
||||
|
||||
The fix (mirroring PR #40782) wraps every blocking ``SessionDB`` call inside
|
||||
the watcher loop in ``asyncio.to_thread(...)`` so the SQLite I/O runs on a
|
||||
worker thread and never blocks the event loop / Discord heartbeat.
|
||||
The fix routes every blocking ``SessionDB`` call in the watcher through the
|
||||
``AsyncSessionDB`` facade, which offloads each call via ``asyncio.to_thread`` so
|
||||
the SQLite I/O runs on a worker thread and never blocks the event loop / Discord
|
||||
heartbeat.
|
||||
|
||||
These tests assert that behaviour contract. They are mutation-survivable:
|
||||
reverting any ``asyncio.to_thread(self._session_db.<call>)`` wrap back to a
|
||||
direct synchronous call on the loop makes the relevant assertion fail.
|
||||
reverting any ``await self._session_db.<call>(...)`` back to a direct synchronous
|
||||
call on the loop makes the relevant assertion fail.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
|
@ -62,9 +63,15 @@ class _RecordingSessionDB:
|
|||
|
||||
|
||||
def _make_fake_runner(session_db, *, fail_process=False):
|
||||
"""Build a minimal object that exposes exactly what the loop body touches."""
|
||||
"""Build a minimal object that exposes exactly what the loop body touches.
|
||||
|
||||
The watcher now talks to the SessionDB through the AsyncSessionDB facade,
|
||||
so wrap the recording stand-in the same way the gateway does.
|
||||
"""
|
||||
from hermes_state import AsyncSessionDB
|
||||
|
||||
fake = types.SimpleNamespace()
|
||||
fake._session_db = session_db
|
||||
fake._session_db = AsyncSessionDB(session_db)
|
||||
# _running yields True for the first loop check, then False so the loop
|
||||
# exits after a single tick.
|
||||
states = iter([True, False])
|
||||
|
|
@ -141,21 +148,23 @@ async def test_watcher_offloads_fail_handoff_to_thread(monkeypatch):
|
|||
async def test_watcher_wraps_calls_via_asyncio_to_thread(monkeypatch):
|
||||
"""Explicitly assert the offload goes through asyncio.to_thread.
|
||||
|
||||
Patches ``run.asyncio.to_thread`` and records which SessionDB callables
|
||||
were handed to it. Mutation-survivable: dropping any wrap removes its
|
||||
callable from the recorded set.
|
||||
Patches the AsyncSessionDB facade's ``asyncio.to_thread`` (it lives in
|
||||
hermes_state) and records which SessionDB callables were handed to it.
|
||||
Mutation-survivable: dropping any await removes its callable from the set.
|
||||
"""
|
||||
import hermes_state
|
||||
|
||||
db = _RecordingSessionDB(loop_thread_ident=-1)
|
||||
fake = _make_fake_runner(db, fail_process=False)
|
||||
|
||||
wrapped = []
|
||||
real_to_thread = run.asyncio.to_thread
|
||||
real_to_thread = hermes_state.asyncio.to_thread
|
||||
|
||||
async def _spy_to_thread(func, *args, **kwargs):
|
||||
wrapped.append(getattr(func, "__name__", repr(func)))
|
||||
return await real_to_thread(func, *args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(run.asyncio, "to_thread", _spy_to_thread)
|
||||
monkeypatch.setattr(hermes_state.asyncio, "to_thread", _spy_to_thread)
|
||||
|
||||
await _run_one_tick(fake, monkeypatch)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import pytest
|
|||
|
||||
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
||||
from gateway.platforms.base import MessageEvent
|
||||
from hermes_state import AsyncSessionDB
|
||||
from gateway.session import (
|
||||
SessionContext,
|
||||
SessionEntry,
|
||||
|
|
@ -343,16 +344,16 @@ def _make_runner(current_source: SessionSource, entries: list[SessionEntry]):
|
|||
runner._clear_session_boundary_security_state = MagicMock()
|
||||
runner._evict_cached_agent = MagicMock()
|
||||
runner._queue_depth = MagicMock(return_value=0)
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.list_sessions_rich.return_value = [
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.list_sessions_rich.return_value = [
|
||||
{"id": entry.session_id, "title": entry.display_name, "preview": ""}
|
||||
for entry in entries
|
||||
]
|
||||
runner._session_db.resolve_resume_session_id.side_effect = lambda sid: sid
|
||||
runner._session_db.get_session_title.side_effect = lambda sid: {
|
||||
runner._session_db._db.resolve_resume_session_id.side_effect = lambda sid: sid
|
||||
runner._session_db._db.get_session_title.side_effect = lambda sid: {
|
||||
entry.session_id: entry.display_name for entry in entries
|
||||
}.get(sid)
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
return runner
|
||||
|
||||
|
||||
|
|
@ -388,7 +389,7 @@ async def test_matrix_resume_does_not_cross_rooms_by_default():
|
|||
entry_a = _entry(source_a, "session-a", "Project A Plan")
|
||||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(_event("/resume Project A Plan", source_b))
|
||||
|
||||
|
|
@ -406,7 +407,7 @@ async def test_matrix_resume_allows_same_room_session():
|
|||
source_b, "session-b-current", "Current Project B"
|
||||
)
|
||||
runner.session_store.switch_session.return_value = entry_b
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
|
||||
|
||||
result = await runner._handle_resume_command(_event("/resume Project B Plan", source_b))
|
||||
|
||||
|
|
@ -423,14 +424,14 @@ async def test_matrix_resume_quoted_title_same_room():
|
|||
source_b, "session-b-current", "Current Project B"
|
||||
)
|
||||
runner.session_store.switch_session.return_value = entry_b
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-b-old"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-b-old"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event('/resume "Project B Plan"', source_b)
|
||||
)
|
||||
|
||||
assert "Resumed session" in result
|
||||
runner._session_db.resolve_session_by_title.assert_called_once_with("Project B Plan")
|
||||
runner._session_db._db.resolve_session_by_title.assert_called_once_with("Project B Plan")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -440,7 +441,7 @@ async def test_matrix_resume_quoted_title_cross_room_blocked():
|
|||
entry_a = _entry(source_a, "session-a", "Project A Plan")
|
||||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event('/resume "Project A Plan"', source_b)
|
||||
|
|
@ -471,7 +472,7 @@ async def test_matrix_resume_cross_room_requires_explicit_flag_and_warns():
|
|||
entry_b = _entry(source_b, "session-b", "Project B Plan")
|
||||
runner = _make_runner(source_b, [entry_a, entry_b])
|
||||
runner.session_store.switch_session.return_value = entry_a
|
||||
runner._session_db.resolve_session_by_title.return_value = "session-a"
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "session-a"
|
||||
|
||||
result = await runner._handle_resume_command(
|
||||
_event("/resume --cross-room Project A Plan", source_b)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from hermes_state import AsyncSessionDB
|
||||
"""Regression tests for approval-state cleanup on session boundaries."""
|
||||
|
||||
from datetime import datetime
|
||||
|
|
@ -86,9 +87,9 @@ def _make_resume_runner():
|
|||
runner.session_store.get_or_create_session.return_value = current_entry
|
||||
runner.session_store.switch_session.return_value = resumed_entry
|
||||
runner.session_store.load_transcript.return_value = []
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.resolve_session_by_title.return_value = "resumed-session"
|
||||
runner._session_db.get_session_title.return_value = "Resumed Work"
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.resolve_session_by_title.return_value = "resumed-session"
|
||||
runner._session_db._db.get_session_title.return_value = "Resumed Work"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
|
|
@ -116,9 +117,9 @@ def _make_branch_runner():
|
|||
{"role": "assistant", "content": "world"},
|
||||
]
|
||||
runner.session_store.switch_session.return_value = branched_entry
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = "Current Work"
|
||||
runner._session_db.get_next_title_in_lineage.return_value = "Current Work #2"
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session_title.return_value = "Current Work"
|
||||
runner._session_db._db.get_next_title_in_lineage.return_value = "Current Work #2"
|
||||
return runner, session_key
|
||||
|
||||
|
||||
|
|
@ -208,7 +209,7 @@ async def test_branch_preserves_persisted_assistant_metadata():
|
|||
result = await runner._handle_branch_command(_make_event("/branch"))
|
||||
|
||||
assert "Branched to" in result
|
||||
append_calls = runner._session_db.append_message.call_args_list
|
||||
append_calls = runner._session_db._db.append_message.call_args_list
|
||||
assert len(append_calls) == 2
|
||||
assistant_kwargs = append_calls[1].kwargs
|
||||
assert assistant_kwargs["role"] == "assistant"
|
||||
|
|
|
|||
|
|
@ -171,8 +171,12 @@ async def test_second_message_during_sentinel_queued_not_duplicate():
|
|||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
# Start first message (will block at barrier)
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
# Yield so task1 enters slow_inner and sentinel is set
|
||||
await asyncio.sleep(0)
|
||||
# Yield until task1 has claimed the sentinel (it crosses a few awaits
|
||||
# before the claim; don't assume a fixed number of scheduler slices).
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0)
|
||||
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
|
||||
break
|
||||
|
||||
# Verify sentinel is set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
|
@ -417,7 +421,10 @@ async def test_stop_during_sentinel_force_cleans_session():
|
|||
|
||||
with patch.object(GatewayRunner, "_handle_message_with_agent", slow_inner):
|
||||
task1 = asyncio.create_task(runner._handle_message(event1))
|
||||
await asyncio.sleep(0)
|
||||
for _ in range(50):
|
||||
await asyncio.sleep(0)
|
||||
if runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL:
|
||||
break
|
||||
|
||||
# Sentinel should be set
|
||||
assert runner._running_agents.get(session_key) is _AGENT_PENDING_SENTINEL
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from hermes_state import AsyncSessionDB
|
||||
"""Tests for gateway /status behavior and token persistence."""
|
||||
|
||||
from datetime import datetime
|
||||
|
|
@ -53,11 +54,11 @@ def _make_runner(session_entry: SessionEntry, *, platform: Platform = Platform.T
|
|||
runner._session_run_generation = {}
|
||||
runner._pending_messages = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session_title.return_value = None
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session_title.return_value = None
|
||||
# Default: no DB row → /status reports 0 tokens. Tests that exercise
|
||||
# the populated path override this.
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
runner._reasoning_config = None
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
|
|
@ -86,7 +87,7 @@ async def test_status_command_reports_running_agent_without_interrupt(monkeypatc
|
|||
)
|
||||
runner = _make_runner(session_entry)
|
||||
# Token total comes from the SQLite SessionDB, not SessionEntry.
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 200,
|
||||
"output_tokens": 121,
|
||||
"cache_read_tokens": 0,
|
||||
|
|
@ -118,7 +119,7 @@ async def test_status_command_includes_session_title_when_present():
|
|||
total_tokens=321,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session_title.return_value = "My titled session"
|
||||
runner._session_db._db.get_session_title.return_value = "My titled session"
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
|
|
@ -141,7 +142,7 @@ async def test_status_command_reads_token_totals_from_session_db():
|
|||
total_tokens=0, # SessionEntry never gets written to — always 0.
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 250,
|
||||
"cache_read_tokens": 500,
|
||||
|
|
@ -169,7 +170,7 @@ async def test_status_command_tokens_zero_when_session_db_row_missing():
|
|||
total_tokens=999, # This should be ignored.
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = None
|
||||
runner._session_db._db.get_session.return_value = None
|
||||
|
||||
result = await runner._handle_message(_make_event("/status"))
|
||||
|
||||
|
|
@ -188,7 +189,7 @@ async def test_status_command_includes_live_agent_model_and_context():
|
|||
total_tokens=0,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 1000,
|
||||
"output_tokens": 250,
|
||||
"cache_read_tokens": 0,
|
||||
|
|
@ -228,7 +229,7 @@ async def test_status_command_includes_persisted_model_and_context_when_agent_no
|
|||
last_prompt_tokens=24_000,
|
||||
)
|
||||
runner = _make_runner(session_entry)
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"input_tokens": 2000,
|
||||
"output_tokens": 500,
|
||||
"cache_read_tokens": 0,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from hermes_state import AsyncSessionDB
|
||||
"""Tests for gateway /usage command — agent cache lookup and output fields."""
|
||||
|
||||
import threading
|
||||
|
|
@ -197,8 +198,8 @@ class TestUsageAccountSection:
|
|||
@pytest.mark.asyncio
|
||||
async def test_usage_command_uses_persisted_provider_when_agent_not_running(self, monkeypatch):
|
||||
runner = _make_runner(SK)
|
||||
runner._session_db = MagicMock()
|
||||
runner._session_db.get_session.return_value = {
|
||||
runner._session_db = AsyncSessionDB(MagicMock())
|
||||
runner._session_db._db.get_session.return_value = {
|
||||
"billing_provider": "openai-codex",
|
||||
"billing_base_url": "https://chatgpt.com/backend-api/codex",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue