fix: make session search initialize session db

This commit is contained in:
HenkDz 2026-05-03 17:43:53 +01:00 committed by Teknium
parent 9c26297c80
commit 840ebe063e
8 changed files with 220 additions and 11 deletions

View file

@ -601,6 +601,7 @@ class SessionManager:
),
"quiet_mode": True,
"session_id": session_id,
"session_db": self._get_db(),
"model": model or default_model,
}

View file

@ -199,6 +199,22 @@ def run_oneshot(
return 0
def _create_session_db_for_oneshot():
"""Best-effort SessionDB for ``hermes -z`` / oneshot mode.
Oneshot bypasses ``HermesCLI._init_agent()``, so it must wire the SQLite
session store itself. Without this, the ``session_search``/recall tool is
advertised but every call returns "Session database not available.".
"""
try:
from hermes_state import SessionDB
return SessionDB()
except Exception as exc:
logging.debug("SQLite session store not available for oneshot mode: %s", exc)
return None
def _run_agent(
prompt: str,
model: Optional[str] = None,
@ -284,6 +300,8 @@ def _run_agent(
if toolsets_list is None and use_config_toolsets:
toolsets_list = sorted(_get_platform_tools(cfg, "cli"))
session_db = _create_session_db_for_oneshot()
agent = AIAgent(
api_key=runtime.get("api_key"),
base_url=runtime.get("base_url"),
@ -293,6 +311,7 @@ def _run_agent(
enabled_toolsets=toolsets_list,
quiet_mode=True,
platform="cli",
session_db=session_db,
credential_pool=runtime.get("credential_pool"),
# Interactive callbacks are intentionally NOT wired beyond this
# one. In oneshot mode there's no user sitting at a terminal:

View file

@ -2396,6 +2396,25 @@ class AIAgent:
"is_anthropic_oauth": self._is_anthropic_oauth,
})
def _get_session_db_for_recall(self):
"""Return a SessionDB for recall, lazily creating it if an entrypoint forgot.
Most frontends pass ``session_db`` into ``AIAgent`` explicitly, but recall
is important enough that a missing constructor argument should degrade by
opening the default state DB instead of making the advertised
``session_search`` tool unusable.
"""
if self._session_db is not None:
return self._session_db
try:
from hermes_state import SessionDB
self._session_db = SessionDB()
return self._session_db
except Exception as exc:
logger.debug("SessionDB unavailable for recall", exc_info=True)
return None
def _ensure_db_session(self) -> None:
"""Create session DB row on first use. Disables _session_db on failure."""
if self._session_db_created or not self._session_db:
@ -9920,7 +9939,8 @@ class AIAgent:
store=self._todo_store,
)
elif function_name == "session_search":
if not self._session_db:
session_db = self._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
return json.dumps({"success": False, "error": format_session_db_unavailable()})
from tools.session_search_tool import session_search as _session_search
@ -9928,7 +9948,7 @@ class AIAgent:
query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3),
db=self._session_db,
db=session_db,
current_session_id=self.session_id,
)
elif function_name == "memory":
@ -10544,7 +10564,8 @@ class AIAgent:
if self._should_emit_quiet_tool_messages():
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
elif function_name == "session_search":
if not self._session_db:
session_db = self._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
else:
@ -10553,7 +10574,7 @@ class AIAgent:
query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3),
db=self._session_db,
db=session_db,
current_session_id=self.session_id,
)
tool_duration = time.time() - tool_start_time

View file

@ -1,4 +1,5 @@
from types import SimpleNamespace
import sys
from types import ModuleType, SimpleNamespace
import pytest
from acp.schema import TextContentBlock
@ -66,6 +67,53 @@ def make_agent_and_state():
return acp_agent, state, fake, conn
def test_acp_real_agent_gets_session_db_for_recall(monkeypatch):
"""ACP sessions persist to SessionDB; recall must receive the same DB handle."""
captured = {}
sentinel_db = NoopDb()
class CapturingAgent(FakeAgent):
def __init__(self, **kwargs):
super().__init__()
captured.update(kwargs)
def mod(name, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=CapturingAgent))
monkeypatch.setitem(
sys.modules,
"hermes_cli.config",
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m", "provider": "p"}}),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.runtime_provider",
mod(
"hermes_cli.runtime_provider",
resolve_runtime_provider=lambda **_kwargs: {
"provider": "p",
"api_mode": "chat_completions",
"base_url": "u",
"api_key": "k",
"command": None,
"args": [],
},
),
)
manager = SessionManager(db=sentinel_db)
agent = manager._make_agent(session_id="acp-session", cwd=".")
assert isinstance(agent, CapturingAgent)
assert captured["session_db"] is sentinel_db
assert captured["platform"] == "acp"
assert captured["session_id"] == "acp-session"
@pytest.mark.asyncio
async def test_acp_steer_slash_command_injects_into_running_agent():
acp_agent, state, fake, _conn = make_agent_and_state()

View file

@ -419,6 +419,72 @@ def test_oneshot_distinguishes_disabled_mcp_from_unknown(monkeypatch, capsys):
assert "mcp-off" in err
def test_oneshot_wires_session_db_for_recall(monkeypatch):
"""hermes -z bypasses HermesCLI, but recall still needs SessionDB."""
from hermes_cli.oneshot import _run_agent
captured = {}
sentinel_db = object()
class FakeAgent:
def __init__(self, **kwargs):
captured.update(kwargs)
self.suppress_status_output = False
self.stream_delta_callback = object()
self.tool_gen_callback = object()
def chat(self, prompt):
captured["prompt"] = prompt
return "ok"
class FakeSessionDB:
def __new__(cls):
return sentinel_db
def mod(name, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=FakeAgent))
monkeypatch.setitem(sys.modules, "hermes_state", mod("hermes_state", SessionDB=FakeSessionDB))
monkeypatch.setitem(
sys.modules,
"hermes_cli.config",
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m"}}),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.models",
mod("hermes_cli.models", detect_provider_for_model=lambda *_args, **_kwargs: None),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.runtime_provider",
mod(
"hermes_cli.runtime_provider",
resolve_runtime_provider=lambda **_kwargs: {
"api_key": "k",
"base_url": "u",
"provider": "p",
"api_mode": "chat_completions",
"credential_pool": None,
},
),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.tools_config",
mod("hermes_cli.tools_config", _get_platform_tools=lambda *_args, **_kwargs: {"session_search"}),
)
assert _run_agent("recall this") == "ok"
assert captured["session_db"] is sentinel_db
assert captured["enabled_toolsets"] == ["session_search"]
assert captured["prompt"] == "recall this"
def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
captured = {}
active_path_during_call = None

View file

@ -1,5 +1,7 @@
from types import SimpleNamespace
from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch
import json
import sys
from run_agent import AIAgent
@ -61,3 +63,33 @@ def test_run_conversation_persists_tokens_for_cron_sessions():
assert result["final_response"] == "done"
session_db.update_token_counts.assert_called_once()
assert session_db.update_token_counts.call_args.args[0] == "cron-session"
def test_session_search_lazily_opens_db_when_entrypoint_did_not_pass_one(monkeypatch):
sentinel_db = object()
captured = {}
class FakeSessionDB:
def __new__(cls):
return sentinel_db
hermes_state = ModuleType("hermes_state")
hermes_state.SessionDB = FakeSessionDB
monkeypatch.setitem(sys.modules, "hermes_state", hermes_state)
session_search_mod = ModuleType("tools.session_search_tool")
def fake_session_search(**kwargs):
captured.update(kwargs)
return json.dumps({"success": True, "results": []})
session_search_mod.session_search = fake_session_search
monkeypatch.setitem(sys.modules, "tools.session_search_tool", session_search_mod)
agent = _make_agent(None, platform="acp")
result = json.loads(agent._invoke_tool("session_search", {"query": "Hermes"}, "task-id"))
assert result["success"] is True
assert captured["db"] is sentinel_db
assert captured["query"] == "Hermes"
assert agent._session_db is sentinel_db

View file

@ -309,11 +309,27 @@ class TestRecentSessionListing:
# =========================================================================
class TestSessionSearch:
def test_no_db_returns_error(self):
def test_no_db_lazily_opens_default_session_db(self, monkeypatch):
from unittest.mock import MagicMock
from tools.session_search_tool import session_search
mock_db = MagicMock()
mock_db.search_messages.return_value = []
class FakeSessionDB:
def __new__(cls):
return mock_db
import types
import sys
fake_state = types.ModuleType("hermes_state")
fake_state.SessionDB = FakeSessionDB
monkeypatch.setitem(sys.modules, "hermes_state", fake_state)
result = json.loads(session_search(query="test"))
assert result["success"] is False
assert "not available" in result["error"].lower()
assert result["success"] is True
mock_db.search_messages.assert_called_once()
def test_empty_query_returns_error(self):
from tools.session_search_tool import session_search

View file

@ -337,8 +337,14 @@ def session_search(
The current session is excluded from results since the agent already has that context.
"""
if db is None:
from hermes_state import format_session_db_unavailable
return tool_error(format_session_db_unavailable(), success=False)
try:
from hermes_state import SessionDB
db = SessionDB()
except Exception:
logging.debug("SessionDB unavailable for session_search", exc_info=True)
from hermes_state import format_session_db_unavailable
return tool_error(format_session_db_unavailable(), success=False)
# Defensive: models (especially open-source) may send non-int limit values
# (None when JSON null, string "int", or even a type object). Coerce to a