refactor(memory): drop on_session_reset — commit-only is enough

OV transparently handles message history across /new and /compress: old
messages stay in the same session and extraction is idempotent, so there's
no need to rebind providers to a new session_id. The only thing the
session boundary actually needs is to trigger extraction.

- MemoryProvider / MemoryManager: remove on_session_reset hook
- OpenViking: remove on_session_reset override (nothing to do)
- AIAgent: replace rotate_memory_session with commit_memory_session
  (just calls on_session_end, no rebind)
- cli.py / run_agent.py: single commit_memory_session call at the
  session boundary before session_id rotates
- tests: replace on_session_reset coverage with routing tests for
  MemoryManager.on_session_end

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
zhiheng.liu 2026-04-16 00:38:19 +08:00 committed by Teknium
parent 8275fa597a
commit 7cb06e3bb3
6 changed files with 30 additions and 156 deletions

View file

@ -281,22 +281,6 @@ class MemoryManager:
provider.name, e,
)
def on_session_reset(self, new_session_id: str) -> None:
"""Notify all providers of a session reset.
Called after on_session_end() has committed the previous session.
Providers with per-session state override on_session_reset to rebind
it cheaply (default is a no-op on the base class).
"""
for provider in self._providers:
try:
provider.on_session_reset(new_session_id)
except Exception as e:
logger.debug(
"Memory provider '%s' on_session_reset failed: %s",
provider.name, e,
)
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
"""Notify all providers before context compression.

View file

@ -160,15 +160,6 @@ class MemoryProvider(ABC):
(CLI exit, /reset, gateway session expiry).
"""
def on_session_reset(self, new_session_id: str) -> None:
"""Transition to a new session without full teardown.
Called after on_session_end() has committed the previous session
(e.g. /new, context compression). Providers with per-session state
override to rebind counters/IDs while keeping HTTP clients alive.
Default: no-op.
"""
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
"""Called before context compression discards old messages.

10
cli.py
View file

@ -4095,12 +4095,13 @@ class HermesCLI:
def new_session(self, silent=False):
"""Start a fresh session with a new session ID and cleared agent state."""
old_history = self.conversation_history
if self.agent and old_history:
if self.agent and self.conversation_history:
try:
self.agent.flush_memories(old_history)
self.agent.flush_memories(self.conversation_history)
except (Exception, KeyboardInterrupt):
pass
# Trigger memory extraction on the old session before session_id rotates.
self.agent.commit_memory_session(self.conversation_history)
self._notify_session_boundary("on_session_finalize")
elif self.agent:
# First session or empty history — still finalize the old session
@ -4149,9 +4150,6 @@ class HermesCLI:
)
except Exception:
pass
# Commit the old session and rebind memory providers to the
# new session_id so subsequent turns are tracked correctly.
self.agent.rotate_memory_session(self.session_id, old_history)
self._notify_session_boundary("on_session_reset")
if not silent:

View file

@ -516,22 +516,6 @@ class OpenVikingMemoryProvider(MemoryProvider):
except Exception as e:
return tool_error(str(e))
def on_session_reset(self, new_session_id: str) -> None:
"""Rebind per-session state to new_session_id. OV auto-creates the
session when the first message is added, so no create call here."""
for t in (self._sync_thread, self._prefetch_thread):
if t and t.is_alive():
t.join(timeout=5.0)
self._session_id = new_session_id
self._turn_count = 0
self._prefetch_result = ""
self._sync_thread = None
self._prefetch_thread = None
global _last_active_provider
_last_active_provider = self
def shutdown(self) -> None:
# Wait for background threads to finish
for t in (self._sync_thread, self._prefetch_thread):

View file

@ -3040,15 +3040,15 @@ class AIAgent:
except Exception:
pass
def rotate_memory_session(self, new_session_id: str, messages: list = None) -> None:
"""Commit the current memory session, then rebind providers to
new_session_id. Keeps HTTP clients/state alive across the transition.
Called when session_id rotates (e.g. /new, context compression)."""
def commit_memory_session(self, messages: list = None) -> None:
"""Trigger end-of-session extraction without tearing providers down.
Called when session_id rotates (e.g. /new, context compression);
providers keep their state and continue running under the old
session_id they just flush pending extraction now."""
if not self._memory_manager:
return
try:
self._memory_manager.on_session_end(messages or [])
self._memory_manager.on_session_reset(new_session_id)
except Exception:
pass
@ -6838,11 +6838,11 @@ class AIAgent:
try:
# Propagate title to the new session with auto-numbering
old_title = self._session_db.get_session_title(self.session_id)
# Trigger memory extraction on the old session before it rotates.
self.commit_memory_session(messages)
self._session_db.end_session(self.session_id, "compression")
old_session_id = self.session_id
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
# Commit the old memory session and rebind providers to the new one.
self.rotate_memory_session(self.session_id, messages)
# Update session_log_file to point to the new session's JSON file
self.session_log_file = self.logs_dir / f"session_{self.session_id}.json"
self._session_db.create_session(

View file

@ -698,124 +698,41 @@ class TestMemoryContextFencing:
# ---------------------------------------------------------------------------
# MemoryManager.on_session_reset() tests
# AIAgent.commit_memory_session — routes to MemoryManager.on_session_end
# ---------------------------------------------------------------------------
class ResettableProvider(FakeMemoryProvider):
"""Provider that records on_session_reset calls for assertions."""
class _CommitRecorder(FakeMemoryProvider):
"""Provider that records on_session_end calls for assertions."""
def __init__(self, name="resettable"):
def __init__(self, name="recorder"):
super().__init__(name)
self.reset_session_calls = []
self.end_calls = []
def on_session_reset(self, new_session_id: str) -> None:
self.reset_session_calls.append(new_session_id)
def on_session_end(self, messages):
self.end_calls.append(list(messages or []))
class TestMemoryManagerOnSessionReset:
def test_fans_out_to_all_providers(self):
class TestCommitMemorySessionRouting:
def test_on_session_end_fans_out(self):
mgr = MemoryManager()
builtin = ResettableProvider("builtin")
external = ResettableProvider("openviking")
builtin = _CommitRecorder("builtin")
external = _CommitRecorder("openviking")
mgr.add_provider(builtin)
mgr.add_provider(external)
mgr.on_session_reset("new-session-123")
msgs = [{"role": "user", "content": "hi"}]
mgr.on_session_end(msgs)
assert builtin.reset_session_calls == ["new-session-123"]
assert external.reset_session_calls == ["new-session-123"]
assert builtin.end_calls == [msgs]
assert external.end_calls == [msgs]
def test_base_default_is_noop(self):
"""Providers that don't override on_session_reset get the base no-op."""
def test_on_session_end_tolerates_failure(self):
mgr = MemoryManager()
builtin = FakeMemoryProvider("builtin")
external = FakeMemoryProvider("honcho")
mgr.add_provider(builtin)
mgr.add_provider(external)
# Must not raise — default is a no-op
mgr.on_session_reset("noop-session")
assert not external.initialized
def test_tolerates_provider_failure(self):
mgr = MemoryManager()
builtin = FakeMemoryProvider("builtin")
bad = ResettableProvider("bad-provider")
def _explode(new_sid):
raise RuntimeError("network error")
bad.on_session_reset = _explode
bad = _CommitRecorder("bad-provider")
bad.on_session_end = lambda m: (_ for _ in ()).throw(RuntimeError("boom"))
mgr.add_provider(builtin)
mgr.add_provider(bad)
mgr.on_session_reset("safe-session") # must not raise
def test_no_providers_is_noop(self):
mgr = MemoryManager()
mgr.on_session_reset("empty-session") # must not raise
# ---------------------------------------------------------------------------
# OpenVikingMemoryProvider.on_session_reset() tests
# ---------------------------------------------------------------------------
class TestOpenVikingOnSessionReset:
"""Unit tests for the cheap session-transition path in the OV plugin."""
def _make_provider(self):
try:
from plugins.memory.openviking import OpenVikingMemoryProvider
except ImportError:
pytest.skip("openviking plugin not importable")
provider = OpenVikingMemoryProvider()
provider._session_id = "old-session"
provider._turn_count = 5
provider._prefetch_result = "cached result"
provider._sync_thread = None
provider._prefetch_thread = None
mock_client = MagicMock()
mock_client.post.return_value = {}
provider._client = mock_client
return provider, mock_client
def test_reset_updates_session_id(self):
provider, _ = self._make_provider()
provider.on_session_reset("new-session-abc")
assert provider._session_id == "new-session-abc"
def test_reset_clears_per_session_state(self):
provider, _ = self._make_provider()
provider.on_session_reset("new-session-xyz")
assert provider._turn_count == 0
assert provider._prefetch_result == ""
assert provider._sync_thread is None
assert provider._prefetch_thread is None
def test_reset_does_not_create_ov_session(self):
"""OV auto-creates on first message; reset must not POST /sessions."""
provider, mock_client = self._make_provider()
provider.on_session_reset("new-session-post")
mock_client.post.assert_not_called()
def test_reset_without_client_is_safe(self):
try:
from plugins.memory.openviking import OpenVikingMemoryProvider
except ImportError:
pytest.skip("openviking plugin not importable")
provider = OpenVikingMemoryProvider()
provider._client = None
provider._session_id = "old"
provider._turn_count = 3
provider._sync_thread = None
provider._prefetch_thread = None
provider._prefetch_result = ""
provider.on_session_reset("new-no-client")
assert provider._session_id == "new-no-client"
assert provider._turn_count == 0
mgr.on_session_end([]) # must not raise