mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-05-18 04:41:56 +00:00
fix: make session search initialize session db
This commit is contained in:
parent
9c26297c80
commit
840ebe063e
8 changed files with 220 additions and 11 deletions
|
|
@ -601,6 +601,7 @@ class SessionManager:
|
||||||
),
|
),
|
||||||
"quiet_mode": True,
|
"quiet_mode": True,
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
|
"session_db": self._get_db(),
|
||||||
"model": model or default_model,
|
"model": model or default_model,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -199,6 +199,22 @@ def run_oneshot(
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _create_session_db_for_oneshot():
|
||||||
|
"""Best-effort SessionDB for ``hermes -z`` / oneshot mode.
|
||||||
|
|
||||||
|
Oneshot bypasses ``HermesCLI._init_agent()``, so it must wire the SQLite
|
||||||
|
session store itself. Without this, the ``session_search``/recall tool is
|
||||||
|
advertised but every call returns "Session database not available.".
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
return SessionDB()
|
||||||
|
except Exception as exc:
|
||||||
|
logging.debug("SQLite session store not available for oneshot mode: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _run_agent(
|
def _run_agent(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
|
|
@ -284,6 +300,8 @@ def _run_agent(
|
||||||
if toolsets_list is None and use_config_toolsets:
|
if toolsets_list is None and use_config_toolsets:
|
||||||
toolsets_list = sorted(_get_platform_tools(cfg, "cli"))
|
toolsets_list = sorted(_get_platform_tools(cfg, "cli"))
|
||||||
|
|
||||||
|
session_db = _create_session_db_for_oneshot()
|
||||||
|
|
||||||
agent = AIAgent(
|
agent = AIAgent(
|
||||||
api_key=runtime.get("api_key"),
|
api_key=runtime.get("api_key"),
|
||||||
base_url=runtime.get("base_url"),
|
base_url=runtime.get("base_url"),
|
||||||
|
|
@ -293,6 +311,7 @@ def _run_agent(
|
||||||
enabled_toolsets=toolsets_list,
|
enabled_toolsets=toolsets_list,
|
||||||
quiet_mode=True,
|
quiet_mode=True,
|
||||||
platform="cli",
|
platform="cli",
|
||||||
|
session_db=session_db,
|
||||||
credential_pool=runtime.get("credential_pool"),
|
credential_pool=runtime.get("credential_pool"),
|
||||||
# Interactive callbacks are intentionally NOT wired beyond this
|
# Interactive callbacks are intentionally NOT wired beyond this
|
||||||
# one. In oneshot mode there's no user sitting at a terminal:
|
# one. In oneshot mode there's no user sitting at a terminal:
|
||||||
|
|
|
||||||
29
run_agent.py
29
run_agent.py
|
|
@ -2396,6 +2396,25 @@ class AIAgent:
|
||||||
"is_anthropic_oauth": self._is_anthropic_oauth,
|
"is_anthropic_oauth": self._is_anthropic_oauth,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def _get_session_db_for_recall(self):
|
||||||
|
"""Return a SessionDB for recall, lazily creating it if an entrypoint forgot.
|
||||||
|
|
||||||
|
Most frontends pass ``session_db`` into ``AIAgent`` explicitly, but recall
|
||||||
|
is important enough that a missing constructor argument should degrade by
|
||||||
|
opening the default state DB instead of making the advertised
|
||||||
|
``session_search`` tool unusable.
|
||||||
|
"""
|
||||||
|
if self._session_db is not None:
|
||||||
|
return self._session_db
|
||||||
|
try:
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
self._session_db = SessionDB()
|
||||||
|
return self._session_db
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("SessionDB unavailable for recall", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def _ensure_db_session(self) -> None:
|
def _ensure_db_session(self) -> None:
|
||||||
"""Create session DB row on first use. Disables _session_db on failure."""
|
"""Create session DB row on first use. Disables _session_db on failure."""
|
||||||
if self._session_db_created or not self._session_db:
|
if self._session_db_created or not self._session_db:
|
||||||
|
|
@ -9920,7 +9939,8 @@ class AIAgent:
|
||||||
store=self._todo_store,
|
store=self._todo_store,
|
||||||
)
|
)
|
||||||
elif function_name == "session_search":
|
elif function_name == "session_search":
|
||||||
if not self._session_db:
|
session_db = self._get_session_db_for_recall()
|
||||||
|
if not session_db:
|
||||||
from hermes_state import format_session_db_unavailable
|
from hermes_state import format_session_db_unavailable
|
||||||
return json.dumps({"success": False, "error": format_session_db_unavailable()})
|
return json.dumps({"success": False, "error": format_session_db_unavailable()})
|
||||||
from tools.session_search_tool import session_search as _session_search
|
from tools.session_search_tool import session_search as _session_search
|
||||||
|
|
@ -9928,7 +9948,7 @@ class AIAgent:
|
||||||
query=function_args.get("query", ""),
|
query=function_args.get("query", ""),
|
||||||
role_filter=function_args.get("role_filter"),
|
role_filter=function_args.get("role_filter"),
|
||||||
limit=function_args.get("limit", 3),
|
limit=function_args.get("limit", 3),
|
||||||
db=self._session_db,
|
db=session_db,
|
||||||
current_session_id=self.session_id,
|
current_session_id=self.session_id,
|
||||||
)
|
)
|
||||||
elif function_name == "memory":
|
elif function_name == "memory":
|
||||||
|
|
@ -10544,7 +10564,8 @@ class AIAgent:
|
||||||
if self._should_emit_quiet_tool_messages():
|
if self._should_emit_quiet_tool_messages():
|
||||||
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
self._vprint(f" {_get_cute_tool_message_impl('todo', function_args, tool_duration, result=function_result)}")
|
||||||
elif function_name == "session_search":
|
elif function_name == "session_search":
|
||||||
if not self._session_db:
|
session_db = self._get_session_db_for_recall()
|
||||||
|
if not session_db:
|
||||||
from hermes_state import format_session_db_unavailable
|
from hermes_state import format_session_db_unavailable
|
||||||
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
|
function_result = json.dumps({"success": False, "error": format_session_db_unavailable()})
|
||||||
else:
|
else:
|
||||||
|
|
@ -10553,7 +10574,7 @@ class AIAgent:
|
||||||
query=function_args.get("query", ""),
|
query=function_args.get("query", ""),
|
||||||
role_filter=function_args.get("role_filter"),
|
role_filter=function_args.get("role_filter"),
|
||||||
limit=function_args.get("limit", 3),
|
limit=function_args.get("limit", 3),
|
||||||
db=self._session_db,
|
db=session_db,
|
||||||
current_session_id=self.session_id,
|
current_session_id=self.session_id,
|
||||||
)
|
)
|
||||||
tool_duration = time.time() - tool_start_time
|
tool_duration = time.time() - tool_start_time
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
from types import SimpleNamespace
|
import sys
|
||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from acp.schema import TextContentBlock
|
from acp.schema import TextContentBlock
|
||||||
|
|
@ -66,6 +67,53 @@ def make_agent_and_state():
|
||||||
return acp_agent, state, fake, conn
|
return acp_agent, state, fake, conn
|
||||||
|
|
||||||
|
|
||||||
|
def test_acp_real_agent_gets_session_db_for_recall(monkeypatch):
|
||||||
|
"""ACP sessions persist to SessionDB; recall must receive the same DB handle."""
|
||||||
|
captured = {}
|
||||||
|
sentinel_db = NoopDb()
|
||||||
|
|
||||||
|
class CapturingAgent(FakeAgent):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
captured.update(kwargs)
|
||||||
|
|
||||||
|
def mod(name, **attrs):
|
||||||
|
module = ModuleType(name)
|
||||||
|
for key, value in attrs.items():
|
||||||
|
setattr(module, key, value)
|
||||||
|
return module
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=CapturingAgent))
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.config",
|
||||||
|
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m", "provider": "p"}}),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.runtime_provider",
|
||||||
|
mod(
|
||||||
|
"hermes_cli.runtime_provider",
|
||||||
|
resolve_runtime_provider=lambda **_kwargs: {
|
||||||
|
"provider": "p",
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
"base_url": "u",
|
||||||
|
"api_key": "k",
|
||||||
|
"command": None,
|
||||||
|
"args": [],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
manager = SessionManager(db=sentinel_db)
|
||||||
|
agent = manager._make_agent(session_id="acp-session", cwd=".")
|
||||||
|
|
||||||
|
assert isinstance(agent, CapturingAgent)
|
||||||
|
assert captured["session_db"] is sentinel_db
|
||||||
|
assert captured["platform"] == "acp"
|
||||||
|
assert captured["session_id"] == "acp-session"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acp_steer_slash_command_injects_into_running_agent():
|
async def test_acp_steer_slash_command_injects_into_running_agent():
|
||||||
acp_agent, state, fake, _conn = make_agent_and_state()
|
acp_agent, state, fake, _conn = make_agent_and_state()
|
||||||
|
|
|
||||||
|
|
@ -419,6 +419,72 @@ def test_oneshot_distinguishes_disabled_mcp_from_unknown(monkeypatch, capsys):
|
||||||
assert "mcp-off" in err
|
assert "mcp-off" in err
|
||||||
|
|
||||||
|
|
||||||
|
def test_oneshot_wires_session_db_for_recall(monkeypatch):
|
||||||
|
"""hermes -z bypasses HermesCLI, but recall still needs SessionDB."""
|
||||||
|
from hermes_cli.oneshot import _run_agent
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
sentinel_db = object()
|
||||||
|
|
||||||
|
class FakeAgent:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
self.suppress_status_output = False
|
||||||
|
self.stream_delta_callback = object()
|
||||||
|
self.tool_gen_callback = object()
|
||||||
|
|
||||||
|
def chat(self, prompt):
|
||||||
|
captured["prompt"] = prompt
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
class FakeSessionDB:
|
||||||
|
def __new__(cls):
|
||||||
|
return sentinel_db
|
||||||
|
|
||||||
|
def mod(name, **attrs):
|
||||||
|
module = types.ModuleType(name)
|
||||||
|
for key, value in attrs.items():
|
||||||
|
setattr(module, key, value)
|
||||||
|
return module
|
||||||
|
|
||||||
|
monkeypatch.setitem(sys.modules, "run_agent", mod("run_agent", AIAgent=FakeAgent))
|
||||||
|
monkeypatch.setitem(sys.modules, "hermes_state", mod("hermes_state", SessionDB=FakeSessionDB))
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.config",
|
||||||
|
mod("hermes_cli.config", load_config=lambda: {"model": {"default": "m"}}),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.models",
|
||||||
|
mod("hermes_cli.models", detect_provider_for_model=lambda *_args, **_kwargs: None),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.runtime_provider",
|
||||||
|
mod(
|
||||||
|
"hermes_cli.runtime_provider",
|
||||||
|
resolve_runtime_provider=lambda **_kwargs: {
|
||||||
|
"api_key": "k",
|
||||||
|
"base_url": "u",
|
||||||
|
"provider": "p",
|
||||||
|
"api_mode": "chat_completions",
|
||||||
|
"credential_pool": None,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules,
|
||||||
|
"hermes_cli.tools_config",
|
||||||
|
mod("hermes_cli.tools_config", _get_platform_tools=lambda *_args, **_kwargs: {"session_search"}),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert _run_agent("recall this") == "ok"
|
||||||
|
assert captured["session_db"] is sentinel_db
|
||||||
|
assert captured["enabled_toolsets"] == ["session_search"]
|
||||||
|
assert captured["prompt"] == "recall this"
|
||||||
|
|
||||||
|
|
||||||
def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
|
def test_launch_tui_exports_model_provider_and_toolsets(monkeypatch, main_mod):
|
||||||
captured = {}
|
captured = {}
|
||||||
active_path_during_call = None
|
active_path_during_call = None
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
from types import SimpleNamespace
|
from types import ModuleType, SimpleNamespace
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
|
||||||
from run_agent import AIAgent
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
|
@ -61,3 +63,33 @@ def test_run_conversation_persists_tokens_for_cron_sessions():
|
||||||
assert result["final_response"] == "done"
|
assert result["final_response"] == "done"
|
||||||
session_db.update_token_counts.assert_called_once()
|
session_db.update_token_counts.assert_called_once()
|
||||||
assert session_db.update_token_counts.call_args.args[0] == "cron-session"
|
assert session_db.update_token_counts.call_args.args[0] == "cron-session"
|
||||||
|
|
||||||
|
|
||||||
|
def test_session_search_lazily_opens_db_when_entrypoint_did_not_pass_one(monkeypatch):
|
||||||
|
sentinel_db = object()
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
class FakeSessionDB:
|
||||||
|
def __new__(cls):
|
||||||
|
return sentinel_db
|
||||||
|
|
||||||
|
hermes_state = ModuleType("hermes_state")
|
||||||
|
hermes_state.SessionDB = FakeSessionDB
|
||||||
|
monkeypatch.setitem(sys.modules, "hermes_state", hermes_state)
|
||||||
|
|
||||||
|
session_search_mod = ModuleType("tools.session_search_tool")
|
||||||
|
|
||||||
|
def fake_session_search(**kwargs):
|
||||||
|
captured.update(kwargs)
|
||||||
|
return json.dumps({"success": True, "results": []})
|
||||||
|
|
||||||
|
session_search_mod.session_search = fake_session_search
|
||||||
|
monkeypatch.setitem(sys.modules, "tools.session_search_tool", session_search_mod)
|
||||||
|
|
||||||
|
agent = _make_agent(None, platform="acp")
|
||||||
|
result = json.loads(agent._invoke_tool("session_search", {"query": "Hermes"}, "task-id"))
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert captured["db"] is sentinel_db
|
||||||
|
assert captured["query"] == "Hermes"
|
||||||
|
assert agent._session_db is sentinel_db
|
||||||
|
|
|
||||||
|
|
@ -309,11 +309,27 @@ class TestRecentSessionListing:
|
||||||
# =========================================================================
|
# =========================================================================
|
||||||
|
|
||||||
class TestSessionSearch:
|
class TestSessionSearch:
|
||||||
def test_no_db_returns_error(self):
|
def test_no_db_lazily_opens_default_session_db(self, monkeypatch):
|
||||||
|
from unittest.mock import MagicMock
|
||||||
from tools.session_search_tool import session_search
|
from tools.session_search_tool import session_search
|
||||||
|
|
||||||
|
mock_db = MagicMock()
|
||||||
|
mock_db.search_messages.return_value = []
|
||||||
|
|
||||||
|
class FakeSessionDB:
|
||||||
|
def __new__(cls):
|
||||||
|
return mock_db
|
||||||
|
|
||||||
|
import types
|
||||||
|
import sys
|
||||||
|
|
||||||
|
fake_state = types.ModuleType("hermes_state")
|
||||||
|
fake_state.SessionDB = FakeSessionDB
|
||||||
|
monkeypatch.setitem(sys.modules, "hermes_state", fake_state)
|
||||||
|
|
||||||
result = json.loads(session_search(query="test"))
|
result = json.loads(session_search(query="test"))
|
||||||
assert result["success"] is False
|
assert result["success"] is True
|
||||||
assert "not available" in result["error"].lower()
|
mock_db.search_messages.assert_called_once()
|
||||||
|
|
||||||
def test_empty_query_returns_error(self):
|
def test_empty_query_returns_error(self):
|
||||||
from tools.session_search_tool import session_search
|
from tools.session_search_tool import session_search
|
||||||
|
|
|
||||||
|
|
@ -337,8 +337,14 @@ def session_search(
|
||||||
The current session is excluded from results since the agent already has that context.
|
The current session is excluded from results since the agent already has that context.
|
||||||
"""
|
"""
|
||||||
if db is None:
|
if db is None:
|
||||||
from hermes_state import format_session_db_unavailable
|
try:
|
||||||
return tool_error(format_session_db_unavailable(), success=False)
|
from hermes_state import SessionDB
|
||||||
|
|
||||||
|
db = SessionDB()
|
||||||
|
except Exception:
|
||||||
|
logging.debug("SessionDB unavailable for session_search", exc_info=True)
|
||||||
|
from hermes_state import format_session_db_unavailable
|
||||||
|
return tool_error(format_session_db_unavailable(), success=False)
|
||||||
|
|
||||||
# Defensive: models (especially open-source) may send non-int limit values
|
# Defensive: models (especially open-source) may send non-int limit values
|
||||||
# (None when JSON null, string "int", or even a type object). Coerce to a
|
# (None when JSON null, string "int", or even a type object). Coerce to a
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue