mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-06-15 09:21:36 +00:00
fix(agent): persist repaired-turn responses (#46071)
This commit is contained in:
parent
723c2331bd
commit
2b4873f7fb
3 changed files with 195 additions and 19 deletions
51
run_agent.py
51
run_agent.py
|
|
@ -1548,9 +1548,10 @@ class AIAgent:
|
|||
def _flush_messages_to_session_db(self, messages: List[Dict], conversation_history: List[Dict] = None):
|
||||
"""Persist any un-flushed messages to the SQLite session store.
|
||||
|
||||
Uses _last_flushed_db_idx to track which messages have already been
|
||||
written, so repeated calls (from multiple exit paths) only write
|
||||
truly new messages — preventing the duplicate-write bug (#860).
|
||||
Uses per-session message identity tracking so repeated calls (from
|
||||
multiple exit paths) only write truly new messages — preventing the
|
||||
duplicate-write bug (#860) without relying on positional slices that
|
||||
can drift after message-sequence repair.
|
||||
"""
|
||||
if not self._session_db:
|
||||
return
|
||||
|
|
@ -1559,14 +1560,41 @@ class AIAgent:
|
|||
# Retry row creation if the earlier attempt failed transiently.
|
||||
if not self._session_db_created:
|
||||
self._ensure_db_session()
|
||||
start_idx = len(conversation_history) if conversation_history else 0
|
||||
# Guard against the flush cursor overshooting the message list.
|
||||
# This can happen when repair_message_sequence compacts the list
|
||||
# (merging consecutive users, dropping stray tools) after the
|
||||
# cursor was set. Fall back to start_idx so we don't skip
|
||||
# persisting the assistant/tool chain (#44837).
|
||||
flush_from = max(start_idx, min(self._last_flushed_db_idx, len(messages)))
|
||||
for msg in messages[flush_from:]:
|
||||
# Positional flushing used to slice at
|
||||
# max(len(conversation_history), _last_flushed_db_idx). That
|
||||
# assumes the live `messages` list is the original history plus a
|
||||
# new tail. repair_message_sequence can shrink/merge the history
|
||||
# copy before the final flush, making len(conversation_history)
|
||||
# larger than len(messages); the slice is then empty and delivered
|
||||
# assistant responses never reach state.db (#46053).
|
||||
#
|
||||
# Track object identities instead. `messages` is a shallow copy of
|
||||
# `conversation_history`, so history dicts are skipped by identity,
|
||||
# and new dicts appended during this turn are written once even if
|
||||
# repair compacts the list around them.
|
||||
current_session_id = getattr(self, "session_id", None)
|
||||
flushed_session_id = getattr(self, "_flushed_db_message_session_id", None)
|
||||
if flushed_session_id != current_session_id or self._last_flushed_db_idx == 0:
|
||||
self._flushed_db_message_ids = set()
|
||||
self._flushed_db_message_session_id = current_session_id
|
||||
flushed_ids = getattr(self, "_flushed_db_message_ids", None)
|
||||
if not isinstance(flushed_ids, set):
|
||||
flushed_ids = set()
|
||||
self._flushed_db_message_ids = flushed_ids
|
||||
history_ids = {
|
||||
id(item) for item in (conversation_history or [])
|
||||
if isinstance(item, dict)
|
||||
}
|
||||
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
msg_id = id(msg)
|
||||
if msg_id in flushed_ids:
|
||||
continue
|
||||
if msg_id in history_ids:
|
||||
flushed_ids.add(msg_id)
|
||||
continue
|
||||
role = msg.get("role", "unknown")
|
||||
content = msg.get("content")
|
||||
# Persist multimodal tool results as their text summary only —
|
||||
|
|
@ -1605,6 +1633,7 @@ class AIAgent:
|
|||
codex_reasoning_items=msg.get("codex_reasoning_items") if role == "assistant" else None,
|
||||
codex_message_items=msg.get("codex_message_items") if role == "assistant" else None,
|
||||
)
|
||||
flushed_ids.add(msg_id)
|
||||
self._last_flushed_db_idx = len(messages)
|
||||
except Exception as e:
|
||||
logger.warning("Session DB append_message failed: %s", e)
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class TestFlushAfterCompression:
|
|||
)
|
||||
|
||||
def test_flush_with_stale_history_loses_messages(self):
|
||||
"""Demonstrates the bug condition: stale conversation_history causes data loss."""
|
||||
"""Stale conversation_history no longer causes data loss."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
|
@ -120,17 +120,14 @@ class TestFlushAfterCompression:
|
|||
{"role": "assistant", "content": "continuing..."},
|
||||
]
|
||||
|
||||
# Bug: passing a conversation_history longer than compressed messages
|
||||
# Stale history longer than messages: the old positional flush
|
||||
# sliced past the end and dropped both messages (#46053).
|
||||
stale_history = [{"role": "user", "content": f"msg{i}"} for i in range(100)]
|
||||
agent._flush_messages_to_session_db(compressed, stale_history)
|
||||
|
||||
rows = db.get_messages("new-session")
|
||||
# With the stale history, flush_from = max(100, 0) = 100
|
||||
# But compressed only has 2 entries → messages[100:] = empty
|
||||
assert len(rows) == 0, (
|
||||
"Expected 0 messages with stale conversation_history "
|
||||
"(this test verifies the bug condition exists)"
|
||||
)
|
||||
assert len(rows) == 2
|
||||
assert [row["content"] for row in rows] == ["summary", "continuing..."]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
150
tests/run_agent/test_identity_flush.py
Normal file
150
tests/run_agent/test_identity_flush.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
"""Regression tests for identity-based SessionDB flushing (#46053)."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
SESSION_ID = "test-identity-flush"
|
||||
|
||||
|
||||
def _make_agent(session_db, session_id=SESSION_ID):
|
||||
with patch.dict(os.environ, {"OPENROUTER_API_KEY": "test-key"}):
|
||||
from run_agent import AIAgent
|
||||
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
session_db=session_db,
|
||||
session_id=session_id,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._ensure_db_session()
|
||||
return agent
|
||||
|
||||
|
||||
def _contents(db, session_id=SESSION_ID):
|
||||
return [row["content"] for row in db.get_messages(session_id)]
|
||||
|
||||
|
||||
class TestIdentityFlush:
|
||||
def test_repair_shrunk_messages_below_history_length_still_persists_assistant(self):
|
||||
"""When repair shortens messages below conversation_history, don't slice empty."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db = SessionDB(db_path=Path(tmpdir) / "t.db")
|
||||
try:
|
||||
agent = _make_agent(db)
|
||||
|
||||
# Simulate history already loaded from state.db.
|
||||
history = [{"role": "user", "content": f"u{i}"} for i in range(6)]
|
||||
for msg in history:
|
||||
db.append_message(
|
||||
session_id=SESSION_ID,
|
||||
role=msg["role"],
|
||||
content=msg["content"],
|
||||
)
|
||||
|
||||
# repair_message_sequence merged the six history rows into one
|
||||
# dict before this turn appended the new user/assistant pair.
|
||||
messages = [
|
||||
{"role": "user", "content": "\n\n".join(f"u{i}" for i in range(6))},
|
||||
{"role": "user", "content": "new question"},
|
||||
{"role": "assistant", "content": "new answer"},
|
||||
]
|
||||
assert len(history) > len(messages)
|
||||
|
||||
# The old positional flush computed flush_from >= len(messages)
|
||||
# and dropped the assistant. Identity flush persists new dicts.
|
||||
agent._last_flushed_db_idx = len(history)
|
||||
agent._flush_messages_to_session_db(messages, history)
|
||||
|
||||
contents = _contents(db)
|
||||
assert "new question" in contents
|
||||
assert "new answer" in contents
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_overlapping_turn_stale_cursor_does_not_drop_assistant(self):
|
||||
"""A stale cached-agent cursor must not suppress this turn's new dicts."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db = SessionDB(db_path=Path(tmpdir) / "t.db")
|
||||
try:
|
||||
agent = _make_agent(db)
|
||||
history = [
|
||||
{"role": "user", "content": "old question"},
|
||||
{"role": "assistant", "content": "old answer"},
|
||||
]
|
||||
for msg in history:
|
||||
db.append_message(
|
||||
session_id=SESSION_ID,
|
||||
role=msg["role"],
|
||||
content=msg["content"],
|
||||
)
|
||||
|
||||
messages = history + [
|
||||
{"role": "user", "content": "current question"},
|
||||
{"role": "assistant", "content": "current answer"},
|
||||
]
|
||||
agent._last_flushed_db_idx = len(messages) + 10
|
||||
agent._flush_messages_to_session_db(messages, history)
|
||||
|
||||
assert _contents(db) == [
|
||||
"old question",
|
||||
"old answer",
|
||||
"current question",
|
||||
"current answer",
|
||||
]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_repeated_flush_same_turn_writes_once(self):
|
||||
"""Identity tracking preserves #860 same-turn dedup behavior."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db = SessionDB(db_path=Path(tmpdir) / "t.db")
|
||||
try:
|
||||
agent = _make_agent(db)
|
||||
messages = [{"role": "user", "content": "q"}]
|
||||
|
||||
agent._flush_messages_to_session_db(messages, [])
|
||||
messages.append({"role": "assistant", "content": "a"})
|
||||
agent._flush_messages_to_session_db(messages, [])
|
||||
agent._flush_messages_to_session_db(messages, [])
|
||||
|
||||
assert _contents(db) == ["q", "a"]
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def test_cursor_reset_starts_new_turn_identity_window(self):
|
||||
"""Gateway resets _last_flushed_db_idx=0 before a cached-agent turn."""
|
||||
from hermes_state import SessionDB
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db = SessionDB(db_path=Path(tmpdir) / "t.db")
|
||||
try:
|
||||
agent = _make_agent(db)
|
||||
first_turn = [
|
||||
{"role": "user", "content": "q1"},
|
||||
{"role": "assistant", "content": "a1"},
|
||||
]
|
||||
agent._flush_messages_to_session_db(first_turn, [])
|
||||
|
||||
history = [dict(m) for m in first_turn]
|
||||
second_turn = history + [
|
||||
{"role": "user", "content": "q2"},
|
||||
{"role": "assistant", "content": "a2"},
|
||||
]
|
||||
agent._last_flushed_db_idx = 0
|
||||
agent._flush_messages_to_session_db(second_turn, history)
|
||||
|
||||
assert _contents(db) == ["q1", "a1", "q2", "a2"]
|
||||
finally:
|
||||
db.close()
|
||||
Loading…
Add table
Add a link
Reference in a new issue