mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
feat(hindsight): richer session-scoped retain metadata
- Add configurable retain_tags / retain_source / retain_user_prefix / retain_assistant_prefix knobs for native Hindsight. - Thread gateway session identity (user_name, chat_id, chat_name, chat_type, thread_id) through AIAgent and MemoryManager into MemoryProvider.initialize kwargs so providers can scope and tag retained memories. - Hindsight attaches the new identity fields as retain metadata, merges per-call tool tags with configured default tags, and uses the configurable transcript labels for auto-retained turns. Co-authored-by: Abner <abner.the.foreman@agentmail.to>
This commit is contained in:
parent
b8663813b6
commit
b66644f0ec
7 changed files with 387 additions and 150 deletions
|
|
@ -79,6 +79,28 @@ class TestMemoryManagerUserIdThreading:
|
|||
assert p._init_kwargs.get("platform") == "telegram"
|
||||
assert p._init_session_id == "sess-123"
|
||||
|
||||
def test_chat_context_forwarded_to_provider(self):
|
||||
mgr = MemoryManager()
|
||||
p = RecordingProvider()
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.initialize_all(
|
||||
session_id="sess-chat",
|
||||
platform="discord",
|
||||
user_id="discord_u_7",
|
||||
user_name="fakeusername",
|
||||
chat_id="1485316232612941897",
|
||||
chat_name="fakeassistantname-forums",
|
||||
chat_type="thread",
|
||||
thread_id="1491249007475949698",
|
||||
)
|
||||
|
||||
assert p._init_kwargs.get("user_name") == "fakeusername"
|
||||
assert p._init_kwargs.get("chat_id") == "1485316232612941897"
|
||||
assert p._init_kwargs.get("chat_name") == "fakeassistantname-forums"
|
||||
assert p._init_kwargs.get("chat_type") == "thread"
|
||||
assert p._init_kwargs.get("thread_id") == "1491249007475949698"
|
||||
|
||||
def test_no_user_id_when_cli(self):
|
||||
"""CLI sessions should not have user_id in kwargs."""
|
||||
mgr = MemoryManager()
|
||||
|
|
@ -334,3 +356,4 @@ class TestAIAgentUserIdPropagation:
|
|||
agent = object.__new__(AIAgent)
|
||||
agent._user_id = None
|
||||
assert agent._user_id is None
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ turn counting, tags), and schema completeness.
|
|||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
|
@ -18,6 +19,7 @@ from plugins.memory.hindsight import (
|
|||
REFLECT_SCHEMA,
|
||||
RETAIN_SCHEMA,
|
||||
_load_config,
|
||||
_normalize_retain_tags,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -32,14 +34,30 @@ def _clean_env(monkeypatch):
|
|||
for key in (
|
||||
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
|
||||
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY",
|
||||
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_SOURCE",
|
||||
"HINDSIGHT_RETAIN_USER_PREFIX", "HINDSIGHT_RETAIN_ASSISTANT_PREFIX",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
def _make_mock_client():
|
||||
"""Create a mock Hindsight client with async methods."""
|
||||
async def _aretain(
|
||||
bank_id,
|
||||
content,
|
||||
timestamp=None,
|
||||
context=None,
|
||||
document_id=None,
|
||||
metadata=None,
|
||||
entities=None,
|
||||
tags=None,
|
||||
update_mode=None,
|
||||
retain_async=None,
|
||||
):
|
||||
return SimpleNamespace(ok=True)
|
||||
|
||||
client = MagicMock()
|
||||
client.aretain = AsyncMock()
|
||||
client.aretain = AsyncMock(side_effect=_aretain)
|
||||
client.arecall = AsyncMock(
|
||||
return_value=SimpleNamespace(
|
||||
results=[
|
||||
|
|
@ -56,6 +74,14 @@ def _make_mock_client():
|
|||
return client
|
||||
|
||||
|
||||
class _FakeSessionDB:
|
||||
def __init__(self, messages=None):
|
||||
self._messages = list(messages or [])
|
||||
|
||||
def get_messages_as_conversation(self, session_id):
|
||||
return list(self._messages)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def provider(tmp_path, monkeypatch):
|
||||
"""Create an initialized HindsightMemoryProvider with a mock client."""
|
||||
|
|
@ -109,6 +135,18 @@ def provider_with_config(tmp_path, monkeypatch):
|
|||
return _make
|
||||
|
||||
|
||||
def test_normalize_retain_tags_accepts_csv_and_dedupes():
|
||||
assert _normalize_retain_tags("agent:fakeassistantname, source_system:hermes-agent, agent:fakeassistantname") == [
|
||||
"agent:fakeassistantname",
|
||||
"source_system:hermes-agent",
|
||||
]
|
||||
|
||||
|
||||
def test_normalize_retain_tags_accepts_json_array_string():
|
||||
value = json.dumps(["agent:fakeassistantname", "source_system:hermes-agent"])
|
||||
assert _normalize_retain_tags(value) == ["agent:fakeassistantname", "source_system:hermes-agent"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -118,6 +156,7 @@ class TestSchemas:
|
|||
def test_retain_schema_has_content(self):
|
||||
assert RETAIN_SCHEMA["name"] == "hindsight_retain"
|
||||
assert "content" in RETAIN_SCHEMA["parameters"]["properties"]
|
||||
assert "tags" in RETAIN_SCHEMA["parameters"]["properties"]
|
||||
assert "content" in RETAIN_SCHEMA["parameters"]["required"]
|
||||
|
||||
def test_recall_schema_has_query(self):
|
||||
|
|
@ -160,7 +199,10 @@ class TestConfig:
|
|||
|
||||
def test_custom_config_values(self, provider_with_config):
|
||||
p = provider_with_config(
|
||||
tags=["tag1", "tag2"],
|
||||
retain_tags=["tag1", "tag2"],
|
||||
retain_source="hermes",
|
||||
retain_user_prefix="User (fakeusername)",
|
||||
retain_assistant_prefix="Assistant (fakeassistantname)",
|
||||
recall_tags=["recall-tag"],
|
||||
recall_tags_match="all",
|
||||
auto_retain=False,
|
||||
|
|
@ -175,6 +217,10 @@ class TestConfig:
|
|||
bank_mission="Test agent mission",
|
||||
)
|
||||
assert p._tags == ["tag1", "tag2"]
|
||||
assert p._retain_tags == ["tag1", "tag2"]
|
||||
assert p._retain_source == "hermes"
|
||||
assert p._retain_user_prefix == "User (fakeusername)"
|
||||
assert p._retain_assistant_prefix == "Assistant (fakeassistantname)"
|
||||
assert p._recall_tags == ["recall-tag"]
|
||||
assert p._recall_tags_match == "all"
|
||||
assert p._auto_retain is False
|
||||
|
|
@ -222,11 +268,20 @@ class TestToolHandlers:
|
|||
assert call_kwargs["content"] == "user likes dark mode"
|
||||
|
||||
def test_retain_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(tags=["pref", "ui"])
|
||||
p = provider_with_config(retain_tags=["pref", "ui"])
|
||||
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
|
||||
call_kwargs = p._client.aretain.call_args.kwargs
|
||||
assert call_kwargs["tags"] == ["pref", "ui"]
|
||||
|
||||
def test_retain_merges_per_call_tags_with_config_tags(self, provider_with_config):
|
||||
p = provider_with_config(retain_tags=["pref", "ui"])
|
||||
p.handle_tool_call(
|
||||
"hindsight_retain",
|
||||
{"content": "likes dark mode", "tags": ["client:x", "ui"]},
|
||||
)
|
||||
call_kwargs = p._client.aretain.call_args.kwargs
|
||||
assert call_kwargs["tags"] == ["pref", "ui", "client:x"]
|
||||
|
||||
def test_retain_without_tags(self, provider):
|
||||
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
|
||||
call_kwargs = provider._client.aretain.call_args.kwargs
|
||||
|
|
@ -389,38 +444,58 @@ class TestPrefetch:
|
|||
|
||||
|
||||
class TestSyncTurn:
|
||||
def _get_retain_kwargs(self, provider):
|
||||
"""Helper to get the kwargs from the aretain_batch call."""
|
||||
return provider._client.aretain_batch.call_args.kwargs
|
||||
def test_sync_turn_retains_metadata_rich_turn(self, provider_with_config):
|
||||
p = provider_with_config(
|
||||
retain_tags=["conv", "session1"],
|
||||
retain_source="hermes",
|
||||
retain_user_prefix="User (fakeusername)",
|
||||
retain_assistant_prefix="Assistant (fakeassistantname)",
|
||||
)
|
||||
p.initialize(
|
||||
session_id="session-1",
|
||||
platform="discord",
|
||||
user_id="fakeusername-123",
|
||||
user_name="fakeusername",
|
||||
chat_id="1485316232612941897",
|
||||
chat_name="fakeassistantname-forums",
|
||||
chat_type="thread",
|
||||
thread_id="1491249007475949698",
|
||||
agent_identity="fakeassistantname",
|
||||
)
|
||||
p._client = _make_mock_client()
|
||||
|
||||
def _get_retain_content(self, provider):
|
||||
"""Helper to get the raw content string from the first item."""
|
||||
kwargs = self._get_retain_kwargs(provider)
|
||||
return kwargs["items"][0]["content"]
|
||||
p.sync_turn("hello", "hi there")
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
def _get_retain_messages(self, provider):
|
||||
"""Helper to parse the first turn's messages from retained content.
|
||||
|
||||
Content is a JSON array of turns: [[msgs...], [msgs...], ...]
|
||||
For single-turn tests, returns the first turn's messages.
|
||||
"""
|
||||
content = self._get_retain_content(provider)
|
||||
turns = json.loads(content)
|
||||
return turns[0] if len(turns) == 1 else turns
|
||||
|
||||
def test_sync_turn_retains(self, provider):
|
||||
provider.sync_turn("hello", "hi there")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
messages = self._get_retain_messages(provider)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[0]["content"] == "hello"
|
||||
assert "timestamp" in messages[0]
|
||||
assert messages[1]["role"] == "assistant"
|
||||
assert messages[1]["content"] == "hi there"
|
||||
assert "timestamp" in messages[1]
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["bank_id"] == "test-bank"
|
||||
assert call_kwargs["document_id"] == "session-1"
|
||||
assert call_kwargs["retain_async"] is True
|
||||
assert len(call_kwargs["items"]) == 1
|
||||
item = call_kwargs["items"][0]
|
||||
assert item["context"] == "conversation between Hermes Agent and the User"
|
||||
assert item["tags"] == ["conv", "session1"]
|
||||
content = json.loads(item["content"])
|
||||
assert len(content) == 1
|
||||
assert content[0][0]["role"] == "user"
|
||||
assert content[0][0]["content"] == "User (fakeusername): hello"
|
||||
assert content[0][1]["role"] == "assistant"
|
||||
assert content[0][1]["content"] == "Assistant (fakeassistantname): hi there"
|
||||
assert item["metadata"]["source"] == "hermes"
|
||||
assert item["metadata"]["session_id"] == "session-1"
|
||||
assert item["metadata"]["platform"] == "discord"
|
||||
assert item["metadata"]["user_id"] == "fakeusername-123"
|
||||
assert item["metadata"]["user_name"] == "fakeusername"
|
||||
assert item["metadata"]["chat_id"] == "1485316232612941897"
|
||||
assert item["metadata"]["chat_name"] == "fakeassistantname-forums"
|
||||
assert item["metadata"]["chat_type"] == "thread"
|
||||
assert item["metadata"]["thread_id"] == "1491249007475949698"
|
||||
assert item["metadata"]["agent_identity"] == "fakeassistantname"
|
||||
assert item["metadata"]["turn_index"] == "1"
|
||||
assert item["metadata"]["message_count"] == "2"
|
||||
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(\.\d+)?\+00:00", content[0][0]["timestamp"])
|
||||
assert re.fullmatch(r"\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", item["metadata"]["retained_at"])
|
||||
|
||||
def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config):
|
||||
p = provider_with_config(auto_retain=False)
|
||||
|
|
@ -428,93 +503,33 @@ class TestSyncTurn:
|
|||
assert p._sync_thread is None
|
||||
p._client.aretain_batch.assert_not_called()
|
||||
|
||||
def test_sync_turn_with_tags(self, provider_with_config):
|
||||
p = provider_with_config(tags=["conv", "session1"])
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["tags"] == ["conv", "session1"]
|
||||
|
||||
def test_sync_turn_uses_aretain_batch(self, provider):
|
||||
"""sync_turn should use aretain_batch with retain_async."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
provider._client.aretain_batch.assert_called_once()
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"] == "test-session"
|
||||
assert call_kwargs["retain_async"] is True
|
||||
assert len(call_kwargs["items"]) == 1
|
||||
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
|
||||
|
||||
def test_sync_turn_custom_context(self, provider_with_config):
|
||||
p = provider_with_config(retain_context="my-agent")
|
||||
p.sync_turn("hello", "hi")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
item = p._client.aretain_batch.call_args.kwargs["items"][0]
|
||||
assert item["context"] == "my-agent"
|
||||
|
||||
def test_sync_turn_every_n_turns(self, provider_with_config):
|
||||
"""With retain_every_n_turns=3, only retains on every 3rd turn."""
|
||||
p = provider_with_config(retain_every_n_turns=3)
|
||||
|
||||
p = provider_with_config(retain_every_n_turns=3, retain_async=False)
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
assert p._sync_thread is None # not retained yet
|
||||
|
||||
assert p._sync_thread is None
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
assert p._sync_thread is None # not retained yet
|
||||
|
||||
assert p._sync_thread is None
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
assert p._sync_thread is not None # retained!
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch.assert_called_once()
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain all 3 turns
|
||||
assert "turn1-user" in content
|
||||
assert "turn2-user" in content
|
||||
assert "turn3-user" in content
|
||||
|
||||
def test_sync_turn_accumulates_full_session(self, provider_with_config):
|
||||
"""Each retain sends the ENTIRE session, not just the latest batch."""
|
||||
p = provider_with_config(retain_every_n_turns=2)
|
||||
|
||||
p.sync_turn("turn1-user", "turn1-asst")
|
||||
p.sync_turn("turn2-user", "turn2-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
p._client.aretain_batch.reset_mock()
|
||||
|
||||
p.sync_turn("turn3-user", "turn3-asst")
|
||||
p.sync_turn("turn4-user", "turn4-asst")
|
||||
if p._sync_thread:
|
||||
p._sync_thread.join(timeout=5.0)
|
||||
|
||||
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
|
||||
# Should contain ALL turns from the session
|
||||
assert "turn1-user" in content
|
||||
assert "turn2-user" in content
|
||||
assert "turn3-user" in content
|
||||
assert "turn4-user" in content
|
||||
|
||||
def test_sync_turn_passes_document_id(self, provider):
|
||||
"""sync_turn should pass session_id as document_id for dedup."""
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
call_kwargs = provider._client.aretain_batch.call_args.kwargs
|
||||
call_kwargs = p._client.aretain_batch.call_args.kwargs
|
||||
assert call_kwargs["document_id"] == "test-session"
|
||||
assert call_kwargs["retain_async"] is False
|
||||
item = call_kwargs["items"][0]
|
||||
content = json.loads(item["content"])
|
||||
assert len(content) == 3
|
||||
assert content[-1][0]["role"] == "user"
|
||||
assert content[-1][0]["content"] == "User: turn3-user"
|
||||
assert content[-1][1]["role"] == "assistant"
|
||||
assert content[-1][1]["content"] == "Assistant: turn3-asst"
|
||||
assert item["metadata"]["turn_index"] == "3"
|
||||
assert item["metadata"]["message_count"] == "6"
|
||||
|
||||
def test_sync_turn_error_does_not_raise(self, provider):
|
||||
"""Errors in sync_turn should be swallowed (non-blocking)."""
|
||||
provider._client.aretain_batch.side_effect = RuntimeError("network error")
|
||||
provider.sync_turn("hello", "hi")
|
||||
if provider._sync_thread:
|
||||
provider._sync_thread.join(timeout=5.0)
|
||||
# Should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -555,10 +570,11 @@ class TestConfigSchema:
|
|||
"mode", "api_url", "api_key", "llm_provider", "llm_api_key",
|
||||
"llm_model", "bank_id", "bank_mission", "bank_retain_mission",
|
||||
"recall_budget", "memory_mode", "recall_prefetch_method",
|
||||
"tags", "recall_tags", "recall_tags_match",
|
||||
"retain_tags", "retain_source",
|
||||
"retain_user_prefix", "retain_assistant_prefix",
|
||||
"recall_tags", "recall_tags_match",
|
||||
"auto_recall", "auto_retain",
|
||||
"retain_every_n_turns", "retain_async",
|
||||
"retain_context",
|
||||
"retain_every_n_turns", "retain_async", "retain_context",
|
||||
"recall_max_tokens", "recall_max_input_chars",
|
||||
"recall_prompt_preamble",
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue