feat(hindsight): make observation scopes configurable on retain

Adds an observation_scopes config key (and HINDSIGHT_RETAIN_OBSERVATION_SCOPES
env var) so retained memories can opt into per_tag / all_combinations /
custom scoping instead of Hindsight's default combined pass.

Threaded through _build_retain_kwargs so all three retain paths honor it:
auto-retain and flush-on-switch already use aretain_batch; the tool retain
path is switched from aretain to aretain_batch (functionally equivalent,
aretain just wraps a single-item batch) since aretain doesn't accept the
observation_scopes parameter.
This commit is contained in:
Nicolò Boschi 2026-06-15 11:55:07 +02:00 committed by Teknium
parent 8844e091c1
commit a376ca0081
2 changed files with 143 additions and 13 deletions

View file

@ -22,6 +22,7 @@ from plugins.memory.hindsight import (
RETAIN_SCHEMA,
_load_config,
_build_embedded_profile_env,
_normalize_observation_scopes,
_normalize_retain_tags,
_resolve_bank_id_template,
_sanitize_bank_segment,
@ -40,7 +41,8 @@ def _clean_env(monkeypatch):
"HINDSIGHT_API_KEY", "HINDSIGHT_API_URL", "HINDSIGHT_BANK_ID",
"HINDSIGHT_BUDGET", "HINDSIGHT_MODE", "HINDSIGHT_TIMEOUT",
"HINDSIGHT_IDLE_TIMEOUT", "HINDSIGHT_LLM_API_KEY",
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_SOURCE",
"HINDSIGHT_RETAIN_TAGS", "HINDSIGHT_RETAIN_OBSERVATION_SCOPES",
"HINDSIGHT_RETAIN_SOURCE",
"HINDSIGHT_RETAIN_USER_PREFIX", "HINDSIGHT_RETAIN_ASSISTANT_PREFIX",
):
monkeypatch.delenv(key, raising=False)
@ -153,6 +155,44 @@ def test_normalize_retain_tags_accepts_json_array_string():
assert _normalize_retain_tags(value) == ["agent:fakeassistantname", "source_system:hermes-agent"]
def test_normalize_observation_scopes_empty_is_none():
assert _normalize_observation_scopes("") is None
assert _normalize_observation_scopes(None) is None
assert _normalize_observation_scopes(" ") is None
def test_normalize_observation_scopes_keywords_pass_through():
assert _normalize_observation_scopes("per_tag") == "per_tag"
assert _normalize_observation_scopes("combined") == "combined"
assert _normalize_observation_scopes(" all_combinations ") == "all_combinations"
def test_normalize_observation_scopes_unknown_keyword_is_none():
assert _normalize_observation_scopes("nonsense") is None
def test_normalize_observation_scopes_json_list_of_lists():
value = json.dumps([["user:alice"], ["team:eng"], ["user:alice", "team:eng"]])
assert _normalize_observation_scopes(value) == [
["user:alice"],
["team:eng"],
["user:alice", "team:eng"],
]
def test_normalize_observation_scopes_flat_list_is_single_scope():
assert _normalize_observation_scopes(["user:alice", "team:eng"]) == [
["user:alice", "team:eng"]
]
def test_normalize_observation_scopes_list_of_lists():
assert _normalize_observation_scopes([["user:alice"], ["team:eng"]]) == [
["user:alice"],
["team:eng"],
]
# ---------------------------------------------------------------------------
# Schema tests
# ---------------------------------------------------------------------------
@ -198,6 +238,7 @@ class TestConfig:
assert provider._recall_max_tokens == 4096
assert provider._recall_max_input_chars == 800
assert provider._tags is None
assert provider._observation_scopes is None
assert provider._recall_tags is None
# Default recall narrowed to observation-only; world/experience are
# aggregate facts that often crowd out concrete-event signal during
@ -225,6 +266,16 @@ class TestConfig:
p = provider_with_config(recall_types=[])
assert p._recall_types == ["observation"]
def test_observation_scopes_keyword_config(self, provider_with_config):
p = provider_with_config(observation_scopes="per_tag")
assert p._observation_scopes == "per_tag"
def test_observation_scopes_custom_list_config(self, provider_with_config):
p = provider_with_config(
observation_scopes=[["user:alice"], ["team:eng"]]
)
assert p._observation_scopes == [["user:alice"], ["team:eng"]]
def test_custom_config_values(self, provider_with_config):
p = provider_with_config(
retain_tags=["tag1", "tag2"],
@ -468,16 +519,20 @@ class TestToolHandlers:
"hindsight_retain", {"content": "user likes dark mode"}
))
assert result["result"] == "Memory stored successfully."
provider._client.aretain.assert_called_once()
call_kwargs = provider._client.aretain.call_args.kwargs
provider._client.aretain_batch.assert_called_once()
call_kwargs = provider._client.aretain_batch.call_args.kwargs
assert call_kwargs["bank_id"] == "test-bank"
assert call_kwargs["content"] == "user likes dark mode"
item = call_kwargs["items"][0]
assert item["content"] == "user likes dark mode"
# bank_id/retain_async are call-level args, never item keys.
assert "bank_id" not in item
assert "retain_async" not in item
def test_retain_with_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["pref", "ui"])
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
call_kwargs = p._client.aretain.call_args.kwargs
assert call_kwargs["tags"] == ["pref", "ui"]
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["tags"] == ["pref", "ui"]
def test_retain_merges_per_call_tags_with_config_tags(self, provider_with_config):
p = provider_with_config(retain_tags=["pref", "ui"])
@ -485,13 +540,24 @@ class TestToolHandlers:
"hindsight_retain",
{"content": "likes dark mode", "tags": ["client:x", "ui"]},
)
call_kwargs = p._client.aretain.call_args.kwargs
assert call_kwargs["tags"] == ["pref", "ui", "client:x"]
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["tags"] == ["pref", "ui", "client:x"]
def test_retain_without_tags(self, provider):
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
call_kwargs = provider._client.aretain.call_args.kwargs
assert "tags" not in call_kwargs
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
assert "tags" not in item
def test_retain_passes_observation_scopes(self, provider_with_config):
p = provider_with_config(observation_scopes="per_tag")
p.handle_tool_call("hindsight_retain", {"content": "likes dark mode"})
item = p._client.aretain_batch.call_args.kwargs["items"][0]
assert item["observation_scopes"] == "per_tag"
def test_retain_omits_observation_scopes_by_default(self, provider):
provider.handle_tool_call("hindsight_retain", {"content": "hello"})
item = provider._client.aretain_batch.call_args.kwargs["items"][0]
assert "observation_scopes" not in item
def test_retain_missing_content(self, provider):
result = json.loads(provider.handle_tool_call(
@ -557,7 +623,7 @@ class TestToolHandlers:
assert "error" in result
def test_retain_error_handling(self, provider):
provider._client.aretain.side_effect = RuntimeError("connection failed")
provider._client.aretain_batch.side_effect = RuntimeError("connection failed")
result = json.loads(provider.handle_tool_call(
"hindsight_retain", {"content": "test"}
))