mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-07-01 12:02:05 +00:00
fix(agent): persist compression backoff across resume (#54465)
This commit is contained in:
parent
5edfda5088
commit
f2ccb2859f
9 changed files with 557 additions and 23 deletions
|
|
@ -1665,6 +1665,12 @@ def init_agent(
|
|||
abort_on_summary_failure=compression_abort_on_summary_failure,
|
||||
max_tokens=agent.max_tokens,
|
||||
)
|
||||
_bind_session_state = getattr(agent.context_compressor, "bind_session_state", None)
|
||||
if callable(_bind_session_state):
|
||||
try:
|
||||
_bind_session_state(session_db=session_db, session_id=agent.session_id)
|
||||
except Exception:
|
||||
pass
|
||||
agent.compression_enabled = compression_enabled
|
||||
agent.compression_in_place = compression_in_place
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ Improvements over v2:
|
|||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
|
@ -638,6 +639,7 @@ class ContextCompressor(ContextEngine):
|
|||
self._last_compression_savings_pct = 100.0
|
||||
self._ineffective_compression_count = 0
|
||||
self._summary_failure_cooldown_until = 0.0 # transient errors must not block a fresh session
|
||||
self._last_summary_error = None
|
||||
self.last_real_prompt_tokens = 0
|
||||
self.last_compression_rough_tokens = 0
|
||||
self.last_rough_tokens_when_real_prompt_fit = 0
|
||||
|
|
@ -659,6 +661,104 @@ class ContextCompressor(ContextEngine):
|
|||
"""
|
||||
self._previous_summary = None
|
||||
|
||||
def bind_session_state(self, session_db: Any = None, session_id: str = "") -> None:
|
||||
"""Bind the current session row so durable cooldowns can round-trip."""
|
||||
self._session_db = session_db
|
||||
self._session_id = session_id or ""
|
||||
self._summary_failure_cooldown_until = 0.0
|
||||
self._last_summary_error = None
|
||||
self.get_active_compression_failure_cooldown()
|
||||
|
||||
def on_session_start(self, session_id: str, **kwargs) -> None:
|
||||
"""Bind session-scoped compression state for a new or resumed session."""
|
||||
super().on_session_start(session_id, **kwargs)
|
||||
self.bind_session_state(kwargs.get("session_db", self._session_db), session_id)
|
||||
|
||||
def get_active_compression_failure_cooldown(self) -> Optional[Dict[str, Any]]:
|
||||
"""Return the live compression-failure cooldown for the bound session."""
|
||||
now_mono = time.monotonic()
|
||||
if self._summary_failure_cooldown_until > now_mono:
|
||||
return {
|
||||
"cooldown_until": time.time() + (
|
||||
self._summary_failure_cooldown_until - now_mono
|
||||
),
|
||||
"remaining_seconds": self._summary_failure_cooldown_until - now_mono,
|
||||
"error": self._last_summary_error,
|
||||
}
|
||||
|
||||
session_db = self._session_db
|
||||
session_id = self._session_id
|
||||
if not session_db or not session_id:
|
||||
return None
|
||||
|
||||
getter = getattr(session_db, "get_compression_failure_cooldown", None)
|
||||
if getter is None:
|
||||
return None
|
||||
try:
|
||||
state = getter(session_id)
|
||||
except sqlite3.Error as exc:
|
||||
logger.debug("compression failure cooldown lookup failed: %s", exc)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
if not state:
|
||||
return None
|
||||
|
||||
remaining_seconds = float(state.get("remaining_seconds") or 0.0)
|
||||
if remaining_seconds <= 0:
|
||||
return None
|
||||
|
||||
self._summary_failure_cooldown_until = now_mono + remaining_seconds
|
||||
self._last_summary_error = state.get("error")
|
||||
return {
|
||||
"cooldown_until": float(state.get("cooldown_until") or 0.0),
|
||||
"remaining_seconds": remaining_seconds,
|
||||
"error": self._last_summary_error,
|
||||
}
|
||||
|
||||
def _record_compression_failure_cooldown(
|
||||
self,
|
||||
cooldown_seconds: float,
|
||||
error: Optional[str],
|
||||
) -> None:
|
||||
cooldown_until = time.time() + cooldown_seconds
|
||||
self._summary_failure_cooldown_until = time.monotonic() + cooldown_seconds
|
||||
self._last_summary_error = error
|
||||
|
||||
session_db = self._session_db
|
||||
session_id = self._session_id
|
||||
if not session_db or not session_id:
|
||||
return
|
||||
|
||||
recorder = getattr(session_db, "record_compression_failure_cooldown", None)
|
||||
if recorder is None:
|
||||
return
|
||||
try:
|
||||
recorder(session_id, cooldown_until, error)
|
||||
except sqlite3.Error as exc:
|
||||
logger.debug("compression failure cooldown persist failed: %s", exc)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _clear_compression_failure_cooldown(self) -> None:
|
||||
self._summary_failure_cooldown_until = 0.0
|
||||
self._last_summary_error = None
|
||||
|
||||
session_db = self._session_db
|
||||
session_id = self._session_id
|
||||
if not session_db or not session_id:
|
||||
return
|
||||
|
||||
clearer = getattr(session_db, "clear_compression_failure_cooldown", None)
|
||||
if clearer is None:
|
||||
return
|
||||
try:
|
||||
clearer(session_id)
|
||||
except sqlite3.Error as exc:
|
||||
logger.debug("compression failure cooldown clear failed: %s", exc)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def update_model(
|
||||
self,
|
||||
model: str,
|
||||
|
|
@ -863,6 +963,8 @@ class ContextCompressor(ContextEngine):
|
|||
self.awaiting_real_usage_after_compression = False
|
||||
|
||||
self.summary_model = summary_model_override or ""
|
||||
self._session_db: Any = None
|
||||
self._session_id: str = ""
|
||||
|
||||
# Stores the previous compaction summary for iterative updates
|
||||
self._previous_summary: Optional[str] = None
|
||||
|
|
@ -1691,7 +1793,7 @@ This compaction should PRIORITISE preserving all information related to the focu
|
|||
summary = redact_sensitive_text(content.strip())
|
||||
# Store for iterative updates on next compaction
|
||||
self._previous_summary = summary
|
||||
self._summary_failure_cooldown_until = 0.0
|
||||
self._clear_compression_failure_cooldown()
|
||||
self._summary_model_fallen_back = False
|
||||
self._last_summary_error = None
|
||||
self._last_summary_auth_failure = False
|
||||
|
|
@ -1711,7 +1813,10 @@ This compaction should PRIORITISE preserving all information related to the focu
|
|||
# a main-model retry before any cooldown. (#11978, #11914)
|
||||
if isinstance(e, RuntimeError) and "no llm provider configured" in str(e).lower():
|
||||
# No provider configured — long cooldown, unlikely to self-resolve
|
||||
self._summary_failure_cooldown_until = time.monotonic() + _SUMMARY_FAILURE_COOLDOWN_SECONDS
|
||||
self._record_compression_failure_cooldown(
|
||||
_SUMMARY_FAILURE_COOLDOWN_SECONDS,
|
||||
"no auxiliary LLM provider configured",
|
||||
)
|
||||
self._last_summary_error = "no auxiliary LLM provider configured"
|
||||
logger.warning("Context compression: no provider available for "
|
||||
"summary. Middle turns will be dropped without summary "
|
||||
|
|
@ -1823,10 +1928,10 @@ This compaction should PRIORITISE preserving all information related to the focu
|
|||
# streaming premature-close) — shorter cooldown for JSON decode and
|
||||
# streaming-closed since those conditions can self-resolve quickly.
|
||||
_transient_cooldown = 30 if (_is_json_decode or _is_streaming_closed) else 60
|
||||
self._summary_failure_cooldown_until = time.monotonic() + _transient_cooldown
|
||||
err_text = str(e).strip() or e.__class__.__name__
|
||||
if len(err_text) > 220:
|
||||
err_text = err_text[:217].rstrip() + "..."
|
||||
self._record_compression_failure_cooldown(_transient_cooldown, err_text)
|
||||
self._last_summary_error = err_text
|
||||
# A terminal connection/network failure (we reach this branch only
|
||||
# after any main-model fallback has already been tried or is
|
||||
|
|
@ -2405,8 +2510,8 @@ This compaction should PRIORITISE preserving all information related to the focu
|
|||
# Manual /compress (force=True) bypasses the failure cooldown so the
|
||||
# user can retry immediately after an auto-compress abort. Without
|
||||
# this, /compress would silently no-op for 30-60s after a failure.
|
||||
if force and self._summary_failure_cooldown_until > 0.0:
|
||||
self._summary_failure_cooldown_until = 0.0
|
||||
if force:
|
||||
self._clear_compression_failure_cooldown()
|
||||
n_messages = len(messages)
|
||||
# Only need head + 3 tail messages minimum (token budget decides the real tail size)
|
||||
_min_for_compress = self._protect_head_size(messages) + 3 + 1
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ import logging
|
|||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Tuple
|
||||
|
|
@ -71,6 +72,53 @@ def _compression_lock_holder(agent: Any) -> str:
|
|||
)
|
||||
|
||||
|
||||
class _CompressionLockLeaseRefresher:
|
||||
def __init__(
|
||||
self,
|
||||
db: Any,
|
||||
session_id: str,
|
||||
holder: str,
|
||||
ttl_seconds: float,
|
||||
refresh_interval_seconds: float | None = None,
|
||||
) -> None:
|
||||
self._db = db
|
||||
self._session_id = session_id
|
||||
self._holder = holder
|
||||
self._ttl_seconds = ttl_seconds
|
||||
if refresh_interval_seconds is None:
|
||||
refresh_interval_seconds = max(1.0, min(60.0, ttl_seconds / 2.0))
|
||||
self._refresh_interval_seconds = max(0.1, float(refresh_interval_seconds))
|
||||
self._stop = threading.Event()
|
||||
self._thread = threading.Thread(
|
||||
target=self._run,
|
||||
name="compression-lock-refresh",
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
def start(self) -> "_CompressionLockLeaseRefresher":
|
||||
self._thread.start()
|
||||
return self
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop.set()
|
||||
if self._thread.is_alive() and threading.current_thread() is not self._thread:
|
||||
self._thread.join(timeout=1.0)
|
||||
|
||||
def _run(self) -> None:
|
||||
while not self._stop.wait(self._refresh_interval_seconds):
|
||||
try:
|
||||
refreshed = self._db.refresh_compression_lock(
|
||||
self._session_id,
|
||||
self._holder,
|
||||
ttl_seconds=self._ttl_seconds,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug("compression lock refresh failed: %s", exc)
|
||||
refreshed = False
|
||||
if not refreshed:
|
||||
break
|
||||
|
||||
|
||||
def check_compression_model_feasibility(agent: Any) -> None:
|
||||
"""Warn at session start if the auxiliary compression model's context
|
||||
window is smaller than the main model's compression threshold.
|
||||
|
|
@ -420,11 +468,17 @@ def compress_context(
|
|||
# and proceed with compression. Skipping the lock risks a rare
|
||||
# concurrent-compression session fork; an infinite no-progress loop
|
||||
# that never compresses at all is strictly worse.
|
||||
try:
|
||||
_lock_ttl = float(getattr(agent, "_compression_lock_ttl_seconds", 300.0) or 300.0)
|
||||
except (TypeError, ValueError):
|
||||
_lock_ttl = 300.0
|
||||
_lock_refresh_interval = getattr(agent, "_compression_lock_refresh_interval", None)
|
||||
_lock_refresher: Optional[_CompressionLockLeaseRefresher] = None
|
||||
if _lock_db is not None and _lock_sid:
|
||||
_lock_holder = _compression_lock_holder(agent)
|
||||
try:
|
||||
_lock_acquired = _lock_db.try_acquire_compression_lock(
|
||||
_lock_sid, _lock_holder
|
||||
_lock_sid, _lock_holder, ttl_seconds=_lock_ttl
|
||||
)
|
||||
except Exception as _lock_err:
|
||||
# Broken/absent lock subsystem (version skew, etc.). Log once
|
||||
|
|
@ -467,9 +521,18 @@ def compress_context(
|
|||
if not _existing_sp:
|
||||
_existing_sp = agent._build_system_prompt(system_message)
|
||||
return messages, _existing_sp
|
||||
_lock_refresher = _CompressionLockLeaseRefresher(
|
||||
_lock_db,
|
||||
_lock_sid,
|
||||
_lock_holder,
|
||||
_lock_ttl,
|
||||
_lock_refresh_interval,
|
||||
).start()
|
||||
|
||||
def _release_lock() -> None:
|
||||
"""Release the lock keyed on the OLD session_id (before rotation)."""
|
||||
if _lock_refresher is not None:
|
||||
_lock_refresher.stop()
|
||||
if _lock_db is not None and _lock_sid and _lock_holder:
|
||||
try:
|
||||
_lock_db.release_compression_lock(_lock_sid, _lock_holder)
|
||||
|
|
|
|||
|
|
@ -360,6 +360,12 @@ def build_turn_context(
|
|||
if _last >= 0 and _preflight_tokens > _last:
|
||||
_compressor.last_prompt_tokens = _preflight_tokens
|
||||
|
||||
_compression_cooldown = getattr(
|
||||
_compressor,
|
||||
"get_active_compression_failure_cooldown",
|
||||
lambda: None,
|
||||
)()
|
||||
|
||||
if _preflight_deferred:
|
||||
logger.info(
|
||||
"Skipping preflight compression: rough estimate ~%s >= %s, "
|
||||
|
|
@ -368,6 +374,13 @@ def build_turn_context(
|
|||
f"{_compressor.threshold_tokens:,}",
|
||||
f"{_compressor.last_real_prompt_tokens:,}",
|
||||
)
|
||||
elif _compression_cooldown:
|
||||
logger.info(
|
||||
"Skipping preflight compression: same-session cooldown active "
|
||||
"(~%s seconds remaining, session %s)",
|
||||
int(_compression_cooldown.get("remaining_seconds", 0.0)),
|
||||
agent.session_id or "none",
|
||||
)
|
||||
elif _compressor.should_compress(_preflight_tokens):
|
||||
logger.info(
|
||||
"Preflight compression: ~%s tokens >= %s threshold (model %s, ctx %s)",
|
||||
|
|
|
|||
113
hermes_state.py
113
hermes_state.py
|
|
@ -675,6 +675,8 @@ CREATE TABLE IF NOT EXISTS sessions (
|
|||
handoff_state TEXT,
|
||||
handoff_platform TEXT,
|
||||
handoff_error TEXT,
|
||||
compression_failure_cooldown_until REAL,
|
||||
compression_failure_error TEXT,
|
||||
rewind_count INTEGER NOT NULL DEFAULT 0,
|
||||
archived INTEGER NOT NULL DEFAULT 0,
|
||||
FOREIGN KEY (parent_session_id) REFERENCES sessions(id)
|
||||
|
|
@ -1722,6 +1724,88 @@ class SessionDB:
|
|||
)
|
||||
|
||||
self._execute_write(_do)
|
||||
|
||||
def record_compression_failure_cooldown(
|
||||
self,
|
||||
session_id: str,
|
||||
cooldown_until: float,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Persist the active compression-failure cooldown for a session."""
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET compression_failure_cooldown_until = ?, "
|
||||
"compression_failure_error = ? WHERE id = ?",
|
||||
(cooldown_until, error, session_id),
|
||||
)
|
||||
|
||||
try:
|
||||
self._execute_write(_do)
|
||||
except sqlite3.Error as exc:
|
||||
logger.warning(
|
||||
"record_compression_failure_cooldown(%s) failed: %s",
|
||||
session_id, exc,
|
||||
)
|
||||
|
||||
def get_compression_failure_cooldown(
|
||||
self,
|
||||
session_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Return the active compression-failure cooldown for ``session_id``."""
|
||||
if not session_id:
|
||||
return None
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
row = self._conn.execute(
|
||||
"SELECT compression_failure_cooldown_until, compression_failure_error "
|
||||
"FROM sessions WHERE id = ?",
|
||||
(session_id,),
|
||||
).fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
cooldown_until = (
|
||||
row["compression_failure_cooldown_until"]
|
||||
if isinstance(row, sqlite3.Row)
|
||||
else row[0]
|
||||
)
|
||||
if cooldown_until is None:
|
||||
return None
|
||||
cooldown_until = float(cooldown_until)
|
||||
if cooldown_until <= now:
|
||||
return None
|
||||
error = (
|
||||
row["compression_failure_error"]
|
||||
if isinstance(row, sqlite3.Row)
|
||||
else row[1]
|
||||
)
|
||||
return {
|
||||
"cooldown_until": cooldown_until,
|
||||
"remaining_seconds": cooldown_until - now,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
def clear_compression_failure_cooldown(self, session_id: str) -> None:
|
||||
"""Clear any persisted compression-failure cooldown for a session."""
|
||||
if not session_id:
|
||||
return
|
||||
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"UPDATE sessions SET compression_failure_cooldown_until = NULL, "
|
||||
"compression_failure_error = NULL WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
try:
|
||||
self._execute_write(_do)
|
||||
except sqlite3.Error as exc:
|
||||
logger.warning(
|
||||
"clear_compression_failure_cooldown(%s) failed: %s",
|
||||
session_id, exc,
|
||||
)
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
# Compression locks
|
||||
# ──────────────────────────────────────────────────────────────────────
|
||||
|
|
@ -1743,6 +1827,35 @@ class SessionDB:
|
|||
# the compress() call plus the rotation. ``holder`` identifies the
|
||||
# current owner (pid:tid:nonce) for diagnostics; the lock is recovered
|
||||
# via ``expires_at`` if the holder process crashed without releasing.
|
||||
def refresh_compression_lock(
|
||||
self,
|
||||
session_id: str,
|
||||
holder: str,
|
||||
ttl_seconds: float = 300.0,
|
||||
) -> bool:
|
||||
"""Extend the compression lock lease if ``holder`` still owns it."""
|
||||
if not session_id or not holder:
|
||||
return False
|
||||
now = time.time()
|
||||
expires_at = now + ttl_seconds
|
||||
|
||||
def _do(conn):
|
||||
cur = conn.execute(
|
||||
"UPDATE compression_locks SET expires_at = ? "
|
||||
"WHERE session_id = ? AND holder = ? AND expires_at >= ?",
|
||||
(expires_at, session_id, holder, now),
|
||||
)
|
||||
return cur.rowcount > 0
|
||||
|
||||
try:
|
||||
return bool(self._execute_write(_do))
|
||||
except sqlite3.Error as exc:
|
||||
logger.warning(
|
||||
"refresh_compression_lock(%s) failed: %s",
|
||||
session_id, exc,
|
||||
)
|
||||
return False
|
||||
|
||||
def try_acquire_compression_lock(
|
||||
self,
|
||||
session_id: str,
|
||||
|
|
|
|||
|
|
@ -177,6 +177,54 @@ def test_skipped_compression_returns_messages_unchanged(tmp_path: Path) -> None:
|
|||
agent.context_compressor.compress.assert_not_called()
|
||||
|
||||
|
||||
def test_lock_refresh_keeps_owner_live_past_initial_ttl(tmp_path: Path, monkeypatch) -> None:
|
||||
"""The owning compression call must keep its lease alive while it runs."""
|
||||
real_try_acquire = SessionDB.try_acquire_compression_lock
|
||||
|
||||
def _short_ttl(self, session_id: str, holder: str, ttl_seconds: float = 300.0) -> bool:
|
||||
return real_try_acquire(self, session_id, holder, ttl_seconds=1.0)
|
||||
|
||||
monkeypatch.setattr(SessionDB, "try_acquire_compression_lock", _short_ttl)
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
|
||||
parent_sid = "REFRESH_TEST"
|
||||
db.create_session(parent_sid, source="discord")
|
||||
|
||||
agent_a = _build_agent_with_db(db, parent_sid)
|
||||
agent_a._compression_lock_ttl_seconds = 1.0
|
||||
agent_a._compression_lock_refresh_interval = 0.25
|
||||
|
||||
def _slow_compress(*_a, **_kw):
|
||||
time.sleep(2.0)
|
||||
return [
|
||||
{"role": "user", "content": "[CONTEXT COMPACTION] summary"},
|
||||
{"role": "user", "content": "tail"},
|
||||
]
|
||||
|
||||
agent_a.context_compressor.compress.side_effect = _slow_compress
|
||||
messages = [{"role": "user", "content": f"m{i}"} for i in range(20)]
|
||||
|
||||
def run(agent):
|
||||
agent._compress_context(messages, "sys", approx_tokens=120_000)
|
||||
|
||||
t_a = threading.Thread(target=run, args=(agent_a,), name="refresh_owner")
|
||||
t_a.start()
|
||||
deadline = time.time() + 2.0
|
||||
while db.get_compression_lock_holder(parent_sid) is None and time.time() < deadline:
|
||||
time.sleep(0.05)
|
||||
assert db.get_compression_lock_holder(parent_sid) is not None
|
||||
time.sleep(1.2)
|
||||
assert db.try_acquire_compression_lock(
|
||||
parent_sid, "refresh_probe", ttl_seconds=1.0
|
||||
) is False, "live owner lease expired and was reclaimable before compression finished"
|
||||
t_a.join(timeout=10)
|
||||
|
||||
assert not t_a.is_alive()
|
||||
assert _count_children(db, parent_sid) == 1
|
||||
assert db.get_compression_lock_holder(parent_sid) is None
|
||||
|
||||
|
||||
class _NoLockSubsystemDB:
|
||||
"""Wraps a real SessionDB but simulates a pre-#34351 version skew.
|
||||
|
||||
|
|
@ -244,7 +292,7 @@ def test_missing_lock_subsystem_fails_open_not_infinite_loop(tmp_path: Path) ->
|
|||
assert agent.session_id != parent_sid
|
||||
|
||||
|
||||
def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None:
|
||||
def test_review_fork_disables_compression_to_prevent_stale_parent_fork(tmp_path: Path) -> None:
|
||||
"""The background-review fork must set ``compression_enabled = False``
|
||||
so it can never compress the parent it shares a session_id with
|
||||
(issue #38727).
|
||||
|
|
@ -270,8 +318,6 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None
|
|||
``AIAgent.run_conversation`` patched (so no LLM call happens) and
|
||||
captures the constructed review agent to assert the flag.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
import agent.background_review as br
|
||||
|
||||
captured = {}
|
||||
|
|
@ -283,21 +329,20 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None
|
|||
|
||||
parent_sid = "REVIEW_FORK_FLAG_TEST"
|
||||
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
db = SessionDB(db_path=Path(td) / "state.db")
|
||||
db.create_session(parent_sid, source="discord")
|
||||
parent = _build_agent_with_db(db, parent_sid)
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session(parent_sid, source="discord")
|
||||
parent = _build_agent_with_db(db, parent_sid)
|
||||
|
||||
# The worker does a local ``from run_agent import AIAgent``; patching
|
||||
# the class method covers that import path.
|
||||
from run_agent import AIAgent
|
||||
# The worker does a local ``from run_agent import AIAgent``; patching
|
||||
# the class method covers that import path.
|
||||
from run_agent import AIAgent
|
||||
|
||||
with patch.object(AIAgent, "run_conversation", _fake_run_conversation):
|
||||
br._run_review_in_thread(
|
||||
parent,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
"review this conversation",
|
||||
)
|
||||
with patch.object(AIAgent, "run_conversation", _fake_run_conversation):
|
||||
br._run_review_in_thread(
|
||||
parent,
|
||||
[{"role": "user", "content": "hi"}],
|
||||
"review this conversation",
|
||||
)
|
||||
|
||||
assert captured, (
|
||||
"_run_review_in_thread never reached run_conversation — the spawn path "
|
||||
|
|
@ -314,3 +359,4 @@ def test_review_fork_disables_compression_to_prevent_stale_parent_fork() -> None
|
|||
"conversation_loop.py only short-circuit when compression_enabled is "
|
||||
"False — this flag MUST be cleared on the review fork."
|
||||
)
|
||||
db.close()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Tests for agent/context_compressor.py — compression logic, thresholds, truncation fallback."""
|
||||
|
||||
import pytest
|
||||
import time
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.context_compressor import (
|
||||
|
|
@ -8,6 +9,7 @@ from agent.context_compressor import (
|
|||
HISTORICAL_TASK_HEADING,
|
||||
SUMMARY_PREFIX,
|
||||
)
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
@ -1464,6 +1466,60 @@ class TestAbortOnSummaryFailure:
|
|||
assert c._summary_failure_cooldown_until == 0.0
|
||||
assert len(result) < len(msgs)
|
||||
|
||||
def test_force_true_bypasses_persisted_session_cooldown(self, tmp_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "summary text"
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("s1", "cli")
|
||||
db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout")
|
||||
|
||||
c = self._make_compressor()
|
||||
c.bind_session_state(db, "s1")
|
||||
msgs = self._make_msgs()
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_llm:
|
||||
result = c.compress(msgs, current_tokens=999999, force=True)
|
||||
|
||||
mock_llm.assert_called()
|
||||
assert c._last_compress_aborted is False
|
||||
assert len(result) < len(msgs)
|
||||
assert db.get_compression_failure_cooldown("s1") is None
|
||||
|
||||
def test_success_clears_persisted_session_cooldown(self, tmp_path):
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "summary text"
|
||||
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("s1", "cli")
|
||||
db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout")
|
||||
|
||||
c = self._make_compressor()
|
||||
c.bind_session_state(db, "s1")
|
||||
c._summary_failure_cooldown_until = 0.0
|
||||
msgs = self._make_msgs()
|
||||
|
||||
with patch("agent.context_compressor.call_llm", return_value=mock_response) as mock_llm:
|
||||
result = c.compress(msgs, current_tokens=999999)
|
||||
|
||||
mock_llm.assert_called()
|
||||
assert c._last_compress_aborted is False
|
||||
assert len(result) < len(msgs)
|
||||
assert db.get_compression_failure_cooldown("s1") is None
|
||||
|
||||
def test_session_end_does_not_clear_persisted_session_cooldown(self, tmp_path):
|
||||
db = SessionDB(db_path=tmp_path / "state.db")
|
||||
db.create_session("s1", "cli")
|
||||
db.record_compression_failure_cooldown("s1", time.time() + 999.0, "timeout")
|
||||
|
||||
c = self._make_compressor()
|
||||
c.bind_session_state(db, "s1")
|
||||
c.on_session_end("s1", [])
|
||||
|
||||
assert db.get_compression_failure_cooldown("s1") is not None
|
||||
|
||||
|
||||
class TestSummaryPrefixNormalization:
|
||||
def test_legacy_prefix_is_replaced(self):
|
||||
|
|
|
|||
|
|
@ -9,11 +9,13 @@ confirm the prologue produces the right ``TurnContext`` and applies the
|
|||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.turn_context import TurnContext, build_turn_context
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
class _FakeTodoStore:
|
||||
|
|
@ -101,6 +103,33 @@ class _FakeAgent:
|
|||
self._persist_calls += 1
|
||||
|
||||
|
||||
def _make_agent_with_cooldown(db_path, session_id, *, cooldown_until=None):
|
||||
agent = _FakeAgent()
|
||||
agent.compression_enabled = True
|
||||
agent._emit_status = MagicMock()
|
||||
agent._compress_context = MagicMock(
|
||||
side_effect=lambda messages, *_a, **_k: (messages, "SYSTEM")
|
||||
)
|
||||
|
||||
db = SessionDB(db_path=db_path)
|
||||
db.create_session(session_id, source="cli")
|
||||
if cooldown_until is not None:
|
||||
db.record_compression_failure_cooldown(session_id, cooldown_until, "timeout")
|
||||
|
||||
with patch("agent.context_compressor.get_model_context_length", return_value=100000):
|
||||
compressor = ContextCompressor(
|
||||
model="test/model",
|
||||
threshold_percent=0.85,
|
||||
protect_first_n=2,
|
||||
protect_last_n=2,
|
||||
quiet_mode=True,
|
||||
)
|
||||
compressor.bind_session_state(db, session_id)
|
||||
agent.context_compressor = compressor
|
||||
agent._session_db = db
|
||||
return agent
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_runtime_main():
|
||||
"""``build_turn_context`` calls ``auxiliary_client.set_runtime_main`` as a
|
||||
|
|
@ -285,3 +314,53 @@ def test_between_turns_refresh_no_churn_when_unchanged():
|
|||
|
||||
assert agent.tools is same # not replaced → no churn
|
||||
|
||||
|
||||
def test_preflight_skips_when_persisted_cooldown_survives_restart(tmp_path):
|
||||
agent = _make_agent_with_cooldown(
|
||||
tmp_path / "state.db",
|
||||
"sess-1",
|
||||
cooldown_until=4_000_000_000.0,
|
||||
)
|
||||
|
||||
with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \
|
||||
patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999):
|
||||
ctx = _build(agent)
|
||||
|
||||
assert isinstance(ctx, TurnContext)
|
||||
agent._emit_status.assert_not_called()
|
||||
agent._compress_context.assert_not_called()
|
||||
|
||||
|
||||
def test_preflight_still_runs_for_other_session_with_same_db(tmp_path):
|
||||
db_path = tmp_path / "state.db"
|
||||
_make_agent_with_cooldown(
|
||||
db_path,
|
||||
"sess-1",
|
||||
cooldown_until=4_000_000_000.0,
|
||||
)
|
||||
agent = _make_agent_with_cooldown(db_path, "sess-2")
|
||||
|
||||
with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \
|
||||
patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999):
|
||||
ctx = _build(agent)
|
||||
|
||||
assert isinstance(ctx, TurnContext)
|
||||
agent._emit_status.assert_called_once()
|
||||
agent._compress_context.assert_called()
|
||||
|
||||
|
||||
def test_expired_cooldown_allows_preflight(tmp_path):
|
||||
agent = _make_agent_with_cooldown(
|
||||
tmp_path / "state.db",
|
||||
"sess-1",
|
||||
cooldown_until=1.0,
|
||||
)
|
||||
|
||||
with patch("agent.turn_context._should_run_preflight_estimate", return_value=True), \
|
||||
patch("agent.turn_context.estimate_request_tokens_rough", return_value=999_999):
|
||||
ctx = _build(agent)
|
||||
|
||||
assert isinstance(ctx, TurnContext)
|
||||
agent._emit_status.assert_called_once()
|
||||
agent._compress_context.assert_called()
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import sqlite3
|
|||
import time
|
||||
import pytest
|
||||
|
||||
import hermes_state
|
||||
from hermes_state import SCHEMA_SQL, SCHEMA_VERSION, SessionDB
|
||||
|
||||
|
||||
|
|
@ -4725,3 +4726,55 @@ def test_gateway_session_recovery_reopens_legacy_agent_close_rows(db):
|
|||
chat_id="chat-1",
|
||||
chat_type="dm",
|
||||
) is None
|
||||
|
||||
|
||||
def test_compression_failure_cooldown_round_trips_and_clears(db):
|
||||
db.create_session("s1", "cli")
|
||||
|
||||
cooldown_until = time.time() + 60.0
|
||||
db.record_compression_failure_cooldown("s1", cooldown_until, "timeout")
|
||||
|
||||
state = db.get_compression_failure_cooldown("s1")
|
||||
assert state is not None
|
||||
assert state["cooldown_until"] == cooldown_until
|
||||
assert state["error"] == "timeout"
|
||||
|
||||
db.clear_compression_failure_cooldown("s1")
|
||||
assert db.get_compression_failure_cooldown("s1") is None
|
||||
|
||||
row = db.get_session("s1")
|
||||
assert row["compression_failure_cooldown_until"] is None
|
||||
assert row["compression_failure_error"] is None
|
||||
|
||||
|
||||
def test_expired_compression_failure_cooldown_is_ignored(db):
|
||||
db.create_session("s1", "cli")
|
||||
|
||||
db.record_compression_failure_cooldown("s1", time.time() - 60.0, "stale")
|
||||
|
||||
assert db.get_compression_failure_cooldown("s1") is None
|
||||
|
||||
|
||||
def test_refresh_compression_lock_requires_holder_and_preserves_reclaimability(db, monkeypatch):
|
||||
db.create_session("s1", "cli")
|
||||
|
||||
monkeypatch.setattr(hermes_state.time, "time", lambda: 1000.0)
|
||||
assert db.try_acquire_compression_lock("s1", "holder-a", ttl_seconds=10.0) is True
|
||||
|
||||
original_expires = db._conn.execute(
|
||||
"SELECT expires_at FROM compression_locks WHERE session_id = ?",
|
||||
("s1",),
|
||||
).fetchone()[0]
|
||||
|
||||
monkeypatch.setattr(hermes_state.time, "time", lambda: 1005.0)
|
||||
assert db.refresh_compression_lock("s1", "holder-a", ttl_seconds=10.0) is True
|
||||
refreshed_expires = db._conn.execute(
|
||||
"SELECT expires_at FROM compression_locks WHERE session_id = ?",
|
||||
("s1",),
|
||||
).fetchone()[0]
|
||||
assert refreshed_expires > original_expires
|
||||
|
||||
assert db.refresh_compression_lock("s1", "holder-b", ttl_seconds=10.0) is False
|
||||
|
||||
monkeypatch.setattr(hermes_state.time, "time", lambda: 1016.0)
|
||||
assert db.try_acquire_compression_lock("s1", "holder-b", ttl_seconds=10.0) is True
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue