fix: make session search initialize session db

This commit is contained in:
HenkDz 2026-05-03 17:43:53 +01:00 committed by Teknium
parent 9c26297c80
commit 840ebe063e
8 changed files with 220 additions and 11 deletions

View file

@ -601,6 +601,7 @@ class SessionManager:
), ),
"quiet_mode": True, "quiet_mode": True,
"session_id": session_id, "session_id": session_id,
"session_db": self._get_db(),
"model": model or default_model, "model": model or default_model,
} }

View file

@ -199,6 +199,22 @@ def run_oneshot(
return 0 return 0
def _create_session_db_for_oneshot():
"""Best-effort SessionDB for ``hermes -z`` / oneshot mode.
Oneshot bypasses ``HermesCLI._init_agent()``, so it must wire the SQLite
session store itself. Without this, the ``session_search``/recall tool is
advertised but every call returns "Session database not available.".
"""
try:
from hermes_state import SessionDB
return SessionDB()
except Exception as exc:
logging.debug("SQLite session store not available for oneshot mode: %s", exc)
return None
def _run_agent( def _run_agent(
prompt: str, prompt: str,
model: Optional[str] = None, model: Optional[str] = None,
@ -284,6 +300,8 @@ def _run_agent(
if toolsets_list is None and use_config_toolsets: if toolsets_list is None and use_config_toolsets:
toolsets_list = sorted(_get_platform_tools(cfg, "cli")) toolsets_list = sorted(_get_platform_tools(cfg, "cli"))
session_db = _create_session_db_for_oneshot()
agent = AIAgent( agent = AIAgent(
api_key=runtime.get("api_key"), api_key=runtime.get("api_key"),
base_url=runtime.get("base_url"), base_url=runtime.get("base_url"),
@ -293,6 +311,7 @@ def _run_agent(
enabled_toolsets=toolsets_list, enabled_toolsets=toolsets_list,
quiet_mode=True, quiet_mode=True,
platform="cli", platform="cli",
session_db=session_db,
credential_pool=runtime.get("credential_pool"), credential_pool=runtime.get("credential_pool"),
# Interactive callbacks are intentionally NOT wired beyond this # Interactive callbacks are intentionally NOT wired beyond this
# one. In oneshot mode there's no user sitting at a terminal: # one. In oneshot mode there's no user sitting at a terminal:

View file

@ -2396,6 +2396,25 @@ class AIAgent:
"is_anthropic_oauth": self._is_anthropic_oauth, "is_anthropic_oauth": self._is_anthropic_oauth,
}) })
def _get_session_db_for_recall(self):
"""Return a SessionDB for recall, lazily creating it if an entrypoint forgot.
Most frontends pass ``session_db`` into ``AIAgent`` explicitly, but recall
is important enough that a missing constructor argument should degrade by
opening the default state DB instead of making the advertised
``session_search`` tool unusable.
"""
if self._session_db is not None:
return self._session_db
try:
from hermes_state import SessionDB
self._session_db = SessionDB()
return self._session_db
except Exception as exc:
logger.debug("SessionDB unavailable for recall", exc_info=True)
return None
def _ensure_db_session(self) -> None: def _ensure_db_session(self) -> None:
"""Create session DB row on first use. Disables _session_db on failure.""" """Create session DB row on first use. Disables _session_db on failure."""
if self._session_db_created or not self._session_db: if self._session_db_created or not self._session_db:
@ -9920,7 +9939,8 @@ class AIAgent:
store=self._todo_store, store=self._todo_store,
) )
elif function_name == "session_search": elif function_name == "session_search":
if not self._session_db: session_db = self._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable from hermes_state import format_session_db_unavailable
return json.dumps({"success": False, "error": format_session_db_unavailable()}) return json.dumps({"success": False, "error": format_session_db_unavailable()})
from tools.session_search_tool import session_search as _session_search from tools.session_search_tool import session_search as _session_search
@ -9928,7 +9948,7 @@ class AIAgent:
query=function_args.get("query", ""), query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"), role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3), limit=function_args.get("limit", 3),
db=self._session_db, db=session_db,
current_session_id=self.session_id, current_session_id=self.session_id,
) )
elif function_name == "memory": elif function_name == "memory":
@ -10544,7 +10564,8 @@ class AIAgent:
if self._should_emit_quiet_tool_messages(): if self._should_emit_quiet_tool_messages():
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}") self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
elif function_name == "session_search": elif function_name == "session_search":
if not self._session_db: session_db = self._get_session_db_for_recall()
if not session_db:
from hermes_state import format_session_db_unavailable from hermes_state import format_session_db_unavailable
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()}) function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
else: else:
@ -10553,7 +10574,7 @@ class AIAgent:
query=function_args.get("query", ""), query=function_args.get("query", ""),
role_filter=function_args.get("role_filter"), role_filter=function_args.get("role_filter"),
limit=function_args.get("limit", 3), limit=function_args.get("limit", 3),
db=self._session_db, db=session_db,
current_session_id=self.session_id, current_session_id=self.session_id,
) )
tool_duration = time.time() - tool_start_time tool_duration = time.time() - tool_start_time

View file

@ -1,4 +1,5 @@
from types import SimpleNamespace import sys
from types import ModuleType, SimpleNamespace
import pytest import pytest
from acp.schema import TextContentBlock from acp.schema import TextContentBlock
@ -66,6 +67,53 @@ def make_agent_and_state():
return acp_agent, state, fake, conn return acp_agent, state, fake, conn
def test_acp_real_agent_gets_session_db_for_recall(monkeypatch):
"""ACP sessions persist to SessionDB; recall must receive the same DB handle."""
captured = {}
sentinel_db = NoopDb()
class CapturingAgent(FakeAgent):
def __init__(self, **kwargs):
super().__init__()
captured.update(kwargs)
def mod(name, **attrs):
module = ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=CapturingAgent))
monkeypatch.setitem(
sys.modules,
"hermes_cli.config",
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m", "provider": "p"}}),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.runtime_provider",
mod(
"hermes_cli.runtime_provider",
resolve_runtime_provider=lambda **_kwargs: {
"provider": "p",
"api_mode": "chat_completions",
"base_url": "u",
"api_key": "k",
"command": None,
"args": [],
},
),
)
manager = SessionManager(db=sentinel_db)
agent = manager._make_agent(session_id="acp-session", cwd=".")
assert isinstance(agent, CapturingAgent)
assert captured["session_db"] is sentinel_db
assert captured["platform"] == "acp"
assert captured["session_id"] == "acp-session"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_acp_steer_slash_command_injects_into_running_agent(): async def test_acp_steer_slash_command_injects_into_running_agent():
acp_agent, state, fake, _conn = make_agent_and_state() acp_agent, state, fake, _conn = make_agent_and_state()

View file

@ -419,6 +419,72 @@ def test_oneshot_distinguishes_disabled_mcp_from_unknown(monkeypatch, capsys):
assert "mcp-off" in err assert "mcp-off" in err
def test_oneshot_wires_session_db_for_recall(monkeypatch):
"""hermes -z bypasses HermesCLI, but recall still needs SessionDB."""
from hermes_cli.oneshot import _run_agent
captured = {}
sentinel_db = object()
class FakeAgent:
def __init__(self, **kwargs):
captured.update(kwargs)
self.suppress_status_output = False
self.stream_delta_callback = object()
self.tool_gen_callback = object()
def chat(self, prompt):
captured["prompt"] = prompt
return "ok"
class FakeSessionDB:
def __new__(cls):
return sentinel_db
def mod(name, **attrs):
module = types.ModuleType(name)
for key, value in attrs.items():
setattr(module, key, value)
return module
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=FakeAgent))
monkeypatch.setitem(sys.modules, "hermes_state", mod("hermes_state", SessionDB=FakeSessionDB))
monkeypatch.setitem(
sys.modules,
"hermes_cli.config",
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m"}}),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.models",
mod("hermes_cli.models", detect_provider_for_model=lambda *_args, **_kwargs: None),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.runtime_provider",
mod(
"hermes_cli.runtime_provider",
resolve_runtime_provider=lambda **_kwargs: {
"api_key": "k",
"base_url": "u",
"provider": "p",
"api_mode": "chat_completions",
"credential_pool": None,
},
),
)
monkeypatch.setitem(
sys.modules,
"hermes_cli.tools_config",
mod("hermes_cli.tools_config", _get_platform_tools=lambda *_args, **_kwargs: {"session_search"}),
)
assert _run_agent("recall this") == "ok"
assert captured["session_db"] is sentinel_db
assert captured["enabled_toolsets"] == ["session_search"]
assert captured["prompt"] == "recall this"
def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod): def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
captured = {} captured = {}
active_path_during_call = None active_path_during_call = None

View file

@ -1,5 +1,7 @@
from types import SimpleNamespace from types import ModuleType, SimpleNamespace
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import json
import sys
from run_agent import AIAgent from run_agent import AIAgent
@ -61,3 +63,33 @@ def test_run_conversation_persists_tokens_for_cron_sessions():
assert result["final_response"] == "done" assert result["final_response"] == "done"
session_db.update_token_counts.assert_called_once() session_db.update_token_counts.assert_called_once()
assert session_db.update_token_counts.call_args.args[0] == "cron-session" assert session_db.update_token_counts.call_args.args[0] == "cron-session"
def test_session_search_lazily_opens_db_when_entrypoint_did_not_pass_one(monkeypatch):
sentinel_db = object()
captured = {}
class FakeSessionDB:
def __new__(cls):
return sentinel_db
hermes_state = ModuleType("hermes_state")
hermes_state.SessionDB = FakeSessionDB
monkeypatch.setitem(sys.modules, "hermes_state", hermes_state)
session_search_mod = ModuleType("tools.session_search_tool")
def fake_session_search(**kwargs):
captured.update(kwargs)
return json.dumps({"success": True, "results": []})
session_search_mod.session_search = fake_session_search
monkeypatch.setitem(sys.modules, "tools.session_search_tool", session_search_mod)
agent = _make_agent(None, platform="acp")
result = json.loads(agent._invoke_tool("session_search", {"query": "Hermes"}, "task-id"))
assert result["success"] is True
assert captured["db"] is sentinel_db
assert captured["query"] == "Hermes"
assert agent._session_db is sentinel_db

View file

@ -309,11 +309,27 @@ class TestRecentSessionListing:
# ========================================================================= # =========================================================================
class TestSessionSearch: class TestSessionSearch:
def test_no_db_returns_error(self): def test_no_db_lazily_opens_default_session_db(self, monkeypatch):
from unittest.mock import MagicMock
from tools.session_search_tool import session_search from tools.session_search_tool import session_search
mock_db = MagicMock()
mock_db.search_messages.return_value = []
class FakeSessionDB:
def __new__(cls):
return mock_db
import types
import sys
fake_state = types.ModuleType("hermes_state")
fake_state.SessionDB = FakeSessionDB
monkeypatch.setitem(sys.modules, "hermes_state", fake_state)
result = json.loads(session_search(query="test")) result = json.loads(session_search(query="test"))
assert result["success"] is False assert result["success"] is True
assert "not available" in result["error"].lower() mock_db.search_messages.assert_called_once()
def test_empty_query_returns_error(self): def test_empty_query_returns_error(self):
from tools.session_search_tool import session_search from tools.session_search_tool import session_search

View file

@ -337,8 +337,14 @@ def session_search(
The current session is excluded from results since the agent already has that context. The current session is excluded from results since the agent already has that context.
""" """
if db is None: if db is None:
from hermes_state import format_session_db_unavailable try:
return tool_error(format_session_db_unavailable(), success=False) from hermes_state import SessionDB
db = SessionDB()
except Exception:
logging.debug("SessionDB unavailable for session_search", exc_info=True)
from hermes_state import format_session_db_unavailable
return tool_error(format_session_db_unavailable(), success=False)
# Defensive: models (especially open-source) may send non-int limit values # Defensive: models (especially open-source) may send non-int limit values
# (None when JSON null, string "int", or even a type object). Coerce to a # (None when JSON null, string "int", or even a type object). Coerce to a