mirror of
https://github.com/NousResearch/hermes-agent.git
synced 2026-04-25 00:51:20 +00:00
Merge 501ac3ff0a into 4fade39c90
This commit is contained in:
commit
c053c0aabe
14 changed files with 2060 additions and 22 deletions
117
agent/vault_injection.py
Normal file
117
agent/vault_injection.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
"""Vault injection — auto-load Obsidian vault files into the system prompt.
|
||||
|
||||
Reads working-context.md and user-profile.md from a configured vault path
|
||||
at session start and injects them into the system prompt alongside Layer 1
|
||||
memory (MEMORY.md / USER.md). This is a structural fix for vault neglect:
|
||||
the agent no longer needs to remember to read these files — they're injected
|
||||
automatically, the same way Layer 1 memory is.
|
||||
|
||||
The vault is Layer 3 in the memory architecture. Files injected here are
|
||||
read-only in the system prompt (frozen at session start). Mid-session
|
||||
writes to vault files require the read_file/write_file tools or the
|
||||
memory-vault skill.
|
||||
|
||||
Config (in config.yaml under 'vault'):
|
||||
enabled: true # enable vault injection
|
||||
path: /path/to/vault # absolute path to the Obsidian vault root
|
||||
|
||||
Files read (relative to vault path):
|
||||
Agent-Hermes/working-context.md — what the agent is actively doing
|
||||
Agent-Shared/user-profile.md — who the user is (durable facts)
|
||||
|
||||
If either file doesn't exist or is empty, it's silently skipped.
|
||||
If the vault path doesn't exist or isn't configured, vault injection is
|
||||
silently disabled.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Character limits for vault injection blocks (to prevent prompt bloat)
|
||||
WORKING_CONTEXT_CHAR_LIMIT = 4000
|
||||
USER_PROFILE_CHAR_LIMIT = 4000
|
||||
|
||||
SEPARATOR = "\u2550" * 46 # ═ same as memory_tool uses
|
||||
|
||||
|
||||
def _read_vault_file(path: Path, char_limit: int) -> Optional[str]:
|
||||
"""Read a vault file and return its content, or None if missing/empty.
|
||||
|
||||
Truncates with a notice if the file exceeds char_limit.
|
||||
"""
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
content = path.read_text(encoding="utf-8").strip()
|
||||
except (OSError, IOError) as e:
|
||||
logger.debug("Could not read vault file %s: %s", path, e)
|
||||
return None
|
||||
|
||||
if not content:
|
||||
return None
|
||||
|
||||
# Strip YAML frontmatter (same as prompt_builder does for context files)
|
||||
content = _strip_yaml_frontmatter(content)
|
||||
|
||||
if not content:
|
||||
return None
|
||||
|
||||
if len(content) > char_limit:
|
||||
truncated = content[:char_limit]
|
||||
# Find last newline to avoid cutting mid-line
|
||||
last_nl = truncated.rfind("\n")
|
||||
if last_nl > char_limit // 2:
|
||||
truncated = truncated[:last_nl]
|
||||
content = truncated + f"\n[... truncated at {char_limit} chars ...]"
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _strip_yaml_frontmatter(content: str) -> str:
|
||||
"""Remove optional YAML frontmatter (--- delimited) from content."""
|
||||
if content.startswith("---"):
|
||||
end = content.find("\n---", 3)
|
||||
if end != -1:
|
||||
body = content[end + 4:].lstrip("\n")
|
||||
return body if body else content
|
||||
return content
|
||||
|
||||
|
||||
def build_vault_system_prompt(vault_path: str) -> str:
|
||||
"""Build the vault injection block for the system prompt.
|
||||
|
||||
Reads working-context.md and user-profile.md from the vault and formats
|
||||
them with headers matching the style of Layer 1 memory blocks.
|
||||
|
||||
Returns an empty string if vault is disabled, path is missing, or
|
||||
all files are empty.
|
||||
"""
|
||||
if not vault_path:
|
||||
return ""
|
||||
|
||||
vault_root = Path(vault_path)
|
||||
if not vault_root.is_dir():
|
||||
logger.debug("Vault path does not exist or is not a directory: %s", vault_path)
|
||||
return ""
|
||||
|
||||
parts = []
|
||||
|
||||
# Read working-context.md (agent's current state)
|
||||
wc_path = vault_root / "Agent-Hermes" / "working-context.md"
|
||||
wc_content = _read_vault_file(wc_path, WORKING_CONTEXT_CHAR_LIMIT)
|
||||
if wc_content:
|
||||
header = "VAULT: WORKING CONTEXT (what you're doing right now)"
|
||||
parts.append(f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{wc_content}")
|
||||
|
||||
# Read user-profile.md (shared user profile)
|
||||
up_path = vault_root / "Agent-Shared" / "user-profile.md"
|
||||
up_content = _read_vault_file(up_path, USER_PROFILE_CHAR_LIMIT)
|
||||
if up_content:
|
||||
header = "VAULT: USER PROFILE (durable facts from Obsidian vault)"
|
||||
parts.append(f"{SEPARATOR}\n{header}\n{SEPARATOR}\n{up_content}")
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
|
@ -754,6 +754,16 @@ DEFAULT_CONFIG = {
|
|||
"provider": "",
|
||||
},
|
||||
|
||||
# Obsidian vault auto-injection — Layer 3 persistent memory.
|
||||
# When enabled, working-context.md and user-profile.md are read from
|
||||
# the vault path at session start and injected into the system prompt
|
||||
# alongside Layer 1 memory. This is a structural fix for vault neglect:
|
||||
# the agent no longer needs to remember to read these files manually.
|
||||
"vault": {
|
||||
"enabled": False, # set true to activate
|
||||
"path": "", # absolute path to the Obsidian vault root
|
||||
},
|
||||
|
||||
# Subagent delegation — override the provider:model used by delegate_task
|
||||
# so child agents can run on a different (cheaper/faster) provider and model.
|
||||
# Uses the same runtime provider resolution as CLI/gateway startup, so all
|
||||
|
|
|
|||
276
hermes_state.py
276
hermes_state.py
|
|
@ -31,7 +31,7 @@ T = TypeVar("T")
|
|||
|
||||
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||
|
||||
SCHEMA_VERSION = 8
|
||||
SCHEMA_VERSION = 9
|
||||
|
||||
SCHEMA_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS schema_version (
|
||||
|
|
@ -95,6 +95,13 @@ CREATE INDEX IF NOT EXISTS idx_sessions_source ON sessions(source);
|
|||
CREATE INDEX IF NOT EXISTS idx_sessions_parent ON sessions(parent_session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_session ON messages(session_id, timestamp);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS term_index (
|
||||
term TEXT NOT NULL,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(id),
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||||
PRIMARY KEY (term, message_id)
|
||||
) WITHOUT ROWID;
|
||||
"""
|
||||
|
||||
FTS_SQL = """
|
||||
|
|
@ -164,7 +171,14 @@ class SessionDB:
|
|||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||
|
||||
self._init_schema()
|
||||
needs_reindex = self._init_schema()
|
||||
|
||||
# If we just migrated to v7, backfill the term index from existing
|
||||
# messages. This runs outside _init_schema so we can use
|
||||
# _execute_write (which manages its own transactions).
|
||||
if needs_reindex:
|
||||
logger.info("v7 migration detected — backfilling term index")
|
||||
self.reindex_term_index()
|
||||
|
||||
# ── Core write helper ──
|
||||
|
||||
|
|
@ -257,11 +271,17 @@ class SessionDB:
|
|||
self._conn = None
|
||||
|
||||
def _init_schema(self):
|
||||
"""Create tables and FTS if they don't exist, run migrations."""
|
||||
"""Create tables and FTS if they don't exist, run migrations.
|
||||
|
||||
Returns True if a v7 migration was performed (term_index created
|
||||
and needs backfill), False otherwise.
|
||||
"""
|
||||
cursor = self._conn.cursor()
|
||||
|
||||
cursor.executescript(SCHEMA_SQL)
|
||||
|
||||
needs_reindex = False
|
||||
|
||||
# Check schema version and run migrations
|
||||
cursor.execute("SELECT version FROM schema_version LIMIT 1")
|
||||
row = cursor.fetchone()
|
||||
|
|
@ -356,6 +376,24 @@ class SessionDB:
|
|||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 8")
|
||||
if current_version < 9:
|
||||
# v9: add term_index table for inverted-index session search.
|
||||
# This is the clustered (term, message_id) WITHOUT ROWID table
|
||||
# used by the session_search fast path. After creating the
|
||||
# table, we backfill from existing messages.
|
||||
try:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS term_index (
|
||||
term TEXT NOT NULL,
|
||||
message_id INTEGER NOT NULL REFERENCES messages(id),
|
||||
session_id TEXT NOT NULL REFERENCES sessions(id),
|
||||
PRIMARY KEY (term, message_id)
|
||||
) WITHOUT ROWID
|
||||
""")
|
||||
except sqlite3.OperationalError:
|
||||
pass # Table already exists
|
||||
cursor.execute("UPDATE schema_version SET version = 9")
|
||||
needs_reindex = True
|
||||
|
||||
# Unique title index — always ensure it exists (safe to run after migrations
|
||||
# since the title column is guaranteed to exist at this point)
|
||||
|
|
@ -375,6 +413,8 @@ class SessionDB:
|
|||
|
||||
self._conn.commit()
|
||||
|
||||
return needs_reindex
|
||||
|
||||
# =========================================================================
|
||||
# Session lifecycle
|
||||
# =========================================================================
|
||||
|
|
@ -569,6 +609,23 @@ class SessionDB:
|
|||
row = cursor.fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
def get_child_session_ids(self, *parent_ids: str) -> List[str]:
|
||||
"""Return IDs of sessions whose parent_session_id is in *parent_ids*.
|
||||
|
||||
Useful for finding delegation/compression child sessions that
|
||||
belong to a parent conversation. Does NOT recurse — only
|
||||
direct children are returned.
|
||||
"""
|
||||
if not parent_ids:
|
||||
return []
|
||||
placeholders = ",".join("?" for _ in parent_ids)
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
f"SELECT id FROM sessions WHERE parent_session_id IN ({placeholders})",
|
||||
list(parent_ids),
|
||||
)
|
||||
return [row["id"] for row in cursor.fetchall()]
|
||||
|
||||
def resolve_session_id(self, session_id_or_prefix: str) -> Optional[str]:
|
||||
"""Resolve an exact or uniquely prefixed session ID to the full ID.
|
||||
|
||||
|
|
@ -1015,6 +1072,24 @@ class SessionDB:
|
|||
"UPDATE sessions SET message_count = message_count + 1 WHERE id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
|
||||
# Insert terms into inverted index (delayed import avoids
|
||||
# circular dependency — hermes_state is imported by nearly
|
||||
# everything at startup, term_index must not be top-level)
|
||||
# Skip tool-role messages: their structured JSON output produces
|
||||
# noise terms (field names, numeric values) with no search value.
|
||||
if content and role != "tool":
|
||||
try:
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms(content)
|
||||
if terms:
|
||||
conn.executemany(
|
||||
"INSERT OR IGNORE INTO term_index (term, message_id, session_id) VALUES (?, ?, ?)",
|
||||
[(t, msg_id, session_id) for t in terms],
|
||||
)
|
||||
except Exception:
|
||||
pass # Term indexing is best-effort; never block a message insert
|
||||
|
||||
return msg_id
|
||||
|
||||
return self._execute_write(_do)
|
||||
|
|
@ -1395,6 +1470,191 @@ class SessionDB:
|
|||
|
||||
return matches
|
||||
|
||||
# =========================================================================
|
||||
# Term index search (inverted index fast path)
|
||||
# =========================================================================
|
||||
|
||||
def search_by_terms(
|
||||
self,
|
||||
terms: List[str],
|
||||
exclude_sources: List[str] = None,
|
||||
limit: int = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search sessions using the term_index inverted index.
|
||||
|
||||
Takes a list of query terms, finds all message IDs for each term,
|
||||
intersects them (AND logic), and returns matching sessions with
|
||||
metadata and match counts.
|
||||
|
||||
This is the fast path for session search -- no LLM calls needed.
|
||||
"""
|
||||
if not terms:
|
||||
return []
|
||||
|
||||
# Filter out stop words from the query (belt and suspenders)
|
||||
# Delayed import avoids circular dependency (same pattern as append_message)
|
||||
from stop_words import is_stop_word
|
||||
filtered = [t.lower() for t in terms if t and not is_stop_word(t.lower())]
|
||||
if not filtered:
|
||||
return []
|
||||
|
||||
# Build the query. For single terms, a simple WHERE suffices.
|
||||
# For multiple terms, we use GROUP BY session_id + HAVING COUNT(DISTINCT term) = N
|
||||
# to enforce AND semantics: only sessions containing ALL query terms match.
|
||||
if len(filtered) == 1:
|
||||
term = filtered[0]
|
||||
params: list = [term]
|
||||
where_sql = "ti.term = ?"
|
||||
|
||||
if exclude_sources:
|
||||
exclude_placeholders = ",".join("?" for _ in exclude_sources)
|
||||
where_sql += f" AND s.source NOT IN ({exclude_placeholders})"
|
||||
params.extend(exclude_sources)
|
||||
|
||||
sql = f"""
|
||||
SELECT ti.session_id,
|
||||
s.source,
|
||||
s.model,
|
||||
s.started_at AS session_started,
|
||||
s.title,
|
||||
COUNT(DISTINCT ti.message_id) AS match_count
|
||||
FROM term_index ti
|
||||
JOIN sessions s ON s.id = ti.session_id
|
||||
WHERE {where_sql}
|
||||
GROUP BY ti.session_id
|
||||
ORDER BY match_count DESC, s.started_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
params.append(limit)
|
||||
else:
|
||||
# Multi-term: GROUP BY + HAVING COUNT(DISTINCT term) = N enforces AND
|
||||
term_placeholders = ",".join("?" for _ in filtered)
|
||||
params = list(filtered)
|
||||
exclude_sql = ""
|
||||
if exclude_sources:
|
||||
exclude_sql = f" AND s.source NOT IN ({','.join('?' for _ in exclude_sources)})"
|
||||
params.extend(exclude_sources)
|
||||
params.extend([len(filtered), limit])
|
||||
|
||||
sql = f"""
|
||||
SELECT ti.session_id,
|
||||
s.source,
|
||||
s.model,
|
||||
s.started_at AS session_started,
|
||||
s.title,
|
||||
COUNT(DISTINCT ti.term) AS term_count,
|
||||
COUNT(DISTINCT ti.message_id) AS match_count
|
||||
FROM term_index ti
|
||||
JOIN sessions s ON s.id = ti.session_id
|
||||
WHERE ti.term IN ({term_placeholders})
|
||||
{exclude_sql}
|
||||
GROUP BY ti.session_id
|
||||
HAVING COUNT(DISTINCT ti.term) = ?
|
||||
ORDER BY match_count DESC, s.started_at DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
try:
|
||||
cursor = self._conn.execute(sql, params)
|
||||
results = [dict(row) for row in cursor.fetchall()]
|
||||
except sqlite3.OperationalError:
|
||||
logger.debug("term_index query failed", exc_info=True)
|
||||
return []
|
||||
|
||||
return results
|
||||
|
||||
def reindex_term_index(self, batch_size: int = 500) -> int:
|
||||
"""
|
||||
Rebuild the term_index from existing messages.
|
||||
|
||||
Processes messages in batches to avoid holding the write lock too
|
||||
long. Returns the total number of term entries inserted.
|
||||
|
||||
Uses a swap strategy: builds a temporary table, then swaps it
|
||||
into place in a single transaction. This avoids the empty-index
|
||||
window that would occur with a simple clear-and-repopulate.
|
||||
"""
|
||||
from term_index import extract_terms
|
||||
|
||||
# Count total messages to index
|
||||
with self._lock:
|
||||
total = self._conn.execute("SELECT COUNT(*) FROM messages").fetchone()[0]
|
||||
|
||||
if total == 0:
|
||||
return 0
|
||||
|
||||
inserted = 0
|
||||
offset = 0
|
||||
|
||||
# Create a temporary table with the same schema as term_index.
|
||||
# We'll populate this, then swap it into place in a single
|
||||
# transaction — no empty-index window for concurrent readers.
|
||||
def _create_temp(conn):
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS _term_index_new (
|
||||
term TEXT NOT NULL,
|
||||
message_id INTEGER NOT NULL,
|
||||
session_id TEXT NOT NULL,
|
||||
PRIMARY KEY (term, message_id)
|
||||
) WITHOUT ROWID
|
||||
""")
|
||||
conn.execute("DELETE FROM _term_index_new")
|
||||
self._execute_write(_create_temp)
|
||||
|
||||
while offset < total:
|
||||
# Read batch outside write lock
|
||||
with self._lock:
|
||||
cursor = self._conn.execute(
|
||||
"SELECT id, session_id, role, content FROM messages ORDER BY id LIMIT ? OFFSET ?",
|
||||
(batch_size, offset),
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Extract terms for the batch, skipping tool-role messages
|
||||
entries = []
|
||||
for row in rows:
|
||||
msg_id = row["id"]
|
||||
session_id = row["session_id"]
|
||||
# Skip tool messages — structured JSON output produces noise terms
|
||||
if row["role"] == "tool":
|
||||
continue
|
||||
content = row["content"] or ""
|
||||
terms = extract_terms(content)
|
||||
for term in terms:
|
||||
entries.append((term, msg_id, session_id))
|
||||
|
||||
# Write batch to temp table
|
||||
if entries:
|
||||
def _insert(conn, _entries=entries):
|
||||
conn.executemany(
|
||||
"INSERT OR IGNORE INTO _term_index_new (term, message_id, session_id) VALUES (?, ?, ?)",
|
||||
_entries,
|
||||
)
|
||||
return len(_entries)
|
||||
inserted += self._execute_write(_insert)
|
||||
|
||||
offset += batch_size
|
||||
|
||||
# Swap: replace term_index with the new table in one transaction.
|
||||
# Concurrent readers see either the old index or the new one —
|
||||
# never an empty table.
|
||||
def _swap(conn):
|
||||
conn.execute("DELETE FROM term_index")
|
||||
conn.execute("""
|
||||
INSERT INTO term_index (term, message_id, session_id)
|
||||
SELECT term, message_id, session_id FROM _term_index_new
|
||||
""")
|
||||
conn.execute("DROP TABLE _term_index_new")
|
||||
self._execute_write(_swap)
|
||||
|
||||
logger.info("Reindexed term_index: %d entries from %d messages", inserted, total)
|
||||
return inserted
|
||||
|
||||
def search_sessions(
|
||||
self,
|
||||
source: str = None,
|
||||
|
|
@ -1466,8 +1726,14 @@ class SessionDB:
|
|||
return results
|
||||
|
||||
def clear_messages(self, session_id: str) -> None:
|
||||
"""Delete all messages for a session and reset its counters."""
|
||||
"""Delete all messages for a session and reset its counters.
|
||||
|
||||
Also removes stale term_index entries that reference the deleted messages.
|
||||
"""
|
||||
def _do(conn):
|
||||
conn.execute(
|
||||
"DELETE FROM term_index WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
conn.execute(
|
||||
"DELETE FROM messages WHERE session_id = ?", (session_id,)
|
||||
)
|
||||
|
|
@ -1496,6 +1762,7 @@ class SessionDB:
|
|||
"WHERE parent_session_id = ?",
|
||||
(session_id,),
|
||||
)
|
||||
conn.execute("DELETE FROM term_index WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (session_id,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (session_id,))
|
||||
return True
|
||||
|
|
@ -1536,6 +1803,7 @@ class SessionDB:
|
|||
)
|
||||
|
||||
for sid in session_ids:
|
||||
conn.execute("DELETE FROM term_index WHERE session_id = ?", (sid,))
|
||||
conn.execute("DELETE FROM messages WHERE session_id = ?", (sid,))
|
||||
conn.execute("DELETE FROM sessions WHERE id = ?", (sid,))
|
||||
return len(session_ids)
|
||||
|
|
|
|||
35
run_agent.py
35
run_agent.py
|
|
@ -1566,7 +1566,8 @@ class AIAgent:
|
|||
try:
|
||||
from hermes_cli.config import load_config as _load_agent_config
|
||||
_agent_cfg = _load_agent_config()
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning("Agent init: load_config() failed: %s: %s — using empty config", type(e).__name__, e)
|
||||
_agent_cfg = {}
|
||||
# Cache only the derived auxiliary compression context override that is
|
||||
# needed later by the startup feasibility check. Avoid exposing a
|
||||
|
|
@ -1597,8 +1598,20 @@ class AIAgent:
|
|||
self._memory_store.load_from_disk()
|
||||
except Exception:
|
||||
pass # Memory is optional -- don't break agent init
|
||||
|
||||
|
||||
# Obsidian vault auto-injection (Layer 3) — structural fix for
|
||||
# vault neglect. Reads working-context.md and user-profile.md
|
||||
# from the configured vault path and injects them into the system
|
||||
# prompt at session start, just like Layer 1 memory.
|
||||
self._vault_enabled = False
|
||||
self._vault_path = ""
|
||||
try:
|
||||
vault_config = _agent_cfg.get("vault", {})
|
||||
self._vault_enabled = vault_config.get("enabled", False)
|
||||
self._vault_path = vault_config.get("path", "")
|
||||
logging.getLogger("agent.vault").info("Vault config: enabled=%s path=%s", self._vault_enabled, self._vault_path)
|
||||
except Exception as e:
|
||||
logging.getLogger("agent.vault").warning("Vault config read failed: %s: %s", type(e).__name__, e)
|
||||
|
||||
# Memory provider plugin (external — one at a time, alongside built-in)
|
||||
# Reads memory.provider from config to select which plugin to activate.
|
||||
|
|
@ -4448,6 +4461,24 @@ class AIAgent:
|
|||
if user_block:
|
||||
prompt_parts.append(user_block)
|
||||
|
||||
# Vault auto-injection (Layer 3) — reads working-context.md and
|
||||
# user-profile.md from the Obsidian vault and injects them into
|
||||
# the system prompt. Structural fix for vault neglect.
|
||||
_vault_log = logging.getLogger("agent.vault")
|
||||
if self._vault_enabled and self._vault_path:
|
||||
try:
|
||||
from agent.vault_injection import build_vault_system_prompt
|
||||
_vault_block = build_vault_system_prompt(self._vault_path)
|
||||
if _vault_block:
|
||||
prompt_parts.append(_vault_block)
|
||||
_vault_log.info("Injection succeeded: %d chars from %s", len(_vault_block), self._vault_path)
|
||||
else:
|
||||
_vault_log.warning("Injection returned empty for path %s", self._vault_path)
|
||||
except Exception as e:
|
||||
_vault_log.warning("Injection failed: %s: %s", type(e).__name__, e)
|
||||
else:
|
||||
_vault_log.info("Injection skipped: enabled=%s path=%s", self._vault_enabled, self._vault_path)
|
||||
|
||||
# External memory provider system prompt block (additive to built-in)
|
||||
if self._memory_manager:
|
||||
try:
|
||||
|
|
|
|||
79
stop_words.py
Normal file
79
stop_words.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Stop word list for term index extraction.
|
||||
|
||||
Uses the well-known NLTK English stop word list (179 words) as a baseline,
|
||||
plus common JSON schema keys from tool output and pure-numeric filter.
|
||||
|
||||
This module is self-contained -- no external dependencies.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# Standard English stop words (NLTK list, public domain)
|
||||
# Covers articles, conjunctions, prepositions, pronouns, auxiliary verbs,
|
||||
# and common function words. Intentionally excludes short tech terms
|
||||
# that overlap (e.g., "go", "it" as in IT/InfoTech handled by context).
|
||||
_ENGLISH_STOP_WORDS = frozenset(
|
||||
w.lower() for w in [
|
||||
"i", "me", "my", "myself", "we", "our", "ours", "ourselves", "you",
|
||||
"your", "yours", "yourself", "yourselves", "he", "him", "his",
|
||||
"himself", "she", "her", "hers", "herself", "it", "its", "itself",
|
||||
"they", "them", "their", "theirs", "themselves", "what", "which",
|
||||
"who", "whom", "this", "that", "these", "those", "am", "is", "are",
|
||||
"was", "were", "be", "been", "being", "have", "has", "had", "having",
|
||||
"do", "does", "did", "doing", "a", "an", "the", "and", "but", "if",
|
||||
"or", "because", "as", "until", "while", "of", "at", "by", "for",
|
||||
"with", "about", "against", "between", "through", "during", "before",
|
||||
"after", "above", "below", "to", "from", "up", "down", "in", "out",
|
||||
"on", "off", "over", "under", "again", "further", "then", "once",
|
||||
"here", "there", "when", "where", "why", "how", "all", "both", "each",
|
||||
"few", "more", "most", "other", "some", "such", "no", "nor", "not",
|
||||
"only", "own", "same", "so", "than", "too", "very", "s", "t", "can",
|
||||
"will", "just", "don", "should", "now", "d", "ll", "m", "o", "re",
|
||||
"ve", "y", "ain", "aren", "couldn", "didn", "doesn", "hadn", "hasn",
|
||||
"haven", "isn", "ma", "mightn", "mustn", "needn", "shan", "shouldn",
|
||||
"wasn", "weren", "won", "wouldn",
|
||||
]
|
||||
)
|
||||
|
||||
# JSON schema keys that appear constantly in tool output.
|
||||
# These are field names from structured tool responses, not semantic content.
|
||||
# Nobody searches for "exit_code" to find a past session.
|
||||
_JSON_KEY_STOP_WORDS = frozenset([
|
||||
"output",
|
||||
"exit_code",
|
||||
"error",
|
||||
"null",
|
||||
"true",
|
||||
"false",
|
||||
"status",
|
||||
"content",
|
||||
"message",
|
||||
"cleared",
|
||||
"success",
|
||||
])
|
||||
|
||||
# Combined stop word set
|
||||
_STOP_WORDS = _ENGLISH_STOP_WORDS | _JSON_KEY_STOP_WORDS
|
||||
|
||||
# Pattern to detect pure numeric tokens (integers, floats, hex)
|
||||
_NUMERIC_RE = re.compile(r"^[0-9]+$")
|
||||
|
||||
|
||||
def is_stop_word(word: str) -> bool:
|
||||
"""Check if a word is a stop word. Case-insensitive."""
|
||||
return word.lower() in _STOP_WORDS
|
||||
|
||||
|
||||
def is_noise_term(word: str) -> bool:
|
||||
"""Check if a term is noise that should be excluded from the index.
|
||||
|
||||
This covers stop words AND pure numeric tokens, which provide zero
|
||||
search value. Nobody searches for '0', '1', or '42' to find a session.
|
||||
"""
|
||||
lower = word.lower()
|
||||
return lower in _STOP_WORDS or _NUMERIC_RE.match(lower) is not None
|
||||
|
||||
|
||||
def get_stop_words() -> frozenset:
|
||||
"""Return the full stop word set (for inspection/bulk use)."""
|
||||
return _STOP_WORDS
|
||||
47
term_index.py
Normal file
47
term_index.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Term index — inverted index extraction for session search fast path.
|
||||
|
||||
Extracts non-stop-word terms from message content for insertion into the
|
||||
term_index table in SessionDB. Terms are lowercased, punctuation-stripped
|
||||
(with preservation of path-like strings), and deduplicated per message.
|
||||
|
||||
Noise filtering:
|
||||
- English stop words (NLTK list)
|
||||
- JSON schema keys from tool output (output, exit_code, error, etc.)
|
||||
- Pure numeric tokens (0, 1, 42, etc.)
|
||||
"""
|
||||
|
||||
import re
|
||||
from stop_words import is_noise_term
|
||||
|
||||
# Matches "words" including paths (foo/bar), filenames (file.py), and
|
||||
# hyphenated terms (self-hosted). Filters out most punctuation but
|
||||
# preserves dots in filenames and slashes in paths.
|
||||
# Strategy: split on whitespace first, then strip leading/trailing punctuation.
|
||||
_TERM_RE = re.compile(r"[a-zA-Z0-9][\w./\-]*[a-zA-Z0-9]|[a-zA-Z0-9]")
|
||||
|
||||
|
||||
def extract_terms(content: str) -> list[str]:
|
||||
"""Extract non-noise terms from message content.
|
||||
|
||||
Returns a deduplicated, lowercased list of terms.
|
||||
Stop words, JSON keys, pure numerics, and empty strings are excluded.
|
||||
"""
|
||||
if not content:
|
||||
return []
|
||||
|
||||
# Find candidate tokens
|
||||
raw_tokens = _TERM_RE.findall(content)
|
||||
|
||||
seen = set()
|
||||
terms = []
|
||||
for token in raw_tokens:
|
||||
lower = token.lower()
|
||||
# Skip noise: stop words, JSON keys, pure numerics
|
||||
if is_noise_term(lower):
|
||||
continue
|
||||
# Deduplicate within this message
|
||||
if lower not in seen:
|
||||
seen.add(lower)
|
||||
terms.append(lower)
|
||||
|
||||
return terms
|
||||
174
tests/agent/test_vault_injection.py
Normal file
174
tests/agent/test_vault_injection.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Tests for agent/vault_injection.py — Obsidian vault auto-injection into system prompt."""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from agent.vault_injection import (
|
||||
build_vault_system_prompt,
|
||||
_read_vault_file,
|
||||
_strip_yaml_frontmatter,
|
||||
WORKING_CONTEXT_CHAR_LIMIT,
|
||||
USER_PROFILE_CHAR_LIMIT,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _strip_yaml_frontmatter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStripYamlFrontmatter:
|
||||
def test_strips_simple_frontmatter(self):
|
||||
content = "---\ndate: 2026-04-22\n---\nHello world"
|
||||
assert _strip_yaml_frontmatter(content) == "Hello world"
|
||||
|
||||
def test_no_frontmatter(self):
|
||||
content = "Just some text"
|
||||
assert _strip_yaml_frontmatter(content) == "Just some text"
|
||||
|
||||
def test_frontmatter_with_blank_lines(self):
|
||||
content = "---\ndate: 2026-04-22\nprojects: [X]\n---\n\nActual content here"
|
||||
result = _strip_yaml_frontmatter(content)
|
||||
assert result == "Actual content here"
|
||||
|
||||
def test_unclosed_frontmatter_returns_original(self):
|
||||
content = "---\ndate: 2026-04-22\nNo closing dashes"
|
||||
assert _strip_yaml_frontmatter(content) == content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _read_vault_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestReadVaultFile:
|
||||
def test_reads_existing_file(self, tmp_path):
|
||||
f = tmp_path / "test.md"
|
||||
f.write_text("some content", encoding="utf-8")
|
||||
result = _read_vault_file(f, 4000)
|
||||
assert result == "some content"
|
||||
|
||||
def test_returns_none_for_missing_file(self, tmp_path):
|
||||
f = tmp_path / "nonexistent.md"
|
||||
result = _read_vault_file(f, 4000)
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path):
|
||||
f = tmp_path / "empty.md"
|
||||
f.write_text("", encoding="utf-8")
|
||||
result = _read_vault_file(f, 4000)
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_whitespace_only(self, tmp_path):
|
||||
f = tmp_path / "ws.md"
|
||||
f.write_text(" \n\n ", encoding="utf-8")
|
||||
result = _read_vault_file(f, 4000)
|
||||
assert result is None
|
||||
|
||||
def test_strips_frontmatter(self, tmp_path):
|
||||
f = tmp_path / "frontmatter.md"
|
||||
f.write_text("---\ndate: 2026-04-22\n---\nReal content", encoding="utf-8")
|
||||
result = _read_vault_file(f, 4000)
|
||||
assert result == "Real content"
|
||||
|
||||
def test_truncates_long_file(self, tmp_path):
|
||||
f = tmp_path / "long.md"
|
||||
long_content = "x" * 5000
|
||||
f.write_text(long_content, encoding="utf-8")
|
||||
result = _read_vault_file(f, 100)
|
||||
assert len(result) < 200 # truncation + notice
|
||||
assert "truncated" in result
|
||||
|
||||
def test_truncation_at_newline(self, tmp_path):
|
||||
f = tmp_path / "multiline.md"
|
||||
lines = ["line " + str(i) for i in range(100)]
|
||||
content = "\n".join(lines)
|
||||
f.write_text(content, encoding="utf-8")
|
||||
# Small limit, should truncate at a newline boundary
|
||||
result = _read_vault_file(f, 50)
|
||||
assert "truncated" in result
|
||||
# Should not cut mid-line
|
||||
for line in result.split("\n"):
|
||||
if line and "truncated" not in line:
|
||||
assert line.startswith("line")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_vault_system_prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBuildVaultSystemPrompt:
|
||||
def test_empty_path_returns_empty(self):
|
||||
assert build_vault_system_prompt("") == ""
|
||||
|
||||
def test_nonexistent_path_returns_empty(self, tmp_path):
|
||||
assert build_vault_system_prompt(str(tmp_path / "nope")) == ""
|
||||
|
||||
def test_empty_vault_dir_returns_empty(self, tmp_path):
|
||||
assert build_vault_system_prompt(str(tmp_path)) == ""
|
||||
|
||||
def test_injects_working_context(self, tmp_path):
|
||||
vault = tmp_path / "vault"
|
||||
agent_dir = vault / "Agent-Hermes"
|
||||
agent_dir.mkdir(parents=True)
|
||||
wc = agent_dir / "working-context.md"
|
||||
wc.write_text("---\ndate: 2026-04-22\n---\n## Current Status\n- Status: Active", encoding="utf-8")
|
||||
|
||||
result = build_vault_system_prompt(str(vault))
|
||||
assert "VAULT: WORKING CONTEXT" in result
|
||||
assert "Status: Active" in result
|
||||
|
||||
def test_injects_user_profile(self, tmp_path):
|
||||
vault = tmp_path / "vault"
|
||||
shared_dir = vault / "Agent-Shared"
|
||||
shared_dir.mkdir(parents=True)
|
||||
up = shared_dir / "user-profile.md"
|
||||
up.write_text("# User Profile\n\nName: AJ", encoding="utf-8")
|
||||
|
||||
result = build_vault_system_prompt(str(vault))
|
||||
assert "VAULT: USER PROFILE" in result
|
||||
assert "Name: AJ" in result
|
||||
|
||||
def test_injects_both_files(self, tmp_path):
|
||||
vault = tmp_path / "vault"
|
||||
agent_dir = vault / "Agent-Hermes"
|
||||
shared_dir = vault / "Agent-Shared"
|
||||
agent_dir.mkdir(parents=True)
|
||||
shared_dir.mkdir(parents=True)
|
||||
|
||||
(agent_dir / "working-context.md").write_text(
|
||||
"---\ndate: 2026-04-22\n---\nWorking on X", encoding="utf-8"
|
||||
)
|
||||
(shared_dir / "user-profile.md").write_text(
|
||||
"# User Profile\n\nName: AJ", encoding="utf-8"
|
||||
)
|
||||
|
||||
result = build_vault_system_prompt(str(vault))
|
||||
assert "VAULT: WORKING CONTEXT" in result
|
||||
assert "VAULT: USER PROFILE" in result
|
||||
assert "Working on X" in result
|
||||
assert "Name: AJ" in result
|
||||
|
||||
def test_skips_empty_working_context(self, tmp_path):
|
||||
vault = tmp_path / "vault"
|
||||
agent_dir = vault / "Agent-Hermes"
|
||||
shared_dir = vault / "Agent-Shared"
|
||||
agent_dir.mkdir(parents=True)
|
||||
shared_dir.mkdir(parents=True)
|
||||
|
||||
(agent_dir / "working-context.md").write_text("", encoding="utf-8")
|
||||
(shared_dir / "user-profile.md").write_text("Name: AJ", encoding="utf-8")
|
||||
|
||||
result = build_vault_system_prompt(str(vault))
|
||||
assert "VAULT: WORKING CONTEXT" not in result
|
||||
assert "VAULT: USER PROFILE" in result
|
||||
|
||||
def test_format_matches_memory_block_style(self, tmp_path):
|
||||
vault = tmp_path / "vault"
|
||||
agent_dir = vault / "Agent-Hermes"
|
||||
agent_dir.mkdir(parents=True)
|
||||
(agent_dir / "working-context.md").write_text("Active task", encoding="utf-8")
|
||||
|
||||
result = build_vault_system_prompt(str(vault))
|
||||
# Should use the same separator as memory_tool (═══)
|
||||
assert "\u2550" in result # ═ character
|
||||
assert "VAULT: WORKING CONTEXT" in result
|
||||
161
tests/run_agent/test_vault_injection.py
Normal file
161
tests/run_agent/test_vault_injection.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
"""Tests for vault auto-injection integration with _build_system_prompt.
|
||||
|
||||
Verifies that vault content appears in the system prompt when vault is
|
||||
configured, and is absent otherwise.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_minimal_agent(**overrides):
|
||||
"""Create a minimal AIAgent for testing, with vault attrs settable."""
|
||||
from run_agent import AIAgent
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
a = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
|
||||
# Apply overrides after creation
|
||||
for k, v in overrides.items():
|
||||
setattr(a, k, v)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
class TestVaultSystemPromptIntegration:
|
||||
"""Test that _build_system_prompt injects vault content when configured."""
|
||||
|
||||
def test_vault_not_injected_when_disabled(self, tmp_path):
|
||||
"""Vault content should not appear when vault_enabled=False."""
|
||||
vault_dir = tmp_path / "vault"
|
||||
agent_dir = vault_dir / "Agent-Hermes"
|
||||
agent_dir.mkdir(parents=True)
|
||||
(agent_dir / "working-context.md").write_text("Active task X", encoding="utf-8")
|
||||
|
||||
agent = _make_minimal_agent(
|
||||
_vault_enabled=False,
|
||||
_vault_path=str(vault_dir),
|
||||
)
|
||||
|
||||
prompt = agent._build_system_prompt()
|
||||
assert "VAULT: WORKING CONTEXT" not in prompt
|
||||
assert "Active task X" not in prompt
|
||||
|
||||
def test_vault_injected_when_enabled(self, tmp_path):
|
||||
"""Vault content should appear in system prompt when vault_enabled=True."""
|
||||
vault_dir = tmp_path / "vault"
|
||||
agent_dir = vault_dir / "Agent-Hermes"
|
||||
shared_dir = vault_dir / "Agent-Shared"
|
||||
agent_dir.mkdir(parents=True)
|
||||
shared_dir.mkdir(parents=True)
|
||||
(agent_dir / "working-context.md").write_text(
|
||||
"---\ndate: 2026-04-22\n---\n## Status\nActive: vault fix",
|
||||
encoding="utf-8",
|
||||
)
|
||||
(shared_dir / "user-profile.md").write_text(
|
||||
"# User Profile\n\nName: Test User",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
agent = _make_minimal_agent(
|
||||
_vault_enabled=True,
|
||||
_vault_path=str(vault_dir),
|
||||
)
|
||||
|
||||
prompt = agent._build_system_prompt()
|
||||
assert "VAULT: WORKING CONTEXT" in prompt
|
||||
assert "Active: vault fix" in prompt
|
||||
assert "VAULT: USER PROFILE" in prompt
|
||||
assert "Name: Test User" in prompt
|
||||
|
||||
def test_vault_after_memory_blocks(self, tmp_path):
|
||||
"""Vault injection should appear after Layer 1 memory blocks."""
|
||||
# Set up memory files
|
||||
mem_dir = tmp_path / "memories"
|
||||
mem_dir.mkdir(parents=True)
|
||||
(mem_dir / "MEMORY.md").write_text("Layer 1 memory note", encoding="utf-8")
|
||||
|
||||
# Set up vault files
|
||||
vault_dir = tmp_path / "vault"
|
||||
agent_dir = vault_dir / "Agent-Hermes"
|
||||
agent_dir.mkdir(parents=True)
|
||||
(agent_dir / "working-context.md").write_text("Vault content", encoding="utf-8")
|
||||
|
||||
# Create agent with memory enabled
|
||||
from run_agent import AIAgent
|
||||
from tools.memory_tool import MemoryStore
|
||||
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
patch(
|
||||
"hermes_cli.config.load_config",
|
||||
return_value={
|
||||
"memory": {
|
||||
"memory_enabled": True,
|
||||
"user_profile_enabled": False,
|
||||
"memory_char_limit": 2200,
|
||||
"user_char_limit": 1375,
|
||||
},
|
||||
},
|
||||
),
|
||||
):
|
||||
monkeypatch_env = {}
|
||||
# Set HERMES_HOME so MemoryStore reads from tmp_path
|
||||
os.environ["HERMES_HOME"] = str(tmp_path)
|
||||
try:
|
||||
a = AIAgent(
|
||||
api_key="test-k...7890",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=False,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
finally:
|
||||
del os.environ["HERMES_HOME"]
|
||||
|
||||
a._vault_enabled = True
|
||||
a._vault_path = str(vault_dir)
|
||||
|
||||
prompt = a._build_system_prompt()
|
||||
mem_pos = prompt.find("MEMORY (your personal notes)")
|
||||
vault_pos = prompt.find("VAULT: WORKING CONTEXT")
|
||||
assert mem_pos > 0, "Layer 1 memory block not found in prompt"
|
||||
assert vault_pos > 0, "Vault block not found in prompt"
|
||||
assert mem_pos < vault_pos, "Vault should appear after Layer 1 memory"
|
||||
|
||||
def test_missing_vault_path_graceful(self, tmp_path):
|
||||
"""Agent works fine even if vault path doesn't exist."""
|
||||
agent = _make_minimal_agent(
|
||||
_vault_enabled=True,
|
||||
_vault_path="/nonexistent/vault/path",
|
||||
)
|
||||
|
||||
# Should not crash
|
||||
prompt = agent._build_system_prompt()
|
||||
assert "VAULT:" not in prompt
|
||||
|
||||
def test_no_vault_config_graceful(self):
|
||||
"""Agent works fine with no vault set (defaults)."""
|
||||
agent = _make_minimal_agent(
|
||||
_vault_enabled=False,
|
||||
_vault_path="",
|
||||
)
|
||||
|
||||
prompt = agent._build_system_prompt()
|
||||
assert "VAULT:" not in prompt
|
||||
|
|
@ -1173,7 +1173,7 @@ class TestSchemaInit:
|
|||
def test_schema_version(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version == 8
|
||||
assert version == 9
|
||||
|
||||
def test_title_column_exists(self, db):
|
||||
"""Verify the title column was created in the sessions table."""
|
||||
|
|
@ -1229,12 +1229,12 @@ class TestSchemaInit:
|
|||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Open with SessionDB — should migrate to v8
|
||||
# Open with SessionDB — should migrate to v9
|
||||
migrated_db = SessionDB(db_path=db_path)
|
||||
|
||||
# Verify migration
|
||||
cursor = migrated_db._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 8
|
||||
assert cursor.fetchone()[0] == 9
|
||||
|
||||
# Verify title column exists and is NULL for existing sessions
|
||||
session = migrated_db.get_session("existing")
|
||||
|
|
|
|||
715
tests/test_term_index.py
Normal file
715
tests/test_term_index.py
Normal file
|
|
@ -0,0 +1,715 @@
|
|||
"""Tests for term_index — inverted index for session search fast path.
|
||||
|
||||
Covers: stop word filtering, term extraction, term insertion at write time,
|
||||
term-based search with session-level results, multi-term intersection.
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from hermes_state import SessionDB
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def db(tmp_path):
|
||||
"""Create a SessionDB with a temp database file."""
|
||||
db_path = tmp_path / "test_state.db"
|
||||
session_db = SessionDB(db_path=db_path)
|
||||
yield session_db
|
||||
session_db.close()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Stop word filtering
|
||||
# =========================================================================
|
||||
|
||||
class TestStopWords:
|
||||
def test_common_english_words_are_stopped(self):
|
||||
from stop_words import is_stop_word
|
||||
for w in ["the", "and", "is", "in", "it", "of", "to", "a", "was", "for"]:
|
||||
assert is_stop_word(w), f"'{w}' should be a stop word"
|
||||
|
||||
def test_case_insensitive_stop_words(self):
|
||||
from stop_words import is_stop_word
|
||||
assert is_stop_word("The")
|
||||
assert is_stop_word("AND")
|
||||
assert is_stop_word("Is")
|
||||
|
||||
def test_non_stop_words_pass(self):
|
||||
from stop_words import is_stop_word
|
||||
for w in ["docker", "kubernetes", "python", "hermes", "session"]:
|
||||
assert not is_stop_word(w), f"'{w}' should NOT be a stop word"
|
||||
|
||||
def test_short_words_not_auto_stopped(self):
|
||||
"""Single letters and 2-letter words that aren't in the list should pass."""
|
||||
from stop_words import is_stop_word
|
||||
# 'go' is a real tech term, 'I' is a stop word
|
||||
assert not is_stop_word("go")
|
||||
assert is_stop_word("I")
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Term extraction
|
||||
# =========================================================================
|
||||
|
||||
class TestTermExtraction:
|
||||
def test_extracts_words_from_content(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("docker compose up -d")
|
||||
assert "docker" in terms
|
||||
assert "compose" in terms
|
||||
|
||||
def test_strips_punctuation(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("It's working! Check the file.py, okay?")
|
||||
assert "working" in terms
|
||||
assert "file.py" in terms # dots in filenames preserved
|
||||
assert "okay" in terms
|
||||
|
||||
def test_filters_stop_words(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("the docker container is running in the background")
|
||||
assert "the" not in terms
|
||||
assert "is" not in terms
|
||||
assert "in" not in terms
|
||||
assert "docker" in terms
|
||||
assert "container" in terms
|
||||
assert "running" in terms
|
||||
|
||||
def test_case_folded(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("Docker DOCKER docker")
|
||||
# Should be case-folded to single term
|
||||
assert len(terms) == len(set(terms)), "Terms should be deduplicated after case folding"
|
||||
|
||||
def test_empty_content(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("")
|
||||
assert terms == []
|
||||
|
||||
def test_none_content(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms(None)
|
||||
assert terms == []
|
||||
|
||||
def test_preserves_paths_and_commands(self):
|
||||
from term_index import extract_terms
|
||||
terms = extract_terms("edited /etc/hosts and ran git push origin main")
|
||||
assert "/etc/hosts" in terms or "etc/hosts" in terms # path fragment
|
||||
assert "git" in terms
|
||||
assert "push" in terms
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Term index insertion
|
||||
# =========================================================================
|
||||
|
||||
class TestTermIndexInsertion:
|
||||
def test_terms_inserted_on_append_message(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(
|
||||
session_id="s1",
|
||||
role="user",
|
||||
content="I need to deploy the docker container",
|
||||
)
|
||||
|
||||
# Should be able to find the message by term
|
||||
results = db.search_by_terms(["docker"])
|
||||
assert len(results) >= 1
|
||||
assert any(r["session_id"] == "s1" for r in results)
|
||||
|
||||
def test_stop_words_not_indexed(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(
|
||||
session_id="s1",
|
||||
role="user",
|
||||
content="the and is in of to a",
|
||||
)
|
||||
|
||||
# All stop words — should find nothing
|
||||
results = db.search_by_terms(["the", "and", "is"])
|
||||
assert len(results) == 0
|
||||
|
||||
def test_same_term_multiple_messages_same_session(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker is great")
|
||||
db.append_message(session_id="s1", role="assistant", content="docker compose ready")
|
||||
|
||||
results = db.search_by_terms(["docker"])
|
||||
# Should return session once, not twice
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert sids.count("s1") == 1
|
||||
|
||||
def test_term_indexed_across_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="telegram")
|
||||
db.append_message(session_id="s1", role="user", content="fix the docker bug")
|
||||
db.append_message(session_id="s2", role="user", content="docker pull failed")
|
||||
|
||||
results = db.search_by_terms(["docker"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" in sids
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Term-based search
|
||||
# =========================================================================
|
||||
|
||||
class TestTermSearch:
|
||||
def test_single_term_search(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(
|
||||
session_id="s1",
|
||||
role="user",
|
||||
content="I need to configure kubernetes",
|
||||
)
|
||||
|
||||
results = db.search_by_terms(["kubernetes"])
|
||||
assert len(results) >= 1
|
||||
assert results[0]["session_id"] == "s1"
|
||||
# Should include session metadata
|
||||
assert "source" in results[0]
|
||||
assert "started_at" in results[0] or "session_started" in results[0]
|
||||
|
||||
def test_multi_term_intersection(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.create_session(session_id="s3", source="cli")
|
||||
|
||||
db.append_message(session_id="s1", role="user", content="docker networking issue")
|
||||
db.append_message(session_id="s2", role="user", content="docker container running")
|
||||
db.append_message(session_id="s3", role="user", content="kubernetes networking problem")
|
||||
|
||||
# Both "docker" AND "networking" should only match s1
|
||||
results = db.search_by_terms(["docker", "networking"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
assert "s3" not in sids
|
||||
|
||||
def test_search_returns_empty_for_stop_words_only(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="the and is")
|
||||
|
||||
results = db.search_by_terms(["the", "and"])
|
||||
assert results == []
|
||||
|
||||
def test_search_excludes_hidden_sources(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="tool")
|
||||
db.append_message(session_id="s1", role="user", content="docker deployment")
|
||||
db.append_message(session_id="s2", role="user", content="docker deployment tool")
|
||||
|
||||
results = db.search_by_terms(["docker"], exclude_sources=["tool"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
|
||||
def test_search_with_limit(self, db):
|
||||
for i in range(5):
|
||||
sid = f"s{i}"
|
||||
db.create_session(session_id=sid, source="cli")
|
||||
db.append_message(session_id=sid, role="user", content="python script")
|
||||
|
||||
results = db.search_by_terms(["python"], limit=3)
|
||||
assert len(results) <= 3
|
||||
|
||||
def test_nonexistent_term_returns_empty(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="hello world")
|
||||
|
||||
results = db.search_by_terms(["nonexistent_xyzzy"])
|
||||
assert results == []
|
||||
|
||||
def test_term_result_includes_match_count(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker docker docker")
|
||||
db.append_message(session_id="s1", role="assistant", content="docker ready")
|
||||
|
||||
results = db.search_by_terms(["docker"])
|
||||
assert len(results) >= 1
|
||||
# Should tell us how many messages matched in the session
|
||||
assert "match_count" in results[0]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Schema and migration
|
||||
# =========================================================================
|
||||
|
||||
class TestTermIndexSchema:
|
||||
def test_term_index_table_exists(self, db):
|
||||
cursor = db._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='term_index'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
def test_term_index_is_without_rowid(self, db):
|
||||
cursor = db._conn.execute(
|
||||
"SELECT sql FROM sqlite_master WHERE type='table' AND name='term_index'"
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
assert row is not None
|
||||
assert "WITHOUT ROWID" in row[0]
|
||||
|
||||
def test_schema_version_bumped(self, db):
|
||||
cursor = db._conn.execute("SELECT version FROM schema_version LIMIT 1")
|
||||
version = cursor.fetchone()[0]
|
||||
assert version >= 9
|
||||
|
||||
def test_existing_data_survives_migration(self, tmp_path):
|
||||
"""Create a v6 DB, then open it with current code -- data should survive."""
|
||||
# Build a v6 DB manually
|
||||
db_path = tmp_path / "migrate.db"
|
||||
db = SessionDB(db_path=db_path)
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="hello world")
|
||||
db.close()
|
||||
|
||||
# Re-open -- migration should run, data intact
|
||||
db2 = SessionDB(db_path=db_path)
|
||||
session = db2.get_session("s1")
|
||||
assert session is not None
|
||||
assert session["source"] == "cli"
|
||||
# term_index should now exist
|
||||
cursor = db2._conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='term_index'"
|
||||
)
|
||||
assert cursor.fetchone() is not None
|
||||
db2.close()
|
||||
|
||||
def test_v9_migration_auto_reindexes(self, tmp_path):
|
||||
"""When a v6 DB with existing messages is opened, the v9 migration
|
||||
should create the term_index and backfill it automatically."""
|
||||
db_path = tmp_path / "migrate_v9.db"
|
||||
|
||||
# Step 1: Create a fresh DB, add messages, then manually downgrade
|
||||
# to v6 so the next open triggers the migration path.
|
||||
db = SessionDB(db_path=db_path)
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="deploy the kubernetes cluster")
|
||||
db.append_message(session_id="s2", role="user", content="debug docker networking issue")
|
||||
db.close()
|
||||
|
||||
# Step 2: Re-open raw, manually set version to 6 and wipe term_index
|
||||
# to simulate a pre-v7 DB.
|
||||
import sqlite3
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("UPDATE schema_version SET version = 6")
|
||||
conn.execute("DROP TABLE IF EXISTS term_index")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
# Step 3: Open with SessionDB — should migrate to v9 and auto-reindex.
|
||||
db2 = SessionDB(db_path=db_path)
|
||||
# Verify version is now 9
|
||||
cursor = db2._conn.execute("SELECT version FROM schema_version")
|
||||
assert cursor.fetchone()[0] == 9
|
||||
|
||||
# Verify term_index is populated — search should find the terms
|
||||
results = db2.search_by_terms(["kubernetes"])
|
||||
assert len(results) >= 1
|
||||
assert results[0]["session_id"] == "s1"
|
||||
|
||||
results2 = db2.search_by_terms(["docker"])
|
||||
assert len(results2) >= 1
|
||||
assert results2[0]["session_id"] == "s2"
|
||||
|
||||
db2.close()
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Regression tests for red-team QA bugs
|
||||
# =========================================================================
|
||||
|
||||
class TestClearMessagesCleansTermIndex:
|
||||
"""BUG 3: clear_messages() left stale term_index entries.
|
||||
|
||||
After clearing messages, search_by_terms should return zero results
|
||||
for that session, not ghost matches pointing to deleted message IDs.
|
||||
"""
|
||||
|
||||
def test_clear_messages_removes_term_entries(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker networking issue")
|
||||
|
||||
# Confirm indexed
|
||||
results = db.search_by_terms(["docker"])
|
||||
assert len(results) >= 1
|
||||
|
||||
# Clear messages
|
||||
db.clear_messages(session_id="s1")
|
||||
|
||||
# Term entries should be gone
|
||||
results = db.search_by_terms(["docker"])
|
||||
assert results == []
|
||||
|
||||
def test_clear_messages_does_not_affect_other_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker test")
|
||||
db.append_message(session_id="s2", role="user", content="docker prod")
|
||||
|
||||
db.clear_messages(session_id="s1")
|
||||
|
||||
# s2 should still be searchable
|
||||
results = db.search_by_terms(["docker"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s2" in sids
|
||||
assert "s1" not in sids
|
||||
|
||||
def test_clear_messages_no_stray_term_rows(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="kubernetes deployment")
|
||||
|
||||
db.clear_messages(session_id="s1")
|
||||
|
||||
cursor = db._conn.execute(
|
||||
"SELECT COUNT(*) FROM term_index WHERE session_id = 's1'"
|
||||
)
|
||||
assert cursor.fetchone()[0] == 0
|
||||
|
||||
|
||||
class TestSearchByTermsParamBinding:
|
||||
"""BUG 1: search_by_terms() had dead code with wrong param binding.
|
||||
|
||||
The multi-term GROUP BY + HAVING path is the one that actually runs.
|
||||
These tests verify parameter binding is correct for both single and
|
||||
multi-term queries, including with exclude_sources.
|
||||
"""
|
||||
|
||||
def test_single_term_with_exclude_sources(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="tool")
|
||||
db.append_message(session_id="s1", role="user", content="docker deploy")
|
||||
db.append_message(session_id="s2", role="user", content="docker deploy")
|
||||
|
||||
results = db.search_by_terms(["docker"], exclude_sources=["tool"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
|
||||
def test_multi_term_and_semantics(self, db):
|
||||
"""Multi-term search should use AND: only sessions with ALL terms match."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.create_session(session_id="s3", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker networking issue")
|
||||
db.append_message(session_id="s2", role="user", content="docker container only")
|
||||
db.append_message(session_id="s3", role="user", content="networking problem only")
|
||||
|
||||
results = db.search_by_terms(["docker", "networking"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
assert "s3" not in sids
|
||||
|
||||
def test_multi_term_with_exclude_sources(self, db):
|
||||
"""Multi-term + exclude_sources: param binding must be correct."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="tool")
|
||||
db.append_message(session_id="s1", role="user", content="docker networking setup")
|
||||
db.append_message(session_id="s2", role="user", content="docker networking deploy")
|
||||
|
||||
results = db.search_by_terms(
|
||||
["docker", "networking"], exclude_sources=["tool"]
|
||||
)
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
|
||||
def test_three_term_intersection(self, db):
|
||||
"""Three-term AND: all three must be present in the session."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker kubernetes aws deployment")
|
||||
db.append_message(session_id="s2", role="user", content="docker kubernetes only two terms")
|
||||
|
||||
results = db.search_by_terms(["docker", "kubernetes", "aws"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s1" in sids
|
||||
assert "s2" not in sids
|
||||
|
||||
|
||||
class TestDeleteSessionCleansTermIndex:
|
||||
"""Verify delete_session() and prune_sessions() clean term_index."""
|
||||
|
||||
def test_delete_session_removes_term_entries(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker deploy")
|
||||
db.append_message(session_id="s1", role="assistant", content="docker is running")
|
||||
|
||||
db.delete_session(session_id="s1")
|
||||
|
||||
cursor = db._conn.execute(
|
||||
"SELECT COUNT(*) FROM term_index WHERE session_id = 's1'"
|
||||
)
|
||||
assert cursor.fetchone()[0] == 0
|
||||
|
||||
def test_delete_session_does_not_affect_other_sessions(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.create_session(session_id="s2", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker one")
|
||||
db.append_message(session_id="s2", role="user", content="docker two")
|
||||
|
||||
db.delete_session(session_id="s1")
|
||||
|
||||
results = db.search_by_terms(["docker"])
|
||||
sids = [r["session_id"] for r in results]
|
||||
assert "s2" in sids
|
||||
assert "s1" not in sids
|
||||
|
||||
|
||||
class TestFastSearchSessionResolution:
|
||||
"""BUG 2: _fast_search didn't resolve child sessions to parent.
|
||||
|
||||
A delegation child and its parent both containing "docker" would appear
|
||||
as two separate results. They should be resolved to the parent session.
|
||||
Also, current session lineage exclusion must cover the entire chain.
|
||||
"""
|
||||
|
||||
def test_child_resolved_to_parent(self, db):
|
||||
"""Parent + child matching same term should return 1 result (parent)."""
|
||||
import json
|
||||
from tools.session_search_tool import _fast_search
|
||||
|
||||
db.create_session(session_id="parent-1", source="cli")
|
||||
db.create_session(session_id="child-1", source="cli", parent_session_id="parent-1")
|
||||
db.append_message(session_id="parent-1", role="user", content="docker setup question")
|
||||
db.append_message(session_id="child-1", role="assistant", content="docker setup done")
|
||||
|
||||
result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id=None))
|
||||
assert result["success"]
|
||||
sids = [e["session_id"] for e in result["results"]]
|
||||
# Should collapse to parent, not show both
|
||||
assert "child-1" not in sids, "Child should be resolved to parent"
|
||||
assert "parent-1" in sids
|
||||
assert len(result["results"]) == 1
|
||||
|
||||
def test_match_count_accumulates_from_children(self, db):
|
||||
"""Match_count should sum parent + child matches."""
|
||||
import json
|
||||
from tools.session_search_tool import _fast_search
|
||||
|
||||
db.create_session(session_id="p", source="cli")
|
||||
db.create_session(session_id="c", source="cli", parent_session_id="p")
|
||||
db.append_message(session_id="p", role="user", content="docker question")
|
||||
db.append_message(session_id="c", role="assistant", content="docker answer")
|
||||
|
||||
result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id=None))
|
||||
entry = result["results"][0]
|
||||
assert entry["session_id"] == "p"
|
||||
assert entry["match_count"] >= 2, f"Expected accumulated count >= 2, got {entry['match_count']}"
|
||||
|
||||
def test_current_session_lineage_excludes_children(self, db):
|
||||
"""When current session is a child, parent should also be excluded."""
|
||||
import json
|
||||
from tools.session_search_tool import _fast_search
|
||||
|
||||
db.create_session(session_id="parent-2", source="cli")
|
||||
db.create_session(session_id="child-2", source="cli", parent_session_id="parent-2")
|
||||
db.create_session(session_id="unrelated", source="cli")
|
||||
db.append_message(session_id="parent-2", role="user", content="docker deploy")
|
||||
db.append_message(session_id="child-2", role="assistant", content="docker deployed")
|
||||
db.append_message(session_id="unrelated", role="user", content="docker build")
|
||||
|
||||
# Current session = child -> should exclude parent-2 AND child-2, keep unrelated
|
||||
result = json.loads(_fast_search(query="docker", db=db, limit=5, current_session_id="child-2"))
|
||||
sids = [e["session_id"] for e in result["results"]]
|
||||
assert "parent-2" not in sids, "Parent of current should be excluded"
|
||||
assert "child-2" not in sids, "Current child should be excluded"
|
||||
assert "unrelated" in sids, "Unrelated session should appear"
|
||||
|
||||
|
||||
class TestGetChildSessionIds:
|
||||
"""Tests for SessionDB.get_child_session_ids -- public API replacing
|
||||
direct db._lock/db._conn access in _fast_search."""
|
||||
|
||||
def test_returns_child_ids(self, db):
|
||||
db.create_session(session_id="parent", source="cli")
|
||||
db.create_session(session_id="child-1", source="delegation", parent_session_id="parent")
|
||||
db.create_session(session_id="child-2", source="compression", parent_session_id="parent")
|
||||
db.create_session(session_id="orphan", source="cli")
|
||||
|
||||
children = db.get_child_session_ids("parent")
|
||||
assert set(children) == {"child-1", "child-2"}
|
||||
|
||||
def test_returns_empty_for_leaf_session(self, db):
|
||||
db.create_session(session_id="leaf", source="cli")
|
||||
assert db.get_child_session_ids("leaf") == []
|
||||
|
||||
def test_returns_empty_for_no_args(self, db):
|
||||
assert db.get_child_session_ids() == []
|
||||
|
||||
def test_multiple_parent_ids(self, db):
|
||||
db.create_session(session_id="p1", source="cli")
|
||||
db.create_session(session_id="p2", source="cli")
|
||||
db.create_session(session_id="c1", source="delegation", parent_session_id="p1")
|
||||
db.create_session(session_id="c2", source="delegation", parent_session_id="p2")
|
||||
|
||||
children = db.get_child_session_ids("p1", "p2")
|
||||
assert set(children) == {"c1", "c2"}
|
||||
|
||||
def test_does_not_recurse(self, db):
|
||||
"""Only direct children, not grandchildren."""
|
||||
db.create_session(session_id="root", source="cli")
|
||||
db.create_session(session_id="child", source="delegation", parent_session_id="root")
|
||||
db.create_session(session_id="grandchild", source="delegation", parent_session_id="child")
|
||||
|
||||
children = db.get_child_session_ids("root")
|
||||
assert children == ["child"]
|
||||
|
||||
|
||||
class TestNoiseReduction:
|
||||
"""Tests for noise reduction in term indexing.
|
||||
|
||||
Tool-role messages (structured JSON output) produce junk terms like
|
||||
'output', 'exit_code', 'null', 'true', 'false'. Pure numeric tokens
|
||||
('0', '1', '2') are never useful search targets. JSON key names that
|
||||
appear in tool output schemas should be treated as stop words.
|
||||
"""
|
||||
|
||||
def test_tool_role_messages_not_indexed(self, db):
|
||||
"""Tool-role messages should be skipped entirely during indexing."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(
|
||||
session_id="s1",
|
||||
role="tool",
|
||||
content='{"output": "docker is running", "exit_code": 0}',
|
||||
tool_name="terminal",
|
||||
)
|
||||
|
||||
# Tool output should NOT index any terms from the JSON blob
|
||||
# Even though 'docker' appears in the output string, it's inside
|
||||
# structured JSON from a tool call, not natural language
|
||||
cursor = db._conn.execute(
|
||||
"SELECT COUNT(*) FROM term_index WHERE session_id = 's1'"
|
||||
)
|
||||
assert cursor.fetchone()[0] == 0
|
||||
|
||||
def test_assistant_role_still_indexed(self, db):
|
||||
"""Non-tool messages should still be indexed normally."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="docker deploy")
|
||||
db.append_message(
|
||||
session_id="s1", role="assistant", content="docker is now running"
|
||||
)
|
||||
|
||||
results = db.search_by_terms(["docker"])
|
||||
assert len(results) >= 1
|
||||
|
||||
def test_pure_numeric_tokens_filtered(self):
|
||||
"""Pure numeric tokens should be excluded from term extraction."""
|
||||
from term_index import extract_terms
|
||||
|
||||
terms = extract_terms("exit code 0 with 42 errors in 123 steps")
|
||||
# These numeric tokens provide zero search value
|
||||
for num in ["0", "42", "123"]:
|
||||
assert num not in terms, f"Pure numeric '{num}' should be filtered"
|
||||
|
||||
# But word tokens should survive
|
||||
assert "exit" in terms
|
||||
assert "code" in terms
|
||||
assert "errors" in terms
|
||||
assert "steps" in terms
|
||||
|
||||
def test_json_key_stopwords_filtered(self):
|
||||
"""Common JSON schema keys from tool output should be stop words."""
|
||||
from stop_words import is_stop_word
|
||||
|
||||
json_keys = [
|
||||
"output",
|
||||
"exit_code",
|
||||
"error",
|
||||
"null",
|
||||
"true",
|
||||
"false",
|
||||
"status",
|
||||
"content",
|
||||
"message",
|
||||
"cleared",
|
||||
"success",
|
||||
]
|
||||
for key in json_keys:
|
||||
assert is_stop_word(key), f"JSON key '{key}' should be a stop word"
|
||||
|
||||
def test_json_key_stopwords_in_extract_terms(self):
|
||||
"""JSON key stop words should be filtered by extract_terms."""
|
||||
from term_index import extract_terms
|
||||
|
||||
# Simulates typical tool output content
|
||||
terms = extract_terms(
|
||||
'{"output": "hello world", "exit_code": 0, "error": null, "success": true}'
|
||||
)
|
||||
for junk in ["output", "exit_code", "error", "null", "success", "true", "false"]:
|
||||
assert junk not in terms, f"JSON key '{junk}' should be filtered"
|
||||
|
||||
# Actual content words should survive
|
||||
assert "hello" in terms
|
||||
assert "world" in terms
|
||||
|
||||
def test_reindex_skips_tool_messages(self, db):
|
||||
"""reindex_term_index should not index tool-role messages."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message(session_id="s1", role="user", content="deploy docker")
|
||||
db.append_message(
|
||||
session_id="s1",
|
||||
role="tool",
|
||||
content='{"output": "docker running", "exit_code": 0}',
|
||||
)
|
||||
|
||||
# Clear and reindex
|
||||
db._conn.execute("DELETE FROM term_index")
|
||||
db._conn.commit()
|
||||
db.reindex_term_index()
|
||||
|
||||
# Tool message terms should not be in index
|
||||
cursor = db._conn.execute(
|
||||
"SELECT term FROM term_index WHERE session_id = 's1'"
|
||||
)
|
||||
indexed_terms = [row[0] for row in cursor.fetchall()]
|
||||
for junk in ["output", "exit_code", "0"]:
|
||||
assert junk not in indexed_terms, f"'{junk}' should not be indexed from tool messages"
|
||||
|
||||
|
||||
class TestCJKFallbackInFastSearch:
|
||||
"""CJK queries should fall through to the slow path even when fast=True.
|
||||
|
||||
The term index can't handle CJK because extract_terms() splits on
|
||||
whitespace, and CJK languages don't use spaces between words.
|
||||
session_search should detect this and use the FTS5+LIKE fallback.
|
||||
"""
|
||||
|
||||
def test_cjk_query_bypasses_fast_path(self, db):
|
||||
"""A CJK query with fast=True should be downgraded to fast=False."""
|
||||
import json
|
||||
from tools.session_search_tool import session_search
|
||||
|
||||
db.create_session(session_id="cjk-1", source="cli")
|
||||
db.append_message(session_id="cjk-1", role="user", content="测试中文搜索")
|
||||
|
||||
# fast=True, but CJK query should fall through to full search
|
||||
result = json.loads(session_search(
|
||||
query="中文", db=db, limit=3, fast=True, current_session_id=None
|
||||
))
|
||||
# The result should come from the slow path (mode="full")
|
||||
# not the fast path (mode="fast") since CJK triggers fallback
|
||||
assert result["success"]
|
||||
# mode should be "full" (not "fast") because CJK forced the fallback
|
||||
assert result.get("mode") != "fast"
|
||||
|
||||
def test_english_query_stays_fast(self, db):
|
||||
"""Non-CJK queries should still use the fast path."""
|
||||
import json
|
||||
from tools.session_search_tool import session_search
|
||||
|
||||
db.create_session(session_id="eng-1", source="cli")
|
||||
db.append_message(session_id="eng-1", role="user", content="deploy the server")
|
||||
|
||||
result = json.loads(session_search(
|
||||
query="deploy", db=db, limit=3, fast=True, current_session_id=None
|
||||
))
|
||||
assert result["success"]
|
||||
assert result.get("mode") == "fast"
|
||||
|
|
@ -656,10 +656,12 @@ class TestDelegationCredentialResolution(unittest.TestCase):
|
|||
self.assertEqual(creds["provider"], "custom")
|
||||
|
||||
def test_direct_endpoint_does_not_fall_back_to_openrouter_api_key_env(self):
|
||||
"""Remote endpoint without OPENAI_API_KEY should raise ValueError,
|
||||
even if OPENROUTER_API_KEY is set (only OPENAI_API_KEY is checked)."""
|
||||
parent = _make_mock_parent(depth=0)
|
||||
cfg = {
|
||||
"model": "qwen2.5-coder",
|
||||
"base_url": "http://localhost:1234/v1",
|
||||
"base_url": "https://api.example.com/v1", # remote, not localhost
|
||||
}
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
|
|
|
|||
251
tests/tools/test_delegation_local_provider.py
Normal file
251
tests/tools/test_delegation_local_provider.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for delegation with local/Ollama providers that don't require API keys.
|
||||
|
||||
Ollama and other local model servers run on localhost and accept requests
|
||||
without authentication. The delegation credential resolver should allow these
|
||||
endpoints to work without requiring an API key.
|
||||
|
||||
Run with: python -m pytest tests/tools/test_delegation_local_provider.py -v
|
||||
"""
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from tools.delegate_tool import (
|
||||
_resolve_delegation_credentials,
|
||||
)
|
||||
|
||||
|
||||
def _make_mock_parent(depth=0):
|
||||
"""Create a mock parent agent with the fields delegate_task expects."""
|
||||
parent = MagicMock()
|
||||
parent.base_url = "http://localhost:11434/v1"
|
||||
parent.api_key = "ollama"
|
||||
parent.provider = "custom"
|
||||
parent.api_mode = "chat_completions"
|
||||
parent.model = "glm-5.1:cloud"
|
||||
parent.platform = "cli"
|
||||
parent.providers_allowed = None
|
||||
parent.providers_ignored = None
|
||||
parent.providers_order = None
|
||||
parent.provider_sort = None
|
||||
parent._session_db = None
|
||||
parent._delegate_depth = depth
|
||||
parent._active_children = []
|
||||
parent._active_children_lock = __import__("threading").Lock()
|
||||
parent._print_fn = None
|
||||
parent.tool_progress_callback = None
|
||||
parent.thinking_callback = None
|
||||
parent._credential_pool = None
|
||||
parent.reasoning_config = None
|
||||
parent.max_tokens = None
|
||||
parent.prefill_messages = None
|
||||
parent.acp_command = None
|
||||
parent.acp_args = []
|
||||
parent.valid_tool_names = ["terminal", "file", "web"]
|
||||
parent.enabled_toolsets = None # None = all tools
|
||||
return parent
|
||||
|
||||
|
||||
class TestLocalProviderCredentials(unittest.TestCase):
|
||||
"""Tests for _resolve_delegation_credentials with local providers."""
|
||||
|
||||
# --- base_url path (localhost) ---
|
||||
|
||||
def test_localhost_base_url_no_api_key_allowed(self):
|
||||
"""localhost base_url should work without an API key (Ollama, LM Studio, etc.)."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"api_key": "",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["base_url"], "http://localhost:11434/v1")
|
||||
self.assertIsNotNone(creds["api_key"])
|
||||
# API key should be a harmless placeholder, not None
|
||||
self.assertNotEqual(creds["api_key"], "")
|
||||
|
||||
def test_127_base_url_no_api_key_allowed(self):
|
||||
"""127.0.0.1 base_url should work without an API key."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "",
|
||||
"base_url": "http://127.0.0.1:11434/v1",
|
||||
"api_key": "",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["base_url"], "http://127.0.0.1:11434/v1")
|
||||
self.assertIsNotNone(creds["api_key"])
|
||||
|
||||
def test_dotlocal_base_url_no_api_key_allowed(self):
|
||||
""".local mDNS hostnames (e.g. studio.local) should work without an API key."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "",
|
||||
"base_url": "http://studio.local:11434/v1",
|
||||
"api_key": "",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["base_url"], "http://studio.local:11434/v1")
|
||||
self.assertIsNotNone(creds["api_key"])
|
||||
|
||||
def test_localhost_base_url_with_explicit_api_key_preserved(self):
|
||||
"""If user provides an API key for localhost, it should be preserved as-is."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"api_key": "my-secret-key",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["api_key"], "my-secret-key")
|
||||
|
||||
def test_localhost_base_url_whitespace_api_key_gets_placeholder(self):
|
||||
"""Whitespace-only api_key should be treated as absent and get placeholder."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"api_key": " ",
|
||||
}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["api_key"], "ollama")
|
||||
|
||||
# --- base_url path (remote) should still require API key ---
|
||||
|
||||
def test_remote_base_url_still_requires_api_key(self):
|
||||
"""Non-localhost base_url without API key should still raise ValueError."""
|
||||
parent = _make_mock_parent()
|
||||
cfg = {
|
||||
"model": "gpt-4o-mini",
|
||||
"provider": "",
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "",
|
||||
}
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": ""}, clear=False):
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
_resolve_delegation_credentials(cfg, parent)
|
||||
self.assertIn("API key", str(ctx.exception))
|
||||
|
||||
# --- provider path with custom/local ---
|
||||
|
||||
@patch("hermes_cli.runtime_provider.resolve_runtime_provider")
|
||||
def test_custom_provider_resolving_to_localhost_no_api_key(self, mock_resolve):
|
||||
"""When delegation.provider='custom' resolves to localhost, empty API key should be allowed."""
|
||||
mock_resolve.return_value = {
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"api_key": "",
|
||||
"api_mode": "chat_completions",
|
||||
}
|
||||
parent = _make_mock_parent()
|
||||
cfg = {"model": "devstral-small-2:24b-cloud", "provider": "custom"}
|
||||
creds = _resolve_delegation_credentials(cfg, parent)
|
||||
self.assertEqual(creds["provider"], "custom")
|
||||
self.assertEqual(creds["base_url"], "http://localhost:11434/v1")
|
||||
# Should get a placeholder key, not raise ValueError
|
||||
self.assertIsNotNone(creds["api_key"])
|
||||
self.assertNotEqual(creds["api_key"], "")
|
||||
|
||||
@patch("hermes_cli.runtime_provider.resolve_runtime_provider")
|
||||
def test_remote_provider_still_requires_api_key(self, mock_resolve):
|
||||
"""Provider resolving to a remote endpoint without API key should still raise."""
|
||||
mock_resolve.return_value = {
|
||||
"provider": "openrouter",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "",
|
||||
"api_mode": "chat_completions",
|
||||
}
|
||||
parent = _make_mock_parent()
|
||||
cfg = {"model": "some-model", "provider": "openrouter"}
|
||||
with self.assertRaises(ValueError) as ctx:
|
||||
_resolve_delegation_credentials(cfg, parent)
|
||||
self.assertIn("no API key", str(ctx.exception))
|
||||
|
||||
# --- Integration: child agent gets local placeholder key ---
|
||||
|
||||
@patch("tools.delegate_tool._load_config")
|
||||
def test_local_delegation_uses_placeholder_key(self, mock_cfg):
|
||||
"""Delegation with localhost base_url should get 'ollama' placeholder API key."""
|
||||
mock_cfg.return_value = {
|
||||
"model": "devstral-small-2:24b-cloud",
|
||||
"provider": "custom",
|
||||
"base_url": "http://localhost:11434/v1",
|
||||
"api_key": "",
|
||||
"max_iterations": 10,
|
||||
"max_concurrent_children": 1,
|
||||
}
|
||||
parent = _make_mock_parent()
|
||||
creds = _resolve_delegation_credentials(mock_cfg.return_value, parent)
|
||||
self.assertEqual(creds["base_url"], "http://localhost:11434/v1")
|
||||
self.assertEqual(creds["api_key"], "ollama")
|
||||
|
||||
|
||||
class TestIsLocalBaseUrlHelper(unittest.TestCase):
|
||||
"""Tests for the _is_local_base_url helper function."""
|
||||
|
||||
def test_localhost_with_port(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://localhost:11434/v1"))
|
||||
|
||||
def test_localhost_no_port(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://localhost/v1"))
|
||||
|
||||
def test_127_ip(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://127.0.0.1:11434/v1"))
|
||||
|
||||
def test_dotlocal(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://studio.local:11434/v1"))
|
||||
|
||||
def test_remote_url(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertFalse(_is_local_base_url("https://api.openai.com/v1"))
|
||||
|
||||
def test_openrouter(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertFalse(_is_local_base_url("https://openrouter.ai/api/v1"))
|
||||
|
||||
def test_192_168_private(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://192.168.1.100:11434/v1"))
|
||||
|
||||
def test_10_private(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://10.0.0.5:11434/v1"))
|
||||
|
||||
def test_172_private(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://172.16.0.1:11434/v1"))
|
||||
|
||||
def test_empty_string(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertFalse(_is_local_base_url(""))
|
||||
|
||||
def test_none(self):
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertFalse(_is_local_base_url(None))
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
"""IPv6 loopback [::1] should be recognized as local."""
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertTrue(_is_local_base_url("http://[::1]:11434/v1"))
|
||||
|
||||
def test_172_outside_private_range(self):
|
||||
"""172.32.x.x is NOT in 172.16/12 and should not be treated as local."""
|
||||
from tools.delegate_tool import _is_local_base_url
|
||||
self.assertFalse(_is_local_base_url("http://172.32.0.1:11434/v1"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
@ -2098,6 +2098,46 @@ def _resolve_child_credential_pool(effective_provider: Optional[str], parent_age
|
|||
return None
|
||||
|
||||
|
||||
def _is_local_base_url(base_url: Optional[str]) -> bool:
|
||||
"""Return True if base_url points to a local/private network address.
|
||||
|
||||
Local providers (Ollama, LM Studio, llama.cpp server, etc.) typically
|
||||
don't require authentication. This check covers:
|
||||
- localhost / loopback (127.0.0.1, ::1)
|
||||
- .local mDNS hostnames (e.g. studio.local)
|
||||
- RFC 1918 private networks (10/8, 172.16/12, 192.168/16)
|
||||
|
||||
.. note::
|
||||
Any URL that resolves as "local" by this function will receive a
|
||||
placeholder ``"ollama"`` API key. If an internal service on a
|
||||
private network actually requires authentication (e.g. a corporate
|
||||
AI gateway at 192.168.x.x), the placeholder key will be rejected
|
||||
by that server (401/403). This is intentional — local servers that
|
||||
genuinely don't need auth work out-of-the-box, while misconfigured
|
||||
endpoints fail loudly rather than silently.
|
||||
"""
|
||||
if not base_url:
|
||||
return False
|
||||
hostname = base_url_hostname(base_url)
|
||||
if not hostname:
|
||||
return False
|
||||
# localhost variants
|
||||
if hostname in ("localhost", "127.0.0.1", "::1"):
|
||||
return True
|
||||
# mDNS .local hostnames
|
||||
if hostname.endswith(".local"):
|
||||
return True
|
||||
# RFC 1918 private subnets
|
||||
import ipaddress
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
return ip.is_private or ip.is_loopback
|
||||
except ValueError:
|
||||
pass # not an IP address, that's fine
|
||||
return False
|
||||
|
||||
|
||||
def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
||||
"""Resolve credentials for subagent delegation.
|
||||
|
||||
|
|
@ -2111,6 +2151,10 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
|||
If neither base_url nor provider is configured, returns None values so the
|
||||
child inherits everything from the parent agent.
|
||||
|
||||
Local endpoints (localhost, 127.0.0.1, .local, RFC 1918 private nets)
|
||||
don't require API keys — a placeholder "ollama" key is used when none
|
||||
is provided, since these servers accept any or no authentication.
|
||||
|
||||
Raises ValueError with a user-friendly message on credential failure.
|
||||
"""
|
||||
configured_model = str(cfg.get("model") or "").strip() or None
|
||||
|
|
@ -2120,6 +2164,10 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
|||
|
||||
if configured_base_url:
|
||||
api_key = configured_api_key or os.getenv("OPENAI_API_KEY", "").strip()
|
||||
# Local endpoints (Ollama, LM Studio, etc.) don't require auth.
|
||||
# Use a dummy key so the OpenAI client doesn't reject the request.
|
||||
if not api_key and _is_local_base_url(configured_base_url):
|
||||
api_key = "ollama"
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Delegation base_url is configured but no API key was found. "
|
||||
|
|
@ -2175,10 +2223,15 @@ def _resolve_delegation_credentials(cfg: dict, parent_agent) -> dict:
|
|||
|
||||
api_key = runtime.get("api_key", "")
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
f"Delegation provider '{configured_provider}' resolved but has no API key. "
|
||||
f"Set the appropriate environment variable or run 'hermes auth'."
|
||||
)
|
||||
# Local providers don't require real API keys.
|
||||
resolved_base = runtime.get("base_url", "")
|
||||
if _is_local_base_url(resolved_base):
|
||||
api_key = "ollama"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Delegation provider '{configured_provider}' resolved but has no API key. "
|
||||
f"Set the appropriate environment variable or run 'hermes auth'."
|
||||
)
|
||||
|
||||
return {
|
||||
"model": configured_model,
|
||||
|
|
|
|||
|
|
@ -321,11 +321,14 @@ def session_search(
|
|||
limit: int = 3,
|
||||
db=None,
|
||||
current_session_id: str = None,
|
||||
fast: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Search past sessions and return focused summaries of matching conversations.
|
||||
|
||||
Uses FTS5 to find matches, then summarizes the top sessions with Gemini Flash.
|
||||
By default (fast=True), uses the term_index inverted index for instant
|
||||
results with session metadata and match counts — no LLM calls needed.
|
||||
Set fast=False to use FTS5 + LLM summarization for detailed summaries.
|
||||
The current session is excluded from results since the agent already has that context.
|
||||
"""
|
||||
if db is None:
|
||||
|
|
@ -348,6 +351,120 @@ def session_search(
|
|||
|
||||
query = query.strip()
|
||||
|
||||
# CJK queries can't be handled by the term index (no word boundaries
|
||||
# for extract_terms to split on). Fall through to FTS5 + LIKE which
|
||||
# has a CJK bigram/LIKE fallback.
|
||||
if fast and db._contains_cjk(query):
|
||||
fast = False
|
||||
|
||||
# ── Fast path: term index (no LLM, ~1ms) ──────────────────────────
|
||||
if fast:
|
||||
return _fast_search(query, db, limit, current_session_id)
|
||||
|
||||
# ── Slow path: FTS5 + LLM summarization (~5-15s) ───────────────────
|
||||
return _full_search(query, role_filter, limit, db, current_session_id)
|
||||
|
||||
|
||||
def _fast_search(query: str, db, limit: int, current_session_id: str = None) -> str:
|
||||
"""Term index fast path: instant search, no LLM calls."""
|
||||
from term_index import extract_terms
|
||||
|
||||
terms = extract_terms(query)
|
||||
if not terms:
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No searchable terms in query (all stop words or empty).",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# Fetch extra results so we have room after dedup/lineage exclusion
|
||||
raw_results = db.search_by_terms(
|
||||
terms=terms,
|
||||
exclude_sources=list(_HIDDEN_SESSION_SOURCES),
|
||||
limit=limit * 3,
|
||||
)
|
||||
|
||||
# Resolve child sessions to their parent root, just like _full_search.
|
||||
# Delegation stores detailed content in child sessions, but the user
|
||||
# sees the parent conversation. Without this, parent + child both
|
||||
# containing "docker" would appear as two separate results.
|
||||
def _resolve_to_parent(session_id: str) -> str:
|
||||
visited = set()
|
||||
sid = session_id
|
||||
while sid and sid not in visited:
|
||||
visited.add(sid)
|
||||
try:
|
||||
session = db.get_session(sid)
|
||||
if not session:
|
||||
break
|
||||
parent = session.get("parent_session_id")
|
||||
if parent:
|
||||
sid = parent
|
||||
else:
|
||||
break
|
||||
except Exception:
|
||||
break
|
||||
return sid
|
||||
|
||||
# Determine current session lineage for exclusion
|
||||
current_lineage = set()
|
||||
if current_session_id:
|
||||
# Walk parent chain AND collect all children
|
||||
root = _resolve_to_parent(current_session_id)
|
||||
current_lineage.add(root)
|
||||
current_lineage.add(current_session_id)
|
||||
# Also find any child sessions of the current root
|
||||
try:
|
||||
children = db.get_child_session_ids(root, current_session_id)
|
||||
current_lineage.update(children)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
seen_sessions = {}
|
||||
for r in raw_results:
|
||||
raw_sid = r.get("session_id", "")
|
||||
resolved_sid = _resolve_to_parent(raw_sid)
|
||||
if resolved_sid in current_lineage or raw_sid in current_lineage:
|
||||
continue
|
||||
if resolved_sid not in seen_sessions:
|
||||
# Sum match_count from child into parent
|
||||
seen_sessions[resolved_sid] = dict(r)
|
||||
seen_sessions[resolved_sid]["session_id"] = resolved_sid
|
||||
else:
|
||||
# Accumulate match_count from child sessions
|
||||
seen_sessions[resolved_sid]["match_count"] = (
|
||||
seen_sessions[resolved_sid].get("match_count", 0)
|
||||
+ r.get("match_count", 0)
|
||||
)
|
||||
if len(seen_sessions) >= limit:
|
||||
break
|
||||
|
||||
entries = []
|
||||
for sid, r in seen_sessions.items():
|
||||
entries.append({
|
||||
"session_id": sid,
|
||||
"when": _format_timestamp(r.get("session_started")),
|
||||
"source": r.get("source", "unknown"),
|
||||
"model": r.get("model"),
|
||||
"title": r.get("title"),
|
||||
"match_count": r.get("match_count", 0),
|
||||
})
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"mode": "fast",
|
||||
"results": entries,
|
||||
"count": len(entries),
|
||||
"message": f"Found {len(entries)} matching sessions via term index (instant, no LLM)."
|
||||
f" Use fast=False for LLM-summarized results.",
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
def _full_search(query: str, role_filter: str, limit: int, db, current_session_id: str = None) -> str:
|
||||
"""FTS5 + LLM summarization path (original behavior)."""
|
||||
try:
|
||||
# Parse role filter
|
||||
role_list = None
|
||||
|
|
@ -367,6 +484,7 @@ def session_search(
|
|||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"mode": "full",
|
||||
"results": [],
|
||||
"count": 0,
|
||||
"message": "No matching sessions found.",
|
||||
|
|
@ -506,6 +624,7 @@ def session_search(
|
|||
return json.dumps({
|
||||
"success": True,
|
||||
"query": query,
|
||||
"mode": "full",
|
||||
"results": summaries,
|
||||
"count": len(summaries),
|
||||
"sessions_searched": len(seen_sessions),
|
||||
|
|
@ -535,7 +654,8 @@ SESSION_SEARCH_SCHEMA = {
|
|||
"Returns titles, previews, and timestamps. Zero LLM cost, instant. "
|
||||
"Start here when the user asks what were we working on or what did we do recently.\n"
|
||||
"2. Keyword search (with query): Search for specific topics across all past sessions. "
|
||||
"Returns LLM-generated summaries of matching sessions.\n\n"
|
||||
"By default uses the term index for instant results (no LLM). "
|
||||
"Set fast=False for detailed LLM-generated summaries.\n\n"
|
||||
"USE THIS PROACTIVELY when:\n"
|
||||
"- The user says 'we did this before', 'remember when', 'last time', 'as I mentioned'\n"
|
||||
"- The user asks about a topic you worked on before but don't have in current context\n"
|
||||
|
|
@ -544,11 +664,15 @@ SESSION_SEARCH_SCHEMA = {
|
|||
"- The user asks 'what did we do about X?' or 'how did we fix Y?'\n\n"
|
||||
"Don't hesitate to search when it is actually cross-session -- it's fast and cheap. "
|
||||
"Better to search and confirm than to guess or ask the user to repeat themselves.\n\n"
|
||||
"Search syntax: keywords joined with OR for broad recall (elevenlabs OR baseten OR funding), "
|
||||
"phrases for exact match (\"docker networking\"), boolean (python NOT java), prefix (deploy*). "
|
||||
"IMPORTANT: Use OR between keywords for best results — FTS5 defaults to AND which misses "
|
||||
"sessions that only mention some terms. If a broad OR query returns nothing, try individual "
|
||||
"keyword searches in parallel. Returns summaries of the top matching sessions."
|
||||
"Search syntax depends on the mode:\n"
|
||||
"- fast=True (default): Simple keyword search with AND semantics. Multiple words "
|
||||
"all must appear in a session. No boolean operators or phrase matching. "
|
||||
"Instant, zero LLM cost.\n"
|
||||
"- fast=False: FTS5 full syntax — OR for broad recall (elevenlabs OR baseten OR funding), "
|
||||
"phrases for exact match (\\\"docker networking\\\"), boolean (python NOT java), "
|
||||
"prefix (deploy*). IMPORTANT: Use OR between keywords for best FTS5 results — "
|
||||
"it defaults to AND which misses sessions that only mention some terms. "
|
||||
"Slower (5-15s) but returns LLM-summarized results."
|
||||
),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
|
|
@ -559,13 +683,18 @@ SESSION_SEARCH_SCHEMA = {
|
|||
},
|
||||
"role_filter": {
|
||||
"type": "string",
|
||||
"description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs.",
|
||||
"description": "Optional: only search messages from specific roles (comma-separated). E.g. 'user,assistant' to skip tool outputs. Only used when fast=False.",
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max sessions to summarize (default: 3, max: 5).",
|
||||
"default": 3,
|
||||
},
|
||||
"fast": {
|
||||
"type": "boolean",
|
||||
"description": "When true (default), use the term index for instant results with no LLM cost. When false, use FTS5 + LLM summarization for detailed summaries.",
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
|
|
@ -583,6 +712,7 @@ registry.register(
|
|||
query=args.get("query") or "",
|
||||
role_filter=args.get("role_filter"),
|
||||
limit=args.get("limit", 3),
|
||||
fast=args.get("fast", True),
|
||||
db=kw.get("db"),
|
||||
current_session_id=kw.get("current_session_id")),
|
||||
check_fn=check_session_search_requirements,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue