From 6a957a74bc03c5269dd34b235d96beb1f27ad02b Mon Sep 17 00:00:00 2001 From: helix4u <4317663+helix4u@users.noreply.github.com> Date: Fri, 24 Apr 2026 13:34:27 -0600 Subject: [PATCH] fix(memory): add write origin metadata --- agent/memory_manager.py | 45 ++++++++++++++++- agent/memory_provider.py | 15 ++++-- run_agent.py | 52 ++++++++++++++++++++ tests/agent/test_memory_provider.py | 52 ++++++++++++++++++++ tests/run_agent/test_flush_memories_codex.py | 25 ++++++++++ 5 files changed, 184 insertions(+), 5 deletions(-) diff --git a/agent/memory_manager.py b/agent/memory_manager.py index 2435c3f24..62cbd6ae1 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -31,6 +31,7 @@ from __future__ import annotations import json import logging import re +import inspect from typing import Any, Dict, List, Optional from agent.memory_provider import MemoryProvider @@ -312,7 +313,39 @@ class MemoryManager: ) return "\n\n".join(parts) - def on_memory_write(self, action: str, target: str, content: str) -> None: + @staticmethod + def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str: + """Return how to pass metadata to a provider's memory-write hook.""" + try: + signature = inspect.signature(provider.on_memory_write) + except (TypeError, ValueError): + return "keyword" + + params = list(signature.parameters.values()) + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params): + return "keyword" + if "metadata" in signature.parameters: + return "keyword" + + accepted = [ + p for p in params + if p.kind in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ] + if len(accepted) >= 4: + return "positional" + return "legacy" + + def on_memory_write( + self, + action: str, + target: str, + content: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: """Notify external providers when the built-in memory tool writes. Skips the builtin provider itself (it's the source of the write). @@ -321,7 +354,15 @@ class MemoryManager: if provider.name == "builtin": continue try: - provider.on_memory_write(action, target, content) + metadata_mode = self._provider_memory_write_metadata_mode(provider) + if metadata_mode == "keyword": + provider.on_memory_write( + action, target, content, metadata=dict(metadata or {}) + ) + elif metadata_mode == "positional": + provider.on_memory_write(action, target, content, dict(metadata or {})) + else: + provider.on_memory_write(action, target, content) except Exception as e: logger.debug( "Memory provider '%s' on_memory_write failed: %s", diff --git a/agent/memory_provider.py b/agent/memory_provider.py index 24593e334..535338f4e 100644 --- a/agent/memory_provider.py +++ b/agent/memory_provider.py @@ -26,7 +26,7 @@ Optional hooks (override to opt in): on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context on_session_end(messages) — end-of-session extraction on_pre_compress(messages) -> str — extract before context compression - on_memory_write(action, target, content) — mirror built-in memory writes + on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes on_delegation(task, result, **kwargs) — parent-side observation of subagent work """ @@ -34,7 +34,7 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional logger = logging.getLogger(__name__) @@ -220,12 +220,21 @@ class MemoryProvider(ABC): should all have ``env_var`` set and this method stays no-op). """ - def on_memory_write(self, action: str, target: str, content: str) -> None: + def on_memory_write( + self, + action: str, + target: str, + content: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> None: """Called when the built-in memory tool writes an entry. action: 'add', 'replace', or 'remove' target: 'memory' or 'user' content: the entry content + metadata: structured provenance for the write, when available. Common + keys include ``write_origin``, ``execution_context``, ``session_id``, + ``parent_session_id``, ``platform``, and ``tool_name``. Use to mirror built-in memory writes to your backend. """ diff --git a/run_agent.py b/run_agent.py index 3daa05db6..22ead36f5 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1437,6 +1437,8 @@ class AIAgent: # Track conversation messages for session logging self._session_messages: List[Dict[str, Any]] = [] + self._memory_write_origin = "assistant_tool" + self._memory_write_context = "foreground" # Cached system prompt -- built once per session, only rebuilt on compression self._cached_system_prompt: Optional[str] = None @@ -3075,7 +3077,10 @@ class AIAgent: quiet_mode=True, platform=self.platform, provider=self.provider, + parent_session_id=self.session_id, ) + review_agent._memory_write_origin = "background_review" + review_agent._memory_write_context = "background_review" review_agent._memory_store = self._memory_store review_agent._memory_enabled = self._memory_enabled review_agent._user_profile_enabled = self._user_profile_enabled @@ -3124,6 +3129,32 @@ class AIAgent: t = threading.Thread(target=_run_review, daemon=True, name="bg-review") t.start() + def _build_memory_write_metadata( + self, + *, + write_origin: Optional[str] = None, + execution_context: Optional[str] = None, + task_id: Optional[str] = None, + tool_call_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Build provenance metadata for external memory-provider mirrors.""" + metadata: Dict[str, Any] = { + "write_origin": write_origin or getattr(self, "_memory_write_origin", "assistant_tool"), + "execution_context": ( + execution_context + or getattr(self, "_memory_write_context", "foreground") + ), + "session_id": self.session_id or "", + "parent_session_id": self._parent_session_id or "", + "platform": self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"), + "tool_name": "memory", + } + if task_id: + metadata["task_id"] = task_id + if tool_call_id: + metadata["tool_call_id"] = tool_call_id + return {k: v for k, v in metadata.items() if v not in (None, "")} + def _apply_persist_user_message_override(self, messages: List[Dict]) -> None: """Rewrite the current-turn user message before persistence/return. @@ -7812,6 +7843,19 @@ class AIAgent: old_text=args.get("old_text"), store=self._memory_store, ) + if self._memory_manager and args.get("action") in ("add", "replace"): + try: + self._memory_manager.on_memory_write( + args.get("action", ""), + flush_target, + args.get("content", ""), + metadata=self._build_memory_write_metadata( + write_origin="memory_flush", + execution_context="flush_memories", + ), + ) + except Exception: + pass if not self.quiet_mode: print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}") except Exception as e: @@ -8043,6 +8087,10 @@ class AIAgent: function_args.get("action", ""), target, function_args.get("content", ""), + metadata=self._build_memory_write_metadata( + task_id=effective_task_id, + tool_call_id=tool_call_id, + ), ) except Exception: pass @@ -8554,6 +8602,10 @@ class AIAgent: function_args.get("action", ""), target, function_args.get("content", ""), + metadata=self._build_memory_write_metadata( + task_id=effective_task_id, + tool_call_id=getattr(tool_call, "id", None), + ), ) except Exception: pass diff --git a/tests/agent/test_memory_provider.py b/tests/agent/test_memory_provider.py index 5cd0d8ab4..ca39da70f 100644 --- a/tests/agent/test_memory_provider.py +++ b/tests/agent/test_memory_provider.py @@ -77,6 +77,13 @@ class FakeMemoryProvider(MemoryProvider): self.memory_writes.append((action, target, content)) +class MetadataMemoryProvider(FakeMemoryProvider): + """Provider that opts into write metadata.""" + + def on_memory_write(self, action, target, content, metadata=None): + self.memory_writes.append((action, target, content, metadata or {})) + + # --------------------------------------------------------------------------- # MemoryProvider ABC tests # --------------------------------------------------------------------------- @@ -862,6 +869,51 @@ class TestOnMemoryWriteBridge: mgr.on_memory_write("add", "memory", "new fact") assert p.memory_writes == [("add", "memory", "new fact")] + def test_on_memory_write_metadata_passed_to_opt_in_provider(self): + """Providers that accept metadata receive structured write provenance.""" + mgr = MemoryManager() + p = MetadataMemoryProvider("ext") + mgr.add_provider(p) + + mgr.on_memory_write( + "add", + "memory", + "new fact", + metadata={ + "write_origin": "assistant_tool", + "execution_context": "foreground", + "session_id": "sess-1", + }, + ) + + assert p.memory_writes == [ + ( + "add", + "memory", + "new fact", + { + "write_origin": "assistant_tool", + "execution_context": "foreground", + "session_id": "sess-1", + }, + ) + ] + + def test_on_memory_write_metadata_keeps_legacy_provider_compatible(self): + """Old 3-arg providers keep working when the manager receives metadata.""" + mgr = MemoryManager() + p = FakeMemoryProvider("ext") + mgr.add_provider(p) + + mgr.on_memory_write( + "add", + "user", + "legacy provider fact", + metadata={"write_origin": "assistant_tool"}, + ) + + assert p.memory_writes == [("add", "user", "legacy provider fact")] + def test_on_memory_write_replace(self): """on_memory_write fires for 'replace' actions.""" mgr = MemoryManager() diff --git a/tests/run_agent/test_flush_memories_codex.py b/tests/run_agent/test_flush_memories_codex.py index 04e20402f..4879580be 100644 --- a/tests/run_agent/test_flush_memories_codex.py +++ b/tests/run_agent/test_flush_memories_codex.py @@ -233,6 +233,31 @@ class TestFlushMemoriesUsesAuxiliaryClient: assert call_kwargs.kwargs["target"] == "notes" assert "dark mode" in call_kwargs.kwargs["content"] + def test_flush_bridges_memory_write_metadata(self, monkeypatch): + """Flush memory writes notify external providers with flush provenance.""" + agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter") + agent._memory_manager = MagicMock() + agent.session_id = "sess-flush" + agent.platform = "cli" + + mock_response = _chat_response_with_memory_call() + + with patch("agent.auxiliary_client.call_llm", return_value=mock_response): + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi"}, + {"role": "user", "content": "Note this"}, + ] + with patch("tools.memory_tool.memory_tool", return_value="Saved."): + agent.flush_memories(messages) + + agent._memory_manager.on_memory_write.assert_called_once() + call_kwargs = agent._memory_manager.on_memory_write.call_args + assert call_kwargs.args[:3] == ("add", "notes", "User prefers dark mode.") + assert call_kwargs.kwargs["metadata"]["write_origin"] == "memory_flush" + assert call_kwargs.kwargs["metadata"]["execution_context"] == "flush_memories" + assert call_kwargs.kwargs["metadata"]["session_id"] == "sess-flush" + def test_flush_strips_artifacts_from_messages(self, monkeypatch): """After flush, the flush prompt and any response should be removed from messages.""" agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")