fix(memory): add write origin metadata

This commit is contained in:
helix4u 2026-04-24 13:34:27 -06:00 committed by Teknium
parent 14b27bb68c
commit 6a957a74bc
5 changed files with 184 additions and 5 deletions

View file

@ -31,6 +31,7 @@ from __future__ import annotations
import json
import logging
import re
import inspect
from typing import Any, Dict, List, Optional
from agent.memory_provider import MemoryProvider
@ -312,7 +313,39 @@ class MemoryManager:
)
return "\n\n".join(parts)
def on_memory_write(self, action: str, target: str, content: str) -> None:
@staticmethod
def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str:
"""Return how to pass metadata to a provider's memory-write hook."""
try:
signature = inspect.signature(provider.on_memory_write)
except (TypeError, ValueError):
return "keyword"
params = list(signature.parameters.values())
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
return "keyword"
if "metadata" in signature.parameters:
return "keyword"
accepted = [
p for p in params
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
]
if len(accepted) >= 4:
return "positional"
return "legacy"
def on_memory_write(
self,
action: str,
target: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Notify external providers when the built-in memory tool writes.
Skips the builtin provider itself (it's the source of the write).
@ -321,6 +354,14 @@ class MemoryManager:
if provider.name == "builtin":
continue
try:
metadata_mode = self._provider_memory_write_metadata_mode(provider)
if metadata_mode == "keyword":
provider.on_memory_write(
action, target, content, metadata=dict(metadata or {})
)
elif metadata_mode == "positional":
provider.on_memory_write(action, target, content, dict(metadata or {}))
else:
provider.on_memory_write(action, target, content)
except Exception as e:
logger.debug(

View file

@ -26,7 +26,7 @@ Optional hooks (override to opt in):
on_turn_start(turn, message, **kwargs) per-turn tick with runtime context
on_session_end(messages) end-of-session extraction
on_pre_compress(messages) -> str extract before context compression
on_memory_write(action, target, content) mirror built-in memory writes
on_memory_write(action, target, content, metadata=None) mirror built-in memory writes
on_delegation(task, result, **kwargs) parent-side observation of subagent work
"""
@ -34,7 +34,7 @@ from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@ -220,12 +220,21 @@ class MemoryProvider(ABC):
should all have ``env_var`` set and this method stays no-op).
"""
def on_memory_write(self, action: str, target: str, content: str) -> None:
def on_memory_write(
self,
action: str,
target: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Called when the built-in memory tool writes an entry.
action: 'add', 'replace', or 'remove'
target: 'memory' or 'user'
content: the entry content
metadata: structured provenance for the write, when available. Common
keys include ``write_origin``, ``execution_context``, ``session_id``,
``parent_session_id``, ``platform``, and ``tool_name``.
Use to mirror built-in memory writes to your backend.
"""

View file

@ -1437,6 +1437,8 @@ class AIAgent:
# Track conversation messages for session logging
self._session_messages: List[Dict[str, Any]] = []
self._memory_write_origin = "assistant_tool"
self._memory_write_context = "foreground"
# Cached system prompt -- built once per session, only rebuilt on compression
self._cached_system_prompt: Optional[str] = None
@ -3075,7 +3077,10 @@ class AIAgent:
quiet_mode=True,
platform=self.platform,
provider=self.provider,
parent_session_id=self.session_id,
)
review_agent._memory_write_origin = "background_review"
review_agent._memory_write_context = "background_review"
review_agent._memory_store = self._memory_store
review_agent._memory_enabled = self._memory_enabled
review_agent._user_profile_enabled = self._user_profile_enabled
@ -3124,6 +3129,32 @@ class AIAgent:
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
t.start()
def _build_memory_write_metadata(
self,
*,
write_origin: Optional[str] = None,
execution_context: Optional[str] = None,
task_id: Optional[str] = None,
tool_call_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Build provenance metadata for external memory-provider mirrors."""
metadata: Dict[str, Any] = {
"write_origin": write_origin or getattr(self, "_memory_write_origin", "assistant_tool"),
"execution_context": (
execution_context
or getattr(self, "_memory_write_context", "foreground")
),
"session_id": self.session_id or "",
"parent_session_id": self._parent_session_id or "",
"platform": self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"),
"tool_name": "memory",
}
if task_id:
metadata["task_id"] = task_id
if tool_call_id:
metadata["tool_call_id"] = tool_call_id
return {k: v for k, v in metadata.items() if v not in (None, "")}
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
"""Rewrite the current-turn user message before persistence/return.
@ -7812,6 +7843,19 @@ class AIAgent:
old_text=args.get("old_text"),
store=self._memory_store,
)
if self._memory_manager and args.get("action") in ("add", "replace"):
try:
self._memory_manager.on_memory_write(
args.get("action", ""),
flush_target,
args.get("content", ""),
metadata=self._build_memory_write_metadata(
write_origin="memory_flush",
execution_context="flush_memories",
),
)
except Exception:
pass
if not self.quiet_mode:
print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}")
except Exception as e:
@ -8043,6 +8087,10 @@ class AIAgent:
function_args.get("action", ""),
target,
function_args.get("content", ""),
metadata=self._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=tool_call_id,
),
)
except Exception:
pass
@ -8554,6 +8602,10 @@ class AIAgent:
function_args.get("action", ""),
target,
function_args.get("content", ""),
metadata=self._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", None),
),
)
except Exception:
pass

View file

@ -77,6 +77,13 @@ class FakeMemoryProvider(MemoryProvider):
self.memory_writes.append((action, target, content))
class MetadataMemoryProvider(FakeMemoryProvider):
"""Provider that opts into write metadata."""
def on_memory_write(self, action, target, content, metadata=None):
self.memory_writes.append((action, target, content, metadata or {}))
# ---------------------------------------------------------------------------
# MemoryProvider ABC tests
# ---------------------------------------------------------------------------
@ -862,6 +869,51 @@ class TestOnMemoryWriteBridge:
mgr.on_memory_write("add", "memory", "new fact")
assert p.memory_writes == [("add", "memory", "new fact")]
def test_on_memory_write_metadata_passed_to_opt_in_provider(self):
"""Providers that accept metadata receive structured write provenance."""
mgr = MemoryManager()
p = MetadataMemoryProvider("ext")
mgr.add_provider(p)
mgr.on_memory_write(
"add",
"memory",
"new fact",
metadata={
"write_origin": "assistant_tool",
"execution_context": "foreground",
"session_id": "sess-1",
},
)
assert p.memory_writes == [
(
"add",
"memory",
"new fact",
{
"write_origin": "assistant_tool",
"execution_context": "foreground",
"session_id": "sess-1",
},
)
]
def test_on_memory_write_metadata_keeps_legacy_provider_compatible(self):
"""Old 3-arg providers keep working when the manager receives metadata."""
mgr = MemoryManager()
p = FakeMemoryProvider("ext")
mgr.add_provider(p)
mgr.on_memory_write(
"add",
"user",
"legacy provider fact",
metadata={"write_origin": "assistant_tool"},
)
assert p.memory_writes == [("add", "user", "legacy provider fact")]
def test_on_memory_write_replace(self):
"""on_memory_write fires for 'replace' actions."""
mgr = MemoryManager()

View file

@ -233,6 +233,31 @@ class TestFlushMemoriesUsesAuxiliaryClient:
assert call_kwargs.kwargs["target"] == "notes"
assert "dark mode" in call_kwargs.kwargs["content"]
def test_flush_bridges_memory_write_metadata(self, monkeypatch):
"""Flush memory writes notify external providers with flush provenance."""
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
agent._memory_manager = MagicMock()
agent.session_id = "sess-flush"
agent.platform = "cli"
mock_response = _chat_response_with_memory_call()
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
{"role": "user", "content": "Note this"},
]
with patch("tools.memory_tool.memory_tool", return_value="Saved."):
agent.flush_memories(messages)
agent._memory_manager.on_memory_write.assert_called_once()
call_kwargs = agent._memory_manager.on_memory_write.call_args
assert call_kwargs.args[:3] == ("add", "notes", "User prefers dark mode.")
assert call_kwargs.kwargs["metadata"]["write_origin"] == "memory_flush"
assert call_kwargs.kwargs["metadata"]["execution_context"] == "flush_memories"
assert call_kwargs.kwargs["metadata"]["session_id"] == "sess-flush"
def test_flush_strips_artifacts_from_messages(self, monkeypatch):
"""After flush, the flush prompt and any response should be removed from messages."""
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")