mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
fix(gateway): dispatch memory session-end hooks on expiry
This commit is contained in:
parent
764536b684
commit
aa267b472a
2 changed files with 72 additions and 0 deletions
|
|
@ -740,6 +740,47 @@ class GatewayRunner:
|
|||
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
def _notify_session_end_memory_hooks(
|
||||
self,
|
||||
history: list[dict[str, Any]],
|
||||
*,
|
||||
old_session_id: str,
|
||||
session_key: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Dispatch gateway session-end hooks to the live cached agent, if any."""
|
||||
if not history or not session_key:
|
||||
return
|
||||
|
||||
live_agent = None
|
||||
_cache_lock = getattr(self, "_agent_cache_lock", None)
|
||||
_agent_cache = getattr(self, "_agent_cache", None)
|
||||
if _cache_lock is not None and _agent_cache is not None:
|
||||
with _cache_lock:
|
||||
_cached = _agent_cache.get(session_key)
|
||||
live_agent = (
|
||||
_cached[0]
|
||||
if isinstance(_cached, tuple)
|
||||
else _cached if _cached else None
|
||||
)
|
||||
|
||||
if live_agent is None:
|
||||
_running_agents = getattr(self, "_running_agents", None)
|
||||
if isinstance(_running_agents, dict):
|
||||
live_agent = _running_agents.get(session_key)
|
||||
|
||||
_memory_manager = getattr(live_agent, "_memory_manager", None)
|
||||
if _memory_manager is None:
|
||||
return
|
||||
|
||||
try:
|
||||
_memory_manager.on_session_end(history)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Gateway on_session_end dispatch failed for session %s: %s",
|
||||
old_session_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
def _flush_memories_for_session(
|
||||
self,
|
||||
old_session_id: str,
|
||||
|
|
@ -758,6 +799,11 @@ class GatewayRunner:
|
|||
|
||||
try:
|
||||
history = self.session_store.load_transcript(old_session_id)
|
||||
self._notify_session_end_memory_hooks(
|
||||
history or [],
|
||||
old_session_id=old_session_id,
|
||||
session_key=session_key,
|
||||
)
|
||||
if not history or len(history) < 4:
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -238,3 +238,29 @@ class TestFlushPromptStructure:
|
|||
assert "Save any important facts" in flush_prompt
|
||||
assert "consider saving it as a skill" in flush_prompt
|
||||
assert "Do NOT respond to the user" in flush_prompt
|
||||
|
||||
|
||||
class TestSessionEndHooks:
|
||||
"""Gateway session expiry/reset should dispatch live memory hooks."""
|
||||
|
||||
def test_live_memory_manager_receives_on_session_end(self, tmp_path, monkeypatch):
|
||||
runner, tmp_agent, _ = _make_flush_context(monkeypatch)
|
||||
|
||||
live_agent = MagicMock()
|
||||
live_agent._memory_manager = MagicMock()
|
||||
|
||||
import threading
|
||||
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._agent_cache = {"session-key": (live_agent, "sig")}
|
||||
runner._running_agents = {}
|
||||
|
||||
with (
|
||||
patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}),
|
||||
patch("gateway.run._resolve_gateway_model", return_value="test-model"),
|
||||
patch.dict("sys.modules", {"tools.memory_tool": MagicMock(get_memory_dir=lambda: tmp_path)}),
|
||||
):
|
||||
runner._flush_memories_for_session("session_hooked", session_key="session-key")
|
||||
|
||||
live_agent._memory_manager.on_session_end.assert_called_once_with(_TRANSCRIPT_4_MSGS)
|
||||
tmp_agent.run_conversation.assert_called_once()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue