fix(memory): add write origin metadata

This commit is contained in:
helix4u 2026-04-24 13:34:27 -06:00 committed by Teknium
parent 14b27bb68c
commit 6a957a74bc
5 changed files with 184 additions and 5 deletions

View file

@ -31,6 +31,7 @@ from __future__ import annotations
import json import json
import logging import logging
import re import re
import inspect
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from agent.memory_provider import MemoryProvider from agent.memory_provider import MemoryProvider
@ -312,7 +313,39 @@ class MemoryManager:
) )
return "\n\n".join(parts) return "\n\n".join(parts)
def on_memory_write(self, action: str, target: str, content: str) -> None: @staticmethod
def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str:
"""Return how to pass metadata to a provider's memory-write hook."""
try:
signature = inspect.signature(provider.on_memory_write)
except (TypeError, ValueError):
return "keyword"
params = list(signature.parameters.values())
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
return "keyword"
if "metadata" in signature.parameters:
return "keyword"
accepted = [
p for p in params
if p.kind in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
]
if len(accepted) >= 4:
return "positional"
return "legacy"
def on_memory_write(
self,
action: str,
target: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Notify external providers when the built-in memory tool writes. """Notify external providers when the built-in memory tool writes.
Skips the builtin provider itself (it's the source of the write). Skips the builtin provider itself (it's the source of the write).
@ -321,7 +354,15 @@ class MemoryManager:
if provider.name == "builtin": if provider.name == "builtin":
continue continue
try: try:
provider.on_memory_write(action, target, content) metadata_mode = self._provider_memory_write_metadata_mode(provider)
if metadata_mode == "keyword":
provider.on_memory_write(
action, target, content, metadata=dict(metadata or {})
)
elif metadata_mode == "positional":
provider.on_memory_write(action, target, content, dict(metadata or {}))
else:
provider.on_memory_write(action, target, content)
except Exception as e: except Exception as e:
logger.debug( logger.debug(
"Memory provider '%s' on_memory_write failed: %s", "Memory provider '%s' on_memory_write failed: %s",

View file

@ -26,7 +26,7 @@ Optional hooks (override to opt in):
on_turn_start(turn, message, **kwargs) per-turn tick with runtime context on_turn_start(turn, message, **kwargs) per-turn tick with runtime context
on_session_end(messages) end-of-session extraction on_session_end(messages) end-of-session extraction
on_pre_compress(messages) -> str extract before context compression on_pre_compress(messages) -> str extract before context compression
on_memory_write(action, target, content) mirror built-in memory writes on_memory_write(action, target, content, metadata=None) mirror built-in memory writes
on_delegation(task, result, **kwargs) parent-side observation of subagent work on_delegation(task, result, **kwargs) parent-side observation of subagent work
""" """
@ -34,7 +34,7 @@ from __future__ import annotations
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -220,12 +220,21 @@ class MemoryProvider(ABC):
should all have ``env_var`` set and this method stays no-op). should all have ``env_var`` set and this method stays no-op).
""" """
def on_memory_write(self, action: str, target: str, content: str) -> None: def on_memory_write(
self,
action: str,
target: str,
content: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Called when the built-in memory tool writes an entry. """Called when the built-in memory tool writes an entry.
action: 'add', 'replace', or 'remove' action: 'add', 'replace', or 'remove'
target: 'memory' or 'user' target: 'memory' or 'user'
content: the entry content content: the entry content
metadata: structured provenance for the write, when available. Common
keys include ``write_origin``, ``execution_context``, ``session_id``,
``parent_session_id``, ``platform``, and ``tool_name``.
Use to mirror built-in memory writes to your backend. Use to mirror built-in memory writes to your backend.
""" """

View file

@ -1437,6 +1437,8 @@ class AIAgent:
# Track conversation messages for session logging # Track conversation messages for session logging
self._session_messages: List[Dict[str, Any]] = [] self._session_messages: List[Dict[str, Any]] = []
self._memory_write_origin = "assistant_tool"
self._memory_write_context = "foreground"
# Cached system prompt -- built once per session, only rebuilt on compression # Cached system prompt -- built once per session, only rebuilt on compression
self._cached_system_prompt: Optional[str] = None self._cached_system_prompt: Optional[str] = None
@ -3075,7 +3077,10 @@ class AIAgent:
quiet_mode=True, quiet_mode=True,
platform=self.platform, platform=self.platform,
provider=self.provider, provider=self.provider,
parent_session_id=self.session_id,
) )
review_agent._memory_write_origin = "background_review"
review_agent._memory_write_context = "background_review"
review_agent._memory_store = self._memory_store review_agent._memory_store = self._memory_store
review_agent._memory_enabled = self._memory_enabled review_agent._memory_enabled = self._memory_enabled
review_agent._user_profile_enabled = self._user_profile_enabled review_agent._user_profile_enabled = self._user_profile_enabled
@ -3124,6 +3129,32 @@ class AIAgent:
t = threading.Thread(target=_run_review, daemon=True, name="bg-review") t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
t.start() t.start()
def _build_memory_write_metadata(
self,
*,
write_origin: Optional[str] = None,
execution_context: Optional[str] = None,
task_id: Optional[str] = None,
tool_call_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Build provenance metadata for external memory-provider mirrors."""
metadata: Dict[str, Any] = {
"write_origin": write_origin or getattr(self, "_memory_write_origin", "assistant_tool"),
"execution_context": (
execution_context
or getattr(self, "_memory_write_context", "foreground")
),
"session_id": self.session_id or "",
"parent_session_id": self._parent_session_id or "",
"platform": self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"),
"tool_name": "memory",
}
if task_id:
metadata["task_id"] = task_id
if tool_call_id:
metadata["tool_call_id"] = tool_call_id
return {k: v for k, v in metadata.items() if v not in (None, "")}
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None: def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
"""Rewrite the current-turn user message before persistence/return. """Rewrite the current-turn user message before persistence/return.
@ -7812,6 +7843,19 @@ class AIAgent:
old_text=args.get("old_text"), old_text=args.get("old_text"),
store=self._memory_store, store=self._memory_store,
) )
if self._memory_manager and args.get("action") in ("add", "replace"):
try:
self._memory_manager.on_memory_write(
args.get("action", ""),
flush_target,
args.get("content", ""),
metadata=self._build_memory_write_metadata(
write_origin="memory_flush",
execution_context="flush_memories",
),
)
except Exception:
pass
if not self.quiet_mode: if not self.quiet_mode:
print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}") print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}")
except Exception as e: except Exception as e:
@ -8043,6 +8087,10 @@ class AIAgent:
function_args.get("action", ""), function_args.get("action", ""),
target, target,
function_args.get("content", ""), function_args.get("content", ""),
metadata=self._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=tool_call_id,
),
) )
except Exception: except Exception:
pass pass
@ -8554,6 +8602,10 @@ class AIAgent:
function_args.get("action", ""), function_args.get("action", ""),
target, target,
function_args.get("content", ""), function_args.get("content", ""),
metadata=self._build_memory_write_metadata(
task_id=effective_task_id,
tool_call_id=getattr(tool_call, "id", None),
),
) )
except Exception: except Exception:
pass pass

View file

@ -77,6 +77,13 @@ class FakeMemoryProvider(MemoryProvider):
self.memory_writes.append((action, target, content)) self.memory_writes.append((action, target, content))
class MetadataMemoryProvider(FakeMemoryProvider):
"""Provider that opts into write metadata."""
def on_memory_write(self, action, target, content, metadata=None):
self.memory_writes.append((action, target, content, metadata or {}))
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MemoryProvider ABC tests # MemoryProvider ABC tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@ -862,6 +869,51 @@ class TestOnMemoryWriteBridge:
mgr.on_memory_write("add", "memory", "new fact") mgr.on_memory_write("add", "memory", "new fact")
assert p.memory_writes == [("add", "memory", "new fact")] assert p.memory_writes == [("add", "memory", "new fact")]
def test_on_memory_write_metadata_passed_to_opt_in_provider(self):
"""Providers that accept metadata receive structured write provenance."""
mgr = MemoryManager()
p = MetadataMemoryProvider("ext")
mgr.add_provider(p)
mgr.on_memory_write(
"add",
"memory",
"new fact",
metadata={
"write_origin": "assistant_tool",
"execution_context": "foreground",
"session_id": "sess-1",
},
)
assert p.memory_writes == [
(
"add",
"memory",
"new fact",
{
"write_origin": "assistant_tool",
"execution_context": "foreground",
"session_id": "sess-1",
},
)
]
def test_on_memory_write_metadata_keeps_legacy_provider_compatible(self):
"""Old 3-arg providers keep working when the manager receives metadata."""
mgr = MemoryManager()
p = FakeMemoryProvider("ext")
mgr.add_provider(p)
mgr.on_memory_write(
"add",
"user",
"legacy provider fact",
metadata={"write_origin": "assistant_tool"},
)
assert p.memory_writes == [("add", "user", "legacy provider fact")]
def test_on_memory_write_replace(self): def test_on_memory_write_replace(self):
"""on_memory_write fires for 'replace' actions.""" """on_memory_write fires for 'replace' actions."""
mgr = MemoryManager() mgr = MemoryManager()

View file

@ -233,6 +233,31 @@ class TestFlushMemoriesUsesAuxiliaryClient:
assert call_kwargs.kwargs["target"] == "notes" assert call_kwargs.kwargs["target"] == "notes"
assert "dark mode" in call_kwargs.kwargs["content"] assert "dark mode" in call_kwargs.kwargs["content"]
def test_flush_bridges_memory_write_metadata(self, monkeypatch):
"""Flush memory writes notify external providers with flush provenance."""
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
agent._memory_manager = MagicMock()
agent.session_id = "sess-flush"
agent.platform = "cli"
mock_response = _chat_response_with_memory_call()
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi"},
{"role": "user", "content": "Note this"},
]
with patch("tools.memory_tool.memory_tool", return_value="Saved."):
agent.flush_memories(messages)
agent._memory_manager.on_memory_write.assert_called_once()
call_kwargs = agent._memory_manager.on_memory_write.call_args
assert call_kwargs.args[:3] == ("add", "notes", "User prefers dark mode.")
assert call_kwargs.kwargs["metadata"]["write_origin"] == "memory_flush"
assert call_kwargs.kwargs["metadata"]["execution_context"] == "flush_memories"
assert call_kwargs.kwargs["metadata"]["session_id"] == "sess-flush"
def test_flush_strips_artifacts_from_messages(self, monkeypatch): def test_flush_strips_artifacts_from_messages(self, monkeypatch):
"""After flush, the flush prompt and any response should be removed from messages.""" """After flush, the flush prompt and any response should be removed from messages."""
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")