diff --git a/agent/memory_manager.py b/agent/memory_manager.py index b67724159..6cd1c860b 100644 --- a/agent/memory_manager.py +++ b/agent/memory_manager.py @@ -281,22 +281,6 @@ class MemoryManager: provider.name, e, ) - def on_session_reset(self, new_session_id: str) -> None: - """Notify all providers of a session reset. - - Called after on_session_end() has committed the previous session. - Providers with per-session state override on_session_reset to rebind - it cheaply (default is a no-op on the base class). - """ - for provider in self._providers: - try: - provider.on_session_reset(new_session_id) - except Exception as e: - logger.debug( - "Memory provider '%s' on_session_reset failed: %s", - provider.name, e, - ) - def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Notify all providers before context compression. diff --git a/agent/memory_provider.py b/agent/memory_provider.py index 9c6f0225c..24593e334 100644 --- a/agent/memory_provider.py +++ b/agent/memory_provider.py @@ -160,15 +160,6 @@ class MemoryProvider(ABC): (CLI exit, /reset, gateway session expiry). """ - def on_session_reset(self, new_session_id: str) -> None: - """Transition to a new session without full teardown. - - Called after on_session_end() has committed the previous session - (e.g. /new, context compression). Providers with per-session state - override to rebind counters/IDs while keeping HTTP clients alive. - Default: no-op. - """ - def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: """Called before context compression discards old messages. diff --git a/cli.py b/cli.py index a00eaf970..fbc8f8525 100644 --- a/cli.py +++ b/cli.py @@ -4095,12 +4095,13 @@ class HermesCLI: def new_session(self, silent=False): """Start a fresh session with a new session ID and cleared agent state.""" - old_history = self.conversation_history - if self.agent and old_history: + if self.agent and self.conversation_history: try: - self.agent.flush_memories(old_history) + self.agent.flush_memories(self.conversation_history) except (Exception, KeyboardInterrupt): pass + # Trigger memory extraction on the old session before session_id rotates. + self.agent.commit_memory_session(self.conversation_history) self._notify_session_boundary("on_session_finalize") elif self.agent: # First session or empty history — still finalize the old session @@ -4149,9 +4150,6 @@ class HermesCLI: ) except Exception: pass - # Commit the old session and rebind memory providers to the - # new session_id so subsequent turns are tracked correctly. - self.agent.rotate_memory_session(self.session_id, old_history) self._notify_session_boundary("on_session_reset") if not silent: diff --git a/plugins/memory/openviking/__init__.py b/plugins/memory/openviking/__init__.py index 4251927cc..86d7ad5ef 100644 --- a/plugins/memory/openviking/__init__.py +++ b/plugins/memory/openviking/__init__.py @@ -516,22 +516,6 @@ class OpenVikingMemoryProvider(MemoryProvider): except Exception as e: return tool_error(str(e)) - def on_session_reset(self, new_session_id: str) -> None: - """Rebind per-session state to new_session_id. OV auto-creates the - session when the first message is added, so no create call here.""" - for t in (self._sync_thread, self._prefetch_thread): - if t and t.is_alive(): - t.join(timeout=5.0) - - self._session_id = new_session_id - self._turn_count = 0 - self._prefetch_result = "" - self._sync_thread = None - self._prefetch_thread = None - - global _last_active_provider - _last_active_provider = self - def shutdown(self) -> None: # Wait for background threads to finish for t in (self._sync_thread, self._prefetch_thread): diff --git a/run_agent.py b/run_agent.py index a19857bc4..d7d1249be 100644 --- a/run_agent.py +++ b/run_agent.py @@ -3040,15 +3040,15 @@ class AIAgent: except Exception: pass - def rotate_memory_session(self, new_session_id: str, messages: list = None) -> None: - """Commit the current memory session, then rebind providers to - new_session_id. Keeps HTTP clients/state alive across the transition. - Called when session_id rotates (e.g. /new, context compression).""" + def commit_memory_session(self, messages: list = None) -> None: + """Trigger end-of-session extraction without tearing providers down. + Called when session_id rotates (e.g. /new, context compression); + providers keep their state and continue running under the old + session_id — they just flush pending extraction now.""" if not self._memory_manager: return try: self._memory_manager.on_session_end(messages or []) - self._memory_manager.on_session_reset(new_session_id) except Exception: pass @@ -6838,11 +6838,11 @@ class AIAgent: try: # Propagate title to the new session with auto-numbering old_title = self._session_db.get_session_title(self.session_id) + # Trigger memory extraction on the old session before it rotates. + self.commit_memory_session(messages) self._session_db.end_session(self.session_id, "compression") old_session_id = self.session_id self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" - # Commit the old memory session and rebind providers to the new one. - self.rotate_memory_session(self.session_id, messages) # Update session_log_file to point to the new session's JSON file self.session_log_file = self.logs_dir / f"session_{self.session_id}.json" self._session_db.create_session( diff --git a/tests/agent/test_memory_provider.py b/tests/agent/test_memory_provider.py index dc7f4b032..505f40bd5 100644 --- a/tests/agent/test_memory_provider.py +++ b/tests/agent/test_memory_provider.py @@ -698,124 +698,41 @@ class TestMemoryContextFencing: # --------------------------------------------------------------------------- -# MemoryManager.on_session_reset() tests +# AIAgent.commit_memory_session — routes to MemoryManager.on_session_end # --------------------------------------------------------------------------- -class ResettableProvider(FakeMemoryProvider): - """Provider that records on_session_reset calls for assertions.""" +class _CommitRecorder(FakeMemoryProvider): + """Provider that records on_session_end calls for assertions.""" - def __init__(self, name="resettable"): + def __init__(self, name="recorder"): super().__init__(name) - self.reset_session_calls = [] + self.end_calls = [] - def on_session_reset(self, new_session_id: str) -> None: - self.reset_session_calls.append(new_session_id) + def on_session_end(self, messages): + self.end_calls.append(list(messages or [])) -class TestMemoryManagerOnSessionReset: - def test_fans_out_to_all_providers(self): +class TestCommitMemorySessionRouting: + def test_on_session_end_fans_out(self): mgr = MemoryManager() - builtin = ResettableProvider("builtin") - external = ResettableProvider("openviking") + builtin = _CommitRecorder("builtin") + external = _CommitRecorder("openviking") mgr.add_provider(builtin) mgr.add_provider(external) - mgr.on_session_reset("new-session-123") + msgs = [{"role": "user", "content": "hi"}] + mgr.on_session_end(msgs) - assert builtin.reset_session_calls == ["new-session-123"] - assert external.reset_session_calls == ["new-session-123"] + assert builtin.end_calls == [msgs] + assert external.end_calls == [msgs] - def test_base_default_is_noop(self): - """Providers that don't override on_session_reset get the base no-op.""" + def test_on_session_end_tolerates_failure(self): mgr = MemoryManager() builtin = FakeMemoryProvider("builtin") - external = FakeMemoryProvider("honcho") - mgr.add_provider(builtin) - mgr.add_provider(external) - - # Must not raise — default is a no-op - mgr.on_session_reset("noop-session") - assert not external.initialized - - def test_tolerates_provider_failure(self): - mgr = MemoryManager() - builtin = FakeMemoryProvider("builtin") - bad = ResettableProvider("bad-provider") - - def _explode(new_sid): - raise RuntimeError("network error") - - bad.on_session_reset = _explode + bad = _CommitRecorder("bad-provider") + bad.on_session_end = lambda m: (_ for _ in ()).throw(RuntimeError("boom")) mgr.add_provider(builtin) mgr.add_provider(bad) - mgr.on_session_reset("safe-session") # must not raise - - def test_no_providers_is_noop(self): - mgr = MemoryManager() - mgr.on_session_reset("empty-session") # must not raise - - -# --------------------------------------------------------------------------- -# OpenVikingMemoryProvider.on_session_reset() tests -# --------------------------------------------------------------------------- - - -class TestOpenVikingOnSessionReset: - """Unit tests for the cheap session-transition path in the OV plugin.""" - - def _make_provider(self): - try: - from plugins.memory.openviking import OpenVikingMemoryProvider - except ImportError: - pytest.skip("openviking plugin not importable") - - provider = OpenVikingMemoryProvider() - provider._session_id = "old-session" - provider._turn_count = 5 - provider._prefetch_result = "cached result" - provider._sync_thread = None - provider._prefetch_thread = None - - mock_client = MagicMock() - mock_client.post.return_value = {} - provider._client = mock_client - return provider, mock_client - - def test_reset_updates_session_id(self): - provider, _ = self._make_provider() - provider.on_session_reset("new-session-abc") - assert provider._session_id == "new-session-abc" - - def test_reset_clears_per_session_state(self): - provider, _ = self._make_provider() - provider.on_session_reset("new-session-xyz") - assert provider._turn_count == 0 - assert provider._prefetch_result == "" - assert provider._sync_thread is None - assert provider._prefetch_thread is None - - def test_reset_does_not_create_ov_session(self): - """OV auto-creates on first message; reset must not POST /sessions.""" - provider, mock_client = self._make_provider() - provider.on_session_reset("new-session-post") - mock_client.post.assert_not_called() - - def test_reset_without_client_is_safe(self): - try: - from plugins.memory.openviking import OpenVikingMemoryProvider - except ImportError: - pytest.skip("openviking plugin not importable") - - provider = OpenVikingMemoryProvider() - provider._client = None - provider._session_id = "old" - provider._turn_count = 3 - provider._sync_thread = None - provider._prefetch_thread = None - provider._prefetch_result = "" - - provider.on_session_reset("new-no-client") - assert provider._session_id == "new-no-client" - assert provider._turn_count == 0 + mgr.on_session_end([]) # must not raise