mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(memory): add write origin metadata
This commit is contained in:
parent
14b27bb68c
commit
6a957a74bc
5 changed files with 184 additions and 5 deletions
|
|
@ -31,6 +31,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
import inspect
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.memory_provider import MemoryProvider
|
||||
|
|
@ -312,7 +313,39 @@ class MemoryManager:
|
|||
)
|
||||
return "\n\n".join(parts)
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
@staticmethod
|
||||
def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str:
|
||||
"""Return how to pass metadata to a provider's memory-write hook."""
|
||||
try:
|
||||
signature = inspect.signature(provider.on_memory_write)
|
||||
except (TypeError, ValueError):
|
||||
return "keyword"
|
||||
|
||||
params = list(signature.parameters.values())
|
||||
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
|
||||
return "keyword"
|
||||
if "metadata" in signature.parameters:
|
||||
return "keyword"
|
||||
|
||||
accepted = [
|
||||
p for p in params
|
||||
if p.kind in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
inspect.Parameter.KEYWORD_ONLY,
|
||||
)
|
||||
]
|
||||
if len(accepted) >= 4:
|
||||
return "positional"
|
||||
return "legacy"
|
||||
|
||||
def on_memory_write(
|
||||
self,
|
||||
action: str,
|
||||
target: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Notify external providers when the built-in memory tool writes.
|
||||
|
||||
Skips the builtin provider itself (it's the source of the write).
|
||||
|
|
@ -321,7 +354,15 @@ class MemoryManager:
|
|||
if provider.name == "builtin":
|
||||
continue
|
||||
try:
|
||||
provider.on_memory_write(action, target, content)
|
||||
metadata_mode = self._provider_memory_write_metadata_mode(provider)
|
||||
if metadata_mode == "keyword":
|
||||
provider.on_memory_write(
|
||||
action, target, content, metadata=dict(metadata or {})
|
||||
)
|
||||
elif metadata_mode == "positional":
|
||||
provider.on_memory_write(action, target, content, dict(metadata or {}))
|
||||
else:
|
||||
provider.on_memory_write(action, target, content)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
"Memory provider '%s' on_memory_write failed: %s",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ Optional hooks (override to opt in):
|
|||
on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context
|
||||
on_session_end(messages) — end-of-session extraction
|
||||
on_pre_compress(messages) -> str — extract before context compression
|
||||
on_memory_write(action, target, content) — mirror built-in memory writes
|
||||
on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes
|
||||
on_delegation(task, result, **kwargs) — parent-side observation of subagent work
|
||||
"""
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -220,12 +220,21 @@ class MemoryProvider(ABC):
|
|||
should all have ``env_var`` set and this method stays no-op).
|
||||
"""
|
||||
|
||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
||||
def on_memory_write(
|
||||
self,
|
||||
action: str,
|
||||
target: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Called when the built-in memory tool writes an entry.
|
||||
|
||||
action: 'add', 'replace', or 'remove'
|
||||
target: 'memory' or 'user'
|
||||
content: the entry content
|
||||
metadata: structured provenance for the write, when available. Common
|
||||
keys include ``write_origin``, ``execution_context``, ``session_id``,
|
||||
``parent_session_id``, ``platform``, and ``tool_name``.
|
||||
|
||||
Use to mirror built-in memory writes to your backend.
|
||||
"""
|
||||
|
|
|
|||
52
run_agent.py
52
run_agent.py
|
|
@ -1437,6 +1437,8 @@ class AIAgent:
|
|||
|
||||
# Track conversation messages for session logging
|
||||
self._session_messages: List[Dict[str, Any]] = []
|
||||
self._memory_write_origin = "assistant_tool"
|
||||
self._memory_write_context = "foreground"
|
||||
|
||||
# Cached system prompt -- built once per session, only rebuilt on compression
|
||||
self._cached_system_prompt: Optional[str] = None
|
||||
|
|
@ -3075,7 +3077,10 @@ class AIAgent:
|
|||
quiet_mode=True,
|
||||
platform=self.platform,
|
||||
provider=self.provider,
|
||||
parent_session_id=self.session_id,
|
||||
)
|
||||
review_agent._memory_write_origin = "background_review"
|
||||
review_agent._memory_write_context = "background_review"
|
||||
review_agent._memory_store = self._memory_store
|
||||
review_agent._memory_enabled = self._memory_enabled
|
||||
review_agent._user_profile_enabled = self._user_profile_enabled
|
||||
|
|
@ -3124,6 +3129,32 @@ class AIAgent:
|
|||
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
||||
t.start()
|
||||
|
||||
def _build_memory_write_metadata(
|
||||
self,
|
||||
*,
|
||||
write_origin: Optional[str] = None,
|
||||
execution_context: Optional[str] = None,
|
||||
task_id: Optional[str] = None,
|
||||
tool_call_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build provenance metadata for external memory-provider mirrors."""
|
||||
metadata: Dict[str, Any] = {
|
||||
"write_origin": write_origin or getattr(self, "_memory_write_origin", "assistant_tool"),
|
||||
"execution_context": (
|
||||
execution_context
|
||||
or getattr(self, "_memory_write_context", "foreground")
|
||||
),
|
||||
"session_id": self.session_id or "",
|
||||
"parent_session_id": self._parent_session_id or "",
|
||||
"platform": self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"),
|
||||
"tool_name": "memory",
|
||||
}
|
||||
if task_id:
|
||||
metadata["task_id"] = task_id
|
||||
if tool_call_id:
|
||||
metadata["tool_call_id"] = tool_call_id
|
||||
return {k: v for k, v in metadata.items() if v not in (None, "")}
|
||||
|
||||
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
|
||||
"""Rewrite the current-turn user message before persistence/return.
|
||||
|
||||
|
|
@ -7812,6 +7843,19 @@ class AIAgent:
|
|||
old_text=args.get("old_text"),
|
||||
store=self._memory_store,
|
||||
)
|
||||
if self._memory_manager and args.get("action") in ("add", "replace"):
|
||||
try:
|
||||
self._memory_manager.on_memory_write(
|
||||
args.get("action", ""),
|
||||
flush_target,
|
||||
args.get("content", ""),
|
||||
metadata=self._build_memory_write_metadata(
|
||||
write_origin="memory_flush",
|
||||
execution_context="flush_memories",
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
if not self.quiet_mode:
|
||||
print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}")
|
||||
except Exception as e:
|
||||
|
|
@ -8043,6 +8087,10 @@ class AIAgent:
|
|||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
metadata=self._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
@ -8554,6 +8602,10 @@ class AIAgent:
|
|||
function_args.get("action", ""),
|
||||
target,
|
||||
function_args.get("content", ""),
|
||||
metadata=self._build_memory_write_metadata(
|
||||
task_id=effective_task_id,
|
||||
tool_call_id=getattr(tool_call, "id", None),
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -77,6 +77,13 @@ class FakeMemoryProvider(MemoryProvider):
|
|||
self.memory_writes.append((action, target, content))
|
||||
|
||||
|
||||
class MetadataMemoryProvider(FakeMemoryProvider):
|
||||
"""Provider that opts into write metadata."""
|
||||
|
||||
def on_memory_write(self, action, target, content, metadata=None):
|
||||
self.memory_writes.append((action, target, content, metadata or {}))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryProvider ABC tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
@ -862,6 +869,51 @@ class TestOnMemoryWriteBridge:
|
|||
mgr.on_memory_write("add", "memory", "new fact")
|
||||
assert p.memory_writes == [("add", "memory", "new fact")]
|
||||
|
||||
def test_on_memory_write_metadata_passed_to_opt_in_provider(self):
|
||||
"""Providers that accept metadata receive structured write provenance."""
|
||||
mgr = MemoryManager()
|
||||
p = MetadataMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write(
|
||||
"add",
|
||||
"memory",
|
||||
"new fact",
|
||||
metadata={
|
||||
"write_origin": "assistant_tool",
|
||||
"execution_context": "foreground",
|
||||
"session_id": "sess-1",
|
||||
},
|
||||
)
|
||||
|
||||
assert p.memory_writes == [
|
||||
(
|
||||
"add",
|
||||
"memory",
|
||||
"new fact",
|
||||
{
|
||||
"write_origin": "assistant_tool",
|
||||
"execution_context": "foreground",
|
||||
"session_id": "sess-1",
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
def test_on_memory_write_metadata_keeps_legacy_provider_compatible(self):
|
||||
"""Old 3-arg providers keep working when the manager receives metadata."""
|
||||
mgr = MemoryManager()
|
||||
p = FakeMemoryProvider("ext")
|
||||
mgr.add_provider(p)
|
||||
|
||||
mgr.on_memory_write(
|
||||
"add",
|
||||
"user",
|
||||
"legacy provider fact",
|
||||
metadata={"write_origin": "assistant_tool"},
|
||||
)
|
||||
|
||||
assert p.memory_writes == [("add", "user", "legacy provider fact")]
|
||||
|
||||
def test_on_memory_write_replace(self):
|
||||
"""on_memory_write fires for 'replace' actions."""
|
||||
mgr = MemoryManager()
|
||||
|
|
|
|||
|
|
@ -233,6 +233,31 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
|||
assert call_kwargs.kwargs["target"] == "notes"
|
||||
assert "dark mode" in call_kwargs.kwargs["content"]
|
||||
|
||||
def test_flush_bridges_memory_write_metadata(self, monkeypatch):
|
||||
"""Flush memory writes notify external providers with flush provenance."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
agent._memory_manager = MagicMock()
|
||||
agent.session_id = "sess-flush"
|
||||
agent.platform = "cli"
|
||||
|
||||
mock_response = _chat_response_with_memory_call()
|
||||
|
||||
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi"},
|
||||
{"role": "user", "content": "Note this"},
|
||||
]
|
||||
with patch("tools.memory_tool.memory_tool", return_value="Saved."):
|
||||
agent.flush_memories(messages)
|
||||
|
||||
agent._memory_manager.on_memory_write.assert_called_once()
|
||||
call_kwargs = agent._memory_manager.on_memory_write.call_args
|
||||
assert call_kwargs.args[:3] == ("add", "notes", "User prefers dark mode.")
|
||||
assert call_kwargs.kwargs["metadata"]["write_origin"] == "memory_flush"
|
||||
assert call_kwargs.kwargs["metadata"]["execution_context"] == "flush_memories"
|
||||
assert call_kwargs.kwargs["metadata"]["session_id"] == "sess-flush"
|
||||
|
||||
def test_flush_strips_artifacts_from_messages(self, monkeypatch):
|
||||
"""After flush, the flush prompt and any response should be removed from messages."""
|
||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue