diff --git a/gateway/run.py b/gateway/run.py index d426dd56310..f469647be45 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -8875,6 +8875,20 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew _response_time, _api_calls, _resp_len, ) + # Re-baseline the cached agent's message_count snapshot now that + # this turn has completed and the agent has flushed its rows to + # the SessionDB. The cross-process coherence guard (#45966) + # snapshots the count at agent-BUILD time (before this turn's own + # writes) and never refreshes it on reuse — so without this, this + # process's own turn would grow the count and the next turn would + # see a mismatch and rebuild the agent every turn, destroying + # prompt caching. Refreshing here makes the guard fire only on a + # DIFFERENT process's writes. Uses the (possibly compaction- + # updated) live session_id. Fail-safe inside the helper. + self._refresh_agent_cache_message_count( + session_key, session_entry.session_id + ) + # Successful turn — clear any stuck-loop counter for this session. # This ensures the counter only accumulates across CONSECUTIVE # restarts where the session was active (never completed). @@ -12784,6 +12798,57 @@ class GatewayRunner(GatewayAuthorizationMixin, GatewayKanbanWatchersMixin, Gatew if release_running_state: self._release_running_agent_state(session_key) + def _refresh_agent_cache_message_count( + self, session_key: str, session_id: Optional[str] + ) -> None: + """Re-baseline a cached agent's stored message_count after THIS turn. + + The cross-process coherence guard (#45966) compares the session's + on-disk ``message_count`` against the count snapshotted next to the + cached agent, and rebuilds the agent on a mismatch. But the snapshot + is taken at agent-BUILD time — before this turn writes its own user + + assistant (+ tool) rows — and the cache entry is never rewritten on a + reuse. So without this re-baseline, THIS process's own turn would + grow ``message_count`` and the very next turn would see a mismatch + and rebuild the agent — every turn, for every conversation — silently + destroying the per-conversation prompt caching the cache exists to + protect. + + Call this once a turn has completed and the agent has flushed its + rows to the SessionDB. It snapshots the now-current count (which + includes this process's own writes) so the guard only fires when a + DIFFERENT process changes the transcript out from under us. The + ``_sig`` is left untouched; only the count element is refreshed, and + only when the same agent is still cached (no rebuild/eviction raced + in between). Fail-safe: any DB error leaves the snapshot as-is, which + at worst costs one unnecessary rebuild on the next turn. + """ + if self._session_db is None or not session_id: + return + _cache_lock = getattr(self, "_agent_cache_lock", None) + _cache = getattr(self, "_agent_cache", None) + if not _cache_lock or _cache is None: + return + try: + _sess_row = self._session_db.get_session(session_id) + _live = _sess_row.get("message_count", 0) if _sess_row else None + except Exception: + return + if _live is None: + return + with _cache_lock: + cached = _cache.get(session_key) + # Only re-baseline a live 3-tuple entry; skip pending sentinels, + # legacy 2-tuples (they intentionally opt out of the guard), and + # the case where the entry was evicted/rebuilt mid-turn. + if ( + isinstance(cached, tuple) + and len(cached) > 2 + and cached[0] is not _AGENT_PENDING_SENTINEL + ): + if cached[2] != _live: + _cache[session_key] = (cached[0], cached[1], _live) + def _evict_cached_agent(self, session_key: str) -> None: """Remove a cached agent for a session (called on /new, /model, etc). diff --git a/tests/gateway/test_agent_cache.py b/tests/gateway/test_agent_cache.py index 350bf216504..88806c7d157 100644 --- a/tests/gateway/test_agent_cache.py +++ b/tests/gateway/test_agent_cache.py @@ -1546,3 +1546,163 @@ class TestAgentConfigSignatureUserId: user_id=None, user_id_alt=None, ) assert sig_implicit == sig_explicit_none + + +class TestAgentCacheMessageCountRebaseline: + """The cross-process coherence guard (#45966) must NOT invalidate the + cache on this process's OWN writes. + + The guard snapshots ``message_count`` at agent-build time (before the + turn writes its own rows) and never refreshes it on reuse. Without a + post-turn re-baseline, the gateway's own turn grows the count and the + next turn sees a mismatch and rebuilds the agent — every turn, for every + conversation — silently destroying per-conversation prompt caching. + + ``_refresh_agent_cache_message_count`` re-baselines the stored count to + the now-current value after each turn, so the guard fires ONLY when a + different process changed the transcript. These tests pin both halves of + the invariant against the REAL SessionDB + the REAL guard condition. + """ + + def _runner_with_db(self, db): + runner = _make_runner() + runner._session_db = db + return runner + + @staticmethod + def _guard_would_reuse(runner, session_key, session_id): + """Mirror the production cache-hit guard's reuse decision exactly. + + Reuse iff the live on-disk count equals the snapshot stored next to + the cached agent (or either side is None / it's a legacy 2-tuple). + """ + try: + row = runner._session_db.get_session(session_id) + live = row.get("message_count", 0) if row else None + except Exception: + live = None + with runner._agent_cache_lock: + cached = runner._agent_cache.get(session_key) + cached_mc = cached[2] if cached and len(cached) > 2 else None + invalidate = ( + cached_mc is not None + and live is not None + and live != cached_mc + ) + return not invalidate + + def test_same_process_turns_preserve_cached_agent(self, tmp_path): + """The regression guard: consecutive same-process turns must REUSE + the cached agent (prompt cache preserved), not rebuild every turn. + + Drives the real lifecycle: snapshot at build (before this turn's + writes), turn appends its own rows, then the post-turn re-baseline + runs — so the NEXT turn's guard sees no external change and reuses. + """ + from hermes_state import SessionDB + + db = SessionDB(db_path=tmp_path / "sessions.db") + db.create_session("s1", source="telegram") + runner = self._runner_with_db(db) + agent = object() + + # Turn 1: cache miss -> build. Snapshot is the count BEFORE this + # turn's own writes (production stores _current_msg_count here). + _row = db.get_session("s1") + build_count = _row.get("message_count", 0) if _row else 0 + with runner._agent_cache_lock: + runner._agent_cache["telegram:s1"] = (agent, "sig", build_count) + + reuses = 0 + for _turn in range(1, 6): + # This process's own turn flushes its user + assistant rows. + db.append_message("s1", role="user", content="u") + db.append_message("s1", role="assistant", content="a") + # Post-turn re-baseline (the fix). + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + # Next turn's guard decision. + if self._guard_would_reuse(runner, "telegram:s1", "s1"): + reuses += 1 + + # All 5 follow-on turns must reuse — WITHOUT the re-baseline this is 0. + assert reuses == 5 + # The same agent instance is still cached (never rebuilt). + with runner._agent_cache_lock: + assert runner._agent_cache["telegram:s1"][0] is agent + + def test_cross_process_write_still_invalidates(self, tmp_path): + """After the re-baseline, a DIFFERENT process appending to the same + session must still flip the guard to rebuild (the #45966 fix holds). + """ + from hermes_state import SessionDB + + db = SessionDB(db_path=tmp_path / "sessions.db") + db.create_session("s1", source="telegram") + runner = self._runner_with_db(db) + agent = object() + + with runner._agent_cache_lock: + _row = db.get_session("s1") + runner._agent_cache["telegram:s1"] = ( + agent, "sig", (_row.get("message_count", 0) if _row else 0), + ) + + # Our own turn + re-baseline -> reuse next turn. + db.append_message("s1", role="user", content="u") + db.append_message("s1", role="assistant", content="a") + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + assert self._guard_would_reuse(runner, "telegram:s1", "s1") is True + + # ANOTHER process (e.g. the desktop dashboard backend) appends a turn + # to the SAME session in the shared DB — we have NOT re-baselined for it. + db.append_message("s1", role="user", content="external from dashboard") + + # Guard must now reject reuse so the agent rebuilds from fresh disk. + assert self._guard_would_reuse(runner, "telegram:s1", "s1") is False + + def test_rebaseline_is_fail_safe_and_skips_legacy_and_pending(self, tmp_path): + """Re-baseline must never crash and must leave legacy 2-tuples and + pending-sentinel entries untouched.""" + from hermes_state import SessionDB + from gateway.run import _AGENT_PENDING_SENTINEL + + db = SessionDB(db_path=tmp_path / "sessions.db") + db.create_session("s1", source="telegram") + db.append_message("s1", role="user", content="hi") + runner = self._runner_with_db(db) + + # No session_db -> no-op, no crash. + runner._session_db = None + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + runner._session_db = db + + # Falsy session_id -> no-op. + runner._refresh_agent_cache_message_count("telegram:s1", "") + runner._refresh_agent_cache_message_count("telegram:s1", None) + + # Legacy 2-tuple is left untouched (it opts out of the guard). + with runner._agent_cache_lock: + runner._agent_cache["telegram:s1"] = (object(), "sig") + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + with runner._agent_cache_lock: + assert len(runner._agent_cache["telegram:s1"]) == 2 + + # Pending sentinel entry is left untouched. + with runner._agent_cache_lock: + runner._agent_cache["telegram:s1"] = (_AGENT_PENDING_SENTINEL, "sig", 0) + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + with runner._agent_cache_lock: + assert runner._agent_cache["telegram:s1"][0] is _AGENT_PENDING_SENTINEL + assert runner._agent_cache["telegram:s1"][2] == 0 + + # A probe that raises is swallowed (no crash, snapshot unchanged). + class _BoomDB: + def get_session(self, _sid): + raise RuntimeError("db locked") + + runner._session_db = _BoomDB() # type: ignore[assignment] + with runner._agent_cache_lock: + runner._agent_cache["telegram:s1"] = (object(), "sig", 5) + runner._refresh_agent_cache_message_count("telegram:s1", "s1") + with runner._agent_cache_lock: + assert runner._agent_cache["telegram:s1"][2] == 5