fix(hindsight): scope document_id per process to avoid resume overwrite (#6602)

Reusing session_id as document_id caused data loss on /resume: when
the session is loaded again, _session_turns starts empty and the next
retain replaces the entire previously stored content.

Now each process lifecycle gets its own document_id formed as
{session_id}-{startup_timestamp}, so:
- Same session, same process: turns accumulate into one document (existing behavior)
- Resume (new process, same session): writes a new document, old one preserved
- Forks: child process gets its own document; parent's doc is untouched

Also adds session lineage tags so all processes for the same session
(or its parent) can still be filtered together via recall:
- session:<session_id> on every retain
- parent:<parent_session_id> when initialized with parent_session_id

Closes #6602
This commit is contained in:
Nicolò Boschi 2026-04-09 16:25:09 +02:00 committed by Teknium
parent 3a86f70969
commit f9c6c5ab84
2 changed files with 142 additions and 5 deletions

View file

@ -362,6 +362,8 @@ class HindsightMemoryProvider(MemoryProvider):
self._prefetch_thread = None self._prefetch_thread = None
self._sync_thread = None self._sync_thread = None
self._session_id = "" self._session_id = ""
self._parent_session_id = ""
self._document_id = ""
# Tags # Tags
self._tags: list[str] | None = None self._tags: list[str] | None = None
@ -679,6 +681,15 @@ class HindsightMemoryProvider(MemoryProvider):
def initialize(self, session_id: str, **kwargs) -> None: def initialize(self, session_id: str, **kwargs) -> None:
self._session_id = str(session_id or "").strip() self._session_id = str(session_id or "").strip()
self._parent_session_id = str(kwargs.get("parent_session_id", "") or "").strip()
# Each process lifecycle gets its own document_id. Reusing session_id
# alone caused overwrites on /resume — the reloaded session starts
# with an empty _session_turns, so the next retain would replace the
# previously stored content. session_id stays in tags so processes
# for the same session remain filterable together.
start_ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
self._document_id = f"{self._session_id}-{start_ts}"
# Check client version and auto-upgrade if needed # Check client version and auto-upgrade if needed
try: try:
@ -1026,6 +1037,12 @@ class HindsightMemoryProvider(MemoryProvider):
len(self._session_turns), sum(len(t) for t in self._session_turns)) len(self._session_turns), sum(len(t) for t in self._session_turns))
content = "[" + ",".join(self._session_turns) + "]" content = "[" + ",".join(self._session_turns) + "]"
lineage_tags: list[str] = []
if self._session_id:
lineage_tags.append(f"session:{self._session_id}")
if self._parent_session_id:
lineage_tags.append(f"parent:{self._parent_session_id}")
def _sync(): def _sync():
try: try:
client = self._get_client() client = self._get_client()
@ -1036,15 +1053,16 @@ class HindsightMemoryProvider(MemoryProvider):
message_count=len(self._session_turns) * 2, message_count=len(self._session_turns) * 2,
turn_index=self._turn_index, turn_index=self._turn_index,
), ),
tags=lineage_tags or None,
) )
item.pop("bank_id", None) item.pop("bank_id", None)
item.pop("retain_async", None) item.pop("retain_async", None)
logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d", logger.debug("Hindsight retain: bank=%s, doc=%s, async=%s, content_len=%d, num_turns=%d",
self._bank_id, self._session_id, self._retain_async, len(content), len(self._session_turns)) self._bank_id, self._document_id, self._retain_async, len(content), len(self._session_turns))
self._run_sync(client.aretain_batch( self._run_sync(client.aretain_batch(
bank_id=self._bank_id, bank_id=self._bank_id,
items=[item], items=[item],
document_id=self._session_id, document_id=self._document_id,
retain_async=self._retain_async, retain_async=self._retain_async,
)) ))
logger.debug("Hindsight retain succeeded") logger.debug("Hindsight retain succeeded")

View file

@ -549,12 +549,12 @@ class TestSyncTurn:
p._client.aretain_batch.assert_called_once() p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["bank_id"] == "test-bank" assert call_kwargs["bank_id"] == "test-bank"
assert call_kwargs["document_id"] == "session-1" assert call_kwargs["document_id"].startswith("session-1-")
assert call_kwargs["retain_async"] is True assert call_kwargs["retain_async"] is True
assert len(call_kwargs["items"]) == 1 assert len(call_kwargs["items"]) == 1
item = call_kwargs["items"][0] item = call_kwargs["items"][0]
assert item["context"] == "conversation between Hermes Agent and the User" assert item["context"] == "conversation between Hermes Agent and the User"
assert item["tags"] == ["conv", "session1"] assert item["tags"] == ["conv", "session1", "session:session-1"]
content = json.loads(item["content"]) content = json.loads(item["content"])
assert len(content) == 1 assert len(content) == 1
assert content[0][0]["role"] == "user" assert content[0][0]["role"] == "user"
@ -582,6 +582,36 @@ class TestSyncTurn:
assert p._sync_thread is None assert p._sync_thread is None
p._client.aretain_batch.assert_not_called() p._client.aretain_batch.assert_not_called()
def test_sync_turn_with_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["conv", "session1"])
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "conv" in item["tags"]
assert "session1" in item["tags"]
assert "session:test-session" in item["tags"]
def test_sync_turn_uses_aretain_batch(self, provider):
"""sync_turn should use aretain_batch with retain_async."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"].startswith("test-session-")
assert call_kwargs["retain_async"] is True
assert len(call_kwargs["items"]) == 1
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
def test_sync_turn_custom_context(self, provider_with_config):
p = provider_with_config(retain_context="my-agent")
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["context"] == "my-agent"
def test_sync_turn_every_n_turns(self, provider_with_config): def test_sync_turn_every_n_turns(self, provider_with_config):
p = provider_with_config(retain_every_n_turns=3, retain_async=False) p = provider_with_config(retain_every_n_turns=3, retain_async=False)
p.sync_turn("turn1-user", "turn1-asst") p.sync_turn("turn1-user", "turn1-asst")
@ -592,7 +622,7 @@ class TestSyncTurn:
p._sync_thread.join(timeout=5.0) p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.assert_called_once() p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"] == "test-session" assert call_kwargs["document_id"].startswith("test-session-")
assert call_kwargs["retain_async"] is False assert call_kwargs["retain_async"] is False
item = call_kwargs["items"][0] item = call_kwargs["items"][0]
content = json.loads(item["content"]) content = json.loads(item["content"])
@ -604,6 +634,95 @@ class TestSyncTurn:
assert item["metadata"]["turn_index"] == "3" assert item["metadata"]["turn_index"] == "3"
assert item["metadata"]["message_count"] == "6" assert item["metadata"]["message_count"] == "6"
def test_sync_turn_accumulates_full_session(self, provider_with_config):
"""Each retain sends the ENTIRE session, not just the latest batch."""
p = provider_with_config(retain_every_n_turns=2)
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.reset_mock()
p.sync_turn("turn3-user", "turn3-asst")
p.sync_turn("turn4-user", "turn4-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain ALL turns from the session
assert "turn1-user" in content
assert "turn2-user" in content
assert "turn3-user" in content
assert "turn4-user" in content
def test_sync_turn_passes_document_id(self, provider):
"""sync_turn should pass document_id (session_id + per-startup ts)."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
call_kwargs = provider._client.aretain_batch.call_args.kwargs
# Format: {session_id}-{YYYYMMDD_HHMMSS_microseconds}
assert call_kwargs["document_id"].startswith("test-session-")
assert call_kwargs["document_id"] == provider._document_id
def test_resume_creates_new_document(self, tmp_path, monkeypatch):
"""Resuming a session (re-initializing) gets a new document_id
so previously stored content is not overwritten."""
config = {"mode": "cloud", "apiKey": "k", "api_url": "http://x", "bank_id": "b"}
config_path = tmp_path / "hindsight" / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(json.dumps(config))
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
p1 = HindsightMemoryProvider()
p1.initialize(session_id="resumed-session", hermes_home=str(tmp_path), platform="cli")
# Sleep just enough that the microsecond timestamp differs
import time
time.sleep(0.001)
p2 = HindsightMemoryProvider()
p2.initialize(session_id="resumed-session", hermes_home=str(tmp_path), platform="cli")
# Same session, but each process gets its own document_id
assert p1._document_id != p2._document_id
assert p1._document_id.startswith("resumed-session-")
assert p2._document_id.startswith("resumed-session-")
def test_sync_turn_session_tag(self, provider):
"""Each retain should be tagged with session:<id> for filtering."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:test-session" in item["tags"]
def test_sync_turn_parent_session_tag(self, tmp_path, monkeypatch):
"""When initialized with parent_session_id, parent tag is added."""
config = {"mode": "cloud", "apiKey": "k", "api_url": "http://x", "bank_id": "b"}
config_path = tmp_path / "hindsight" / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(json.dumps(config))
monkeypatch.setattr("plugins.memory.hindsight.get_hermes_home", lambda: tmp_path)
p = HindsightMemoryProvider()
p.initialize(
session_id="child-session",
hermes_home=str(tmp_path),
platform="cli",
parent_session_id="parent-session",
)
p._client = _make_mock_client()
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert "session:child-session" in item["tags"]
assert "parent:parent-session" in item["tags"]
def test_sync_turn_error_does_not_raise(self, provider): def test_sync_turn_error_does_not_raise(self, provider):
provider._client.aretain_batch.side_effect = RuntimeError("network error") provider._client.aretain_batch.side_effect = RuntimeError("network error")
provider.sync_turn("hello", "hi") provider.sync_turn("hello", "hi")