diff --git a/hermes_state.py b/hermes_state.py index 68387ede17..2d8a0fd4af 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -1249,10 +1249,37 @@ class SessionDB: try: with self._lock: ctx_cursor = self._conn.execute( - """SELECT role, content FROM messages - WHERE session_id = ? AND id >= ? - 1 AND id <= ? + 1 - ORDER BY id""", - (match["session_id"], match["id"], match["id"]), + """WITH target AS ( + SELECT session_id, timestamp, id + FROM messages + WHERE id = ? + ) + SELECT role, content + FROM ( + SELECT m.id, m.timestamp, m.role, m.content + FROM messages m + JOIN target t ON t.session_id = m.session_id + WHERE (m.timestamp < t.timestamp) + OR (m.timestamp = t.timestamp AND m.id < t.id) + ORDER BY m.timestamp DESC, m.id DESC + LIMIT 1 + ) + UNION ALL + SELECT role, content + FROM messages + WHERE id = ? + UNION ALL + SELECT role, content + FROM ( + SELECT m.id, m.timestamp, m.role, m.content + FROM messages m + JOIN target t ON t.session_id = m.session_id + WHERE (m.timestamp > t.timestamp) + OR (m.timestamp = t.timestamp AND m.id > t.id) + ORDER BY m.timestamp ASC, m.id ASC + LIMIT 1 + )""", + (match["id"], match["id"]), ) context_msgs = [ {"role": r["role"], "content": (r["content"] or "")[:200]} diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 72cf47e076..dfb2445c55 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -365,6 +365,25 @@ class TestFTS5Search: assert isinstance(results[0]["context"], list) assert len(results[0]["context"]) > 0 + def test_search_context_uses_session_neighbors_when_ids_are_interleaved(self, db): + db.create_session(session_id="s1", source="cli") + db.create_session(session_id="s2", source="cli") + + db.append_message("s1", role="user", content="before needle") + db.append_message("s2", role="user", content="other session message") + db.append_message("s1", role="assistant", content="needle match") + db.append_message("s2", role="assistant", content="another other session message") + db.append_message("s1", role="user", content="after needle") + + results = db.search_messages('"needle match"') + needle_result = next(r for r in results if r["session_id"] == "s1" and "needle match" in r["snippet"]) + + assert [msg["content"] for msg in needle_result["context"]] == [ + "before needle", + "needle match", + "after needle", + ] + def test_search_special_chars_do_not_crash(self, db): """FTS5 special characters in queries must not raise OperationalError.""" db.create_session(session_id="s1", source="cli")