feat(hindsight): feature parity, setup wizard, and config improvements

Port missing features from the hindsight-hermes external integration
package into the native plugin. Only touches plugin files — no core
changes.

Features:
- Tags on retain/recall (tags, recall_tags, recall_tags_match)
- Recall config (recall_max_tokens, recall_max_input_chars, recall_types,
  recall_prompt_preamble)
- Retain controls (retain_every_n_turns, auto_retain, auto_recall,
  retain_async via aretain_batch, retain_context)
- Bank config via Banks API (bank_mission, bank_retain_mission)
- Structured JSON retain with per-message timestamps
- Full session accumulation with document_id for dedup
- Custom post_setup() wizard with curses picker
- Mode-aware dep install (hindsight-client for cloud, hindsight-all for local)
- local_external mode and openai_compatible LLM provider
- OpenRouter support with auto base URL
- Auto-upgrade of hindsight-client to >=0.4.22 on session start
- Comprehensive debug logging across all operations
- 46 unit tests
- Updated README and website docs
This commit is contained in:
Nicolò Boschi 2026-04-09 07:27:31 +02:00 committed by Teknium
parent d97f6cec7f
commit 25757d631b
5 changed files with 1072 additions and 73 deletions

View file

@ -0,0 +1,598 @@
"""Tests for the Hindsight memory provider plugin.
Tests cover config loading, tool handlers (tags, max_tokens, types),
prefetch (auto_recall, preamble, query truncation), sync_turn (auto_retain,
turn counting, tags), and schema completeness.
"""
import json
import threading
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from plugins.memory.hindsight import (
HindsightMemoryProvider,
RECALL_SCHEMA,
REFLECT_SCHEMA,
RETAIN_SCHEMA,
_load_config,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
"""Ensure no stale env vars leak between tests."""
for key in (
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_LLM_API_KEY",
):
monkeypatch.delenv(key, raising=False)
def _make_mock_client():
"""Create a mock Hindsight client with async methods."""
client = MagicMock()
client.aretain = AsyncMock()
client.arecall = AsyncMock(
return_value=SimpleNamespace(
results=[
SimpleNamespace(text="Memory 1"),
SimpleNamespace(text="Memory 2"),
]
)
)
client.areflect = AsyncMock(
return_value=SimpleNamespace(text="Synthesized answer")
)
client.aretain_batch = AsyncMock()
client.aclose = AsyncMock()
return client
@pytest.fixture()
def provider(tmp_path, monkeypatch):
"""Create an initialized HindsightMemoryProvider with a mock client."""
config = {
"mode": "cloud",
"apiKey": "test-key",
"api_url": "http://localhost:9999",
"bank_id": "test-bank",
"budget": "mid",
"memory_mode": "hybrid",
}
config_path = tmp_path / "hindsight" / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(json.dumps(config))
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
)
p = HindsightMemoryProvider()
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
p._client = _make_mock_client()
return p
@pytest.fixture()
def provider_with_config(tmp_path, monkeypatch):
"""Create a provider factory that accepts custom config overrides."""
def _make(**overrides):
config = {
"mode": "cloud",
"apiKey": "test-key",
"api_url": "http://localhost:9999",
"bank_id": "test-bank",
"budget": "mid",
"memory_mode": "hybrid",
}
config.update(overrides)
config_path = tmp_path / "hindsight" / "config.json"
config_path.parent.mkdir(parents=True, exist_ok=True)
config_path.write_text(json.dumps(config))
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home", lambda: tmp_path
)
p = HindsightMemoryProvider()
p.initialize(session_id="test-session", hermes_home=str(tmp_path), platform="cli")
p._client = _make_mock_client()
return p
return _make
# ---------------------------------------------------------------------------
# Schema tests
# ---------------------------------------------------------------------------
class TestSchemas:
def test_retain_schema_has_content(self):
assert RETAIN_SCHEMA["name"] == "hindsight_retain"
assert "content" in RETAIN_SCHEMA["parameters"]["properties"]
assert "content" in RETAIN_SCHEMA["parameters"]["required"]
def test_recall_schema_has_query(self):
assert RECALL_SCHEMA["name"] == "hindsight_recall"
assert "query" in RECALL_SCHEMA["parameters"]["properties"]
assert "query" in RECALL_SCHEMA["parameters"]["required"]
def test_reflect_schema_has_query(self):
assert REFLECT_SCHEMA["name"] == "hindsight_reflect"
assert "query" in REFLECT_SCHEMA["parameters"]["properties"]
def test_get_tool_schemas_returns_three(self, provider):
schemas = provider.get_tool_schemas()
assert len(schemas) == 3
names = {s["name"] for s in schemas}
assert names == {"hindsight_retain", "hindsight_recall", "hindsight_reflect"}
def test_context_mode_returns_no_tools(self, provider_with_config):
p = provider_with_config(memory_mode="context")
assert p.get_tool_schemas() == []
# ---------------------------------------------------------------------------
# Config tests
# ---------------------------------------------------------------------------
class TestConfig:
def test_default_values(self, provider):
assert provider._auto_retain is True
assert provider._auto_recall is True
assert provider._retain_every_n_turns == 1
assert provider._recall_max_tokens == 4096
assert provider._recall_max_input_chars == 800
assert provider._tags is None
assert provider._recall_tags is None
assert provider._bank_mission == ""
assert provider._bank_retain_mission is None
assert provider._retain_context == "conversation between Hermes Agent and the User"
def test_custom_config_values(self, provider_with_config):
p = provider_with_config(
tags=["tag1", "tag2"],
recall_tags=["recall-tag"],
recall_tags_match="all",
auto_retain=False,
auto_recall=False,
retain_every_n_turns=3,
retain_context="custom-ctx",
bank_retain_mission="Extract key facts",
recall_max_tokens=2048,
recall_types=["world", "experience"],
recall_prompt_preamble="Custom preamble:",
recall_max_input_chars=500,
bank_mission="Test agent mission",
)
assert p._tags == ["tag1", "tag2"]
assert p._recall_tags == ["recall-tag"]
assert p._recall_tags_match == "all"
assert p._auto_retain is False
assert p._auto_recall is False
assert p._retain_every_n_turns == 3
assert p._retain_context == "custom-ctx"
assert p._bank_retain_mission == "Extract key facts"
assert p._recall_max_tokens == 2048
assert p._recall_types == ["world", "experience"]
assert p._recall_prompt_preamble == "Custom preamble:"
assert p._recall_max_input_chars == 500
assert p._bank_mission == "Test agent mission"
def test_config_from_env_fallback(self, tmp_path, monkeypatch):
"""When no config file exists, falls back to env vars."""
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home",
lambda: tmp_path / "nonexistent",
)
monkeypatch.setenv("HINDSIGHT_MODE", "cloud")
monkeypatch.setenv("HINDSIGHT_API_KEY", "env-key")
monkeypatch.setenv("HINDSIGHT_BANK_ID", "env-bank")
monkeypatch.setenv("HINDSIGHT_BUDGET", "high")
cfg = _load_config()
assert cfg["apiKey"] == "env-key"
assert cfg["banks"]["hermes"]["bankId"] == "env-bank"
assert cfg["banks"]["hermes"]["budget"] == "high"
# ---------------------------------------------------------------------------
# Tool handler tests
# ---------------------------------------------------------------------------
class TestToolHandlers:
def test_retain_success(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_retain", {"content": "user likes dark mode"}
))
assert result["result"] == "Memory stored successfully."
provider._client.aretain.assert_called_once()
call_kwargs = provider._client.aretain.call_args.kwargs
assert call_kwargs["bank_id"] == "test-bank"
assert call_kwargs["content"] == "user likes dark mode"
def test_retain_with_tags(self, provider_with_config):
p = provider_with_config(tags=["pref", "ui"])
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
call_kwargs = p._client.aretain.call_args.kwargs
assert call_kwargs["tags"] == ["pref", "ui"]
def test_retain_without_tags(self, provider):
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
call_kwargs = provider._client.aretain.call_args.kwargs
assert "tags" not in call_kwargs
def test_retain_missing_content(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_retain", {}
))
assert "error" in result
def test_recall_success(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_recall", {"query": "dark mode"}
))
assert "Memory 1" in result["result"]
assert "Memory 2" in result["result"]
def test_recall_passes_max_tokens(self, provider_with_config):
p = provider_with_config(recall_max_tokens=2048)
p.handle_tool_call("hindsight_recall", {"query": "test"})
call_kwargs = p._client.arecall.call_args.kwargs
assert call_kwargs["max_tokens"] == 2048
def test_recall_passes_tags(self, provider_with_config):
p = provider_with_config(recall_tags=["tag1"], recall_tags_match="all")
p.handle_tool_call("hindsight_recall", {"query": "test"})
call_kwargs = p._client.arecall.call_args.kwargs
assert call_kwargs["tags"] == ["tag1"]
assert call_kwargs["tags_match"] == "all"
def test_recall_passes_types(self, provider_with_config):
p = provider_with_config(recall_types=["world", "experience"])
p.handle_tool_call("hindsight_recall", {"query": "test"})
call_kwargs = p._client.arecall.call_args.kwargs
assert call_kwargs["types"] == ["world", "experience"]
def test_recall_no_results(self, provider):
provider._client.arecall.return_value = SimpleNamespace(results=[])
result = json.loads(provider.handle_tool_call(
"hindsight_recall", {"query": "test"}
))
assert result["result"] == "No relevant memories found."
def test_recall_missing_query(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_recall", {}
))
assert "error" in result
def test_reflect_success(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_reflect", {"query": "summarize"}
))
assert result["result"] == "Synthesized answer"
def test_reflect_missing_query(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_reflect", {}
))
assert "error" in result
def test_unknown_tool(self, provider):
result = json.loads(provider.handle_tool_call(
"hindsight_unknown", {}
))
assert "error" in result
def test_retain_error_handling(self, provider):
provider._client.aretain.side_effect = RuntimeError("connection failed")
result = json.loads(provider.handle_tool_call(
"hindsight_retain", {"content": "test"}
))
assert "error" in result
assert "connection failed" in result["error"]
def test_recall_error_handling(self, provider):
provider._client.arecall.side_effect = RuntimeError("timeout")
result = json.loads(provider.handle_tool_call(
"hindsight_recall", {"query": "test"}
))
assert "error" in result
# ---------------------------------------------------------------------------
# Prefetch tests
# ---------------------------------------------------------------------------
class TestPrefetch:
def test_prefetch_returns_empty_when_no_result(self, provider):
assert provider.prefetch("test") == ""
def test_prefetch_default_preamble(self, provider):
provider._prefetch_result = "- some memory"
result = provider.prefetch("test")
assert "Hindsight Memory" in result
assert "- some memory" in result
def test_prefetch_custom_preamble(self, provider_with_config):
p = provider_with_config(recall_prompt_preamble="Custom header:")
p._prefetch_result = "- memory line"
result = p.prefetch("test")
assert result.startswith("Custom header:")
assert "- memory line" in result
def test_queue_prefetch_skipped_in_tools_mode(self, provider_with_config):
p = provider_with_config(memory_mode="tools")
p.queue_prefetch("test")
# Should not start a thread
assert p._prefetch_thread is None
def test_queue_prefetch_skipped_when_auto_recall_off(self, provider_with_config):
p = provider_with_config(auto_recall=False)
p.queue_prefetch("test")
assert p._prefetch_thread is None
def test_queue_prefetch_truncates_query(self, provider_with_config):
p = provider_with_config(recall_max_input_chars=10)
# Mock _run_sync to capture the query
original_query = None
def _capture_recall(**kwargs):
nonlocal original_query
original_query = kwargs.get("query", "")
return SimpleNamespace(results=[])
p._client.arecall = AsyncMock(side_effect=_capture_recall)
long_query = "a" * 100
p.queue_prefetch(long_query)
if p._prefetch_thread:
p._prefetch_thread.join(timeout=5.0)
# The query passed to arecall should be truncated
if original_query is not None:
assert len(original_query) <= 10
def test_queue_prefetch_passes_recall_params(self, provider_with_config):
p = provider_with_config(
recall_tags=["t1"],
recall_tags_match="all",
recall_max_tokens=1024,
recall_types=["world"],
)
p.queue_prefetch("test query")
if p._prefetch_thread:
p._prefetch_thread.join(timeout=5.0)
call_kwargs = p._client.arecall.call_args.kwargs
assert call_kwargs["max_tokens"] == 1024
assert call_kwargs["tags"] == ["t1"]
assert call_kwargs["tags_match"] == "all"
assert call_kwargs["types"] == ["world"]
# ---------------------------------------------------------------------------
# sync_turn tests
# ---------------------------------------------------------------------------
class TestSyncTurn:
def _get_retain_kwargs(self, provider):
"""Helper to get the kwargs from the aretain_batch call."""
return provider._client.aretain_batch.call_args.kwargs
def _get_retain_content(self, provider):
"""Helper to get the raw content string from the first item."""
kwargs = self._get_retain_kwargs(provider)
return kwargs["items"][0]["content"]
def _get_retain_messages(self, provider):
"""Helper to parse the first turn's messages from retained content.
Content is a JSON array of turns: [[msgs...], [msgs...], ...]
For single-turn tests, returns the first turn's messages.
"""
content = self._get_retain_content(provider)
turns = json.loads(content)
return turns[0] if len(turns) == 1 else turns
def test_sync_turn_retains(self, provider):
provider.sync_turn("hello", "hi there")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_called_once()
messages = self._get_retain_messages(provider)
assert len(messages) == 2
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "hello"
assert "timestamp" in messages[0]
assert messages[1]["role"] == "assistant"
assert messages[1]["content"] == "hi there"
assert "timestamp" in messages[1]
def test_sync_turn_skipped_when_auto_retain_off(self, provider_with_config):
p = provider_with_config(auto_retain=False)
p.sync_turn("hello", "hi")
assert p._sync_thread is None
p._client.aretain_batch.assert_not_called()
def test_sync_turn_with_tags(self, provider_with_config):
p = provider_with_config(tags=["conv", "session1"])
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["tags"] == ["conv", "session1"]
def test_sync_turn_uses_aretain_batch(self, provider):
"""sync_turn should use aretain_batch with retain_async."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"] == "test-session"
assert call_kwargs["retain_async"] is True
assert len(call_kwargs["items"]) == 1
assert call_kwargs["items"][0]["context"] == "conversation between Hermes Agent and the User"
def test_sync_turn_custom_context(self, provider_with_config):
p = provider_with_config(retain_context="my-agent")
p.sync_turn("hello", "hi")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["context"] == "my-agent"
def test_sync_turn_every_n_turns(self, provider_with_config):
"""With retain_every_n_turns=3, only retains on every 3rd turn."""
p = provider_with_config(retain_every_n_turns=3)
p.sync_turn("turn1-user", "turn1-asst")
assert p._sync_thread is None # not retained yet
p.sync_turn("turn2-user", "turn2-asst")
assert p._sync_thread is None # not retained yet
p.sync_turn("turn3-user", "turn3-asst")
assert p._sync_thread is not None # retained!
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.assert_called_once()
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain all 3 turns
assert "turn1-user" in content
assert "turn2-user" in content
assert "turn3-user" in content
def test_sync_turn_accumulates_full_session(self, provider_with_config):
"""Each retain sends the ENTIRE session, not just the latest batch."""
p = provider_with_config(retain_every_n_turns=2)
p.sync_turn("turn1-user", "turn1-asst")
p.sync_turn("turn2-user", "turn2-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
p._client.aretain_batch.reset_mock()
p.sync_turn("turn3-user", "turn3-asst")
p.sync_turn("turn4-user", "turn4-asst")
if p._sync_thread:
p._sync_thread.join(timeout=5.0)
content = p._client.aretain_batch.call_args.kwargs["items"][0]["content"]
# Should contain ALL turns from the session
assert "turn1-user" in content
assert "turn2-user" in content
assert "turn3-user" in content
assert "turn4-user" in content
def test_sync_turn_passes_document_id(self, provider):
"""sync_turn should pass session_id as document_id for dedup."""
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["document_id"] == "test-session"
def test_sync_turn_error_does_not_raise(self, provider):
"""Errors in sync_turn should be swallowed (non-blocking)."""
provider._client.aretain_batch.side_effect = RuntimeError("network error")
provider.sync_turn("hello", "hi")
if provider._sync_thread:
provider._sync_thread.join(timeout=5.0)
# Should not raise
# ---------------------------------------------------------------------------
# System prompt tests
# ---------------------------------------------------------------------------
class TestSystemPrompt:
def test_hybrid_mode_prompt(self, provider):
block = provider.system_prompt_block()
assert "Hindsight Memory" in block
assert "hindsight_recall" in block
assert "automatically injected" in block
def test_context_mode_prompt(self, provider_with_config):
p = provider_with_config(memory_mode="context")
block = p.system_prompt_block()
assert "context mode" in block
assert "hindsight_recall" not in block
def test_tools_mode_prompt(self, provider_with_config):
p = provider_with_config(memory_mode="tools")
block = p.system_prompt_block()
assert "tools mode" in block
assert "hindsight_recall" in block
# ---------------------------------------------------------------------------
# Config schema tests
# ---------------------------------------------------------------------------
class TestConfigSchema:
def test_schema_has_all_new_fields(self, provider):
schema = provider.get_config_schema()
keys = {f["key"] for f in schema}
expected_keys = {
"mode", "api_url", "api_key", "llm_provider", "llm_api_key",
"llm_model", "bank_id", "bank_mission", "bank_retain_mission",
"recall_budget", "memory_mode", "recall_prefetch_method",
"tags", "recall_tags", "recall_tags_match",
"auto_recall", "auto_retain",
"retain_every_n_turns", "retain_async",
"retain_context",
"recall_max_tokens", "recall_max_input_chars",
"recall_prompt_preamble",
}
assert expected_keys.issubset(keys), f"Missing: {expected_keys - keys}"
# ---------------------------------------------------------------------------
# Availability tests
# ---------------------------------------------------------------------------
class TestAvailability:
def test_available_with_api_key(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home",
lambda: tmp_path / "nonexistent",
)
monkeypatch.setenv("HINDSIGHT_API_KEY", "test-key")
p = HindsightMemoryProvider()
assert p.is_available()
def test_not_available_without_config(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home",
lambda: tmp_path / "nonexistent",
)
p = HindsightMemoryProvider()
assert not p.is_available()
def test_available_in_local_mode(self, tmp_path, monkeypatch):
monkeypatch.setattr(
"plugins.memory.hindsight.get_hermes_home",
lambda: tmp_path / "nonexistent",
)
monkeypatch.setenv("HINDSIGHT_MODE", "local")
p = HindsightMemoryProvider()
assert p.is_available()