diff --git a/agent/agent_init.py b/agent/agent_init.py index 41f7cc11bbb..12597e5050f 100644 --- a/agent/agent_init.py +++ b/agent/agent_init.py @@ -1665,6 +1665,12 @@ def init_agent( abort_on_summary_failure=compression_abort_on_summary_failure, max_tokens=agent.max_tokens, ) + _bind_session_state = getattr(agent.context_compressor, "bind_session_state", None) + if callable(_bind_session_state): + try: + _bind_session_state(session_db=session_db, session_id=agent.session_id) + except Exception: + pass agent.compression_enabled = compression_enabled agent.compression_in_place = compression_in_place diff --git a/agent/context_compressor.py b/agent/context_compressor.py index fbde99bda5f..feaa1c2cc6d 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -19,6 +19,7 @@ Improvements over v2: import hashlib import json import logging +import sqlite3 import re import time from typing import Any, Dict, List, Optional @@ -638,6 +639,7 @@ class ContextCompressor(ContextEngine): self._last_compression_savings_pct = 100.0 self._ineffective_compression_count = 0 self._summary_failure_cooldown_until = 0.0 # transient errors must not block a fresh session + self._last_summary_error = None self.last_real_prompt_tokens = 0 self.last_compression_rough_tokens = 0 self.last_rough_tokens_when_real_prompt_fit = 0 @@ -659,6 +661,104 @@ class ContextCompressor(ContextEngine): """ self._previous_summary = None + def bind_session_state(self, session_db: Any = None, session_id: str = "") -> None: + """Bind the current session row so durable cooldowns can round-trip.""" + self._session_db = session_db + self._session_id = session_id or "" + self._summary_failure_cooldown_until = 0.0 + self._last_summary_error = None + self.get_active_compression_failure_cooldown() + + def on_session_start(self, session_id: str, **kwargs) -> None: + """Bind session-scoped compression state for a new or resumed session.""" + super().on_session_start(session_id, **kwargs) + self.bind_session_state(kwargs.get("session_db", self._session_db), session_id) + + def get_active_compression_failure_cooldown(self) -> Optional[Dict[str, Any]]: + """Return the live compression-failure cooldown for the bound session.""" + now_mono = time.monotonic() + if self._summary_failure_cooldown_until > now_mono: + return { + "cooldown_until": time.time() + ( + self._summary_failure_cooldown_until - now_mono + ), + "remaining_seconds": self._summary_failure_cooldown_until - now_mono, + "error": self._last_summary_error, + } + + session_db = self._session_db + session_id = self._session_id + if not session_db or not session_id: + return None + + getter = getattr(session_db, "get_compression_failure_cooldown", None) + if getter is None: + return None + try: + state = getter(session_id) + except sqlite3.Error as exc: + logger.debug("compression failure cooldown lookup failed: %s", exc) + return None + except Exception: + return None + if not state: + return None + + remaining_seconds = float(state.get("remaining_seconds") or 0.0) + if remaining_seconds <= 0: + return None + + self._summary_failure_cooldown_until = now_mono + remaining_seconds + self._last_summary_error = state.get("error") + return { + "cooldown_until": float(state.get("cooldown_until") or 0.0), + "remaining_seconds": remaining_seconds, + "error": self._last_summary_error, + } + + def _record_compression_failure_cooldown( + self, + cooldown_seconds: float, + error: Optional[str], + ) -> None: + cooldown_until = time.time() + cooldown_seconds + self._summary_failure_cooldown_until = time.monotonic() + cooldown_seconds + self._last_summary_error = error + + session_db = self._session_db + session_id = self._session_id + if not session_db or not session_id: + return + + recorder = getattr(session_db, "record_compression_failure_cooldown", None) + if recorder is None: + return + try: + recorder(session_id, cooldown_until, error) + except sqlite3.Error as exc: + logger.debug("compression failure cooldown persist failed: %s", exc) + except Exception: + pass + + def _clear_compression_failure_cooldown(self) -> None: + self._summary_failure_cooldown_until = 0.0 + self._last_summary_error = None + + session_db = self._session_db + session_id = self._session_id + if not session_db or not session_id: + return + + clearer = getattr(session_db, "clear_compression_failure_cooldown", None) + if clearer is None: + return + try: + clearer(session_id) + except sqlite3.Error as exc: + logger.debug("compression failure cooldown clear failed: %s", exc) + except Exception: + pass + def update_model( self, model: str, @@ -863,6 +963,8 @@ class ContextCompressor(ContextEngine): self.awaiting_real_usage_after_compression = False self.summary_model = summary_model_override or "" + self._session_db: Any = None + self._session_id: str = "" # Stores the previous compaction summary for iterative updates self._previous_summary: Optional[str] = None @@ -1691,7 +1793,7 @@ This compaction should PRIORITISE preserving all information related to the focu summary = redact_sensitive_text(content.strip()) # Store for iterative updates on next compaction self._previous_summary = summary - self._summary_failure_cooldown_until = 0.0 + self._clear_compression_failure_cooldown() self._summary_model_fallen_back = False self._last_summary_error = None self._last_summary_auth_failure = False @@ -1711,7 +1813,10 @@ This compaction should PRIORITISE preserving all information related to the focu # a main-model retry before any cooldown. (#11978, #11914) if isinstance(e, RuntimeError) and "no llm provider configured" in str(e).lower(): # No provider configured — long cooldown, unlikely to self-resolve - self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS + self._record_compression_failure_cooldown( + _SUMMARY_FAILURE_COOLDOWN_SECONDS, + "no auxiliary LLM provider configured", + ) self._last_summary_error = "no auxiliary LLM provider configured" logger.warning("Context compression: no provider available for " "summary. Middle turns will be dropped without summary " @@ -1823,10 +1928,10 @@ This compaction should PRIORITISE preserving all information related to the focu # streaming premature-close) — shorter cooldown for JSON decode and # streaming-closed since those conditions can self-resolve quickly. _transient_cooldown = 30 if (_is_json_decode or _is_streaming_closed) else 60 - self._summary_failure_cooldown_until = time.monotonic() + _transient_cooldown err_text = str(e).strip() or e.__class__.__name__ if len(err_text) > 220: err_text = err_text[:217].rstrip() + "..." + self._record_compression_failure_cooldown(_transient_cooldown, err_text) self._last_summary_error = err_text # A terminal connection/network failure (we reach this branch only # after any main-model fallback has already been tried or is @@ -2405,8 +2510,8 @@ This compaction should PRIORITISE preserving all information related to the focu # Manual /compress (force=True) bypasses the failure cooldown so the # user can retry immediately after an auto-compress abort. Without # this, /compress would silently no-op for 30-60s after a failure. - if force and self._summary_failure_cooldown_until > 0.0: - self._summary_failure_cooldown_until = 0.0 + if force: + self._clear_compression_failure_cooldown() n_messages = len(messages) # Only need head + 3 tail messages minimum (token budget decides the real tail size) _min_for_compress = self._protect_head_size(messages) + 3 + 1 diff --git a/agent/conversation_compression.py b/agent/conversation_compression.py index b16765ea9b4..551bdcdee6f 100644 --- a/agent/conversation_compression.py +++ b/agent/conversation_compression.py @@ -32,6 +32,7 @@ import logging import os import tempfile import uuid +import threading from datetime import datetime from pathlib import Path from typing import Any, Optional, Tuple @@ -71,6 +72,53 @@ def _compression_lock_holder(agent: Any) -> str: ) +class _CompressionLockLeaseRefresher: + def __init__( + self, + db: Any, + session_id: str, + holder: str, + ttl_seconds: float, + refresh_interval_seconds: float | None = None, + ) -> None: + self._db = db + self._session_id = session_id + self._holder = holder + self._ttl_seconds = ttl_seconds + if refresh_interval_seconds is None: + refresh_interval_seconds = max(1.0, min(60.0, ttl_seconds / 2.0)) + self._refresh_interval_seconds = max(0.1, float(refresh_interval_seconds)) + self._stop = threading.Event() + self._thread = threading.Thread( + target=self._run, + name="compression-lock-refresh", + daemon=True, + ) + + def start(self) -> "_CompressionLockLeaseRefresher": + self._thread.start() + return self + + def stop(self) -> None: + self._stop.set() + if self._thread.is_alive() and threading.current_thread() is not self._thread: + self._thread.join(timeout=1.0) + + def _run(self) -> None: + while not self._stop.wait(self._refresh_interval_seconds): + try: + refreshed = self._db.refresh_compression_lock( + self._session_id, + self._holder, + ttl_seconds=self._ttl_seconds, + ) + except Exception as exc: + logger.debug("compression lock refresh failed: %s", exc) + refreshed = False + if not refreshed: + break + + def check_compression_model_feasibility(agent: Any) -> None: """Warn at session start if the auxiliary compression model's context window is smaller than the main model's compression threshold. @@ -420,11 +468,17 @@ def compress_context( # and proceed with compression. Skipping the lock risks a rare # concurrent-compression session fork; an infinite no-progress loop # that never compresses at all is strictly worse. + try: + _lock_ttl = float(getattr(agent, "_compression_lock_ttl_seconds", 300.0) or 300.0) + except (TypeError, ValueError): + _lock_ttl = 300.0 + _lock_refresh_interval = getattr(agent, "_compression_lock_refresh_interval", None) + _lock_refresher: Optional[_CompressionLockLeaseRefresher] = None if _lock_db is not None and _lock_sid: _lock_holder = _compression_lock_holder(agent) try: _lock_acquired = _lock_db.try_acquire_compression_lock( - _lock_sid, _lock_holder + _lock_sid, _lock_holder, ttl_seconds=_lock_ttl ) except Exception as _lock_err: # Broken/absent lock subsystem (version skew, etc.). Log once @@ -467,9 +521,18 @@ def compress_context( if not _existing_sp: _existing_sp = agent._build_system_prompt(system_message) return messages, _existing_sp + _lock_refresher = _CompressionLockLeaseRefresher( + _lock_db, + _lock_sid, + _lock_holder, + _lock_ttl, + _lock_refresh_interval, + ).start() def _release_lock() -> None: """Release the lock keyed on the OLD session_id (before rotation).""" + if _lock_refresher is not None: + _lock_refresher.stop() if _lock_db is not None and _lock_sid and _lock_holder: try: _lock_db.release_compression_lock(_lock_sid, _lock_holder) diff --git a/agent/turn_context.py b/agent/turn_context.py index f53a89a9497..5ec0a0a9065 100644 --- a/agent/turn_context.py +++ b/agent/turn_context.py @@ -360,6 +360,12 @@ def build_turn_context( if _last >= 0 and _preflight_tokens > _last: _compressor.last_prompt_tokens = _preflight_tokens + _compression_cooldown = getattr( + _compressor, + "get_active_compression_failure_cooldown", + lambda: None, + )() + if _preflight_deferred: logger.info( "Skipping preflight compression: rough estimate ~%s >= %s, " @@ -368,6 +374,13 @@ def build_turn_context( f"{_compressor.threshold_tokens:,}", f"{_compressor.last_real_prompt_tokens:,}", ) + elif _compression_cooldown: + logger.info( + "Skipping preflight compression: same-session cooldown active " + "(~%s seconds remaining, session %s)", + int(_compression_cooldown.get("remaining_seconds", 0.0)), + agent.session_id or "none", + ) elif _compressor.should_compress(_preflight_tokens): logger.info( "Preflight compression: ~%s tokens >= %s threshold (model %s, ctx %s)", diff --git a/hermes_state.py b/hermes_state.py index 10b481a05db..b88a7b0c1a4 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -675,6 +675,8 @@ CREATE TABLE IF NOT EXISTS sessions ( handoff_state TEXT, handoff_platform TEXT, handoff_error TEXT, + compression_failure_cooldown_until REAL, + compression_failure_error TEXT, rewind_count INTEGER NOT NULL DEFAULT 0, archived INTEGER NOT NULL DEFAULT 0, FOREIGN KEY (parent_session_id) REFERENCES sessions(id) @@ -1722,6 +1724,88 @@ class SessionDB: ) self._execute_write(_do) + + def record_compression_failure_cooldown( + self, + session_id: str, + cooldown_until: float, + error: Optional[str] = None, + ) -> None: + """Persist the active compression-failure cooldown for a session.""" + if not session_id: + return + + def _do(conn): + conn.execute( + "UPDATE sessions SET compression_failure_cooldown_until = ?, " + "compression_failure_error = ? WHERE id = ?", + (cooldown_until, error, session_id), + ) + + try: + self._execute_write(_do) + except sqlite3.Error as exc: + logger.warning( + "record_compression_failure_cooldown(%s) failed: %s", + session_id, exc, + ) + + def get_compression_failure_cooldown( + self, + session_id: str, + ) -> Optional[Dict[str, Any]]: + """Return the active compression-failure cooldown for ``session_id``.""" + if not session_id: + return None + now = time.time() + with self._lock: + row = self._conn.execute( + "SELECT compression_failure_cooldown_until, compression_failure_error " + "FROM sessions WHERE id = ?", + (session_id,), + ).fetchone() + if row is None: + return None + cooldown_until = ( + row["compression_failure_cooldown_until"] + if isinstance(row, sqlite3.Row) + else row[0] + ) + if cooldown_until is None: + return None + cooldown_until = float(cooldown_until) + if cooldown_until <= now: + return None + error = ( + row["compression_failure_error"] + if isinstance(row, sqlite3.Row) + else row[1] + ) + return { + "cooldown_until": cooldown_until, + "remaining_seconds": cooldown_until - now, + "error": error, + } + + def clear_compression_failure_cooldown(self, session_id: str) -> None: + """Clear any persisted compression-failure cooldown for a session.""" + if not session_id: + return + + def _do(conn): + conn.execute( + "UPDATE sessions SET compression_failure_cooldown_until = NULL, " + "compression_failure_error = NULL WHERE id = ?", + (session_id,), + ) + + try: + self._execute_write(_do) + except sqlite3.Error as exc: + logger.warning( + "clear_compression_failure_cooldown(%s) failed: %s", + session_id, exc, + ) # ────────────────────────────────────────────────────────────────────── # Compression locks # ────────────────────────────────────────────────────────────────────── @@ -1743,6 +1827,35 @@ class SessionDB: # the compress() call plus the rotation. ``holder`` identifies the # current owner (pid:tid:nonce) for diagnostics; the lock is recovered # via ``expires_at`` if the holder process crashed without releasing. + def refresh_compression_lock( + self, + session_id: str, + holder: str, + ttl_seconds: float = 300.0, + ) -> bool: + """Extend the compression lock lease if ``holder`` still owns it.""" + if not session_id or not holder: + return False + now = time.time() + expires_at = now + ttl_seconds + + def _do(conn): + cur = conn.execute( + "UPDATE compression_locks SET expires_at = ? " + "WHERE session_id = ? AND holder = ? AND expires_at >= ?", + (expires_at, session_id, holder, now), + ) + return cur.rowcount > 0 + + try: + return bool(self._execute_write(_do)) + except sqlite3.Error as exc: + logger.warning( + "refresh_compression_lock(%s) failed: %s", + session_id, exc, + ) + return False + def try_acquire_compression_lock( self, session_id: str, diff --git a/tests/agent/test_compression_concurrent_fork.py b/tests/agent/test_compression_concurrent_fork.py index 617ded2e0e0..a8e5ccf97e6 100644 --- a/tests/agent/test_compression_concurrent_fork.py +++ b/tests/agent/test_compression_concurrent_fork.py @@ -177,6 +177,54 @@ def test_skipped_compression_returns_messages_unchanged(tmp_path: Path) -> None: agent.context_compressor.compress.assert_not_called() +def test_lock_refresh_keeps_owner_live_past_initial_ttl(tmp_path: Path, monkeypatch) -> None: + """The owning compression call must keep its lease alive while it runs.""" + real_try_acquire = SessionDB.try_acquire_compression_lock + + def _short_ttl(self, session_id: str, holder: str, ttl_seconds: float = 300.0) -> bool: + return real_try_acquire(self, session_id, holder, ttl_seconds=1.0) + + monkeypatch.setattr(SessionDB, "try_acquire_compression_lock", _short_ttl) + + db = SessionDB(db_path=tmp_path / "state.db") + + parent_sid = "REFRESH_TEST" + db.create_session(parent_sid, source="discord") + + agent_a = _build_agent_with_db(db, parent_sid) + agent_a._compression_lock_ttl_seconds = 1.0 + agent_a._compression_lock_refresh_interval = 0.25 + + def _slow_compress(*_a, **_kw): + time.sleep(2.0) + return [ + {"role": "user", "content": "[CONTEXT COMPACTION] summary"}, + {"role": "user", "content": "tail"}, + ] + + agent_a.context_compressor.compress.side_effect = _slow_compress + messages = [{"role": "user", "content": f"m{i}"} for i in range(20)] + + def run(agent): + agent._compress_context(messages, "sys", approx_tokens=120_000) + + t_a = threading.Thread(target=run, args=(agent_a,), name="refresh_owner") + t_a.start() + deadline = time.time() + 2.0 + while db.get_compression_lock_holder(parent_sid) is None and time.time() < deadline: + time.sleep(0.05) + assert db.get_compression_lock_holder(parent_sid) is not None + time.sleep(1.2) + assert db.try_acquire_compression_lock( + parent_sid, "refresh_probe", ttl_seconds=1.0 + ) is False, "live owner lease expired and was reclaimable before compression finished" + t_a.join(timeout=10) + + assert not t_a.is_alive() + assert _count_children(db, parent_sid) == 1 + assert db.get_compression_lock_holder(parent_sid) is None + + class _NoLockSubsystemDB: """Wraps a real SessionDB but simulates a pre-#34351 version skew. @@ -244,7 +292,7 @@ def test_missing_lock_subsystem_fails_open_not_infinite_loop(tmp_path: Path) -> assert agent.session_id != parent_sid -def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None: +def test_review_fork_disables_compression_to_prevent_stale_parent_fork(tmp_path: Path) -> None: """The background-review fork must set ``compression_enabled = False`` so it can never compress the parent it shares a session_id with (issue #38727). @@ -270,8 +318,6 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None ``AIAgent.run_conversation`` patched (so no LLM call happens) and captures the constructed review agent to assert the flag. """ - import tempfile - import agent.background_review as br captured = {} @@ -283,21 +329,20 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None parent_sid = "REVIEW_FORK_FLAG_TEST" - with tempfile.TemporaryDirectory() as td: - db = SessionDB(db_path=Path(td) / "state.db") - db.create_session(parent_sid, source="discord") - parent = _build_agent_with_db(db, parent_sid) + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session(parent_sid, source="discord") + parent = _build_agent_with_db(db, parent_sid) - # The worker does a local ``from run_agent import AIAgent``; patching - # the class method covers that import path. - from run_agent import AIAgent + # The worker does a local ``from run_agent import AIAgent``; patching + # the class method covers that import path. + from run_agent import AIAgent - with patch.object(AIAgent, "run_conversation", _fake_run_conversation): - br._run_review_in_thread( - parent, - [{"role": "user", "content": "hi"}], - "review this conversation", - ) + with patch.object(AIAgent, "run_conversation", _fake_run_conversation): + br._run_review_in_thread( + parent, + [{"role": "user", "content": "hi"}], + "review this conversation", + ) assert captured, ( "_run_review_in_thread never reached run_conversation — the spawn path " @@ -314,3 +359,4 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None "conversation_loop.py only short-circuit when compression_enabled is " "False — this flag MUST be cleared on the review fork." ) + db.close() diff --git a/tests/agent/test_context_compressor.py b/tests/agent/test_context_compressor.py index 84e7d1847d7..e3d03829e79 100644 --- a/tests/agent/test_context_compressor.py +++ b/tests/agent/test_context_compressor.py @@ -1,6 +1,7 @@ """Tests for agent/context_compressor.py — compression logic, thresholds, truncation fallback.""" import pytest +import time from unittest.mock import patch, MagicMock from agent.context_compressor import ( @@ -8,6 +9,7 @@ from agent.context_compressor import ( HISTORICAL_TASK_HEADING, SUMMARY_PREFIX, ) +from hermes_state import SessionDB @pytest.fixture() @@ -1464,6 +1466,60 @@ class TestAbortOnSummaryFailure: assert c._summary_failure_cooldown_until == 0.0 assert len(result) < len(msgs) + def test_force_true_bypasses_persisted_session_cooldown(self, tmp_path): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "summary text" + + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("s1", "cli") + db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout") + + c = self._make_compressor() + c.bind_session_state(db, "s1") + msgs = self._make_msgs() + + with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_llm: + result = c.compress(msgs, current_tokens=999999, force=True) + + mock_llm.assert_called() + assert c._last_compress_aborted is False + assert len(result) < len(msgs) + assert db.get_compression_failure_cooldown("s1") is None + + def test_success_clears_persisted_session_cooldown(self, tmp_path): + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "summary text" + + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("s1", "cli") + db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout") + + c = self._make_compressor() + c.bind_session_state(db, "s1") + c._summary_failure_cooldown_until = 0.0 + msgs = self._make_msgs() + + with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_llm: + result = c.compress(msgs, current_tokens=999999) + + mock_llm.assert_called() + assert c._last_compress_aborted is False + assert len(result) < len(msgs) + assert db.get_compression_failure_cooldown("s1") is None + + def test_session_end_does_not_clear_persisted_session_cooldown(self, tmp_path): + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("s1", "cli") + db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout") + + c = self._make_compressor() + c.bind_session_state(db, "s1") + c.on_session_end("s1", []) + + assert db.get_compression_failure_cooldown("s1") is not None + class TestSummaryPrefixNormalization: def test_legacy_prefix_is_replaced(self): diff --git a/tests/agent/test_turn_context.py b/tests/agent/test_turn_context.py index ae602723160..0cd0e91caa7 100644 --- a/tests/agent/test_turn_context.py +++ b/tests/agent/test_turn_context.py @@ -9,11 +9,13 @@ confirm the prologue produces the right ``TurnContext`` and applies the from __future__ import annotations import types -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest +from agent.context_compressor import ContextCompressor from agent.turn_context import TurnContext, build_turn_context +from hermes_state import SessionDB class _FakeTodoStore: @@ -101,6 +103,33 @@ class _FakeAgent: self._persist_calls += 1 +def _make_agent_with_cooldown(db_path, session_id, *, cooldown_until=None): + agent = _FakeAgent() + agent.compression_enabled = True + agent._emit_status = MagicMock() + agent._compress_context = MagicMock( + side_effect=lambda messages, *_a, **_k: (messages, "SYSTEM") + ) + + db = SessionDB(db_path=db_path) + db.create_session(session_id, source="cli") + if cooldown_until is not None: + db.record_compression_failure_cooldown(session_id, cooldown_until, "timeout") + + with patch("agent.context_compressor.get_model_context_length", return_value=100000): + compressor = ContextCompressor( + model="test/model", + threshold_percent=0.85, + protect_first_n=2, + protect_last_n=2, + quiet_mode=True, + ) + compressor.bind_session_state(db, session_id) + agent.context_compressor = compressor + agent._session_db = db + return agent + + @pytest.fixture(autouse=True) def _stub_runtime_main(): """``build_turn_context`` calls ``auxiliary_client.set_runtime_main`` as a @@ -285,3 +314,53 @@ def test_between_turns_refresh_no_churn_when_unchanged(): assert agent.tools is same # not replaced → no churn + +def test_preflight_skips_when_persisted_cooldown_survives_restart(tmp_path): + agent = _make_agent_with_cooldown( + tmp_path / "state.db", + "sess-1", + cooldown_until=4_000_000_000.0, + ) + + with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \ + patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999): + ctx = _build(agent) + + assert isinstance(ctx, TurnContext) + agent._emit_status.assert_not_called() + agent._compress_context.assert_not_called() + + +def test_preflight_still_runs_for_other_session_with_same_db(tmp_path): + db_path = tmp_path / "state.db" + _make_agent_with_cooldown( + db_path, + "sess-1", + cooldown_until=4_000_000_000.0, + ) + agent = _make_agent_with_cooldown(db_path, "sess-2") + + with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \ + patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999): + ctx = _build(agent) + + assert isinstance(ctx, TurnContext) + agent._emit_status.assert_called_once() + agent._compress_context.assert_called() + + +def test_expired_cooldown_allows_preflight(tmp_path): + agent = _make_agent_with_cooldown( + tmp_path / "state.db", + "sess-1", + cooldown_until=1.0, + ) + + with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \ + patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999): + ctx = _build(agent) + + assert isinstance(ctx, TurnContext) + agent._emit_status.assert_called_once() + agent._compress_context.assert_called() + diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 5076182166a..ec15d0be435 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -4,6 +4,7 @@ import sqlite3 import time import pytest +import hermes_state from hermes_state import SCHEMA_SQL, SCHEMA_VERSION, SessionDB @@ -4725,3 +4726,55 @@ def test_gateway_session_recovery_reopens_legacy_agent_close_rows(db): chat_id="chat-1", chat_type="dm", ) is None + + +def test_compression_failure_cooldown_round_trips_and_clears(db): + db.create_session("s1", "cli") + + cooldown_until = time.time() + 60.0 + db.record_compression_failure_cooldown("s1", cooldown_until, "timeout") + + state = db.get_compression_failure_cooldown("s1") + assert state is not None + assert state["cooldown_until"] == cooldown_until + assert state["error"] == "timeout" + + db.clear_compression_failure_cooldown("s1") + assert db.get_compression_failure_cooldown("s1") is None + + row = db.get_session("s1") + assert row["compression_failure_cooldown_until"] is None + assert row["compression_failure_error"] is None + + +def test_expired_compression_failure_cooldown_is_ignored(db): + db.create_session("s1", "cli") + + db.record_compression_failure_cooldown("s1", time.time() - 60.0, "stale") + + assert db.get_compression_failure_cooldown("s1") is None + + +def test_refresh_compression_lock_requires_holder_and_preserves_reclaimability(db, monkeypatch): + db.create_session("s1", "cli") + + monkeypatch.setattr(hermes_state.time, "time", lambda: 1000.0) + assert db.try_acquire_compression_lock("s1", "holder-a", ttl_seconds=10.0) is True + + original_expires = db._conn.execute( + "SELECT expires_at FROM compression_locks WHERE session_id = ?", + ("s1",), + ).fetchone()[0] + + monkeypatch.setattr(hermes_state.time, "time", lambda: 1005.0) + assert db.refresh_compression_lock("s1", "holder-a", ttl_seconds=10.0) is True + refreshed_expires = db._conn.execute( + "SELECT expires_at FROM compression_locks WHERE session_id = ?", + ("s1",), + ).fetchone()[0] + assert refreshed_expires > original_expires + + assert db.refresh_compression_lock("s1", "holder-b", ttl_seconds=10.0) is False + + monkeypatch.setattr(hermes_state.time, "time", lambda: 1016.0) + assert db.try_acquire_compression_lock("s1", "holder-b", ttl_seconds=10.0) is True