feat(hindsight): richer session-scoped retain metadata

- Add configurable retain_tags / retain_source / retain_user_prefix /
  retain_assistant_prefix knobs for native Hindsight.
- Thread gateway session identity (user_name, chat_id, chat_name,
  chat_type, thread_id) through AIAgent and MemoryManager into
  MemoryProvider.initialize kwargs so providers can scope and tag
  retained memories.
- Hindsight attaches the new identity fields as retain metadata,
  merges per-call tool tags with configured default tags, and uses
  the configurable transcript labels for auto-retained turns.

Co-authored-by: Abner <abner.the.foreman@agentmail.to>
This commit is contained in:
Abner 2026-04-22 05:23:50 -07:00 committed by Teknium
parent b8663813b6
commit b66644f0ec
7 changed files with 387 additions and 150 deletions

View file

@ -79,6 +79,28 @@ class TestMemoryManagerUserIdThreading:
assert p._init_kwargs.get("platform") == "telegram"
assert p._init_session_id == "sess-123"
def test_chat_context_forwarded_to_provider(self):
mgr = MemoryManager()
p = RecordingProvider()
mgr.add_provider(p)
mgr.initialize_all(
session_id="sess-chat",
platform="discord",
user_id="discord_u_7",
user_name="fakeusername",
chat_id="1485316232612941897",
chat_name="fakeassistantname-forums",
chat_type="thread",
thread_id="1491249007475949698",
)
assert p._init_kwargs.get("user_name") == "fakeusername"
assert p._init_kwargs.get("chat_id") == "1485316232612941897"
assert p._init_kwargs.get("chat_name") == "fakeassistantname-forums"
assert p._init_kwargs.get("chat_type") == "thread"
assert p._init_kwargs.get("thread_id") == "1491249007475949698"
def test_no_user_id_when_cli(self):
"""CLI sessions should not have user_id in kwargs."""
mgr = MemoryManager()
@ -334,3 +356,4 @@ class TestAIAgentUserIdPropagation:
agent = object.__new__(AIAgent)
agent._user_id = None
assert agent._user_id is None

View file

@ -6,6 +6,7 @@ turn counting, tags), and schema completeness.
"""
import json
import re
import threading
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
@ -18,6 +19,7 @@ from plugins.memory.hindsight import (
REFLECT_SCHEMA,
RETAIN_SCHEMA,
_load_config,
_normalize_retain_tags,
)
@ -32,14 +34,30 @@ def _clean_env(monkeypatch):
for key in (
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY",
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_SOURCE",
"HINDSIGHT_RETAIN_USER_PREFIX", "HINDSIGHT_RETAIN_ASSISTANT_PREFIX",
):
monkeypatch.delenv(key, raising=False)
def _make_mock_client():
"""Create a mock Hindsight client with async methods."""
async def _aretain(
bank_id,
content,
timestamp=None,
context=None,
document_id=None,
metadata=None,
entities=None,
tags=None,
update_mode=None,
retain_async=None,
):
return SimpleNamespace(ok=True)
client = MagicMock()
client.aretain = AsyncMock()
client.aretain = AsyncMock(side_effect=_aretain)
client.arecall = AsyncMock(
return_value=SimpleNamespace(
results=[
@ -56,6 +74,14 @@ def _make_mock_client():
return client
class _FakeSessionDB:
def __init__(self, messages=None):
self._messages = list(messages or [])
def get_messages_as_conversation(self, session_id):
return list(self._messages)
@pytest.fixture()
def provider(tmp_path, monkeypatch):
"""Create an initialized HindsightMemoryProvider with a mock client."""
@ -109,6 +135,18 @@ def provider_with_config(tmp_path, monkeypatch):
return _make
def test_normalize_retain_tags_accepts_csv_and_dedupes():
assert _normalize_retain_tags("agent:fakeassistantname, source_system:hermes-agent, agent:fakeassistantname") == [
"agent:fakeassistantname",
"source_system:hermes-agent",
]
def test_normalize_retain_tags_accepts_json_array_string():
value = json.dumps(["agent:fakeassistantname", "source_system:hermes-agent"])
assert _normalize_retain_tags(value) == ["agent:fakeassistantname", "source_system:hermes-agent"]
# ---------------------------------------------------------------------------
# Schema tests
# ---------------------------------------------------------------------------
@ -118,6 +156,7 @@ class TestSchemas:
def test_retain_schema_has_content(self):
assert RETAIN_SCHEMA["name"] == "hindsight_retain"
assert "content" in RETAIN_SCHEMA["parameters"]["properties"]
assert "tags" in RETAIN_SCHEMA["parameters"]["properties"]
assert "content" in RETAIN_SCHEMA["parameters"]["required"]
def test_recall_schema_has_query(self):
@ -160,7 +199,10 @@ class TestConfig:
def test_custom_config_values(self, provider_with_config):
p = provider_with_config(
tags=["tag1", "tag2"],
retain_tags=["tag1", "tag2"],
retain_source="hermes",
retain_user_prefix="User (fakeusername)",
retain_assistant_prefix="Assistant (fakeassistantname)",
recall_tags=["recall-tag"],
recall_tags_match="all",
auto_retain=False,
@ -175,6 +217,10 @@ class TestConfig:
bank_mission="Test agent mission",
)
assert p._tags == ["tag1", "tag2"]
assert p._retain_tags == ["tag1", "tag2"]
assert p._retain_source == "hermes"
assert p._retain_user_prefix == "User (fakeusername)"
assert p._retain_assistant_prefix == "Assistant (fakeassistantname)"
assert p._recall_tags == ["recall-tag"]
assert p._recall_tags_match == "all"
assert p._auto_retain is False
@ -222,11 +268,20 @@ class TestToolHandlers:
assert call_kwargs["content"] == "user likes dark mode"
def test_retain_with_tags(self, provider_with_config):
p = provider_with_config(tags=["pref", "ui"])
p = provider_with_config(retain_tags=["pref", "ui"])
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
call_kwargs = p._client.aretain.call_args.kwargs
assert call_kwargs["tags"] == ["pref", "ui"]
def test_retain_merges_per_call_tags_with_config_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["pref", "ui"])
p.handle_tool_call(
"hindsight_retain",
{"content": "likes dark mode", "tags": ["client:x", "ui"]},
)
call_kwargs = p._client.aretain.call_args.kwargs
assert call_kwargs["tags"] == ["pref", "ui", "client:x"]
def test_retain_without_tags(self, provider):
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
call_kwargs = provider._client.aretain.call_args.kwargs
@ -389,38 +444,58 @@ class TestPrefetch:
class TestSyncTurn:
def _get_retain_kwargs(self, provider):
"""Helper to get the kwargs from the aretain_batch call."""
return provider._client.aretain_batch.call_args.kwargs
def test_sync_turn_retains_metadata_rich_turn(self, provider_with_config):
p = provider_with_config(
retain_tags=["conv", "session1"],
retain_source="hermes",
retain_user_prefix="User (fakeusername)",
retain_assistant_prefix="Assistant (fakeassistantname)",
)
p.initialize(
session_id="session-1",
platform="discord",
user_id="fakeusername-123",
user_name="fakeusername",
chat_id="1485316232612941897",
chat_name="fakeassistantname-forums",
chat_type="thread",
thread_id="1491249007475949698",
agent_identity="fakeassistantname",
)
p._client = _make_mock_client()
def _get_retain_content(self, provider):
"""Helper to get the raw content string from the first item."""
kwargs = self._get_retain_kwargs(provider)
return kwargs["items"][0]["content"]
p.sync_turn("hello", "hi there")
p._sync_thread.join(timeout=5.0)
def _get_retain_messages(self, provider):
"""Helper to parse the first turn's messages from retained content.
Content is a JSON array of turns: [[msgs...], [msgs...], ...]
For single-turn tests, returns the first turn's messages.
"""
content = self._get_retain_content(provider)
turns = json.loads(content)
return turns[0] if len(turns) == 1 else turns
def test_sync_turn_retains(self, provider):
provider.sync_turn("hello", "hi there")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_called_once()
messages = self._get_retain_messages(provider)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "hello"
assert "timestamp" in messages[0]
assert messages[1]["role"] == "assistant"
assert messages[1]["content"] == "hi there"
assert "timestamp" in messages[1]
p._client.aretain_batch.assert_called_once()
call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["bank_id"] == "test-bank"
assert call_kwargs["document_id"] == "session-1"
assert call_kwargs["retain_async"] is True
assert len(call_kwargs["items"]) == 1
item = call_kwargs["items"][0]
assert item["context"] == "conversation between Hermes Agent and the User"
assert item["tags"] == ["conv", "session1"]
content = json.loads(item["content"])
assert len(content) == 1
assert content[0][0]["role"] == "user"
assert content[0][0]["content"] == "User (fakeusername): hello"
assert content[0][1]["role"] == "assistant"
assert content[0][1]["content"] == "Assistant (fakeassistantname): hi there"
assert item["metadata"]["source"] == "hermes"
assert item["metadata"]["session_id"] == "session-1"
assert item["metadata"]["platform"] == "discord"
assert item["metadata"]["user_id"] == "fakeusername-123"
assert item["metadata"]["user_name"] == "fakeusername"
assert item["metadata"]["chat_id"] == "1485316232612941897"
assert item["metadata"]["chat_name"] == "fakeassistantname-forums"
assert item["metadata"]["chat_type"] == "thread"
assert item["metadata"]["thread_id"] == "1491249007475949698"
assert item["metadata"]["agent_identity"] == "fakeassistantname"
assert item["metadata"]["turn_index"] == "1"
assert item["metadata"]["message_count"] == "2"
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?\+00:00", content[0][0]["timestamp"])
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", item["metadata"]["retained_at"])
def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config):
p = provider_with_config(auto_retain=False)
@ -428,93 +503,33 @@ class TestSyncTurn:
assert p._sync_thread is None
p._client.aretain_batch.assert_not_called()
def test_sync_turn_with_tags(self, provider_with_config):
p = provider_with_config(tags=["conv", "session1"])
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["tags"] == ["conv", "session1"]
def test_sync_turn_uses_aretain_batch(self, provider):
"""sync_turn should use aretain_batch with retain_async."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"] == "test-session"
assert call_kwargs["retain_async"] is True
assert len(call_kwargs["items"]) == 1
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
def test_sync_turn_custom_context(self, provider_with_config):
p = provider_with_config(retain_context="my-agent")
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["context"] == "my-agent"
def test_sync_turn_every_n_turns(self, provider_with_config):
"""With retain_every_n_turns=3, only retains on every 3rd turn."""
p = provider_with_config(retain_every_n_turns=3)
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
p.sync_turn("turn1-user", "turn1-asst")
assert p._sync_thread is None # not retained yet
assert p._sync_thread is None
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None # not retained yet
assert p._sync_thread is None
p.sync_turn("turn3-user", "turn3-asst")
assert p._sync_thread is not None # retained!
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.assert_called_once()
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain all 3 turns
assert "turn1-user" in content
assert "turn2-user" in content
assert "turn3-user" in content
def test_sync_turn_accumulates_full_session(self, provider_with_config):
"""Each retain sends the ENTIRE session, not just the latest batch."""
p = provider_with_config(retain_every_n_turns=2)
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.reset_mock()
p.sync_turn("turn3-user", "turn3-asst")
p.sync_turn("turn4-user", "turn4-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain ALL turns from the session
assert "turn1-user" in content
assert "turn2-user" in content
assert "turn3-user" in content
assert "turn4-user" in content
def test_sync_turn_passes_document_id(self, provider):
"""sync_turn should pass session_id as document_id for dedup."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
call_kwargs = provider._client.aretain_batch.call_args.kwargs
call_kwargs = p._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"] == "test-session"
assert call_kwargs["retain_async"] is False
item = call_kwargs["items"][0]
content = json.loads(item["content"])
assert len(content) == 3
assert content[-1][0]["role"] == "user"
assert content[-1][0]["content"] == "User: turn3-user"
assert content[-1][1]["role"] == "assistant"
assert content[-1][1]["content"] == "Assistant: turn3-asst"
assert item["metadata"]["turn_index"] == "3"
assert item["metadata"]["message_count"] == "6"
def test_sync_turn_error_does_not_raise(self, provider):
"""Errors in sync_turn should be swallowed (non-blocking)."""
provider._client.aretain_batch.side_effect = RuntimeError("network error")
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
# Should not raise
# ---------------------------------------------------------------------------
@ -555,10 +570,11 @@ class TestConfigSchema:
"mode", "api_url", "api_key", "llm_provider", "llm_api_key",
"llm_model", "bank_id", "bank_mission", "bank_retain_mission",
"recall_budget", "memory_mode", "recall_prefetch_method",
"tags", "recall_tags", "recall_tags_match",
"retain_tags", "retain_source",
"retain_user_prefix", "retain_assistant_prefix",
"recall_tags", "recall_tags_match",
"auto_recall", "auto_retain",
"retain_every_n_turns", "retain_async",
"retain_context",
"retain_every_n_turns", "retain_async", "retain_context",
"recall_max_tokens", "recall_max_input_chars",
"recall_prompt_preamble",
}