diff --git a/agent/memory_manager.py b/agent/memory_manager.py index 2831eb7bf8..ea9b7425fc 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -402,6 +402,41 @@ class MemoryManager: provider.name, e, ) + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Notify all providers that the agent's session_id has rotated. + + Fires on ``/resume``, ``/branch``, ``/reset``, ``/new``, and + context compression — any path that reassigns + ``AIAgent.session_id`` without tearing the provider down. + + Providers keep running; they only need to refresh cached + per-session state so subsequent writes land in the correct + session's record. See ``MemoryProvider.on_session_switch`` for + the full contract. + """ + if not new_session_id: + return + for provider in self._providers: + try: + provider.on_session_switch( + new_session_id, + parent_session_id=parent_session_id, + reset=reset, + **kwargs, + ) + except Exception as e: + logger.debug( + "Memory provider '%s' on_session_switch failed: %s", + provider.name, e, + ) + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Notify all providers before context compression. diff --git a/agent/memory_provider.py b/agent/memory_provider.py index 535338f4ee..1c8dbaf682 100644 --- a/agent/memory_provider.py +++ b/agent/memory_provider.py @@ -25,6 +25,7 @@ Lifecycle (called by MemoryManager, wired in run_agent.py): Optional hooks (override to opt in): on_turn_start(turn, message, **kwargs) — per-turn tick with runtime context on_session_end(messages) — end-of-session extraction + on_session_switch(new_session_id, **kwargs) — mid-process session_id rotation on_pre_compress(messages) -> str — extract before context compression on_memory_write(action, target, content, metadata=None) — mirror built-in memory writes on_delegation(task, result, **kwargs) — parent-side observation of subagent work @@ -160,6 +161,45 @@ class MemoryProvider(ABC): (CLI exit, /reset, gateway session expiry). """ + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Called when the agent switches session_id mid-process. + + Fires on ``/resume``, ``/branch``, ``/reset``, ``/new`` (CLI), the + gateway equivalents, and context compression — any path that + reassigns ``AIAgent.session_id`` without tearing the provider down. + + Providers that cache per-session state in ``initialize()`` + (``_session_id``, ``_document_id``, accumulated turn buffers, + counters) should update or reset that state here so subsequent + writes land in the correct session's record. + + Parameters + ---------- + new_session_id: + The session_id the agent just switched to. + parent_session_id: + The previous session_id, if meaningful — set for ``/branch`` + (fork lineage), context compression (continuation lineage), + and ``/resume`` (the session we're leaving). Empty string + when no lineage applies. + reset: + ``True`` when this is a genuinely new conversation, not a + resumption of an existing one. Fired by ``/reset`` / ``/new``. + Providers should flush accumulated per-session buffers + (``_session_turns``, ``_turn_counter``, etc.) when this is + set. ``False`` for ``/resume`` / ``/branch`` / compression + where the logical conversation continues under the new id. + + Default is no-op for backward compatibility. + """ + def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Called before context compression discards old messages. diff --git a/cli.py b/cli.py index 1f2f26fc9f..9c9fe268b3 100644 --- a/cli.py +++ b/cli.py @@ -4809,6 +4809,22 @@ class HermesCLI: ) except Exception: pass + # Notify memory providers that session_id rotated to a fresh + # conversation. reset=True signals providers to flush accumulated + # per-session state (_session_turns, _turn_counter, _document_id). + # Fires BEFORE the plugin on_session_reset hook (shell hooks only + # see the new id; Python providers see the transition). See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + self.session_id, + parent_session_id=old_session_id or "", + reset=True, + reason="new_session", + ) + except Exception: + pass self._notify_session_boundary("on_session_reset") if not silent: @@ -4861,6 +4877,7 @@ class HermesCLI: _cprint(" Already on that session.") return + old_session_id = self.session_id # End current session try: self._session_db.end_session(self.session_id, "resumed_other") @@ -4898,6 +4915,22 @@ class HermesCLI: if hasattr(self.agent, "_invalidate_system_prompt"): self.agent._invalidate_system_prompt() + # Notify memory providers that session_id rotated to a resumed + # session. reset=False — the provider's accumulated state is + # still valid; it just needs to target the new session_id for + # subsequent writes. See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + target_id, + parent_session_id=old_session_id or "", + reset=False, + reason="resume", + ) + except Exception: + pass + title_part = f" \"{session_meta['title']}\"" if session_meta.get("title") else "" msg_count = len([m for m in self.conversation_history if m.get("role") == "user"]) if self.conversation_history: @@ -5018,6 +5051,22 @@ class HermesCLI: if hasattr(self.agent, "_invalidate_system_prompt"): self.agent._invalidate_system_prompt() + # Notify memory providers that session_id forked to a new branch. + # reset=False — the branched session carries the transcript + # forward, so provider state tracks the lineage. parent_session_id + # links the branch back to the original. See #6672. + try: + _mm = getattr(self.agent, "_memory_manager", None) + if _mm is not None: + _mm.on_session_switch( + new_session_id, + parent_session_id=parent_session_id or "", + reset=False, + reason="branch", + ) + except Exception: + pass + msg_count = len([m for m in self.conversation_history if m.get("role") == "user"]) _cprint( f" ⑂ Branched session \"{branch_title}\"" diff --git a/gateway/run.py b/gateway/run.py index a37f72b5ec..4948dbbc16 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -7817,6 +7817,13 @@ class GatewayRunner: return "Failed to switch session." self._clear_session_boundary_security_state(session_key) + # Evict any cached agent for this session so the next message + # rebuilds with the correct session_id end-to-end — mirrors + # /branch and /reset. Without this, the cached AIAgent (and its + # memory provider, which cached `_session_id` during initialize()) + # keeps writing into the wrong session's record. See #6672. + self._evict_cached_agent(session_key) + # Get the title for confirmation title = self._session_db.get_session_title(target_id) or name diff --git a/plugins/memory/hindsight/__init__.py b/plugins/memory/hindsight/__init__.py index 1710b74f6a..31a04d5d4a 100644 --- a/plugins/memory/hindsight/__init__.py +++ b/plugins/memory/hindsight/__init__.py @@ -1325,6 +1325,51 @@ class HindsightMemoryProvider(MemoryProvider): return tool_error(f"Unknown tool: {tool_name}") + def on_session_switch( + self, + new_session_id: str, + *, + parent_session_id: str = "", + reset: bool = False, + **kwargs, + ) -> None: + """Refresh cached per-session state when the agent rotates session_id. + + Fires on /resume, /branch, /reset, /new, and context compression. + Without this hook, initialize()-cached state (``_session_id``, + ``_document_id``, ``_session_turns``, ``_turn_counter``) would keep + pointing at the previous session and writes would land in the wrong + document. See hermes-agent#6672. + + Always update ``_session_id`` so metadata and tags on subsequent + retains reflect the active session. Always mint a fresh + ``_document_id`` so the new session's retain doesn't overwrite the + old session's document on vectorize-io/hindsight#1303. Always clear + the accumulated batch buffers (``_session_turns``, ``_turn_counter``, + ``_turn_index``) — even for /resume and /branch, the new session's + batching must start from zero so an in-flight retain doesn't flush + under the wrong ``_document_id``. + + ``parent_session_id`` is recorded for lineage tags on future retains. + ``reset`` is accepted but not needed for Hindsight's state model — + buffer clearing is correct for every session switch, not only /reset. + """ + new_id = str(new_session_id or "").strip() + if not new_id: + return + if parent_session_id: + self._parent_session_id = str(parent_session_id).strip() + self._session_id = new_id + start_ts = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + self._document_id = f"{self._session_id}-{start_ts}" + self._session_turns = [] + self._turn_counter = 0 + self._turn_index = 0 + logger.debug( + "Hindsight on_session_switch: new_session=%s parent=%s reset=%s doc=%s", + self._session_id, self._parent_session_id, reset, self._document_id, + ) + def shutdown(self) -> None: logger.debug("Hindsight shutdown: waiting for background threads") for t in (self._prefetch_thread, self._sync_thread): diff --git a/run_agent.py b/run_agent.py index 2b386db2ad..895e68644c 100644 --- a/run_agent.py +++ b/run_agent.py @@ -4565,8 +4565,14 @@ class AIAgent: if not (self._memory_manager and final_response and original_user_message): return try: - self._memory_manager.sync_all(original_user_message, final_response) - self._memory_manager.queue_prefetch_all(original_user_message) + self._memory_manager.sync_all( + original_user_message, final_response, + session_id=self.session_id or "", + ) + self._memory_manager.queue_prefetch_all( + original_user_message, + session_id=self.session_id or "", + ) except Exception: pass @@ -8938,6 +8944,23 @@ class AIAgent: except Exception as _ce_err: logger.debug("context engine on_session_start (compression): %s", _ce_err) + # Notify memory providers of the compression-driven session_id rotation + # so provider-cached per-session state (Hindsight's _document_id, + # accumulated turn buffers, counters) refreshes. reset=False because + # the logical conversation continues; only the id and DB row rolled + # over. See #6672. + try: + _old_sid = locals().get("old_session_id") + if _old_sid and self._memory_manager: + self._memory_manager.on_session_switch( + self.session_id or "", + parent_session_id=_old_sid, + reset=False, + reason="compression", + ) + except Exception as _me_err: + logger.debug("memory manager on_session_switch (compression): %s", _me_err) + # Warn on repeated compressions (quality degrades with each pass) _cc = self.context_compressor.compression_count if _cc >= 2: diff --git a/tests/agent/test_memory_session_switch.py b/tests/agent/test_memory_session_switch.py new file mode 100644 index 0000000000..1cf945e738 --- /dev/null +++ b/tests/agent/test_memory_session_switch.py @@ -0,0 +1,282 @@ +"""Tests for the on_session_switch hook and session_id propagation. + +Covers #6672: memory providers must be notified when AIAgent.session_id +rotates mid-process (via /resume, /branch, /reset, /new, or context +compression). Without the notification, providers that cache per-session +state in initialize() (Hindsight, and any plugin that stores session_id +for scoped writes) keep writing into the old session's record. +""" + +import json + +import pytest + +from agent.memory_manager import MemoryManager +from agent.memory_provider import MemoryProvider + + +class _RecordingProvider(MemoryProvider): + """Provider that records every lifecycle call for assertion.""" + + def __init__(self, name="rec"): + self._name = name + self.switch_calls: list[dict] = [] + self.sync_calls: list[dict] = [] + self.queue_calls: list[dict] = [] + self.initialize_calls: list[dict] = [] + + @property + def name(self) -> str: + return self._name + + def is_available(self) -> bool: # pragma: no cover - unused + return True + + def initialize(self, session_id, **kwargs): + self.initialize_calls.append({"session_id": session_id, **kwargs}) + + def get_tool_schemas(self): + return [] + + def sync_turn(self, user_content, assistant_content, *, session_id=""): + self.sync_calls.append( + {"user": user_content, "asst": assistant_content, "session_id": session_id} + ) + + def queue_prefetch(self, query, *, session_id=""): + self.queue_calls.append({"query": query, "session_id": session_id}) + + def on_session_switch( + self, + new_session_id, + *, + parent_session_id="", + reset=False, + **kwargs, + ): + self.switch_calls.append( + { + "new": new_session_id, + "parent": parent_session_id, + "reset": reset, + "extra": kwargs, + } + ) + + +# --------------------------------------------------------------------------- +# MemoryProvider ABC — default on_session_switch is a no-op +# --------------------------------------------------------------------------- + + +class _MinimalProvider(MemoryProvider): + """Provider that does NOT override on_session_switch — ABC default must no-op.""" + + @property + def name(self) -> str: + return "minimal" + + def is_available(self) -> bool: + return True + + def initialize(self, session_id, **kwargs): # pragma: no cover - unused + pass + + def get_tool_schemas(self): + return [] + + +def test_abc_default_on_session_switch_is_noop(): + """Providers that don't override the hook must not raise.""" + p = _MinimalProvider() + # All three call styles must be accepted without raising + p.on_session_switch("new-id") + p.on_session_switch("new-id", parent_session_id="old-id") + p.on_session_switch("new-id", parent_session_id="old-id", reset=True) + p.on_session_switch("new-id", parent_session_id="old-id", reset=True, reason="new_session") + + +# --------------------------------------------------------------------------- +# MemoryManager.on_session_switch — fan-out +# --------------------------------------------------------------------------- + + +def test_manager_fans_out_to_all_providers(): + mm = MemoryManager() + # Only one external provider is allowed; use the builtin slot for p1. + p1 = _RecordingProvider(name="builtin") + p2 = _RecordingProvider(name="hindsight") + mm.add_provider(p1) + mm.add_provider(p2) + + mm.on_session_switch("new-sid", parent_session_id="old-sid", reset=False, reason="resume") + + assert len(p1.switch_calls) == 1 + assert len(p2.switch_calls) == 1 + for call in (p1.switch_calls[0], p2.switch_calls[0]): + assert call["new"] == "new-sid" + assert call["parent"] == "old-sid" + assert call["reset"] is False + assert call["extra"] == {"reason": "resume"} + + +def test_manager_ignores_empty_session_id(): + """Empty string session_id must not trigger provider hooks. + + Prevents accidental fires during shutdown when self.session_id may be + cleared. Providers expect a meaningful id to switch TO. + """ + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.on_session_switch("") + mm.on_session_switch(None) # type: ignore[arg-type] + assert p.switch_calls == [] + + +def test_manager_isolates_provider_failures(): + """A provider that raises must not block other providers.""" + + class _Broken(_RecordingProvider): + def on_session_switch(self, *args, **kwargs): # type: ignore[override] + raise RuntimeError("boom") + + mm = MemoryManager() + # MemoryManager rejects a second external provider, so pair broken + # (builtin slot) with a good external one. + broken = _Broken(name="builtin") + good = _RecordingProvider(name="good") + mm.add_provider(broken) + mm.add_provider(good) + + # Must not raise — exceptions in one provider are swallowed + logged + mm.on_session_switch("new-sid", parent_session_id="old-sid") + assert len(good.switch_calls) == 1 + assert good.switch_calls[0]["new"] == "new-sid" + + +def test_manager_reset_flag_preserved(): + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.on_session_switch("new-sid", reset=True, reason="new_session") + assert p.switch_calls[0]["reset"] is True + assert p.switch_calls[0]["extra"] == {"reason": "new_session"} + + +# --------------------------------------------------------------------------- +# MemoryManager.sync_all / queue_prefetch_all — session_id propagation +# --------------------------------------------------------------------------- + + +def test_sync_all_propagates_session_id_to_providers(): + """run_agent.py's sync_all call must pass session_id through to providers. + + Without this, a provider that updates _session_id defensively in + sync_turn (as Hindsight does at hindsight/__init__.py:1199) never + sees the new id and keeps writing under the old one. + """ + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.sync_all("hello", "world", session_id="sess-42") + assert p.sync_calls == [ + {"user": "hello", "asst": "world", "session_id": "sess-42"} + ] + + +def test_queue_prefetch_all_propagates_session_id_to_providers(): + mm = MemoryManager() + p = _RecordingProvider() + mm.add_provider(p) + mm.queue_prefetch_all("next query", session_id="sess-42") + assert p.queue_calls == [{"query": "next query", "session_id": "sess-42"}] + + +# --------------------------------------------------------------------------- +# Hindsight reference implementation — state-flush semantics +# --------------------------------------------------------------------------- + + +def _make_hindsight_provider(): + """Build a bare HindsightMemoryProvider that skips network setup. + + We instantiate without importing optional deps at class-level by + bypassing __init__ and seeding the attributes on_session_switch + reads/writes. This keeps the test hermetic. + """ + hindsight_mod = pytest.importorskip("plugins.memory.hindsight") + provider = object.__new__(hindsight_mod.HindsightMemoryProvider) + provider._session_id = "old-sid" + provider._parent_session_id = "" + provider._document_id = "old-sid-20260101_000000_000000" + provider._session_turns = ["turn-1", "turn-2"] + provider._turn_counter = 2 + provider._turn_index = 2 + return provider + + +def test_hindsight_on_session_switch_updates_session_id_and_mints_fresh_doc(): + provider = _make_hindsight_provider() + old_doc = provider._document_id + + provider.on_session_switch( + "new-sid", parent_session_id="old-sid", reset=False, reason="resume" + ) + + assert provider._session_id == "new-sid" + assert provider._parent_session_id == "old-sid" + # Document id MUST be fresh — else next retain overwrites old session doc + assert provider._document_id != old_doc + assert provider._document_id.startswith("new-sid-") + + +def test_hindsight_on_session_switch_clears_turn_buffers(): + """Accumulated _session_turns must not leak into the next session. + + Hindsight batches turns under a single _document_id. If the buffer + isn't cleared on switch, the next retain under the new _document_id + flushes turns that belong to the previous session. + """ + provider = _make_hindsight_provider() + provider.on_session_switch("new-sid", parent_session_id="old-sid") + assert provider._session_turns == [] + assert provider._turn_counter == 0 + assert provider._turn_index == 0 + + +def test_hindsight_on_session_switch_clears_on_reset_true(): + """reset=True (from /new, /reset) must also flush buffers.""" + provider = _make_hindsight_provider() + provider.on_session_switch("new-sid", reset=True, reason="new_session") + assert provider._session_id == "new-sid" + assert provider._session_turns == [] + assert provider._turn_counter == 0 + + +def test_hindsight_on_session_switch_ignores_empty_id(): + """Empty new_session_id must be a no-op to avoid corrupting state.""" + provider = _make_hindsight_provider() + before = ( + provider._session_id, + provider._document_id, + list(provider._session_turns), + provider._turn_counter, + ) + provider.on_session_switch("") + provider.on_session_switch(None) # type: ignore[arg-type] + after = ( + provider._session_id, + provider._document_id, + list(provider._session_turns), + provider._turn_counter, + ) + assert before == after + + +def test_hindsight_preserves_parent_across_empty_parent_arg(): + """Omitting parent_session_id must NOT overwrite an existing one.""" + provider = _make_hindsight_provider() + provider._parent_session_id = "original-parent" + provider.on_session_switch("new-sid") # no parent passed + assert provider._parent_session_id == "original-parent" diff --git a/tests/cli/test_branch_command.py b/tests/cli/test_branch_command.py index 581cdbdb6a..5e78815b8f 100644 --- a/tests/cli/test_branch_command.py +++ b/tests/cli/test_branch_command.py @@ -192,6 +192,33 @@ class TestBranchCommandCLI: assert cli_instance._resumed is True + def test_branch_fires_on_session_switch_hook(self, cli_instance, session_db): + """The /branch command must notify memory providers of the rotation. + + Without this, providers that cache per-session state in + initialize() keep writing under the old session_id. See #6672. + """ + from cli import HermesCLI + + # Wire a real-ish agent object with a MagicMock memory_manager + agent = MagicMock() + mm = MagicMock() + agent._memory_manager = mm + cli_instance.agent = agent + original_id = cli_instance.session_id + + HermesCLI._handle_branch_command(cli_instance, "/branch") + + # Hook must have been called exactly once with the new session_id, + # parent pointing at the branched-from session, reset=False, and + # reason="branch" for diagnostics. + assert mm.on_session_switch.call_count == 1 + _, kwargs = mm.on_session_switch.call_args + assert mm.on_session_switch.call_args.args[0] == cli_instance.session_id + assert kwargs["parent_session_id"] == original_id + assert kwargs["reset"] is False + assert kwargs["reason"] == "branch" + def test_fork_alias(self): """The /fork alias should resolve to 'branch'.""" from hermes_cli.commands import resolve_command diff --git a/tests/gateway/test_resume_command.py b/tests/gateway/test_resume_command.py index 42377325e9..0d2060ef31 100644 --- a/tests/gateway/test_resume_command.py +++ b/tests/gateway/test_resume_command.py @@ -230,3 +230,30 @@ class TestHandleResumeCommand: assert real_key not in runner._running_agents db.close() + + @pytest.mark.asyncio + async def test_resume_evicts_cached_agent(self, tmp_path): + """Gateway /resume evicts the cached AIAgent so the next message + rebuilds with the correct session_id end-to-end — mirrors /branch + and /reset. Without this, the cached agent's memory provider keeps + writing into the wrong session. See #6672. + """ + import threading + from hermes_state import SessionDB + db = SessionDB(db_path=tmp_path / "state.db") + db.create_session("old_session", "telegram") + db.set_session_title("old_session", "Old Work") + db.create_session("current_session_001", "telegram") + + event = _make_event(text="/resume Old Work") + runner = _make_runner(session_db=db, current_session_id="current_session_001", + event=event) + # Seed the cache with a fake agent + real_key = _session_key_for_event(event) + runner._agent_cache = {real_key: (MagicMock(), object())} + runner._agent_cache_lock = threading.RLock() + + await runner._handle_resume_command(event) + + assert real_key not in runner._agent_cache + db.close() diff --git a/tests/run_agent/test_memory_sync_interrupted.py b/tests/run_agent/test_memory_sync_interrupted.py index 32313740dc..feeb028927 100644 --- a/tests/run_agent/test_memory_sync_interrupted.py +++ b/tests/run_agent/test_memory_sync_interrupted.py @@ -31,6 +31,10 @@ def _bare_agent(): agent = AIAgent.__new__(AIAgent) agent._memory_manager = MagicMock() + # session_id is now propagated into sync_all / queue_prefetch_all so + # providers that cache per-session state can update it mid-process + # (see #6672). + agent.session_id = "test_session_001" return agent @@ -80,9 +84,11 @@ class TestSyncExternalMemoryForTurn: ) agent._memory_manager.sync_all.assert_called_once_with( "What's the weather in Paris?", "It's sunny and 22°C.", + session_id="test_session_001", ) agent._memory_manager.queue_prefetch_all.assert_called_once_with( "What's the weather in Paris?", + session_id="test_session_001", ) # --- Edge cases (pre-existing behaviour preserved) ------------------