From aa267b472a9369fd59b96d67458434248372a4e6 Mon Sep 17 00:00:00 2001 From: LeonSGP43 Date: Fri, 17 Apr 2026 08:59:07 +0800 Subject: [PATCH] fix(gateway): dispatch memory session-end hooks on expiry --- gateway/run.py | 46 +++++++++++++++++++ .../gateway/test_flush_memory_stale_guard.py | 26 +++++++++++ 2 files changed, 72 insertions(+) diff --git a/gateway/run.py b/gateway/run.py index ba7ea43ad..81cc2867b 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -740,6 +740,47 @@ class GatewayRunner: # ----------------------------------------------------------------- + def _notify_session_end_memory_hooks( + self, + history: list[dict[str, Any]], + *, + old_session_id: str, + session_key: Optional[str] = None, + ) -> None: + """Dispatch gateway session-end hooks to the live cached agent, if any.""" + if not history or not session_key: + return + + live_agent = None + _cache_lock = getattr(self, "_agent_cache_lock", None) + _agent_cache = getattr(self, "_agent_cache", None) + if _cache_lock is not None and _agent_cache is not None: + with _cache_lock: + _cached = _agent_cache.get(session_key) + live_agent = ( + _cached[0] + if isinstance(_cached, tuple) + else _cached if _cached else None + ) + + if live_agent is None: + _running_agents = getattr(self, "_running_agents", None) + if isinstance(_running_agents, dict): + live_agent = _running_agents.get(session_key) + + _memory_manager = getattr(live_agent, "_memory_manager", None) + if _memory_manager is None: + return + + try: + _memory_manager.on_session_end(history) + except Exception as exc: + logger.warning( + "Gateway on_session_end dispatch failed for session %s: %s", + old_session_id, + exc, + ) + def _flush_memories_for_session( self, old_session_id: str, @@ -758,6 +799,11 @@ class GatewayRunner: try: history = self.session_store.load_transcript(old_session_id) + self._notify_session_end_memory_hooks( + history or [], + old_session_id=old_session_id, + session_key=session_key, + ) if not history or len(history) < 4: return diff --git a/tests/gateway/test_flush_memory_stale_guard.py b/tests/gateway/test_flush_memory_stale_guard.py index c4e4e1fb6..4b7884520 100644 --- a/tests/gateway/test_flush_memory_stale_guard.py +++ b/tests/gateway/test_flush_memory_stale_guard.py @@ -238,3 +238,29 @@ class TestFlushPromptStructure: assert "Save any important facts" in flush_prompt assert "consider saving it as a skill" in flush_prompt assert "Do NOT respond to the user" in flush_prompt + + +class TestSessionEndHooks: + """Gateway session expiry/reset should dispatch live memory hooks.""" + + def test_live_memory_manager_receives_on_session_end(self, tmp_path, monkeypatch): + runner, tmp_agent, _ = _make_flush_context(monkeypatch) + + live_agent = MagicMock() + live_agent._memory_manager = MagicMock() + + import threading + + runner._agent_cache_lock = threading.Lock() + runner._agent_cache = {"session-key": (live_agent, "sig")} + runner._running_agents = {} + + with ( + patch("gateway.run._resolve_runtime_agent_kwargs", return_value={"api_key": "k"}), + patch("gateway.run._resolve_gateway_model", return_value="test-model"), + patch.dict("sys.modules", {"tools.memory_tool": MagicMock(get_memory_dir=lambda: tmp_path)}), + ): + runner._flush_memories_for_session("session_hooked", session_key="session-key") + + live_agent._memory_manager.on_session_end.assert_called_once_with(_TRANSCRIPT_4_MSGS) + tmp_agent.run_conversation.assert_called_once()