fix(agent): rewind flush cursor exactly when repair compacts before the cursor

Follow-up to the #44837 clamp: a min() clamp only fixes cursor overshoot
past the new end of the list. When repair_message_sequence drops/merges
messages at indexes below the cursor, the clamp leaves the cursor pointing
past unflushed rows and the turn-end flush silently skips them.

Extract repair_message_sequence_with_cursor(): snapshot the flushed prefix
by object identity before repair, then recompute the cursor as the count
of surviving flushed messages. Falls back to the clamp when no snapshot is
available. Keeps the safety guard in _flush_messages_to_session_db.

Adds targeted tests for overshoot, before-cursor compaction, no-repair,
bare-agent, and the flush guard.
This commit is contained in:
Teknium 2026-06-12 15:47:34 -07:00
parent 5d0408d9fe
commit 8905ee6b8a
3 changed files with 143 additions and 9 deletions

View file

@ -445,6 +445,45 @@ def repair_message_sequence(agent, messages: List[Dict]) -> int:
return repairs
def repair_message_sequence_with_cursor(agent, messages: List[Dict]) -> int:
"""Run :func:`repair_message_sequence` and keep the SessionDB flush
cursor consistent with the compacted list (#44837).
``repair_message_sequence`` merges/drops messages in place, shrinking
the list. ``_last_flushed_db_idx`` (the DB-write cursor) indexes into
that list, so after compaction it can point past the new end the
turn-end flush would then skip the assistant/tool chain entirely or
past unflushed messages shifted to lower indexes.
Repair preserves object identity for surviving messages, so counting
the survivors from the previously-flushed prefix gives the exact new
cursor even when messages are dropped/merged at indexes *before* the
cursor a plain ``min()`` clamp would silently skip that many
unflushed rows. Falls back to the clamp when no prefix snapshot is
available.
Returns the number of repairs made (same as ``repair_message_sequence``).
"""
pre_repair_flushed_ids = None
flush_cursor = getattr(agent, "_last_flushed_db_idx", None)
if isinstance(flush_cursor, int) and flush_cursor > 0:
pre_repair_flushed_ids = {id(m) for m in messages[:flush_cursor]}
repairs = repair_message_sequence(agent, messages)
if repairs > 0 and hasattr(agent, "_last_flushed_db_idx"):
if pre_repair_flushed_ids is not None:
agent._last_flushed_db_idx = sum(
1 for m in messages if id(m) in pre_repair_flushed_ids
)
else:
agent._last_flushed_db_idx = min(
agent._last_flushed_db_idx, len(messages)
)
return repairs
def strip_think_blocks(agent, content: str) -> str:
"""Remove reasoning/thinking blocks from content, returning only visible text.

View file

@ -595,21 +595,17 @@ def run_conversation(
# landed after an orphan tool result). Most providers return
# empty content on malformed sequences, which would otherwise
# retrigger the empty-retry loop indefinitely.
repaired_seq = agent._repair_message_sequence(messages)
# repair_message_sequence_with_cursor also recomputes the SessionDB
# flush cursor (_last_flushed_db_idx) when repair compacts the list,
# so the turn-end flush doesn't skip the assistant/tool chain (#44837).
from agent.agent_runtime_helpers import repair_message_sequence_with_cursor
repaired_seq = repair_message_sequence_with_cursor(agent, messages)
if repaired_seq > 0:
request_logger.info(
"Repaired %s message-alternation violations before request (session=%s)",
repaired_seq,
agent.session_id or "-",
)
# Clamp the SessionDB flush cursor after compaction. If repair
# merged or dropped messages, _last_flushed_db_idx may now point
# past the new end of `messages`, causing turn-end flush to skip
# the assistant/tool chain entirely (#44837).
if hasattr(agent, "_last_flushed_db_idx"):
agent._last_flushed_db_idx = min(
agent._last_flushed_db_idx, len(messages)
)
api_messages = []
for idx, msg in enumerate(messages):

View file

@ -199,3 +199,102 @@ def test_repair_preserves_system_messages():
AIAgent._repair_message_sequence(agent, messages)
assert messages == original
# ── repair_message_sequence_with_cursor (#44837) ───────────────────────────
from agent.agent_runtime_helpers import repair_message_sequence_with_cursor
def test_cursor_clamped_when_compaction_shrinks_below_cursor():
"""Cursor past the new end of the list must come back in range so the
turn-end flush doesn't skip the assistant/tool chain (#44837)."""
agent = _bare_agent()
messages = [
{"role": "user", "content": "first"},
{"role": "user", "content": "second"},
]
agent._last_flushed_db_idx = 2 # both rows already flushed
repairs = repair_message_sequence_with_cursor(agent, messages)
assert repairs == 1
assert len(messages) == 1
assert agent._last_flushed_db_idx == 1
def test_cursor_rewinds_when_compaction_happens_before_cursor():
"""Repair that drops/merges messages at indexes BELOW the cursor must
rewind it by the number removed, or unflushed rows get skipped.
A plain min() clamp does NOT catch this case."""
agent = _bare_agent()
flushed_a = {"role": "user", "content": "first"}
flushed_b = {"role": "user", "content": "second"} # merged into flushed_a
unflushed_assistant = {"role": "assistant", "content": "answer"}
messages = [flushed_a, flushed_b, unflushed_assistant]
agent._last_flushed_db_idx = 2 # the two user rows are flushed
repairs = repair_message_sequence_with_cursor(agent, messages)
assert repairs == 1
assert len(messages) == 2
# Cursor must now point at the assistant (index 1), not stay at 2 —
# min(2, len=2) would leave it at 2 and the flush would skip it.
assert agent._last_flushed_db_idx == 1
assert messages[agent._last_flushed_db_idx] is unflushed_assistant
def test_cursor_untouched_when_no_repairs():
agent = _bare_agent()
messages = [
{"role": "user", "content": "hi"},
{"role": "assistant", "content": "hello"},
]
agent._last_flushed_db_idx = 1
repairs = repair_message_sequence_with_cursor(agent, messages)
assert repairs == 0
assert agent._last_flushed_db_idx == 1
def test_cursor_helper_safe_without_cursor_attribute():
"""Bare agents (no _last_flushed_db_idx) must not crash."""
agent = _bare_agent()
messages = [
{"role": "user", "content": "a"},
{"role": "user", "content": "b"},
]
repairs = repair_message_sequence_with_cursor(agent, messages)
assert repairs == 1
assert not hasattr(agent, "_last_flushed_db_idx")
def test_flush_guard_clamps_overshooting_cursor():
"""_flush_messages_to_session_db safety net: an overshooting cursor must
not produce a negative-start slice that skips everything (#44837)."""
class _DB:
def __init__(self):
self.rows = []
def append_message(self, **kw):
self.rows.append(kw)
agent = _bare_agent()
agent._session_db = _DB()
agent._session_db_created = True
agent.session_id = "s1"
agent._persist_user_message_override = None
agent._last_flushed_db_idx = 5 # stale — past end of compacted list
messages = [
{"role": "user", "content": "q"},
{"role": "assistant", "content": "a"},
]
AIAgent._flush_messages_to_session_db(agent, messages, conversation_history=[])
# min(5, 2) = 2 → nothing skipped below start_idx, cursor settles at 2
assert agent._last_flushed_db_idx == 2