mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(memory): add write origin metadata
This commit is contained in:
parent
14b27bb68c
commit
6a957a74bc
5 changed files with 184 additions and 5 deletions
|
|
@ -31,6 +31,7 @@ from __future__ import annotations
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import inspect
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from agent.memory_provider import MemoryProvider
|
from agent.memory_provider import MemoryProvider
|
||||||
|
|
@ -312,7 +313,39 @@ class MemoryManager:
|
||||||
)
|
)
|
||||||
return "\n\n".join(parts)
|
return "\n\n".join(parts)
|
||||||
|
|
||||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
@staticmethod
|
||||||
|
def _provider_memory_write_metadata_mode(provider: MemoryProvider) -> str:
|
||||||
|
"""Return how to pass metadata to a provider's memory-write hook."""
|
||||||
|
try:
|
||||||
|
signature = inspect.signature(provider.on_memory_write)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return "keyword"
|
||||||
|
|
||||||
|
params = list(signature.parameters.values())
|
||||||
|
if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in params):
|
||||||
|
return "keyword"
|
||||||
|
if "metadata" in signature.parameters:
|
||||||
|
return "keyword"
|
||||||
|
|
||||||
|
accepted = [
|
||||||
|
p for p in params
|
||||||
|
if p.kind in (
|
||||||
|
inspect.Parameter.POSITIONAL_ONLY,
|
||||||
|
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||||
|
inspect.Parameter.KEYWORD_ONLY,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
if len(accepted) >= 4:
|
||||||
|
return "positional"
|
||||||
|
return "legacy"
|
||||||
|
|
||||||
|
def on_memory_write(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
target: str,
|
||||||
|
content: str,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
"""Notify external providers when the built-in memory tool writes.
|
"""Notify external providers when the built-in memory tool writes.
|
||||||
|
|
||||||
Skips the builtin provider itself (it's the source of the write).
|
Skips the builtin provider itself (it's the source of the write).
|
||||||
|
|
@ -321,6 +354,14 @@ class MemoryManager:
|
||||||
if provider.name == "builtin":
|
if provider.name == "builtin":
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
metadata_mode = self._provider_memory_write_metadata_mode(provider)
|
||||||
|
if metadata_mode == "keyword":
|
||||||
|
provider.on_memory_write(
|
||||||
|
action, target, content, metadata=dict(metadata or {})
|
||||||
|
)
|
||||||
|
elif metadata_mode == "positional":
|
||||||
|
provider.on_memory_write(action, target, content, dict(metadata or {}))
|
||||||
|
else:
|
||||||
provider.on_memory_write(action, target, content)
|
provider.on_memory_write(action, target, content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ Optional hooks (override to opt in):
|
||||||
on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context
|
on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context
|
||||||
on_session_end(messages) — end-of-session extraction
|
on_session_end(messages) — end-of-session extraction
|
||||||
on_pre_compress(messages) -> str — extract before context compression
|
on_pre_compress(messages) -> str — extract before context compression
|
||||||
on_memory_write(action, target, content) — mirror built-in memory writes
|
on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes
|
||||||
on_delegation(task, result, **kwargs) — parent-side observation of subagent work
|
on_delegation(task, result, **kwargs) — parent-side observation of subagent work
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
@ -34,7 +34,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -220,12 +220,21 @@ class MemoryProvider(ABC):
|
||||||
should all have ``env_var`` set and this method stays no-op).
|
should all have ``env_var`` set and this method stays no-op).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def on_memory_write(self, action: str, target: str, content: str) -> None:
|
def on_memory_write(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
target: str,
|
||||||
|
content: str,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> None:
|
||||||
"""Called when the built-in memory tool writes an entry.
|
"""Called when the built-in memory tool writes an entry.
|
||||||
|
|
||||||
action: 'add', 'replace', or 'remove'
|
action: 'add', 'replace', or 'remove'
|
||||||
target: 'memory' or 'user'
|
target: 'memory' or 'user'
|
||||||
content: the entry content
|
content: the entry content
|
||||||
|
metadata: structured provenance for the write, when available. Common
|
||||||
|
keys include ``write_origin``, ``execution_context``, ``session_id``,
|
||||||
|
``parent_session_id``, ``platform``, and ``tool_name``.
|
||||||
|
|
||||||
Use to mirror built-in memory writes to your backend.
|
Use to mirror built-in memory writes to your backend.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
52
run_agent.py
52
run_agent.py
|
|
@ -1437,6 +1437,8 @@ class AIAgent:
|
||||||
|
|
||||||
# Track conversation messages for session logging
|
# Track conversation messages for session logging
|
||||||
self._session_messages: List[Dict[str, Any]] = []
|
self._session_messages: List[Dict[str, Any]] = []
|
||||||
|
self._memory_write_origin = "assistant_tool"
|
||||||
|
self._memory_write_context = "foreground"
|
||||||
|
|
||||||
# Cached system prompt -- built once per session, only rebuilt on compression
|
# Cached system prompt -- built once per session, only rebuilt on compression
|
||||||
self._cached_system_prompt: Optional[str] = None
|
self._cached_system_prompt: Optional[str] = None
|
||||||
|
|
@ -3075,7 +3077,10 @@ class AIAgent:
|
||||||
quiet_mode=True,
|
quiet_mode=True,
|
||||||
platform=self.platform,
|
platform=self.platform,
|
||||||
provider=self.provider,
|
provider=self.provider,
|
||||||
|
parent_session_id=self.session_id,
|
||||||
)
|
)
|
||||||
|
review_agent._memory_write_origin = "background_review"
|
||||||
|
review_agent._memory_write_context = "background_review"
|
||||||
review_agent._memory_store = self._memory_store
|
review_agent._memory_store = self._memory_store
|
||||||
review_agent._memory_enabled = self._memory_enabled
|
review_agent._memory_enabled = self._memory_enabled
|
||||||
review_agent._user_profile_enabled = self._user_profile_enabled
|
review_agent._user_profile_enabled = self._user_profile_enabled
|
||||||
|
|
@ -3124,6 +3129,32 @@ class AIAgent:
|
||||||
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
t = threading.Thread(target=_run_review, daemon=True, name="bg-review")
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
def _build_memory_write_metadata(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
write_origin: Optional[str] = None,
|
||||||
|
execution_context: Optional[str] = None,
|
||||||
|
task_id: Optional[str] = None,
|
||||||
|
tool_call_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Build provenance metadata for external memory-provider mirrors."""
|
||||||
|
metadata: Dict[str, Any] = {
|
||||||
|
"write_origin": write_origin or getattr(self, "_memory_write_origin", "assistant_tool"),
|
||||||
|
"execution_context": (
|
||||||
|
execution_context
|
||||||
|
or getattr(self, "_memory_write_context", "foreground")
|
||||||
|
),
|
||||||
|
"session_id": self.session_id or "",
|
||||||
|
"parent_session_id": self._parent_session_id or "",
|
||||||
|
"platform": self.platform or os.environ.get("HERMES_SESSION_SOURCE", "cli"),
|
||||||
|
"tool_name": "memory",
|
||||||
|
}
|
||||||
|
if task_id:
|
||||||
|
metadata["task_id"] = task_id
|
||||||
|
if tool_call_id:
|
||||||
|
metadata["tool_call_id"] = tool_call_id
|
||||||
|
return {k: v for k, v in metadata.items() if v not in (None, "")}
|
||||||
|
|
||||||
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
|
def _apply_persist_user_message_override(self, messages: List[Dict]) -> None:
|
||||||
"""Rewrite the current-turn user message before persistence/return.
|
"""Rewrite the current-turn user message before persistence/return.
|
||||||
|
|
||||||
|
|
@ -7812,6 +7843,19 @@ class AIAgent:
|
||||||
old_text=args.get("old_text"),
|
old_text=args.get("old_text"),
|
||||||
store=self._memory_store,
|
store=self._memory_store,
|
||||||
)
|
)
|
||||||
|
if self._memory_manager and args.get("action") in ("add", "replace"):
|
||||||
|
try:
|
||||||
|
self._memory_manager.on_memory_write(
|
||||||
|
args.get("action", ""),
|
||||||
|
flush_target,
|
||||||
|
args.get("content", ""),
|
||||||
|
metadata=self._build_memory_write_metadata(
|
||||||
|
write_origin="memory_flush",
|
||||||
|
execution_context="flush_memories",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}")
|
print(f" 🧠 Memory flush: saved to {args.get('target', 'memory')}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -8043,6 +8087,10 @@ class AIAgent:
|
||||||
function_args.get("action", ""),
|
function_args.get("action", ""),
|
||||||
target,
|
target,
|
||||||
function_args.get("content", ""),
|
function_args.get("content", ""),
|
||||||
|
metadata=self._build_memory_write_metadata(
|
||||||
|
task_id=effective_task_id,
|
||||||
|
tool_call_id=tool_call_id,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
@ -8554,6 +8602,10 @@ class AIAgent:
|
||||||
function_args.get("action", ""),
|
function_args.get("action", ""),
|
||||||
target,
|
target,
|
||||||
function_args.get("content", ""),
|
function_args.get("content", ""),
|
||||||
|
metadata=self._build_memory_write_metadata(
|
||||||
|
task_id=effective_task_id,
|
||||||
|
tool_call_id=getattr(tool_call, "id", None),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,13 @@ class FakeMemoryProvider(MemoryProvider):
|
||||||
self.memory_writes.append((action, target, content))
|
self.memory_writes.append((action, target, content))
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataMemoryProvider(FakeMemoryProvider):
|
||||||
|
"""Provider that opts into write metadata."""
|
||||||
|
|
||||||
|
def on_memory_write(self, action, target, content, metadata=None):
|
||||||
|
self.memory_writes.append((action, target, content, metadata or {}))
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# MemoryProvider ABC tests
|
# MemoryProvider ABC tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
@ -862,6 +869,51 @@ class TestOnMemoryWriteBridge:
|
||||||
mgr.on_memory_write("add", "memory", "new fact")
|
mgr.on_memory_write("add", "memory", "new fact")
|
||||||
assert p.memory_writes == [("add", "memory", "new fact")]
|
assert p.memory_writes == [("add", "memory", "new fact")]
|
||||||
|
|
||||||
|
def test_on_memory_write_metadata_passed_to_opt_in_provider(self):
|
||||||
|
"""Providers that accept metadata receive structured write provenance."""
|
||||||
|
mgr = MemoryManager()
|
||||||
|
p = MetadataMemoryProvider("ext")
|
||||||
|
mgr.add_provider(p)
|
||||||
|
|
||||||
|
mgr.on_memory_write(
|
||||||
|
"add",
|
||||||
|
"memory",
|
||||||
|
"new fact",
|
||||||
|
metadata={
|
||||||
|
"write_origin": "assistant_tool",
|
||||||
|
"execution_context": "foreground",
|
||||||
|
"session_id": "sess-1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert p.memory_writes == [
|
||||||
|
(
|
||||||
|
"add",
|
||||||
|
"memory",
|
||||||
|
"new fact",
|
||||||
|
{
|
||||||
|
"write_origin": "assistant_tool",
|
||||||
|
"execution_context": "foreground",
|
||||||
|
"session_id": "sess-1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_on_memory_write_metadata_keeps_legacy_provider_compatible(self):
|
||||||
|
"""Old 3-arg providers keep working when the manager receives metadata."""
|
||||||
|
mgr = MemoryManager()
|
||||||
|
p = FakeMemoryProvider("ext")
|
||||||
|
mgr.add_provider(p)
|
||||||
|
|
||||||
|
mgr.on_memory_write(
|
||||||
|
"add",
|
||||||
|
"user",
|
||||||
|
"legacy provider fact",
|
||||||
|
metadata={"write_origin": "assistant_tool"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert p.memory_writes == [("add", "user", "legacy provider fact")]
|
||||||
|
|
||||||
def test_on_memory_write_replace(self):
|
def test_on_memory_write_replace(self):
|
||||||
"""on_memory_write fires for 'replace' actions."""
|
"""on_memory_write fires for 'replace' actions."""
|
||||||
mgr = MemoryManager()
|
mgr = MemoryManager()
|
||||||
|
|
|
||||||
|
|
@ -233,6 +233,31 @@ class TestFlushMemoriesUsesAuxiliaryClient:
|
||||||
assert call_kwargs.kwargs["target"] == "notes"
|
assert call_kwargs.kwargs["target"] == "notes"
|
||||||
assert "dark mode" in call_kwargs.kwargs["content"]
|
assert "dark mode" in call_kwargs.kwargs["content"]
|
||||||
|
|
||||||
|
def test_flush_bridges_memory_write_metadata(self, monkeypatch):
|
||||||
|
"""Flush memory writes notify external providers with flush provenance."""
|
||||||
|
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||||
|
agent._memory_manager = MagicMock()
|
||||||
|
agent.session_id = "sess-flush"
|
||||||
|
agent.platform = "cli"
|
||||||
|
|
||||||
|
mock_response = _chat_response_with_memory_call()
|
||||||
|
|
||||||
|
with patch("agent.auxiliary_client.call_llm", return_value=mock_response):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi"},
|
||||||
|
{"role": "user", "content": "Note this"},
|
||||||
|
]
|
||||||
|
with patch("tools.memory_tool.memory_tool", return_value="Saved."):
|
||||||
|
agent.flush_memories(messages)
|
||||||
|
|
||||||
|
agent._memory_manager.on_memory_write.assert_called_once()
|
||||||
|
call_kwargs = agent._memory_manager.on_memory_write.call_args
|
||||||
|
assert call_kwargs.args[:3] == ("add", "notes", "User prefers dark mode.")
|
||||||
|
assert call_kwargs.kwargs["metadata"]["write_origin"] == "memory_flush"
|
||||||
|
assert call_kwargs.kwargs["metadata"]["execution_context"] == "flush_memories"
|
||||||
|
assert call_kwargs.kwargs["metadata"]["session_id"] == "sess-flush"
|
||||||
|
|
||||||
def test_flush_strips_artifacts_from_messages(self, monkeypatch):
|
def test_flush_strips_artifacts_from_messages(self, monkeypatch):
|
||||||
"""After flush, the flush prompt and any response should be removed from messages."""
|
"""After flush, the flush prompt and any response should be removed from messages."""
|
||||||
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
agent = _make_agent(monkeypatch, api_mode="chat_completions", provider="openrouter")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue