refactor(memory): promote on_session_reset to base provider hook

Replace hasattr-forked OpenViking-specific paths with a proper base-class
hook. Collapse the two agent wrappers into a single rotate_memory_session
so callers don't orchestrate commit + rebind themselves.

- MemoryProvider: add on_session_reset(new_session_id) as a default no-op
- MemoryManager: on_session_reset fans out unconditionally (no hasattr,
  no builtin skip — base no-op covers it)
- OpenViking: rename reset_session -> on_session_reset; drop the explicit
  POST /api/v1/sessions (OV auto-creates on first message) and the two
  debug raise_for_status wrappers
- AIAgent: collapse commit_memory_session + reinitialize_memory_session
  into rotate_memory_session(new_sid, messages)
- cli.py / run_agent.py: replace hasattr blocks and the split calls with
  a single unconditional rotate_memory_session call; compression path
  now passes the real messages list instead of []
- tests: align with on_session_reset, assert reset does NOT POST /sessions

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
zhiheng.liu 2026-04-16 00:31:48 +08:00 committed by Teknium
parent 7856d304f2
commit 8275fa597a
6 changed files with 68 additions and 152 deletions

View file

@ -281,25 +281,19 @@ class MemoryManager:
provider.name, e, provider.name, e,
) )
def restart_session(self, new_session_id: str) -> None: def on_session_reset(self, new_session_id: str) -> None:
"""Transition external providers to a new session without full teardown. """Notify all providers of a session reset.
Must be called AFTER on_session_end() has committed the old session. Called after on_session_end() has committed the previous session.
Providers that implement reset_session() are transitioned cheaply Providers with per-session state override on_session_reset to rebind
(HTTP client kept alive); others fall back to a full initialize(). it cheaply (default is a no-op on the base class).
The builtin provider is skipped it has no per-session state.
""" """
for provider in self._providers: for provider in self._providers:
if provider.name == "builtin":
continue
try: try:
if hasattr(provider, "reset_session"): provider.on_session_reset(new_session_id)
provider.reset_session(new_session_id)
else:
provider.initialize(session_id=new_session_id)
except Exception as e: except Exception as e:
logger.debug( logger.debug(
"Memory provider '%s' restart_session failed: %s", "Memory provider '%s' on_session_reset failed: %s",
provider.name, e, provider.name, e,
) )

View file

@ -160,6 +160,15 @@ class MemoryProvider(ABC):
(CLI exit, /reset, gateway session expiry). (CLI exit, /reset, gateway session expiry).
""" """
def on_session_reset(self, new_session_id: str) -> None:
"""Transition to a new session without full teardown.
Called after on_session_end() has committed the previous session
(e.g. /new, context compression). Providers with per-session state
override to rebind counters/IDs while keeping HTTP clients alive.
Default: no-op.
"""
def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str: def on_pre_compress(self, messages: List[Dict[str, Any]]) -> str:
"""Called before context compression discards old messages. """Called before context compression discards old messages.

22
cli.py
View file

@ -4095,18 +4095,12 @@ class HermesCLI:
def new_session(self, silent=False): def new_session(self, silent=False):
"""Start a fresh session with a new session ID and cleared agent state.""" """Start a fresh session with a new session ID and cleared agent state."""
if self.agent and self.conversation_history: old_history = self.conversation_history
if self.agent and old_history:
try: try:
self.agent.flush_memories(self.conversation_history) self.agent.flush_memories(old_history)
except (Exception, KeyboardInterrupt): except (Exception, KeyboardInterrupt):
pass pass
# Commit external memory providers (e.g. OpenViking) BEFORE
# session_id changes so extraction runs on the correct session.
if hasattr(self.agent, "commit_memory_session"):
try:
self.agent.commit_memory_session(self.conversation_history)
except Exception:
pass
self._notify_session_boundary("on_session_finalize") self._notify_session_boundary("on_session_finalize")
elif self.agent: elif self.agent:
# First session or empty history — still finalize the old session # First session or empty history — still finalize the old session
@ -4155,13 +4149,9 @@ class HermesCLI:
) )
except Exception: except Exception:
pass pass
# Reinitialize external memory providers with the new session_id # Commit the old session and rebind memory providers to the
# so subsequent turns are tracked under the new session. # new session_id so subsequent turns are tracked correctly.
if hasattr(self.agent, "reinitialize_memory_session"): self.agent.rotate_memory_session(self.session_id, old_history)
try:
self.agent.reinitialize_memory_session(self.session_id)
except Exception:
pass
self._notify_session_boundary("on_session_reset") self._notify_session_boundary("on_session_reset")
if not silent: if not silent:

View file

@ -109,12 +109,7 @@ class _VikingClient:
resp = self._httpx.get( resp = self._httpx.get(
self._url(path), headers=self._headers(), timeout=_TIMEOUT, **kwargs self._url(path), headers=self._headers(), timeout=_TIMEOUT, **kwargs
) )
try: resp.raise_for_status()
resp.raise_for_status()
except Exception as e:
logger.debug("OpenViking request failed: %s %s, status: %s, response: %s",
"GET", path, resp.status_code, resp.text)
raise
return resp.json() return resp.json()
def post(self, path: str, payload: dict = None, **kwargs) -> dict: def post(self, path: str, payload: dict = None, **kwargs) -> dict:
@ -122,12 +117,7 @@ class _VikingClient:
self._url(path), json=payload or {}, headers=self._headers(), self._url(path), json=payload or {}, headers=self._headers(),
timeout=_TIMEOUT, **kwargs timeout=_TIMEOUT, **kwargs
) )
try: resp.raise_for_status()
resp.raise_for_status()
except Exception as e:
logger.debug("OpenViking request failed: %s %s, status: %s, response: %s",
"POST", path, resp.status_code, resp.text)
raise
return resp.json() return resp.json()
def health(self) -> bool: def health(self) -> bool:
@ -336,13 +326,6 @@ class OpenVikingMemoryProvider(MemoryProvider):
if not self._client.health(): if not self._client.health():
logger.warning("OpenViking server at %s is not reachable", self._endpoint) logger.warning("OpenViking server at %s is not reachable", self._endpoint)
self._client = None self._client = None
else:
# Explicitly create the session to ensure it exists
try:
self._client.post("/api/v1/sessions", {"session_id": self._session_id})
logger.info("OpenViking session %s created", self._session_id)
except Exception as e:
logger.debug("OpenViking session creation failed (may already exist): %s", e)
except ImportError: except ImportError:
logger.warning("httpx not installed — OpenViking plugin disabled") logger.warning("httpx not installed — OpenViking plugin disabled")
self._client = None self._client = None
@ -533,14 +516,9 @@ class OpenVikingMemoryProvider(MemoryProvider):
except Exception as e: except Exception as e:
return tool_error(str(e)) return tool_error(str(e))
def reset_session(self, new_session_id: str) -> None: def on_session_reset(self, new_session_id: str) -> None:
"""Transition to a new session without tearing down the HTTP client. """Rebind per-session state to new_session_id. OV auto-creates the
session when the first message is added, so no create call here."""
Called by MemoryManager.restart_session() after on_session_end() has
committed the old session (e.g. after CLI /new or context compression).
Lighter than shutdown() + initialize(): keeps the httpx client alive,
resets per-session counters, and creates the new OV session.
"""
for t in (self._sync_thread, self._prefetch_thread): for t in (self._sync_thread, self._prefetch_thread):
if t and t.is_alive(): if t and t.is_alive():
t.join(timeout=5.0) t.join(timeout=5.0)
@ -551,13 +529,6 @@ class OpenVikingMemoryProvider(MemoryProvider):
self._sync_thread = None self._sync_thread = None
self._prefetch_thread = None self._prefetch_thread = None
if self._client:
try:
self._client.post("/api/v1/sessions", {"session_id": self._session_id})
logger.info("OpenViking new session %s created", self._session_id)
except Exception as e:
logger.debug("OpenViking session creation on reset: %s", e)
global _last_active_provider global _last_active_provider
_last_active_provider = self _last_active_provider = self

View file

@ -3040,33 +3040,17 @@ class AIAgent:
except Exception: except Exception:
pass pass
def commit_memory_session(self, messages: list = None) -> None: def rotate_memory_session(self, new_session_id: str, messages: list = None) -> None:
"""Commit external memory providers for the current session. """Commit the current memory session, then rebind providers to
new_session_id. Keeps HTTP clients/state alive across the transition.
Calls on_session_end() WITHOUT shutting down providers the session Called when session_id rotates (e.g. /new, context compression)."""
data (e.g. OpenViking) is committed for extraction, but the HTTP if not self._memory_manager:
client and provider state remain alive for the next session. return
Called before session_id changes (e.g. /new, context compression). try:
""" self._memory_manager.on_session_end(messages or [])
if self._memory_manager: self._memory_manager.on_session_reset(new_session_id)
try: except Exception:
self._memory_manager.on_session_end(messages or []) pass
except Exception:
pass
def reinitialize_memory_session(self, new_session_id: str) -> None:
"""Transition memory providers to a new session after commit.
Calls restart_session() which uses reset_session() on providers that
support it (cheap: keeps HTTP client, resets per-session counters) or
falls back to initialize() for providers that don't.
Called after session_id has been assigned (e.g. /new, compression).
"""
if self._memory_manager:
try:
self._memory_manager.restart_session(new_session_id)
except Exception:
pass
def close(self) -> None: def close(self) -> None:
"""Release all resources held by this agent instance. """Release all resources held by this agent instance.
@ -6854,14 +6838,11 @@ class AIAgent:
try: try:
# Propagate title to the new session with auto-numbering # Propagate title to the new session with auto-numbering
old_title = self._session_db.get_session_title(self.session_id) old_title = self._session_db.get_session_title(self.session_id)
# Commit external memory (e.g. OpenViking) before session_id
# changes so extraction runs on the correct session.
self.commit_memory_session([])
self._session_db.end_session(self.session_id, "compression") self._session_db.end_session(self.session_id, "compression")
old_session_id = self.session_id old_session_id = self.session_id
self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}" self.session_id = f"{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:6]}"
# Transition external memory providers to the new session_id. # Commit the old memory session and rebind providers to the new one.
self.reinitialize_memory_session(self.session_id) self.rotate_memory_session(self.session_id, messages)
# Update session_log_file to point to the new session's JSON file # Update session_log_file to point to the new session's JSON file
self.session_log_file = self.logs_dir / f"session_{self.session_id}.json" self.session_log_file = self.logs_dir / f"session_{self.session_id}.json"
self._session_db.create_session( self._session_db.create_session(

View file

@ -698,61 +698,47 @@ class TestMemoryContextFencing:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# MemoryManager.restart_session() tests # MemoryManager.on_session_reset() tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class ResettableProvider(FakeMemoryProvider): class ResettableProvider(FakeMemoryProvider):
"""Provider that implements reset_session() for cheap session transitions.""" """Provider that records on_session_reset calls for assertions."""
def __init__(self, name="resettable"): def __init__(self, name="resettable"):
super().__init__(name) super().__init__(name)
self.reset_session_calls = [] self.reset_session_calls = []
def reset_session(self, new_session_id: str) -> None: def on_session_reset(self, new_session_id: str) -> None:
self.reset_session_calls.append(new_session_id) self.reset_session_calls.append(new_session_id)
class TestMemoryManagerRestartSession: class TestMemoryManagerOnSessionReset:
def test_restart_calls_reset_session_on_external(self): def test_fans_out_to_all_providers(self):
"""restart_session() calls reset_session() on external providers that have it."""
mgr = MemoryManager() mgr = MemoryManager()
builtin = FakeMemoryProvider("builtin") builtin = ResettableProvider("builtin")
external = ResettableProvider("openviking") external = ResettableProvider("openviking")
mgr.add_provider(builtin) mgr.add_provider(builtin)
mgr.add_provider(external) mgr.add_provider(external)
mgr.restart_session("new-session-123") mgr.on_session_reset("new-session-123")
assert builtin.reset_session_calls == ["new-session-123"]
assert external.reset_session_calls == ["new-session-123"] assert external.reset_session_calls == ["new-session-123"]
# builtin is skipped — it has no per-session state
assert not hasattr(builtin, "reset_session_calls")
def test_restart_skips_builtin(self): def test_base_default_is_noop(self):
"""restart_session() does not call anything on the builtin provider.""" """Providers that don't override on_session_reset get the base no-op."""
mgr = MemoryManager()
builtin = ResettableProvider("builtin")
mgr.add_provider(builtin)
mgr.restart_session("new-session-456")
assert builtin.reset_session_calls == []
def test_restart_falls_back_to_initialize(self):
"""restart_session() calls initialize() when provider has no reset_session()."""
mgr = MemoryManager() mgr = MemoryManager()
builtin = FakeMemoryProvider("builtin") builtin = FakeMemoryProvider("builtin")
external = FakeMemoryProvider("honcho") external = FakeMemoryProvider("honcho")
mgr.add_provider(builtin) mgr.add_provider(builtin)
mgr.add_provider(external) mgr.add_provider(external)
mgr.restart_session("fallback-session") # Must not raise — default is a no-op
mgr.on_session_reset("noop-session")
assert not external.initialized
assert external.initialized def test_tolerates_provider_failure(self):
assert external._init_kwargs["session_id"] == "fallback-session"
def test_restart_tolerates_provider_failure(self):
"""restart_session() swallows failures so other providers are still called."""
mgr = MemoryManager() mgr = MemoryManager()
builtin = FakeMemoryProvider("builtin") builtin = FakeMemoryProvider("builtin")
bad = ResettableProvider("bad-provider") bad = ResettableProvider("bad-provider")
@ -760,32 +746,26 @@ class TestMemoryManagerRestartSession:
def _explode(new_sid): def _explode(new_sid):
raise RuntimeError("network error") raise RuntimeError("network error")
bad.reset_session = _explode bad.on_session_reset = _explode
good = ResettableProvider("good-provider")
# Register bad provider first, but only one external is allowed —
# so test both providers by using the fallback path.
mgr.add_provider(builtin) mgr.add_provider(builtin)
mgr.add_provider(bad) mgr.add_provider(bad)
# Calling restart_session should not raise even though the provider fails. mgr.on_session_reset("safe-session") # must not raise
mgr.restart_session("safe-session")
def test_restart_no_providers_is_noop(self): def test_no_providers_is_noop(self):
"""restart_session() on an empty manager does not raise."""
mgr = MemoryManager() mgr = MemoryManager()
mgr.restart_session("empty-session") # must not raise mgr.on_session_reset("empty-session") # must not raise
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# OpenVikingMemoryProvider.reset_session() tests # OpenVikingMemoryProvider.on_session_reset() tests
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
class TestOpenVikingResetSession: class TestOpenVikingOnSessionReset:
"""Unit tests for the cheap session-transition path in the OV plugin.""" """Unit tests for the cheap session-transition path in the OV plugin."""
def _make_provider(self): def _make_provider(self):
"""Return an OpenVikingMemoryProvider with a mock _client."""
try: try:
from plugins.memory.openviking import OpenVikingMemoryProvider from plugins.memory.openviking import OpenVikingMemoryProvider
except ImportError: except ImportError:
@ -805,33 +785,24 @@ class TestOpenVikingResetSession:
def test_reset_updates_session_id(self): def test_reset_updates_session_id(self):
provider, _ = self._make_provider() provider, _ = self._make_provider()
provider.reset_session("new-session-abc") provider.on_session_reset("new-session-abc")
assert provider._session_id == "new-session-abc" assert provider._session_id == "new-session-abc"
def test_reset_clears_per_session_state(self): def test_reset_clears_per_session_state(self):
provider, _ = self._make_provider() provider, _ = self._make_provider()
provider.reset_session("new-session-xyz") provider.on_session_reset("new-session-xyz")
assert provider._turn_count == 0 assert provider._turn_count == 0
assert provider._prefetch_result == "" assert provider._prefetch_result == ""
assert provider._sync_thread is None assert provider._sync_thread is None
assert provider._prefetch_thread is None assert provider._prefetch_thread is None
def test_reset_creates_new_ov_session(self): def test_reset_does_not_create_ov_session(self):
"""OV auto-creates on first message; reset must not POST /sessions."""
provider, mock_client = self._make_provider() provider, mock_client = self._make_provider()
provider.reset_session("new-session-post") provider.on_session_reset("new-session-post")
mock_client.post.assert_called_once_with( mock_client.post.assert_not_called()
"/api/v1/sessions", {"session_id": "new-session-post"}
)
def test_reset_tolerates_ov_api_failure(self): def test_reset_without_client_is_safe(self):
provider, mock_client = self._make_provider()
mock_client.post.side_effect = RuntimeError("connection refused")
# Must not raise — OV API failure is non-fatal for the reset path
provider.reset_session("no-server-session")
assert provider._session_id == "no-server-session"
def test_reset_without_client_is_noop(self):
"""reset_session() works even if provider was never initialized (no client)."""
try: try:
from plugins.memory.openviking import OpenVikingMemoryProvider from plugins.memory.openviking import OpenVikingMemoryProvider
except ImportError: except ImportError:
@ -845,6 +816,6 @@ class TestOpenVikingResetSession:
provider._prefetch_thread = None provider._prefetch_thread = None
provider._prefetch_result = "" provider._prefetch_result = ""
provider.reset_session("new-no-client") provider.on_session_reset("new-no-client")
assert provider._session_id == "new-no-client" assert provider._session_id == "new-no-client"
assert provider._turn_count == 0 assert provider._turn_count == 0